|
|
@ -1,51 +0,0 @@ |
|
|
|
from torch.distributions import Normal as Normal_t |
|
|
|
from jittor.distributions import Normal as Normal_j |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import torch |
|
|
|
import jittor as jt |
|
|
|
|
|
|
|
|
|
|
|
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float()) |
|
|
|
dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32)) |
|
|
|
dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32)) |
|
|
|
|
|
|
|
def test_acc(a, b): |
|
|
|
# if isinstance(a, jt.Var): |
|
|
|
a = np.array(a) |
|
|
|
b = np.array(b) |
|
|
|
if np.sum(a - b) < 1e-3: |
|
|
|
print("pass") |
|
|
|
else: |
|
|
|
raise ValueError("not match") |
|
|
|
|
|
|
|
def test_sample(a, b): |
|
|
|
a_record, b_record = [], [] |
|
|
|
# sample 10000次,取均值 |
|
|
|
for i in range(10000): |
|
|
|
a_record.append(np.array(a.sample((10,4)))) |
|
|
|
b_record.append(np.array(b.sample())) |
|
|
|
|
|
|
|
a_mean = np.array(a_record).mean(0) |
|
|
|
b_mean = np.array(b_record).mean(0) |
|
|
|
|
|
|
|
if np.sum(a_mean - b_mean) < 1e-3: |
|
|
|
print("pass") |
|
|
|
else: |
|
|
|
raise ValueError("not match") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
# test log_prob |
|
|
|
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32)) |
|
|
|
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32)) |
|
|
|
test_acc(log_prob_jittor, log_prob_normal) |
|
|
|
|
|
|
|
# test sample() |
|
|
|
test_sample(dis_jittor, dis_torch) |
|
|
|
|
|
|
|
# test entropy |
|
|
|
entropy_jittor = dis_jittor.entropy() |
|
|
|
entropy_normal = dis_torch.entropy() |
|
|
|
test_acc(entropy_jittor, entropy_normal) |