def __init__(self, bijector, distribution, seed=None, name="transformed_distribution"): """ Constructor of transformed_distribution class. """ param = dict(locals()) validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__) validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) self._bijector = bijector self._distribution = distribution self._is_linear_transformation = bijector.is_constant_jacobian self.default_parameters = distribution.default_parameters self.parameter_names = distribution.parameter_names self.exp = exp_generic self.log = log_generic self.isnan = P.IsNan() self.equal_base = P.Equal() self.select_base = P.Select() self.fill = P.Fill() # check if batch shape of the distribution and event shape is broadcastable if hasattr(self.bijector, 'event_shape'): event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape else: self._batch_event = self.broadcast_shape
def __init__(self, probs=None, seed=None, dtype=mstype.int32, name="Categorical"): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Categorical, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') if self.probs is not None: check_rank(self.probs) check_prob(self.probs) check_sum_equal_one(probs) # update is_scalar_batch and broadcast_shape # drop one dimension if self.probs.shape[:-1] == (): self._is_scalar_batch = True self._broadcast_shape = self._broadcast_shape[:-1] self.argmax = P.ArgMaxWithValue(axis=-1) self.broadcast = broadcast_to self.cast = P.Cast() self.clip_by_value = C.clip_by_value self.concat = P.Concat(-1) self.cumsum = P.CumSum() self.dtypeop = P.DType() self.exp = exp_generic self.expand_dim = P.ExpandDims() self.fill = P.Fill() self.gather = P.GatherNd() self.greater = P.Greater() self.issubclass = P.IsSubClass() self.less = P.Less() self.log = log_generic self.log_softmax = P.LogSoftmax() self.logicor = P.LogicalOr() self.logicand = P.LogicalAnd() self.multinomial = P.Multinomial(seed=self.seed) self.reshape = P.Reshape() self.reduce_sum = P.ReduceSum(keep_dims=True) self.select = P.Select() self.shape = P.Shape() self.softmax = P.Softmax() self.squeeze = P.Squeeze() self.squeeze_first_axis = P.Squeeze(0) self.squeeze_last_axis = P.Squeeze(-1) self.square = P.Square() self.transpose = P.Transpose() self.is_nan = P.IsNan() self.index_type = mstype.int32 self.nan = np.nan
def __init__(self, bijector, distribution, seed=None, name="transformed_distribution"): """ Constructor of transformed_distribution class. """ param = dict(locals()) validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__) validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__) super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) self._bijector = bijector self._distribution = distribution # set attributes self._is_linear_transformation = self.bijector.is_constant_jacobian self._dtype = self.distribution.dtype self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch self._batch_shape = self.distribution.batch_shape self.default_parameters = self.distribution.default_parameters self.parameter_names = self.distribution.parameter_names # by default, set the parameter_type to be the distribution's parameter_type self.parameter_type = self.distribution.parameter_type self.exp = exp_generic self.log = log_generic self.isnan = P.IsNan() self.cast_base = P.Cast() self.equal_base = P.Equal() self.select_base = P.Select() self.fill_base = P.Fill() # broadcast bijector batch_shape and distribution batch_shape self._broadcast_shape = self._broadcast_bijector_dist()
def __init__(self): super(Netnan, self).__init__() self.isnan = P.IsNan()