개발
pytorch Tensor 비교하기
then-go
2023. 10. 8. 21:33
pytorch에서 tensor를 사용하다보면 A tensor와 B tensor가 같은지를 확인하고 싶을때가 있다.
이럴때는 아래처럼 하면된다.
1.원소 마다 비교하기
torch.eq() 혹은 == 를 사용하면 된다.
import torch
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([1,2,3,4,5])
print(a == b)
print(torch.eq(a,b))
이렇게 각 원소마다 true/false 값을 내보내줌.
2. tensor 전체 비교
이때는 torch.equal()을 사용하면 됨.
import torch
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([1,2,3,4,6])
c = torch.tensor([1,2,3,4,5])
print(a == b)
print(torch.eq(a,b))
print(torch.equal(a,b))
print(torch.equal(a,c))