def _parameter_control_dependencies(self, is_init): assertions = [] # Check num_steps is a scalar that's at least 1. if is_init != tensor_util.is_ref(self.num_steps): num_steps = tf.convert_to_tensor(self.num_steps) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'.format( np.ndim(num_steps_))) if num_steps_ < 1: raise ValueError('`num_steps` must be at least 1.') elif self.validate_args: message = '`num_steps` must be a scalar' assertions.append( assert_util.assert_rank_at_most(self.num_steps, 0, message=message)) assertions.append( assert_util.assert_greater_equal( num_steps, 1, message='`num_steps` must be at least 1.')) # Check that the initial distribution has scalar events over the # integers. if is_init and not dtype_util.is_integer(self.initial_distribution.dtype): raise ValueError( '`initial_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.initial_distribution.dtype))) if tensorshape_util.rank(self.initial_distribution.event_shape) is not None: if tensorshape_util.rank(self.initial_distribution.event_shape) != 0: raise ValueError('`initial_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( ps.size(self.initial_distribution.event_shape_tensor()), 0, message='`initial_distribution` must have scalar `event_dim`s'), ] # Check that the transition distribution is over the integers. if (is_init and not dtype_util.is_integer(self.transition_distribution.dtype)): raise ValueError( '`transition_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.transition_distribution.dtype))) # Check observations have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0: raise ValueError( "`observation_distribution` can't have scalar batches") # Check transitions have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0: raise ValueError( "`transition_distribution` can't have scalar batches") # Check compatibility of transition distribution and observation # distribution. tdbs = self.transition_distribution.batch_shape odbs = self.observation_distribution.batch_shape if (tensorshape_util.dims(tdbs) is not None and tf.compat.dimension_value(odbs[-1]) is not None): if (tf.compat.dimension_value(tdbs[-1]) != tf.compat.dimension_value(odbs[-1])): raise ValueError( '`transition_distribution` and `observation_distribution` ' 'must agree on last dimension of batch size') elif self.validate_args: tdbs = self.transition_distribution.batch_shape_tensor() odbs = self.observation_distribution.batch_shape_tensor() transition_precondition = assert_util.assert_greater( ps.size(tdbs), 0, message=('`transition_distribution` can\'t have scalar ' 'batches')) observation_precondition = assert_util.assert_greater( ps.size(odbs), 0, message=('`observation_distribution` can\'t have scalar ' 'batches')) with tf.control_dependencies([ transition_precondition, observation_precondition]): assertions += [ assert_util.assert_equal( tdbs[-1], odbs[-1], message=('`transition_distribution` and ' '`observation_distribution` ' 'must agree on last dimension of batch size'))] return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] axis = None paddings = None if is_init != tensor_util.is_ref(self.axis): # First we check the shape of the axis argument. msg = 'Argument `axis` must be scalar or vector.' if tensorshape_util.rank(self.axis.shape) is not None: if tensorshape_util.rank(self.axis.shape) > 1: raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_rank_at_most( axis, 1, message=msg)) # Next we check the values of the axis argument. axis_ = tf.get_static_value(self.axis) msg = 'Argument `axis` must be negative.' if axis_ is not None: if np.any(axis_ > -1): raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_less(axis, 0, message=msg)) msg = 'Argument `axis` elements must be unique.' if axis_ is not None: if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)): raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_equal( prefer_static.size0(axis), prefer_static.size0(prefer_static.setdiff1d(axis)), message=msg)) if is_init != tensor_util.is_ref(self.paddings): # First we check the shape of the paddings argument. msg = 'Argument `paddings` must be a vector of pairs.' if tensorshape_util.is_fully_defined(self.paddings.shape): shape = np.int32(self.paddings.shape) if len(shape) != 2 or shape[0] < 1 or shape[1] != 2: raise ValueError(msg) elif self.validate_args: if paddings is None: paddings = tf.convert_to_tensor(self.paddings) with tf.control_dependencies([ assert_util.assert_equal(tf.rank(paddings), 2, message=msg)]): shape = tf.shape(paddings) assertions.extend([ assert_util.assert_greater(shape[0], 0, message=msg), assert_util.assert_equal(shape[1], 2, message=msg), ]) # Next we check the values of the paddings argument. paddings_ = tf.get_static_value(self.paddings) msg = 'Argument `paddings` must be non-negative.' if paddings_ is not None: if np.any(paddings_ < 0): raise ValueError(msg) elif self.validate_args: if paddings is None: paddings = tf.convert_to_tensor(self.paddings) assertions.append(assert_util.assert_greater( paddings, -1, message=msg)) if is_init != (tensor_util.is_ref(self.axis) and tensor_util.is_ref(self.paddings)): axis_ = tf.get_static_value(self.axis) if axis_ is None and axis is None: axis = tf.convert_to_tensor(self.axis) len_axis = prefer_static.size0(prefer_static.reshape( axis if axis_ is None else axis_, shape=-1)) paddings_ = tf.get_static_value(self.paddings) if paddings_ is None and paddings is None: paddings = tf.convert_to_tensor(self.paddings) len_paddings = prefer_static.size0( paddings if paddings_ is None else paddings_) msg = ('Arguments `axis` and `paddings` must have the same number ' 'of elements.') if (prefer_static.is_numpy(len_axis) and prefer_static.is_numpy(len_paddings)): if len_axis != len_paddings: raise ValueError(msg + ' Saw: {}, {}.'.format( self.axis, self.paddings)) elif self.validate_args: assertions.append(assert_util.assert_equal( len_axis, len_paddings, message=msg)) return assertions