개발

pytorch Tensor 비교하기

then-go 2023. 10. 8. 21:33

pytorch에서 tensor를 사용하다보면 A tensor와 B tensor가 같은지를 확인하고 싶을때가 있다.

이럴때는 아래처럼 하면된다.

 

1.원소 마다 비교하기

torch.eq() 혹은 == 를 사용하면 된다.

import torch
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([1,2,3,4,5])

print(a == b)
print(torch.eq(a,b))

이렇게 각 원소마다 true/false 값을 내보내줌.

 

2. tensor 전체 비교

이때는 torch.equal()을 사용하면 됨.

import torch
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([1,2,3,4,6])
c = torch.tensor([1,2,3,4,5])

print(a == b)
print(torch.eq(a,b))
print(torch.equal(a,b))
print(torch.equal(a,c))