def _parameter_control_dependencies(self, is_init):
    assertions = []

    # Check num_steps is a scalar that's at least 1.
    if is_init != tensor_util.is_ref(self.num_steps):
      num_steps = tf.convert_to_tensor(self.num_steps)
      num_steps_ = tf.get_static_value(num_steps)
      if num_steps_ is not None:
        if np.ndim(num_steps_) != 0:
          raise ValueError(
              '`num_steps` must be a scalar but it has rank {}'.format(
                  np.ndim(num_steps_)))
        if num_steps_ < 1:
          raise ValueError('`num_steps` must be at least 1.')
      elif self.validate_args:
        message = '`num_steps` must be a scalar'
        assertions.append(
            assert_util.assert_rank_at_most(self.num_steps, 0, message=message))
        assertions.append(
            assert_util.assert_greater_equal(
                num_steps, 1,
                message='`num_steps` must be at least 1.'))

    # Check that the initial distribution has scalar events over the
    # integers.
    if is_init and not dtype_util.is_integer(self.initial_distribution.dtype):
      raise ValueError(
          '`initial_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.initial_distribution.dtype)))

    if tensorshape_util.rank(self.initial_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.initial_distribution.event_shape) != 0:
        raise ValueError('`initial_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              ps.size(self.initial_distribution.event_shape_tensor()),
              0,
              message='`initial_distribution` must have scalar `event_dim`s'),
      ]

    # Check that the transition distribution is over the integers.
    if (is_init and
        not dtype_util.is_integer(self.transition_distribution.dtype)):
      raise ValueError(
          '`transition_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.transition_distribution.dtype)))

    # Check observations have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0:
      raise ValueError(
          "`observation_distribution` can't have scalar batches")

    # Check transitions have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0:
      raise ValueError(
          "`transition_distribution` can't have scalar batches")

    # Check compatibility of transition distribution and observation
    # distribution.
    tdbs = self.transition_distribution.batch_shape
    odbs = self.observation_distribution.batch_shape
    if (tensorshape_util.dims(tdbs) is not None and
        tf.compat.dimension_value(odbs[-1]) is not None):
      if (tf.compat.dimension_value(tdbs[-1]) !=
          tf.compat.dimension_value(odbs[-1])):
        raise ValueError(
            '`transition_distribution` and `observation_distribution` '
            'must agree on last dimension of batch size')
    elif self.validate_args:
      tdbs = self.transition_distribution.batch_shape_tensor()
      odbs = self.observation_distribution.batch_shape_tensor()
      transition_precondition = assert_util.assert_greater(
          ps.size(tdbs), 0,
          message=('`transition_distribution` can\'t have scalar '
                   'batches'))
      observation_precondition = assert_util.assert_greater(
          ps.size(odbs), 0,
          message=('`observation_distribution` can\'t have scalar '
                   'batches'))
      with tf.control_dependencies([
          transition_precondition,
          observation_precondition]):
        assertions += [
            assert_util.assert_equal(
                tdbs[-1],
                odbs[-1],
                message=('`transition_distribution` and '
                         '`observation_distribution` '
                         'must agree on last dimension of batch size'))]

    return assertions
Esempio n. 2
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    axis = None
    paddings = None

    if is_init != tensor_util.is_ref(self.axis):
      # First we check the shape of the axis argument.
      msg = 'Argument `axis` must be scalar or vector.'
      if tensorshape_util.rank(self.axis.shape) is not None:
        if tensorshape_util.rank(self.axis.shape) > 1:
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_rank_at_most(
            axis, 1, message=msg))
      # Next we check the values of the axis argument.
      axis_ = tf.get_static_value(self.axis)
      msg = 'Argument `axis` must be negative.'
      if axis_ is not None:
        if np.any(axis_ > -1):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_less(axis, 0, message=msg))
      msg = 'Argument `axis` elements must be unique.'
      if axis_ is not None:
        if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_equal(
            prefer_static.size0(axis),
            prefer_static.size0(prefer_static.setdiff1d(axis)),
            message=msg))

    if is_init != tensor_util.is_ref(self.paddings):
      # First we check the shape of the paddings argument.
      msg = 'Argument `paddings` must be a vector of pairs.'
      if tensorshape_util.is_fully_defined(self.paddings.shape):
        shape = np.int32(self.paddings.shape)
        if len(shape) != 2 or shape[0] < 1 or shape[1] != 2:
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        with tf.control_dependencies([
            assert_util.assert_equal(tf.rank(paddings), 2, message=msg)]):
          shape = tf.shape(paddings)
          assertions.extend([
              assert_util.assert_greater(shape[0], 0, message=msg),
              assert_util.assert_equal(shape[1], 2, message=msg),
          ])
      # Next we check the values of the paddings argument.
      paddings_ = tf.get_static_value(self.paddings)
      msg = 'Argument `paddings` must be non-negative.'
      if paddings_ is not None:
        if np.any(paddings_ < 0):
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        assertions.append(assert_util.assert_greater(
            paddings, -1, message=msg))

    if is_init != (tensor_util.is_ref(self.axis) and
                   tensor_util.is_ref(self.paddings)):
      axis_ = tf.get_static_value(self.axis)
      if axis_ is None and axis is None:
        axis = tf.convert_to_tensor(self.axis)
      len_axis = prefer_static.size0(prefer_static.reshape(
          axis if axis_ is None else axis_, shape=-1))

      paddings_ = tf.get_static_value(self.paddings)
      if paddings_ is None and paddings is None:
        paddings = tf.convert_to_tensor(self.paddings)
      len_paddings = prefer_static.size0(
          paddings if paddings_ is None else paddings_)

      msg = ('Arguments `axis` and `paddings` must have the same number '
             'of elements.')
      if (prefer_static.is_numpy(len_axis) and
          prefer_static.is_numpy(len_paddings)):
        if len_axis != len_paddings:
          raise ValueError(msg + ' Saw: {}, {}.'.format(
              self.axis, self.paddings))
      elif self.validate_args:
        assertions.append(assert_util.assert_equal(
            len_axis, len_paddings, message=msg))

    return assertions