You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

51 lines
1.5 KB

from torch.distributions import Normal as Normal_t
from distributions import Normal as Normal_j
import numpy as np
import torch
import jittor as jt
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32))
dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32))
def test_acc(a, b):
# if isinstance(a, jt.Var):
a = np.array(a)
b = np.array(b)
if np.sum(a - b) < 1e-3:
print("pass")
else:
raise ValueError("not match")
def test_sample(a, b):
a_record, b_record = [], []
# sample 10000次,取均值
for i in range(10000):
a_record.append(np.array(a.sample((10,4))))
b_record.append(np.array(b.sample()))
a_mean = np.array(a_record).mean(0)
b_mean = np.array(b_record).mean(0)
if np.sum(a_mean - b_mean) < 1e-3:
print("pass")
else:
raise ValueError("not match")
if __name__ == '__main__':
# test log_prob
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32))
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32))
test_acc(log_prob_jittor, log_prob_normal)
# test sample()
test_sample(dis_jittor, dis_torch)
# test entropy
entropy_jittor = dis_jittor.entropy()
entropy_normal = dis_torch.entropy()
test_acc(entropy_jittor, entropy_normal)