@ -0,0 +1,120 @@ | |||
# *************************************************************** | |||
# Copyright (c) 2021 Jittor. All Rights Reserved. | |||
# Maintainers: | |||
# Dun Liang <randonlang@gmail.com>. | |||
# | |||
# This file is subject to the terms and conditions defined in | |||
# file 'LICENSE.txt', which is part of this source code package. | |||
# *************************************************************** | |||
import jittor as jt | |||
import math | |||
import numpy as np | |||
# import sys # 导入sys模块 | |||
# sys.setrecursionlimit(30000) | |||
def simple_presum(x): | |||
src = ''' | |||
__inline_static__ | |||
@python.jittor.auto_parallel(1) | |||
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) { | |||
out[i0*(nl+1)] = 0; | |||
for (int i=0; i<nl; i++) | |||
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i]; | |||
} | |||
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]); | |||
''' | |||
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x], | |||
cpu_src=src, cuda_src=src) | |||
class OneHotCategorical: | |||
def __init__(self, probs=None, logits=None): | |||
assert not (probs is None and logits is None) | |||
if probs is None: | |||
probs = jt.sigmoid(logits) | |||
self.probs = probs | |||
with jt.no_grad(): | |||
self.probs = probs / probs.sum(-1, True) | |||
self.cum_probs = simple_presum(self.probs) | |||
self.cum_probs_l = self.cum_probs[..., :-1] | |||
self.cum_probs_r = self.cum_probs[..., 1:] | |||
if logits is None: | |||
self.logits = jt.log(self.probs) | |||
# self._categorical = Categorical(probs, logits) | |||
def sample(self, sample_shape=[]): | |||
shape = sample_shape + self.probs.shape[:-1] + (1,) | |||
rand = jt.rand(shape) | |||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float() | |||
return one_hot | |||
def log_prob(self, value): | |||
logits = self.logits | |||
value = jt.broadcast(value, logits.shape) | |||
# import ipdb | |||
# ipdb.set_trace() | |||
# value = jt.array(value) | |||
log_pmf = jt.broadcast(logits, logits.shape) | |||
# log_pmf = jt.array(log_pmf) | |||
indices = jt.argmax(value, dim=-1)[0] | |||
return log_pmf.gather(1, indices.unsqueeze(-1)).squeeze(-1) | |||
def entropy(self): | |||
p_log_p = self.logits * self.probs | |||
return -p_log_p.sum(-1) | |||
class Categorical: | |||
def __init__(self, probs=None, logits=None): | |||
OneHotCategorical.__init__(self, probs, logits) | |||
def sample(self, sample_shape=[]): | |||
shape = sample_shape + self.probs.shape[:-1] + (1,) | |||
rand = jt.rand(shape) | |||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r) | |||
index = one_hot.index(one_hot.ndim-1) | |||
return (one_hot * index).sum(-1) | |||
def log_prob(self, x): | |||
return jt.log(self.probs)[0,x] | |||
def entropy(self): | |||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) | |||
logits = jt.clamp(self.logits,min_v=min_real) | |||
p_log_p = logits * self.probs | |||
return -p_log_p.sum(-1) | |||
class Normal: | |||
@property | |||
def mean(self): | |||
return self.loc | |||
@property | |||
def stddev(self): | |||
return self.scale | |||
@property | |||
def variance(self): | |||
return self.stddev.pow(2) | |||
def __init__(self, loc, scale): | |||
self.loc = loc | |||
self.scale = scale | |||
self.log_scale = jt.log(self.scale) | |||
def log_prob(self, x): | |||
return -((x - self.loc) ** 2) / (2 * self.variance) - self.log_scale - math.log(math.sqrt(2 * math.pi)) | |||
def sample(self, sample_shape=[]): | |||
shape = sample_shape + self.loc.shape | |||
with jt.no_grad(): | |||
eps = jt.randn(shape) | |||
return self.loc + self.scale * eps | |||
def entropy(self): | |||
return self.log_scale + 0.5 * math.log(2 * math.pi * math.e) |
@ -0,0 +1,52 @@ | |||
from torch.distributions import Categorical as Categorical_t | |||
from jittor.distributions import Categorical as Categorical_j | |||
import numpy as np | |||
import torch | |||
import jittor as jt | |||
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float()) | |||
dis_torch = Categorical_t(torch.ones((10, 4), dtype=torch.float32)) | |||
dis_jittor = Categorical_j(jt.ones((10, 4), dtype=jt.float32)) | |||
def test_acc(a, b): | |||
# if isinstance(a, jt.Var): | |||
a = np.array(a) | |||
b = np.array(b) | |||
if np.sum(a - b) < 1e-2: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
def test_sample(a, b): | |||
a_record, b_record = [], [] | |||
# sample 10000次,取均值 | |||
for i in range(10000): | |||
a_record.append(np.array(a.sample())) | |||
b_record.append(np.array(b.sample())) | |||
a_mean = np.array(a_record).mean(0) | |||
b_mean = np.array(b_record).mean(0) | |||
if np.sum(a_mean - b_mean) < 1e-2: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
if __name__ == '__main__': | |||
# test log_prob | |||
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32)) | |||
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32)) | |||
print(log_prob_jittor, log_prob_normal) | |||
test_acc(log_prob_jittor, log_prob_normal) | |||
# test sample() | |||
test_sample(dis_jittor, dis_torch) | |||
# test entropy | |||
entropy_jittor = dis_jittor.entropy() | |||
entropy_normal = dis_torch.entropy() | |||
test_acc(entropy_jittor, entropy_normal) |
@ -0,0 +1,51 @@ | |||
from torch.distributions import Normal as Normal_t | |||
from distributions import Normal as Normal_j | |||
import numpy as np | |||
import torch | |||
import jittor as jt | |||
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float()) | |||
dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32)) | |||
dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32)) | |||
def test_acc(a, b): | |||
# if isinstance(a, jt.Var): | |||
a = np.array(a) | |||
b = np.array(b) | |||
if np.sum(a - b) < 1e-3: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
def test_sample(a, b): | |||
a_record, b_record = [], [] | |||
# sample 10000次,取均值 | |||
for i in range(10000): | |||
a_record.append(np.array(a.sample((10,4)))) | |||
b_record.append(np.array(b.sample())) | |||
a_mean = np.array(a_record).mean(0) | |||
b_mean = np.array(b_record).mean(0) | |||
if np.sum(a_mean - b_mean) < 1e-3: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
if __name__ == '__main__': | |||
# test log_prob | |||
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32)) | |||
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32)) | |||
test_acc(log_prob_jittor, log_prob_normal) | |||
# test sample() | |||
test_sample(dis_jittor, dis_torch) | |||
# test entropy | |||
entropy_jittor = dis_jittor.entropy() | |||
entropy_normal = dis_torch.entropy() | |||
test_acc(entropy_jittor, entropy_normal) |
@ -0,0 +1,52 @@ | |||
from torch.distributions import OneHotCategorical as Categorical_t | |||
from distributions import OneHotCategorical as Categorical_j | |||
import numpy as np | |||
import torch | |||
import jittor as jt | |||
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float()) | |||
dis_torch = Categorical_t(torch.ones((10, 4), dtype=torch.float32)) | |||
dis_jittor = Categorical_j(jt.ones((10, 4), dtype=jt.float32)) | |||
def test_acc(a, b): | |||
# if isinstance(a, jt.Var): | |||
a = np.array(a) | |||
b = np.array(b) | |||
if np.sum(a - b) < 1e-3: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
def test_sample(a, b): | |||
a_record, b_record = [], [] | |||
# sample 10000次,取均值 | |||
for i in range(10000): | |||
a_record.append(np.array(a.sample())) | |||
b_record.append(np.array(b.sample())) | |||
a_mean = np.array(a_record).mean(0) | |||
b_mean = np.array(b_record).mean(0) | |||
if np.sum(a_mean - b_mean) < 1e-2: | |||
print("pass") | |||
else: | |||
raise ValueError("not match") | |||
if __name__ == '__main__': | |||
# test log_prob | |||
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 4), dtype=jt.float32)) | |||
log_prob_normal = dis_torch.log_prob(torch.ones((10, 4), dtype=torch.float32)) | |||
test_acc(log_prob_jittor, log_prob_normal) | |||
# test sample() | |||
test_sample(dis_jittor, dis_torch) | |||
print(dis_jittor.sample().shape) | |||
print(dis_torch.sample().shape) | |||
# test entropy | |||
entropy_jittor = dis_jittor.entropy() | |||
entropy_normal = dis_torch.entropy() | |||
test_acc(entropy_jittor, entropy_normal) |