|
|
- from torch.distributions import Normal as Normal_t
- from distributions import Normal as Normal_j
- import numpy as np
-
- import torch
- import jittor as jt
-
-
- # dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
- dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32))
- dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32))
-
- def test_acc(a, b):
- # if isinstance(a, jt.Var):
- a = np.array(a)
- b = np.array(b)
- if np.sum(a - b) < 1e-3:
- print("pass")
- else:
- raise ValueError("not match")
-
- def test_sample(a, b):
- a_record, b_record = [], []
- # sample 10000次,取均值
- for i in range(10000):
- a_record.append(np.array(a.sample((10,4))))
- b_record.append(np.array(b.sample()))
-
- a_mean = np.array(a_record).mean(0)
- b_mean = np.array(b_record).mean(0)
-
- if np.sum(a_mean - b_mean) < 1e-3:
- print("pass")
- else:
- raise ValueError("not match")
-
-
-
- if __name__ == '__main__':
- # test log_prob
- log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32))
- log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32))
- test_acc(log_prob_jittor, log_prob_normal)
-
- # test sample()
- test_sample(dis_jittor, dis_torch)
-
- # test entropy
- entropy_jittor = dis_jittor.entropy()
- entropy_normal = dis_torch.entropy()
- test_acc(entropy_jittor, entropy_normal)
|