You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

108 lines
3.6 KB

# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import math
import numpy as np
# import sys # 导入sys模块
# sys.setrecursionlimit(30000)
def simple_presum(x):
src = '''
__inline_static__
@python.jittor.auto_parallel(1)
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
out[i0*(nl+1)] = 0;
for (int i=0; i<nl; i++)
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
}
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
'''
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
cpu_src=src, cuda_src=src)
class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
probs = jt.sigmoid(logits)
self.probs = probs
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
if logits is None:
self.logits = jt.log(self.probs)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self, value):
logits = self.logits
value = jt.broadcast(value, logits.shape)
log_pmf = jt.broadcast(logits, logits.shape)
indices = jt.argmax(value, dim=-1)[0]
return log_pmf.gather(1, indices.unsqueeze(-1)).squeeze(-1)
def entropy(self):
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
class Categorical:
def __init__(self, probs=None, logits=None):
OneHotCategorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.log(self.probs)[0,x]
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
class Normal:
@property
def mean(self):
return self.loc
@property
def stddev(self):
return self.scale
@property
def variance(self):
return self.stddev.pow(2)
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.log_scale = jt.log(self.scale)
def log_prob(self, x):
return -((x - self.loc) ** 2) / (2 * self.variance) - self.log_scale - math.log(math.sqrt(2 * math.pi))
def sample(self, sample_shape=[]):
shape = sample_shape + self.loc.shape
with jt.no_grad():
eps = jt.randn(shape)
return self.loc + self.scale * eps
def entropy(self):
return self.log_scale + 0.5 * math.log(2 * math.pi * math.e)