Browse Source

update

master
huaqimao 3 years ago
parent
commit
64e2fca998
5 changed files with 275 additions and 0 deletions
  1. BIN
      __pycache__/distributions.cpython-38.pyc
  2. +120
    -0
      distributions.py
  3. +52
    -0
      test_categorical.py
  4. +51
    -0
      test_normal.py
  5. +52
    -0
      test_onehot_categorical.py

BIN
__pycache__/distributions.cpython-38.pyc View File


+ 120
- 0
distributions.py View File

@ -0,0 +1,120 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import math
import numpy as np
# import sys # 导入sys模块
# sys.setrecursionlimit(30000)
def simple_presum(x):
src = '''
__inline_static__
@python.jittor.auto_parallel(1)
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
out[i0*(nl+1)] = 0;
for (int i=0; i<nl; i++)
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
}
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
'''
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
cpu_src=src, cuda_src=src)
class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
probs = jt.sigmoid(logits)
self.probs = probs
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
if logits is None:
self.logits = jt.log(self.probs)
# self._categorical = Categorical(probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self, value):
logits = self.logits
value = jt.broadcast(value, logits.shape)
# import ipdb
# ipdb.set_trace()
# value = jt.array(value)
log_pmf = jt.broadcast(logits, logits.shape)
# log_pmf = jt.array(log_pmf)
indices = jt.argmax(value, dim=-1)[0]
return log_pmf.gather(1, indices.unsqueeze(-1)).squeeze(-1)
def entropy(self):
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
class Categorical:
def __init__(self, probs=None, logits=None):
OneHotCategorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.log(self.probs)[0,x]
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
class Normal:
@property
def mean(self):
return self.loc
@property
def stddev(self):
return self.scale
@property
def variance(self):
return self.stddev.pow(2)
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.log_scale = jt.log(self.scale)
def log_prob(self, x):
return -((x - self.loc) ** 2) / (2 * self.variance) - self.log_scale - math.log(math.sqrt(2 * math.pi))
def sample(self, sample_shape=[]):
shape = sample_shape + self.loc.shape
with jt.no_grad():
eps = jt.randn(shape)
return self.loc + self.scale * eps
def entropy(self):
return self.log_scale + 0.5 * math.log(2 * math.pi * math.e)

+ 52
- 0
test_categorical.py View File

@ -0,0 +1,52 @@
from torch.distributions import Categorical as Categorical_t
from jittor.distributions import Categorical as Categorical_j
import numpy as np
import torch
import jittor as jt
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
dis_torch = Categorical_t(torch.ones((10, 4), dtype=torch.float32))
dis_jittor = Categorical_j(jt.ones((10, 4), dtype=jt.float32))
def test_acc(a, b):
# if isinstance(a, jt.Var):
a = np.array(a)
b = np.array(b)
if np.sum(a - b) < 1e-2:
print("pass")
else:
raise ValueError("not match")
def test_sample(a, b):
a_record, b_record = [], []
# sample 10000次,取均值
for i in range(10000):
a_record.append(np.array(a.sample()))
b_record.append(np.array(b.sample()))
a_mean = np.array(a_record).mean(0)
b_mean = np.array(b_record).mean(0)
if np.sum(a_mean - b_mean) < 1e-2:
print("pass")
else:
raise ValueError("not match")
if __name__ == '__main__':
# test log_prob
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32))
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32))
print(log_prob_jittor, log_prob_normal)
test_acc(log_prob_jittor, log_prob_normal)
# test sample()
test_sample(dis_jittor, dis_torch)
# test entropy
entropy_jittor = dis_jittor.entropy()
entropy_normal = dis_torch.entropy()
test_acc(entropy_jittor, entropy_normal)

+ 51
- 0
test_normal.py View File

@ -0,0 +1,51 @@
from torch.distributions import Normal as Normal_t
from distributions import Normal as Normal_j
import numpy as np
import torch
import jittor as jt
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
dis_torch = Normal_t(torch.ones((10, 4), dtype=torch.float32), torch.ones((10, 4), dtype=torch.float32))
dis_jittor = Normal_j(jt.ones((10, 4), dtype=jt.float32), jt.ones((10, 4), dtype=jt.float32))
def test_acc(a, b):
# if isinstance(a, jt.Var):
a = np.array(a)
b = np.array(b)
if np.sum(a - b) < 1e-3:
print("pass")
else:
raise ValueError("not match")
def test_sample(a, b):
a_record, b_record = [], []
# sample 10000次,取均值
for i in range(10000):
a_record.append(np.array(a.sample((10,4))))
b_record.append(np.array(b.sample()))
a_mean = np.array(a_record).mean(0)
b_mean = np.array(b_record).mean(0)
if np.sum(a_mean - b_mean) < 1e-3:
print("pass")
else:
raise ValueError("not match")
if __name__ == '__main__':
# test log_prob
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 1), dtype=jt.float32))
log_prob_normal = dis_torch.log_prob(torch.ones((10, 1), dtype=torch.float32))
test_acc(log_prob_jittor, log_prob_normal)
# test sample()
test_sample(dis_jittor, dis_torch)
# test entropy
entropy_jittor = dis_jittor.entropy()
entropy_normal = dis_torch.entropy()
test_acc(entropy_jittor, entropy_normal)

+ 52
- 0
test_onehot_categorical.py View File

@ -0,0 +1,52 @@
from torch.distributions import OneHotCategorical as Categorical_t
from distributions import OneHotCategorical as Categorical_j
import numpy as np
import torch
import jittor as jt
# dis_torch = Normal_t(torch.ones(10, 4).float(), torch.ones(10, 4).float())
dis_torch = Categorical_t(torch.ones((10, 4), dtype=torch.float32))
dis_jittor = Categorical_j(jt.ones((10, 4), dtype=jt.float32))
def test_acc(a, b):
# if isinstance(a, jt.Var):
a = np.array(a)
b = np.array(b)
if np.sum(a - b) < 1e-3:
print("pass")
else:
raise ValueError("not match")
def test_sample(a, b):
a_record, b_record = [], []
# sample 10000次,取均值
for i in range(10000):
a_record.append(np.array(a.sample()))
b_record.append(np.array(b.sample()))
a_mean = np.array(a_record).mean(0)
b_mean = np.array(b_record).mean(0)
if np.sum(a_mean - b_mean) < 1e-2:
print("pass")
else:
raise ValueError("not match")
if __name__ == '__main__':
# test log_prob
log_prob_jittor = dis_jittor.log_prob(jt.ones((10, 4), dtype=jt.float32))
log_prob_normal = dis_torch.log_prob(torch.ones((10, 4), dtype=torch.float32))
test_acc(log_prob_jittor, log_prob_normal)
# test sample()
test_sample(dis_jittor, dis_torch)
print(dis_jittor.sample().shape)
print(dis_torch.sample().shape)
# test entropy
entropy_jittor = dis_jittor.entropy()
entropy_normal = dis_torch.entropy()
test_acc(entropy_jittor, entropy_normal)

Loading…
Cancel
Save