from torch.distributions import Normal as Normal_t
|
|
from distributions import Normal as Normal_j
|
|
import numpy as np
|
|
|
|
import torch
|
|
import jittor as jt
|
|
|
|
|
|
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
|
|
dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32))
|
|
dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32))
|
|
|
|
def test_acc(a, b):
|
|
# if isinstance(a, jt.Var):
|
|
a = np.array(a)
|
|
b = np.array(b)
|
|
if np.sum(a - b) < 1e-3:
|
|
print("pass")
|
|
else:
|
|
raise ValueError("not match")
|
|
|
|
def test_sample(a, b):
|
|
a_record, b_record = [], []
|
|
# sample 10000次,取均值
|
|
for i in range(10000):
|
|
a_record.append(np.array(a.sample((10,4))))
|
|
b_record.append(np.array(b.sample()))
|
|
|
|
a_mean = np.array(a_record).mean(0)
|
|
b_mean = np.array(b_record).mean(0)
|
|
|
|
if np.sum(a_mean - b_mean) < 1e-3:
|
|
print("pass")
|
|
else:
|
|
raise ValueError("not match")
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# test log_prob
|
|
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32))
|
|
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32))
|
|
test_acc(log_prob_jittor, log_prob_normal)
|
|
|
|
# test sample()
|
|
test_sample(dis_jittor, dis_torch)
|
|
|
|
# test entropy
|
|
entropy_jittor = dis_jittor.entropy()
|
|
entropy_normal = dis_torch.entropy()
|
|
test_acc(entropy_jittor, entropy_normal)
|