Example #1
0
    def step(self, player, home_away_race, upgrades, available_act_mask, minimap):
        """Sample actions and compute logp(a|s)"""
        out = self.call(player, available_act_mask, minimap)

        # Gumbel-max sampling
        action_id = categorical_sample(out["action_id"], available_act_mask)

        tf.assert_greater(available_act_mask[:, action_id.numpy().item()], 0.0)

        # Fill out args based on sampled action type
        arg_spatial = []
        arg_nonspatial = []

        logp_a = log_prob(action_id, out["action_id"])

        for arg_type in self.action_spec.functions[action_id.numpy().item()].args:
            if arg_type.name in ["screen", "screen2", "minimap"]:
                location_id = categorical_sample(out["target_location"])
                arg_spatial.append(location_id)
                logp_a += log_prob(location_id, out["target_location"])
            else:
                # non-spatial args
                sample = categorical_sample(out[arg_type.name])
                arg_nonspatial.append(sample)
                logp_a += log_prob(sample, out[arg_type.name])
        # tf.debugging.check_numerics(logp_a, "Bad logp(a|s)")

        return (
            out["value"],
            action_id,
            arg_spatial,
            arg_nonspatial,
            logp_a,
        )
Example #2
0
 def get_encoded_inputs(self, *x_list, **kwargs):
     """Runs the reference and candidate images through the feature model phi.
 Returns:
   h_train: [B, N, D]
   h_unlabel: [B, P, D]
   h_test: [B, M, D]
 """
     if 'ext_wts' in kwargs:
         ext_wts = kwargs['ext_wts']
     else:
         ext_wts = None
     VAT = False
     if 'VAT' in kwargs:
         VAT = kwargs['VAT']
     config = self.config
     bsize = tf.shape(self.x_train)[0]
     bsize = tf.shape(x_list[0])[0]
     num = [tf.shape(xx)[1] for xx in x_list]
     x_all = concat(x_list, 1)
     x_all = tf.reshape(
         x_all, [-1, config.height, config.width, config.num_channel])
     h_all = self.phi(x_all, ext_wts=ext_wts, VAT=VAT)
     tf.assert_greater(tf.reduce_mean(tf.abs(h_all)), 0.0)
     # h_all_p = self.phi(tf.random_normal(tf.shape(x_all)), ext_wts=ext_wts)
     # h_all = tf.Print(h_all, [tf.reduce_sum(h_all),tf.reduce_sum(h_all - h_all_p)], '\n-----------')
     h_all = tf.reshape(h_all, [bsize, sum(num), -1])
     h_list = tf.split(h_all, num, axis=1)
     return h_list
Example #3
0
    def __call__(self, in_batch):
        # `in_batch` should be a batch tensor

        if self.pool_size == 0:
            return in_batch

        if self.batch_size is None:
            tf.assert_greater(in_batch.shape[0], 0)
            self.batch_size = in_batch.shape[0]
            self.batch_ids = tf.range(in_batch.shape[0])

        replace_batch_bool = self.sampler.sample(self.batch_size)
        replace_batch_ids = self.batch_ids[replace_batch_bool]
        sample_size = tf.shape(replace_batch_ids)[0]
        if sample_size > 0:  # scatterupdate will fail on empty sample, replace with static impl for batch size > 1
            sample_pool_ids = tf.random.shuffle(self.pool_ids)[:sample_size]
            batch_buffer = tf.stop_gradient(in_batch[replace_batch_bool])
            pooled_buffer = tf.gather(self.pool, sample_pool_ids)
            new_batch = tf.tensor_scatter_nd_update(
                in_batch, tf.expand_dims(replace_batch_ids, 1), pooled_buffer)
            self.pool = tf.tensor_scatter_nd_update(
                self.pool, tf.expand_dims(sample_pool_ids, 1), batch_buffer)
        else:
            new_batch = in_batch
        return new_batch
    def __init__(self,
                 point_cloud: PointCloud,
                 cell_sizes,
                 sample_mode='poisson',
                 name=None):
        #Initialize the attributes.
        self._aabb = point_cloud.get_AABB()
        self._point_clouds = [point_cloud]
        self._cell_sizes = []
        self._neighborhoods = []

        self._dimension = point_cloud._dimension
        self._batch_shape = point_cloud._batch_shape

        #Create the different sampling operations.
        cur_point_cloud = point_cloud
        for sample_iter, cur_cell_sizes in enumerate(cell_sizes):
            cur_cell_sizes = tf.convert_to_tensor(value=cur_cell_sizes,
                                                  dtype=tf.float32)

            # Check if the cell size is defined for all the dimensions.
            # If not, the last cell size value is tiled until all the dimensions
            # have a value.
            cur_num_dims = tf.gather(cur_cell_sizes.shape, 0)
            cur_cell_sizes = tf.cond(
                cur_num_dims < self._dimension, lambda: tf.concat(
                    (cur_cell_sizes,
                     tf.tile(
                         tf.gather(cur_cell_sizes, [
                             tf.rank(cur_cell_sizes) - 1
                         ]), [self._dimension - cur_num_dims])),
                    axis=0), lambda: cur_cell_sizes)
            tf.assert_greater(
                self._dimension + 1,
                cur_num_dims,
                f'Too many dimensions in cell sizes {cur_num_dims} ' + \
                f'instead of max. {self._dimension}')
            # old version, does not run in graph mode
            # if cur_num_dims < self._dimension:
            #   cur_cell_sizes = tf.concat((cur_cell_sizes,
            #                  tf.tile(tf.gather(cur_cell_sizes,
            #                                    [tf.rank(cur_cell_sizes) - 1]),
            #                    [self._dimension - cur_num_dims])),
            #                       axis=0)
            # if cur_num_dims > self._dimension:
            #   raise ValueError(
            #       f'Too many dimensions in cell sizes {cur_num_dims} ' + \
            #       f'instead of max. {self._dimension}')

            self._cell_sizes.append(cur_cell_sizes)

            #Create the sampling operation.
            cur_grid = Grid(cur_point_cloud, cur_cell_sizes, self._aabb)
            cur_neighborhood = Neighborhood(cur_grid, cur_cell_sizes)
            cur_point_cloud, _ = sample(cur_neighborhood, sample_mode)

            self._neighborhoods.append(cur_neighborhood)
            cur_point_cloud.set_batch_shape(self._batch_shape)
            self._point_clouds.append(cur_point_cloud)
Example #5
0
def get_gaussian_pdf_equal_split_points(k):
    tf.assert_greater(k, 2)
    lp = tf.cast(tf.linspace(0, 1, k + 1), dtype=tf.float32)
    pk = tfd.Normal(0.0, 1.0).quantile(lp[1:-1])
    qk_pre = (pk[:-1] + pk[1:]) / 2
    qk_first = 2 * pk[0] - qk_pre[0]
    qk_last = 2 * pk[-1] - qk_pre[-1]
    qk = tf.concat([[qk_first], qk_pre, [qk_last]], axis=0)
    return qk
Example #6
0
def build_grid_ds_tf(sorted_keys, num_cells, batch_size, name=None):
    """ Method to build a fast access data structure for point clouds.

  Creates a 2D regular grid in the first two dimension, saving the first and
  last index belonging to that cell array.

  Args:
    sorted_keys: An `int` `Tensor` of shape `[N]`, the sorted keys.
    num_cells: An `int` `Tensor` of shape `[D]`, the total number of cells
      per dimension.
    batch_size: An `int`.

  Returns:
    An `int` `Tensor` of shape `[batch_size, num_cells[0], num_cells[1], 2]`.

  """
    sorted_keys = tf.cast(tf.convert_to_tensor(value=sorted_keys), tf.int32)
    num_cells = tf.cast(tf.convert_to_tensor(value=num_cells), tf.int32)

    num_keys = tf.shape(sorted_keys)[0]
    num_cells_2D = batch_size * num_cells[0] * num_cells[1]
    tf.assert_greater(
        tf.shape(num_cells)[0], 1, 'Points must have dimensionality >1.')
    cells_per_2D_cell = tf.cond(
        tf.shape(num_cells)[0] > 2, lambda: tf.reduce_prod(num_cells[2:]),
        lambda: 1)
    # condition without graph mode
    # if tf.shape(num_cells)[0] > 2:
    #     cells_per_2D_cell = tf.reduce_prod(num_cells[2:])
    # elif tf.shape(num_cells)[0] == 2:
    #     cells_per_2D_cell = 1

    ds_indices = tf.cast(tf.floor(sorted_keys / cells_per_2D_cell),
                         dtype=tf.int32)
    indices = tf.range(0, num_keys, dtype=tf.int32)
    first_per_cell = tf.math.unsorted_segment_min(indices, ds_indices,
                                                  num_cells_2D)
    last_per_cell = tf.math.unsorted_segment_max(indices + 1, ds_indices,
                                                 num_cells_2D)

    empty_cells = first_per_cell < 0
    first_per_cell = tf.where(empty_cells, tf.zeros_like(first_per_cell),
                              first_per_cell)
    last_per_cell = tf.where(empty_cells, tf.zeros_like(last_per_cell),
                             last_per_cell)
    empty_cells = first_per_cell > num_keys
    first_per_cell = tf.where(empty_cells, tf.zeros_like(first_per_cell),
                              first_per_cell)
    last_per_cell = tf.where(empty_cells, tf.zeros_like(last_per_cell),
                             last_per_cell)

    return tf.stack([
        tf.reshape(first_per_cell, [batch_size, num_cells[0], num_cells[1]]),
        tf.reshape(last_per_cell, [batch_size, num_cells[0], num_cells[1]])
    ],
                    axis=3)
Example #7
0
def look_at(eye, center, world_up):
    """Computes camera viewing matrices.

    Functionality mimes gluLookAt (third_party/GL/glu/include/GLU/glu.h).

    Args:
      eye: 2-D float32 tensor with shape [batch_size, 3] containing the XYZ world
          space position of the camera.
      center: 2-D float32 tensor with shape [batch_size, 3] containing a position
          along the center of the camera's gaze.
      world_up: 2-D float32 tensor with shape [batch_size, 3] specifying the
          world's up direction; the output camera will have no tilt with respect
          to this direction.

    Returns:
      A [batch_size, 4, 4] float tensor containing a right-handed camera
      extrinsics matrix that maps points from world space to points in eye space.
    """
    batch_size = center.shape[0].value
    vector_degeneracy_cutoff = 1e-6
    forward = center - eye
    forward_norm = tf.norm(forward, ord='euclidean', axis=1, keep_dims=True)
    tf.assert_greater(
        forward_norm,
        vector_degeneracy_cutoff,
        message='Camera matrix is degenerate because eye and center are close.'
    )
    forward = tf.divide(forward, forward_norm)

    to_side = tf.cross(forward, world_up)
    to_side_norm = tf.norm(to_side, ord='euclidean', axis=1, keep_dims=True)
    tf.assert_greater(
        to_side_norm,
        vector_degeneracy_cutoff,
        message='Camera matrix is degenerate because up and gaze are close or'
        'because up is degenerate.')
    to_side = tf.divide(to_side, to_side_norm)
    cam_up = tf.cross(to_side, forward)

    w_column = tf.constant(batch_size * [[0., 0., 0., 1.]],
                           dtype=tf.float32)  # [batch_size, 4]
    w_column = tf.reshape(w_column, [batch_size, 4, 1])
    view_rotation = tf.stack(
        [to_side, cam_up, -forward,
         tf.zeros_like(to_side, dtype=tf.float32)],
        axis=1)  # [batch_size, 4, 3] matrix
    view_rotation = tf.concat([view_rotation, w_column],
                              axis=2)  # [batch_size, 4, 4]

    identity_batch = tf.tile(tf.expand_dims(tf.eye(3), 0), [batch_size, 1, 1])
    view_translation = tf.concat([identity_batch, tf.expand_dims(-eye, 2)], 2)
    view_translation = tf.concat(
        [view_translation,
         tf.reshape(w_column, [batch_size, 1, 4])], 1)
    camera_matrices = tf.matmul(view_rotation, view_translation)
    return camera_matrices
    def validate_flat_to_one_hot(labels, edges):
        # make sure they are of shape [b, h, w, c]
        tf.debugging.assert_rank(labels, 4, 'label')
        tf.debugging.assert_rank(labels, 4, 'edges')

        # make sure have convincing number of channels
        tf.debugging.assert_shapes([
            (edges, ('b', 'h', 'w', 2)),
            (labels, ('b', 'h', 'w', 'c')),
        ])
        label_channels = tf.shape(labels)[-1]
        tf.assert_greater(label_channels, 1)
Example #9
0
    def __init__(
        self,
        low=0,
        high=1,
        validate_args=False,
        allow_nan_stats=True,
        dtype=tf.int32,
        float_dtype=tf.float64,
        name="UniformInteger",
    ):
        """Initialise a UniformInteger random variable on `[low, high)`.

        Args:
          low: Integer tensor, lower boundary of the output interval. Must have
           `low <= high`.
          high: Integer tensor, _inclusive_ upper boundary of the output
           interval.  Must have `low <= high`.
          validate_args: Python `bool`, default `False`. When `True` distribution
           parameters are checked for validity despite possibly degrading runtime
           performance. When `False` invalid inputs may silently render incorrect
           outputs.
           allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
           (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
           result is undefined. When `False`, an exception is raised if one or
           more of the statistic's batch members are undefined.
           dtype: the dtype of the output variates
          name: Python `str` name prefixed to Ops created by this class.

        Raises:
          InvalidArgument if `low > high` and `validate_args=False`.
        """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._low = tf.cast(low, name="low", dtype=dtype)
            self._high = tf.cast(high, name="high", dtype=dtype)
            super(UniformInteger, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name,
            )
        self.float_dtype = float_dtype
        if validate_args is True:
            tf.assert_greater(self._high, self._low,
                              "Condition low < high failed")
Example #10
0
  def _training(self):
    """Perform multiple training iterations of both policy and value baseline.

    Training on the episodes collected in the memory. Reset the memory
    afterwards. Always returns a summary string.

    Returns:
      Summary tensor.
    """
    with tf.name_scope('training'):
      assert_full = tf.assert_equal(self._memory_index, self._config.update_every)
      with tf.control_dependencies([assert_full]):
        data = self._memory.data()
      (observ, action, old_mean, old_logstd, reward), length = data
      with tf.control_dependencies([tf.assert_greater(length, 0)]):
        length = tf.identity(length)
      observ = self._observ_filter.transform(observ)
      reward = self._reward_filter.transform(reward)
      update_summary = self._perform_update_steps(observ, action, old_mean, old_logstd, reward,
                                                  length)
      with tf.control_dependencies([update_summary]):
        penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length)
      with tf.control_dependencies([penalty_summary]):
        clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0))
      with tf.control_dependencies([clear_memory]):
        weight_summary = utility.variable_summaries(tf.trainable_variables(),
                                                    self._config.weight_summaries)
        return tf.summary.merge([update_summary, penalty_summary, weight_summary])
Example #11
0
 def test_doesnt_raise_when_both_empty(self):
     with self.test_session():
         larry = tf.constant([])
         curly = tf.constant([])
         with tf.control_dependencies([tf.assert_greater(larry, curly)]):
             out = tf.identity(larry)
         out.eval()
Example #12
0
def pixel_position_aware_loss(y_true, y_pred, from_logits, gamma, ksize):
    y_true, y_pred, sample_weight = validate_input(y_true,
                                                   y_pred,
                                                   weight=None,
                                                   dtype='int32',
                                                   rank=4,
                                                   channel='sparse')

    y_true_1h = tf.one_hot(y_true[..., 0],
                           max(2, y_pred.shape[-1]),
                           dtype=y_pred.dtype)

    min_shape = tf.reduce_min(tf.shape(y_true)[1:3])
    assert_shape = tf.assert_greater(min_shape, ksize - 1)
    with tf.control_dependencies([assert_shape]):
        weight = 1 + gamma * tf.abs(
            tf.nn.avg_pool2d(y_true_1h, ksize=ksize, strides=1, padding='SAME')
            - y_true_1h)
        weight = tf.stop_gradient(weight)

    sample_weight = weight if sample_weight is None else sample_weight * weight

    wce = crossentropy(y_true, y_pred, sample_weight, from_logits)
    wiou = iou(y_true,
               y_pred,
               sample_weight,
               from_logits=from_logits,
               square=False,
               smooth=1.,
               dice=False)

    loss = wce + wiou

    return tf.reduce_mean(loss, axis=-1)
Example #13
0
def adaptive_pixel_intensity_loss(y_true, y_pred, from_logits):
    y_true, y_pred, sample_weight = validate_input(
        y_true, y_pred, weight=None, dtype='int32', rank=4, channel='sparse')

    y_true_1h = tf.one_hot(y_true[..., 0], max(2, y_pred.shape[-1]), dtype=y_pred.dtype)

    min_shape = tf.reduce_min(tf.shape(y_true)[1:3])
    assert_shape = tf.assert_greater(min_shape, 30)
    with tf.control_dependencies([assert_shape]):
        omega = sum([
            tf.abs(tf.nn.avg_pool2d(y_true_1h, ksize=k, strides=1, padding='SAME') - y_true_1h)
            for k in [3, 15, 31]
        ]) * y_true_1h * .5 + 1.
        omega = tf.math.divide_no_nan(omega, tf.reduce_mean(omega, axis=[1, 2], keepdims=True))
        omega = omega if sample_weight is None else omega * sample_weight
        omega = tf.stop_gradient(omega)

    weight = omega if sample_weight is None else omega * sample_weight

    # Skipped omega normalization from original paper
    ace = crossentropy(y_true, y_pred, weight, from_logits)
    aiou = iou(y_true, y_pred, weight, from_logits=from_logits, square=False, smooth=1., dice=False)
    amae = mae(y_true, y_pred, weight, from_logits=from_logits)

    loss = ace + aiou + amae

    return tf.reduce_mean(loss, axis=-1)
Example #14
0
def get_ecdf(
        sample: tf.Tensor,
        weights: Optional[tf.Tensor] = None) -> Tuple[tf.Tensor, tf.Tensor]:
    """
    Get empirical CDF from a weighted 1D sample
    """

    if weights is None:
        weights = tf.ones_like(sample)

    with tf.control_dependencies(
        [tf.assert_equal(tf.shape(sample), tf.shape(weights))]):
        i = tf.contrib.framework.argsort(sample, axis=0)

        x = _T(tf.batch_gather(_T(sample), _T(i)))
        w = _T(tf.batch_gather(_T(weights), _T(i)))

        w_cumsum = tf.cumsum(w, axis=0)

        smallest_wsum = tf.reduce_min(w_cumsum[-1])
        with tf.control_dependencies(
            [tf.assert_greater(smallest_wsum, tf.zeros_like(smallest_wsum))]):
            w_cumsum /= w_cumsum[-1]

    return x, w_cumsum
def parse_tfexample(raw_data, features):
  """Read a single TF Example proto and return a subset of its features.
  Args:
    raw_data: A serialized tf.Example proto.
    features: A dictionary of features, mapping string feature names to a tuple
      (dtype, shape). This dictionary should be a subset of
      protein_features.FEATURES (or the dictionary itself for all features).
  Returns:
    A dictionary of features mapping feature names to features. Only the given
    features are returned, all other ones are filtered out.
  """
  feature_map = {
      k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
      for k, v in features.items()
  }
  parsed_features = tf.io.parse_single_example(raw_data, feature_map)

  # Find out what is the number of sequences and the number of alignments.
  num_residues = tf.cast(parsed_features['seq_length'][0], dtype=tf.int32)

  # Reshape the tensors according to the sequence length and num alignments.
  for k, v in parsed_features.items():
    new_shape = shape(feature_name=k, num_residues=num_residues)
    # Make sure the feature we are reshaping is not empty.
    assert_non_empty = tf.assert_greater(
        tf.size(v), 0, name='assert_%s_non_empty' % k,
        message='The feature %s is not set in the tf.Example. Either do not '
        'request the feature or use a tf.Example that has the feature set.' % k)
    with tf.control_dependencies([assert_non_empty]):
      parsed_features[k] = tf.reshape(v, new_shape, name='reshape_%s' % k)

  return parsed_features
Example #16
0
def assert_positive_integer(value, dtype, name):
    """
    Whether `value` is a scalar (or 0-D tensor) and positive.
    If `value` is the instance of built-in type, it will be checked
    directly. Otherwise, it will be converted to a `dtype` tensor and checked.

    :param value: The value to be checked.
    :param dtype: The tensor dtype.
    :param name: The name of `value` used in error message.
    :return: The checked value.
    """
    sign_err_msg = name + " must be positive"
    if isinstance(value, (int, float)):
        if value <= 0:
            raise ValueError(sign_err_msg)
        return value
    else:
        try:
            tensor = tf.convert_to_tensor(value, dtype)
        except ValueError:
            raise TypeError(name + ' must be ' + str(dtype))
        _assert_rank_op = tf.assert_rank(tensor,
                                         0,
                                         message=name +
                                         " should be a scalar (0-D Tensor).")
        _assert_positive_op = tf.assert_greater(tensor,
                                                tf.constant(0, dtype),
                                                message=sign_err_msg)
        with tf.control_dependencies([_assert_rank_op, _assert_positive_op]):
            tensor = tf.identity(tensor)
        return tensor
Example #17
0
    def _training(self):
        """Perform one training iterations of both policy and value baseline.

    Training on the episodes collected in the memory. Reset the memory
    afterwards. Always returns a summary string.

    Returns:
      Summary tensor.
    """
        with tf.name_scope('training'):
            assert_full = tf.assert_equal(self._memory_index,
                                          self._config.update_every)
            with tf.control_dependencies([assert_full]):
                data = self._memory.data()
            (observ, option, action, reward, done, nextob,
             option_terminated), length = data
            with tf.control_dependencies([tf.assert_greater(length, 0)]):
                length = tf.identity(length)

            network_summary = self._update_network(
                observ, option, action, reward, tf.cast(done, tf.bool), nextob,
                tf.cast(option_terminated, tf.bool), length)

            with tf.control_dependencies([network_summary]):
                clear_memory = tf.group(self._memory.clear(),
                                        self._memory_index.assign(0))
            with tf.control_dependencies([clear_memory]):
                weight_summary = utility.variable_summaries(
                    tf.trainable_variables(), self._config.weight_summaries)
                return tf.summary.merge([network_summary, weight_summary])
Example #18
0
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     is_greater_one = tf.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")
     output_shape = control_flow_ops.with_dependencies(
         [is_greater_one], output_shape)
   return (output_shape[-1])[..., tf.newaxis]
Example #19
0
 def output(self) -> tf.Tensor:
     # Pad the sequence with a large negative value, but make sure it has
     # non-zero length.
     length = tf.reduce_sum(self._input_mask)
     with tf.control_dependencies([tf.assert_greater(length, 0.5)]):
         padded_input = self._masked_input + 1e-15 * (1 - self._input_mask)
     return tf.reduce_max(padded_input, axis=1)
Example #20
0
def gumbel_softmax(inputs: tf.Tensor,
                   temperature: float = 1.0,
                   symmetric: bool = False,
                   axis: int = -1,
                   scope: Optional[str] = None) -> tf.Tensor:
    if inputs.shape[axis].value <= 2:
        raise ValueError(
            'logits must have at least size 3 on axis={}'.format(axis))

    temperature = tf.convert_to_tensor(temperature, dtype=inputs.dtype)
    with tf.variable_scope(scope, 'gumbel_softmax', [inputs]):
        temperature = tf.assert_scalar(temperature)
        assert_op = tf.assert_greater(temperature, 0.0)

        with tf.control_dependencies([assert_op]):
            gumbel = -tf.log(
                -tf.log(tf.random_uniform(inputs.shape, dtype=inputs.dtype)))
            if symmetric:
                gumbel = (gumbel + tf.matrix_transpose(gumbel)) / 2.0
            gumbel_logits = gumbel + inputs
            gumbel_softmax = tf.exp(
                tf.nn.log_softmax(tf.div(gumbel_logits, temperature),
                                  axis=axis))

    return gumbel_softmax
Example #21
0
 def split_block(self, parents, sizes, is_left, depth, ys, orders,
                 prev_counter, prev_end, prev_node_id, prev_new_parent,
                 prev_left_size, prev_right_size, prev_left_mask, prev_right_mask):
     assert_op = tf.assert_greater(
         tf.shape(sizes), prev_counter, message="size too small")
     with tf.control_dependencies([assert_op]):
         parent, size = parents[prev_counter], sizes[prev_counter]
         node_id = self.new_node(prev_node_id, parent, is_left)
         start = prev_end
         new_parent, left_size, right_size, left_mask, right_mask = tf.cond(
             tf.logical_or(size <= self.min_sample_leaf,
                           depth >= self.max_depth),
             partial(self.add_leaf, ys, node_id, start, size,
                     prev_left_mask, prev_right_mask),
             partial(self.try_split, ys, orders, node_id, start, size,
                     prev_left_mask, prev_right_mask))
         assert_op_a = tf.assert_greater_equal(size, left_size)
         assert_op_b = tf.assert_greater_equal(size, right_size)
         with tf.control_dependencies([assert_op_a, assert_op_b]):
             return (
                 prev_counter+1,
                 start+size,
                 node_id,
                 tf.concat([prev_new_parent, [new_parent]], axis=0),
                 tf.concat([prev_left_size, [left_size]], axis=0),
                 tf.concat([prev_right_size, [right_size]], axis=0),
                 left_mask,
                 right_mask
             )
Example #22
0
  def _training(self):
    """Perform multiple training iterations of both policy and value baseline.

    Training on the episodes collected in the memory. Reset the memory
    afterwards. Always returns a summary string.

    Returns:
      Summary tensor.
    """
    with tf.name_scope('training'):
      assert_full = tf.assert_equal(
          self._memory_index, self._config.update_every)
      with tf.control_dependencies([assert_full]):
        data = self._memory.data()
      (observ, action, old_mean, old_logstd, reward), length = data
      with tf.control_dependencies([tf.assert_greater(length, 0)]):
        length = tf.identity(length)
      observ = self._observ_filter.transform(observ)
      reward = self._reward_filter.transform(reward)
      policy_summary = self._update_policy(
          observ, action, old_mean, old_logstd, reward, length)
      with tf.control_dependencies([policy_summary]):
        value_summary = self._update_value(observ, reward, length)
      with tf.control_dependencies([value_summary]):
        penalty_summary = self._adjust_penalty(
            observ, old_mean, old_logstd, length)
      with tf.control_dependencies([penalty_summary]):
        clear_memory = tf.group(
            self._memory.clear(), self._memory_index.assign(0))
      with tf.control_dependencies([clear_memory]):
        weight_summary = utility.variable_summaries(
            tf.trainable_variables(), self._config.weight_summaries)
        return tf.summary.merge([
            policy_summary, value_summary, penalty_summary, weight_summary])
Example #23
0
def assert_positive_int32_scalar(value, name):
    """
    Whether `value` is a integer(or 0-D `tf.int32` tensor) and positive.
    If `value` is the instance of built-in type, it will be checked directly.
    Otherwise, it will be converted to a `tf.int32` tensor and checked.
    :param value: The value to be checked.
    :param name: The name of `value` used in error message.
    :return: The checked value.
    """
    if isinstance(value, (int, float)):
        if isinstance(value, int) and value > 0:
            return value
        elif isinstance(value, float):
            raise TypeError(name + " must be integer")
        elif value <= 0:
            raise ValueError(name + " must be positive")
    else:
        try:
            tensor = tf.convert_to_tensor(value, tf.int32)
        except (TypeError, ValueError):
            raise TypeError(name + ' must be (convertible to) tf.int32')
        _assert_rank_op = tf.assert_rank(
            tensor, 0,
            message=name + " should be a scalar (0-D Tensor).")
        _assert_positive_op = tf.assert_greater(
            tensor, tf.constant(0, tf.int32),
            message=name + " must be positive")
        with tf.control_dependencies([_assert_rank_op,
                                      _assert_positive_op]):
            tensor = tf.identity(tensor)
        return tensor
Example #24
0
def _add_weighted_loss_to_collection(losses, weights):
    """Weights `losses` by weights, and adds the weighted losses, normalized by
    the number of joints present, to `tf.GraphKeys.LOSSES`.

    Specifically, the losses are summed across all dimensions (x, y,
    num_joints), producing a scalar loss per batch. That scalar loss then needs
    to be normalized by the number of joints present. This is equivalent to
    sum(weights[:, 0, 0, :]), since `weights` is a [image_dim, image_dim] map
    of eithers all 1s or all 0s, depending on whether a joints is present or
    not, respectively.

    Args:
        losses: Element-wise losses as calculated by your favourite function.
        weights: Element-wise weights.
    """
    losses = tf.transpose(a=losses, perm=[1, 2, 0, 3])
    weighted_loss = tf.multiply(losses, weights)
    per_batch_loss = tf.reduce_sum(input_tensor=weighted_loss, axis=[0, 1, 3])

    num_joints_present = tf.reduce_sum(input_tensor=weights, axis=1)

    assert_safe_div = tf.assert_greater(num_joints_present, 0.0)
    with tf.control_dependencies(control_inputs=[assert_safe_div]):
        per_batch_loss /= num_joints_present

    total_loss = tf.reduce_mean(input_tensor=per_batch_loss)
    tf.add_to_collection(name=tf.GraphKeys.LOSSES, value=total_loss)
Example #25
0
 def output(self) -> tf.Tensor:
     # Pad the sequence with a large negative value, but make sure it has
     # non-zero length.
     length = tf.reduce_sum(self._input_mask)
     with tf.control_dependencies([tf.assert_greater(length, 0.5)]):
         padded_input = self._masked_input + 1e-15 * (1 - self._input_mask)
     return tf.reduce_max(padded_input, axis=1)
Example #26
0
def _crop(image, offset_height, offset_width, crop_height, crop_width):
    origin_shape = tf.shape(image)
    rank_assertion = tf.Assert(tf.equal(tf.rank(image), 3), ['Rank of image must be equal to 3.'])
    with tf.control_dependencies([rank_assertion]):
        cropped_shape = tf.stack([crop_height, crop_width, origin_shape[2]])

    size_assertion = tf.Assert(tf.logical_and(
        tf.assert_greater(origin_shape[0], crop_height),
        tf.assert_greater(origin_shape[1], crop_width)
    ), ['Crop size greater than the image size.'])

    offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))

    with tf.control_dependencies([size_assertion]):
        image = tf.slice(image, offsets, cropped_shape)
    return tf.reshape(image, cropped_shape)
Example #27
0
 def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
     with self.test_session():
         small = tf.constant([1], name="small")
         big = tf.constant([3, 2], name="big")
         with tf.control_dependencies([tf.assert_greater(big, small)]):
             out = tf.identity(small)
         out.eval()
Example #28
0
 def test_doesnt_raise_when_both_empty(self):
     with self.test_session():
         larry = tf.constant([])
         curly = tf.constant([])
         with tf.control_dependencies([tf.assert_greater(larry, curly)]):
             out = tf.identity(larry)
         out.eval()
Example #29
0
 def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
     with self.test_session():
         small = tf.constant([1], name="small")
         big = tf.constant([3, 2], name="big")
         with tf.control_dependencies([tf.assert_greater(big, small)]):
             out = tf.identity(small)
         out.eval()
Example #30
0
 def test_raises_when_equal(self):
     with self.test_session():
         small = tf.constant([1, 2], name="small")
         with tf.control_dependencies([tf.assert_greater(small, small, message="fail")]):
             out = tf.identity(small)
         with self.assertRaisesOpError("fail.*small.*small"):
             out.eval()
    def parse_tfexample(raw_data, features):
        feature_map = {
            k: tf.io.FixedLenSequenceFeature(shape=(),
                                             dtype=eval(f'tf.{v[0]}'),
                                             allow_missing=True)
            for k, v in features.items()
        }
        parsed_features = tf.io.parse_single_example(raw_data, feature_map)
        num_residues = tf.cast(parsed_features['seq_length'][0],
                               dtype=tf.int32)

        for k, v in parsed_features.items():
            new_shape = [
                num_residues if s is None else s for s in FEATURES[k][1]
            ]

            assert_non_empty = tf.assert_greater(
                tf.size(v),
                0,
                name=f'assert_{k}_non_empty',
                message=
                f'The feature {k} is not set in the tf.Example. Either do not '
                'request the feature or use a tf.Example that has the feature set.'
            )
            with tf.control_dependencies([assert_non_empty]):
                parsed_features[k] = tf.reshape(v,
                                                new_shape,
                                                name=f'reshape_{k}')

        return parsed_features
Example #32
0
def fast_guided_filter(lr_x, lr_y, hr_x, r, eps=1e-8, nhwc=False):
    assert lr_x.shape.ndims == 4 and \
           lr_y.shape.ndims == 4 and \
           hr_x.shape.ndims == 4
    
    # data format
    if nhwc:
        lr_x = tf.transpose(lr_x, [0, 3, 1, 2])
        lr_y = tf.transpose(lr_y, [0, 3, 1, 2])
        hr_x = tf.transpose(hr_x, [0, 3, 1, 2])

    # shape check
    lr_x_shape = tf.shape(lr_x)
    lr_y_shape = tf.shape(lr_y)
    hr_x_shape = tf.shape(hr_x)

    assets = [tf.assert_equal(   lr_x_shape[0], lr_y_shape[0]),
              tf.assert_equal(   lr_x_shape[0], hr_x_shape[0]),
              tf.assert_equal(   lr_x_shape[1], hr_x_shape[1]),
              tf.assert_equal(  lr_x_shape[2:], lr_y_shape[2:]),
              tf.assert_greater(lr_x_shape[2:], 2 * r + 1),
              tf.Assert(tf.logical_or(tf.equal(lr_x_shape[1], 1),
                                      tf.equal(lr_x_shape[1], lr_y_shape[1])),
                                      [lr_x_shape, lr_y_shape])]

    with tf.control_dependencies(assets):
        lr_x = tf.identity(lr_x)

    # N
    N = box_filter(tf.ones((1, lr_x_shape[1], lr_x_shape[2],
                            lr_x_shape[3]), dtype=lr_x.dtype), r)

    # mean_x
    mean_x = box_filter(lr_x, r) / N
    # mean_y
    mean_y = box_filter(lr_y, r) / N
    # cov_xy
    cov_xy = box_filter(lr_x * lr_y, r) / N - mean_x * mean_y
    # var_x
    var_x  = box_filter(lr_x * lr_x, r) / N - mean_x * mean_x

    # A
    A = cov_xy / (var_x + eps)
    # b
    b = mean_y - A * mean_x

    # mean_A; mean_b
    A    = tf.transpose(A,    [0, 2, 3, 1])
    b    = tf.transpose(b,    [0, 2, 3, 1])
    hr_x = tf.transpose(hr_x, [0, 2, 3, 1])

    mean_A = tf.image.resize_images(A, hr_x_shape[2:])
    mean_b = tf.image.resize_images(b, hr_x_shape[2:])

    output = mean_A * hr_x + mean_b
    if not nhwc:
        output = tf.transpose(output, [0, 3, 1, 2])

    return output
Example #33
0
 def test_raises_when_greater_but_non_broadcastable_shapes(self):
     with self.test_session():
         small = tf.constant([1, 1, 1], name="small")
         big = tf.constant([3, 2], name="big")
         with self.assertRaisesRegexp(ValueError, "must be"):
             with tf.control_dependencies([tf.assert_greater(big, small)]):
                 out = tf.identity(small)
             out.eval()
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     # It is not possible for a negative shape so we need only check <= 1.
     is_greater_one = tf.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")
     output_shape = control_flow_ops.with_dependencies(
         [is_greater_one], output_shape)
   return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
Example #35
0
def ssd_loss(loc_pre: tf.Tensor,
             cls_pre,
             loc_gt,
             cls_gt,
             alpha=1,
             neg_radio=3.):
    """
    loc_pre:(batch, anchor_num, 4)
    cls_pre:(batch, anchor_num, num_classes)
    neg_radio:
    n_class:
    loc_gt: (batch, anchor_num, 4)
    cls_gt:(batch,anchor,num_classes) onehot
    :return:
    """
    # print(loc_pre.shape,cls_pre.shape)
    # print(loc_gt.shape,cls_gt.shape)
    # 因为xij存在,所以位置误差仅针对正样本进行计算
    glabel = tf.argmax(cls_gt, axis=-1)  # (batch,anchor)
    pos_mask = glabel > 0  # 挑选出正样本 (batch,anchor)
    pos_idx = tf.cast(pos_mask, tf.float32)  #
    n_pos = tf.reduce_sum(pos_idx)
    loc_pre_pos = tf.boolean_mask(loc_pre, pos_mask)  # REW:果然可以广播;广播可以忽略最后一个维度
    loc_gt_pos = tf.boolean_mask(loc_gt,
                                 pos_mask)  # 会降维 给flatten 返回shape(None,4)
    with tf.name_scope("localization"):
        loc_loss = my_smooth_l1_loss(loc_pre_pos, loc_gt_pos, n_pos, alpha)
        tf.losses.add_loss(loc_loss)
    logits = tf.stop_gradient(cls_pre)  # REW:只是作负样本选择,所以不计算梯度
    labels = tf.stop_gradient(cls_gt)
    neg_mask = hard_negtives(labels, logits, pos_idx, neg_radio)
    # FIXME:分开来得到pos,neg
    conf_p = tf.boolean_mask(cls_pre, tf.logical_or(pos_mask, neg_mask))
    target = tf.boolean_mask(cls_gt, tf.logical_or(pos_mask, neg_mask))
    # 下面几行都是对的
    # cls_pre_pos = tf.boolean_mask(cls_pre,pos_mask)  # 应该有广播
    # ×cls_pre_neg = cls_pre*neg_idx  # 不能广播
    # cls_gt_pos = tf.boolean_mask(cls_gt,pos_mask)
    # FIXME:分开来用softmaxce;
    with tf.name_scope("cross_entropy"):
        # 交叉熵都会减少一个维度
        cls_loss = tf.nn.softmax_cross_entropy_with_logits_v2(target, conf_p)
        tf.assert_greater(n_pos, tf.cast(0., tf.float32))
        cls_loss = tf.div(tf.reduce_sum(cls_loss), n_pos, name="conf")
        tf.losses.add_loss(cls_loss)
    return cls_loss, loc_loss
Example #36
0
 def test_raises_when_less(self):
     with self.test_session():
         small = tf.constant([1, 2], name="small")
         big = tf.constant([3, 4], name="big")
         with tf.control_dependencies([tf.assert_greater(small, big)]):
             out = tf.identity(big)
         with self.assertRaisesOpError("small.*big"):
             out.eval()
Example #37
0
 def test_raises_when_greater_but_non_broadcastable_shapes(self):
     with self.test_session():
         small = tf.constant([1, 1, 1], name="small")
         big = tf.constant([3, 2], name="big")
         with self.assertRaisesRegexp(ValueError, "must be"):
             with tf.control_dependencies([tf.assert_greater(big, small)]):
                 out = tf.identity(small)
             out.eval()
Example #38
0
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     # It is not possible for a negative shape so we need only check <= 1.
     is_greater_one = tf.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")
     output_shape = control_flow_ops.with_dependencies(
         [is_greater_one], output_shape)
   return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
Example #39
0
 def test_raises_when_less(self):
     with self.test_session():
         small = tf.constant([1, 2], name="small")
         big = tf.constant([3, 4], name="big")
         with tf.control_dependencies([tf.assert_greater(small, big)]):
             out = tf.identity(big)
         with self.assertRaisesOpError("small.*big"):
             out.eval()
Example #40
0
 def mixture_kl():
     with tf.control_dependencies([tf.assert_greater(consistency_trust, 0.0),
                                   tf.assert_less(consistency_trust, 1.0)]):
         uniform = tf.constant(1 / num_classes, shape=[num_classes])
         mixed_softmax1 = consistency_trust * softmax1 + (1 - consistency_trust) * uniform
         mixed_softmax2 = consistency_trust * softmax2 + (1 - consistency_trust) * uniform
         costs = tf.reduce_sum(mixed_softmax2 * tf.log(mixed_softmax2 / mixed_softmax1), axis=1)
         costs = costs * kl_cost_multiplier
         return costs
 def sample_from_logits(logits):
   with tf.control_dependencies([tf.assert_greater(temperature, 0.0)]):
     logits = tf.identity(logits)
   reshaped_logits = (
       tf.reshape(logits, [-1, tf.shape(logits)[-1]]) / temperature)
   choices = tf.multinomial(reshaped_logits, 1)
   choices = tf.reshape(choices,
                        tf.shape(logits)[:logits.get_shape().ndims - 1])
   return choices
Example #42
0
  def _maybe_assert_valid_y(self, y):
    if not self.validate_args:
      return y
    is_valid = [
        tf.assert_greater(
            y,
            tf.cast(-1., dtype=y.dtype.base_dtype),
            message="Inverse transformation input must be greater than -1."),
        tf.assert_less(
            y,
            tf.cast(1., dtype=y.dtype.base_dtype),
            message="Inverse transformation input must be less than 1.")
    ]

    return control_flow_ops.with_dependencies(is_valid, y)
Example #43
0
  def _training(self):
    """Perform multiple training iterations of both policy and value baseline.

    Training on the episodes collected in the memory. Reset the memory
    afterwards. Always returns a summary string.

    Returns:
      Summary tensor.
    """
    with tf.device('/gpu:0' if self._use_gpu else '/cpu:0'):
      with tf.name_scope('training'):
        assert_full = tf.assert_equal(
            self._num_finished_episodes, self._config.update_every)
        with tf.control_dependencies([assert_full]):
          data = self._finished_episodes.data()
        (observ, action, old_policy_params, reward), length = data
        # We set padding frames of the parameters to ones to prevent Gaussians
        # with zero variance. This would result in an infinite KL divergence,
        # which, even if masked out, would result in NaN gradients.
        old_policy_params = tools.nested.map(
            lambda param: self._mask(param, length, 1), old_policy_params)
        with tf.control_dependencies([tf.assert_greater(length, 0)]):
          length = tf.identity(length)
        observ = self._observ_filter.transform(observ)
        reward = self._reward_filter.transform(reward)
        update_summary = self._perform_update_steps(
            observ, action, old_policy_params, reward, length)
        with tf.control_dependencies([update_summary]):
          penalty_summary = self._adjust_penalty(
              observ, old_policy_params, length)
        with tf.control_dependencies([penalty_summary]):
          clear_memory = tf.group(
              self._finished_episodes.clear(),
              self._num_finished_episodes.assign(0))
        with tf.control_dependencies([clear_memory]):
          weight_summary = utility.variable_summaries(
              tf.trainable_variables(), self._config.weight_summaries)
          return tf.summary.merge([
              update_summary, penalty_summary, weight_summary])
  def __init__(self,
               initial_distribution,
               transition_distribution,
               observation_distribution,
               num_steps,
               validate_args=False,
               allow_nan_stats=True,
               name="HiddenMarkovModel"):
    """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

    parameters = dict(locals())

    # pylint: disable=protected-access
    with tf.name_scope(name=name, values=(
        initial_distribution._graph_parents +
        transition_distribution._graph_parents +
        observation_distribution._graph_parents)) as name:
      self._runtime_assertions = []  # pylint: enable=protected-access

      if num_steps < 1:
        raise ValueError("num_steps ({}) must be at least 1.".format(num_steps))

      self._initial_distribution = initial_distribution
      self._observation_distribution = observation_distribution
      self._transition_distribution = transition_distribution

      if (initial_distribution.event_shape is not None
          and initial_distribution.event_shape.ndims != 0):
        raise ValueError(
            "`initial_distribution` must have scalar `event_dim`s")
      elif validate_args:
        self._runtime_assertions += [
            tf.assert_equal(
                tf.shape(initial_distribution.event_shape_tensor())[0], 0,
                message="`initial_distribution` must have scalar"
                        "`event_dim`s")]

      if (transition_distribution.event_shape is not None
          and transition_distribution.event_shape.ndims != 0):
        raise ValueError(
            "`transition_distribution` must have scalar `event_dim`s")
      elif validate_args:
        self._runtime_assertions += [
            tf.assert_equal(
                tf.shape(transition_distribution.event_shape_tensor())[0], 0,
                message="`transition_distribution` must have scalar"
                        "`event_dim`s")]

      if (transition_distribution.batch_shape is not None
          and transition_distribution.batch_shape.ndims == 0):
        raise ValueError(
            "`transition_distribution` can't have scalar batches")
      elif validate_args:
        self._runtime_assertions += [
            tf.assert_greater(
                tf.size(transition_distribution.batch_shape_tensor()), 0,
                message="`transition_distribution` can't have scalar "
                        "batches")]

      if (observation_distribution.batch_shape is not None
          and observation_distribution.batch_shape.ndims == 0):
        raise ValueError(
            "`observation_distribution` can't have scalar batches")
      elif validate_args:
        self._runtime_assertions += [
            tf.assert_greater(
                tf.size(observation_distribution.batch_shape_tensor()), 0,
                message="`observation_distribution` can't have scalar "
                        "batches")]

      # Infer number of hidden states and check consistency
      # between transitions and observations
      with tf.control_dependencies(self._runtime_assertions):
        self._num_states = ((transition_distribution.batch_shape and
                             transition_distribution.batch_shape[-1]) or
                            transition_distribution.batch_shape_tensor()[-1])

        observation_states = ((observation_distribution.batch_shape and
                               observation_distribution.batch_shape[-1]) or
                              observation_distribution.batch_shape_tensor()[-1])

      if (tf.contrib.framework.is_tensor(self._num_states) or
          tf.contrib.framework.is_tensor(observation_states)):
        if validate_args:
          self._runtime_assertions += [
              tf.assert_equal(
                  self._num_states, observation_states,
                  message="`transition_distribution` and "
                          "`observation_distribution` must agree on "
                          "last dimension of batch size")]
      elif self._num_states != observation_states:
        raise ValueError("`transition_distribution` and "
                         "`observation_distribution` must agree on "
                         "last dimension of batch size")

      self._log_init = _extract_log_probs(self._num_states,
                                          initial_distribution)
      self._log_trans = _extract_log_probs(self._num_states,
                                           transition_distribution)

      self._num_steps = num_steps
      self._num_states = tf.shape(self._log_init)[-1]

      self._underlying_event_rank = (
          self._observation_distribution.event_shape.ndims)

      self.static_event_shape = tf.TensorShape(
          [num_steps]).concatenate(self._observation_distribution.event_shape)

      with tf.control_dependencies(self._runtime_assertions):
        self.static_batch_shape = tf.broadcast_static_shape(
            self._initial_distribution.batch_shape,
            tf.broadcast_static_shape(
                self._transition_distribution.batch_shape[:-1],
                self._observation_distribution.batch_shape[:-1]))

      # pylint: disable=protected-access
      super(HiddenMarkovModel, self).__init__(
          dtype=self._observation_distribution.dtype,
          reparameterization_type=tf.distributions.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              self._initial_distribution._graph_parents +
              self._transition_distribution._graph_parents +
              self._observation_distribution._graph_parents),
          name=name)
      # pylint: enable=protected-access

      self._parameters = parameters
def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an Estimator.

  Arguments:
    features: The input features for the Estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  # Set up the model's learnable parameters.
  logit_concentration = tf.get_variable(
      "logit_concentration",
      shape=[1, params["num_topics"]],
      initializer=tf.constant_initializer(
          _softplus_inverse(params["prior_initial_value"])))
  concentration = _clip_dirichlet_parameters(
      tf.nn.softplus(logit_concentration))

  num_words = features.shape[1]
  topics_words_logits = tf.get_variable(
      "topics_words_logits",
      shape=[params["num_topics"], num_words],
      initializer=tf.glorot_normal_initializer())
  topics_words = tf.nn.softmax(topics_words_logits, axis=-1)

  # Compute expected log-likelihood. First, sample from the variational
  # distribution; second, compute the log-likelihood given the sample.
  lda_variational = make_lda_variational(
      params["activation"],
      params["num_topics"],
      params["layer_sizes"])
  with ed.tape() as variational_tape:
    _ = lda_variational(features)

  with ed.tape() as model_tape:
    with ed.interception(
        make_value_setter(topics=variational_tape["topics_posterior"])):
      posterior_predictive = latent_dirichlet_allocation(concentration,
                                                         topics_words)

  log_likelihood = posterior_predictive.distribution.log_prob(features)
  tf.summary.scalar("log_likelihood", tf.reduce_mean(log_likelihood))

  # Compute the KL-divergence between two Dirichlets analytically.
  # The sampled KL does not work well for "sparse" distributions
  # (see Appendix D of [2]).
  kl = variational_tape["topics_posterior"].distribution.kl_divergence(
      model_tape["topics"].distribution)
  tf.summary.scalar("kl", tf.reduce_mean(kl))

  # Ensure that the KL is non-negative (up to a very small slack).
  # Negative KL can happen due to numerical instability.
  with tf.control_dependencies([tf.assert_greater(kl, -1e-3, message="kl")]):
    kl = tf.identity(kl)

  elbo = log_likelihood - kl
  avg_elbo = tf.reduce_mean(elbo)
  tf.summary.scalar("elbo", avg_elbo)
  loss = -avg_elbo

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  optimizer = tf.train.AdamOptimizer(params["learning_rate"])

  # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
  # For the first prior_burn_in_steps steps they are fixed, and then trained
  # jointly with the other parameters.
  grads_and_vars = optimizer.compute_gradients(loss)
  grads_and_vars_except_prior = [
      x for x in grads_and_vars if x[1] != logit_concentration]

  def train_op_except_prior():
    return optimizer.apply_gradients(
        grads_and_vars_except_prior,
        global_step=global_step)

  def train_op_all():
    return optimizer.apply_gradients(
        grads_and_vars,
        global_step=global_step)

  train_op = tf.cond(
      global_step < params["prior_burn_in_steps"],
      true_fn=train_op_except_prior,
      false_fn=train_op_all)

  # The perplexity is an exponent of the average negative ELBO per word.
  words_per_document = tf.reduce_sum(features, axis=1)
  log_perplexity = -elbo / words_per_document
  tf.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity)))
  (log_perplexity_tensor, log_perplexity_update) = tf.metrics.mean(
      log_perplexity)
  perplexity_tensor = tf.exp(log_perplexity_tensor)

  # Obtain the topics summary. Implemented as a py_func for simplicity.
  topics = tf.py_func(
      functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
      [topics_words, concentration], tf.string, stateful=False)
  tf.summary.text("topics", topics)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo": tf.metrics.mean(elbo),
          "log_likelihood": tf.metrics.mean(log_likelihood),
          "kl": tf.metrics.mean(kl),
          "perplexity": (perplexity_tensor, log_perplexity_update),
          "topics": (topics, tf.no_op()),
      },
  )
Example #46
0
  def __init__(self,
               power,
               dtype=tf.int32,
               interpolate_nondiscrete=True,
               sample_maximum_iterations=100,
               validate_args=False,
               allow_nan_stats=False,
               name="Zipf"):
    """Initialize a batch of Zipf distributions.

    Args:
      power: `Float` like `Tensor` representing the power parameter. Must be
        strictly greater than `1`.
      dtype: The `dtype` of `Tensor` returned by `sample`.
        Default value: `tf.int32`.
      interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns
        `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`,
        `log_prob` evaluates the continuous function `-power log(k) -
        log(zeta(power))` , which matches the Zipf pmf at integer arguments `k`
        (note that this function is not itself a normalized probability
        log-density).
        Default value: `True`.
      sample_maximum_iterations: Maximum number of iterations of allowable
        iterations in `sample`. When `validate_args=True`, samples which fail to
        reach convergence (subject to this cap) are masked out with
        `self.dtype.min` or `nan` depending on `self.dtype.is_integer`.
        Default value: `100`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Zipf'`.

    Raises:
      TypeError: if `power` is not `float` like.
    """
    parameters = dict(locals())
    with tf.name_scope(name, values=[power]) as name:
      power = tf.convert_to_tensor(
          power,
          name="power",
          dtype=dtype_util.common_dtype([power], preferred_dtype=tf.float32))
      if not power.dtype.is_floating or power.dtype.base_dtype is tf.float16:
        raise TypeError(
            "power.dtype ({}) is not a supported `float` type.".format(
                power.dtype.name))
      runtime_assertions = []
      if validate_args:
        runtime_assertions += [
            tf.assert_greater(power, tf.cast(1., power.dtype))
        ]
      with tf.control_dependencies(runtime_assertions):
        self._power = tf.identity(power, name="power")

    self._interpolate_nondiscrete = interpolate_nondiscrete
    self._sample_maximum_iterations = sample_maximum_iterations
    super(Zipf, self).__init__(
        dtype=dtype,
        reparameterization_type=tf.distributions.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._power],
        name=name)
Example #47
0
def chol_det(X):
    conditioned=condition(X)
    return tf.square(tf.assert_greater(tf.reduce_prod(tf.diag_part(tf.cholesky(conditioned)))))
Example #48
0
  def build_greedy_training(self, state, network_states):
    """Builds a training loop for this component.

    This loop repeatedly evaluates the network and computes the loss, but it
    does not advance using the predictions of the network. Instead, it advances
    using the oracle defined in the underlying transition system. The final
    state will always correspond to the gold annotation.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
        underlying master to this component.
      network_states: NetworkState object containing component TensorArrays.

    Returns:
      (state, cost, correct, total) -- These are TF ops corresponding to
      the final state after unrolling, the total cost, the total number of
      correctly predicted actions, and the total number of actions.
    """
    logging.info('Building component: %s', self.spec.name)
    stride = state.current_batch_size * self.training_beam_size

    cost = tf.constant(0.)
    correct = tf.constant(0)
    total = tf.constant(0)

    # Create the TensorArray's to store activations for downstream/recurrent
    # connections.
    def cond(handle, *_):
      all_final = dragnn_ops.emit_all_final(handle, component=self.name)
      return tf.logical_not(tf.reduce_all(all_final))

    def body(handle, cost, correct, total, *arrays):
      """Runs the network and advances the state by a step."""

      with tf.control_dependencies([handle, cost, correct, total] +
                                   [x.flow for x in arrays]):
        # Get a copy of the network inside this while loop.
        updated_state = MasterState(handle, state.current_batch_size)
        network_tensors = self._feedforward_unit(
            updated_state, arrays, network_states, stride, during_training=True)

        # Every layer is written to a TensorArray, so that it can be backprop'd.
        next_arrays = update_tensor_arrays(network_tensors, arrays)
        with tf.control_dependencies([x.flow for x in next_arrays]):
          with tf.name_scope('compute_loss'):
            # A gold label > -1 determines that the sentence is still
            # in a valid state. Otherwise, the sentence has ended.
            #
            # We add only the valid sentences to the loss, in the following way:
            #   1. We compute 'valid_ix', the indices in gold that contain
            #      valid oracle actions.
            #   2. We compute the cost function by comparing logits and gold
            #      only for the valid indices.
            gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
            gold.set_shape([None])
            valid = tf.greater(gold, -1)
            valid_ix = tf.reshape(tf.where(valid), [-1])
            gold = tf.gather(gold, valid_ix)

            logits = self.network.get_logits(network_tensors)
            logits = tf.gather(logits, valid_ix)

            cost += tf.reduce_sum(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.cast(gold, tf.int64), logits=logits))

            if (self.eligible_for_self_norm and
                self.master.hyperparams.self_norm_alpha > 0):
              log_z = tf.reduce_logsumexp(logits, [1])
              cost += (self.master.hyperparams.self_norm_alpha *
                       tf.nn.l2_loss(log_z))

            correct += tf.reduce_sum(
                tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
            total += tf.size(gold)

        with tf.control_dependencies([cost, correct, total, gold]):
          handle = dragnn_ops.advance_from_oracle(handle, component=self.name)
        return [handle, cost, correct, total] + next_arrays

    with tf.name_scope(self.name + '/train_state'):
      init_arrays = []
      for layer in self.network.layers:
        init_arrays.append(layer.create_array(state.current_batch_size))

    output = tf.while_loop(
        cond,
        body, [state.handle, cost, correct, total] + init_arrays,
        name='train_%s' % self.name)

    # Saves completed arrays and return final state and cost.
    state.handle = output[0]
    correct = output[2]
    total = output[3]
    arrays = output[4:]
    cost = output[1]

    # Store handles to the final output for use in subsequent tasks.
    network_state = network_states[self.name]
    with tf.name_scope(self.name + '/stored_act'):
      for index, layer in enumerate(self.network.layers):
        network_state.activations[layer.name] = network_units.StoredActivations(
            array=arrays[index])

    # Normalize the objective by the total # of steps taken.
    with tf.control_dependencies([tf.assert_greater(total, 0)]):
      cost /= tf.to_float(total)

    # Adds regularization for the hidden weights.
    cost = self.add_regularizer(cost)

    with tf.control_dependencies([x.flow for x in arrays]):
      return tf.identity(state.handle), cost, correct, total
Example #49
0
  def __init__(self,
               batch_size,
               total_num_examples,
               max_learning_rate=1.,
               preconditioner_decay_rate=0.95,
               burnin=25,
               burnin_max_learning_rate=1e-6,
               use_single_learning_rate=False,
               name=None,
               variable_scope=None):
    default_name = 'VariationalSGD'
    with tf.name_scope(name, default_name, [
        max_learning_rate, preconditioner_decay_rate, batch_size, burnin,
        burnin_max_learning_rate
    ]):
      if variable_scope is None:
        var_scope_name = tf.get_default_graph().unique_name(
            name or default_name)
        with tf.variable_scope(var_scope_name) as scope:
          self._variable_scope = scope
      else:
        self._variable_scope = variable_scope

      self._preconditioner_decay_rate = tf.convert_to_tensor(
          preconditioner_decay_rate, name='preconditioner_decay_rate')
      self._batch_size = tf.convert_to_tensor(batch_size, name='batch_size')
      self._total_num_examples = tf.convert_to_tensor(
          total_num_examples, name='total_num_examples')
      self._burnin = tf.convert_to_tensor(burnin, name='burnin')
      self._burnin_max_learning_rate = tf.convert_to_tensor(
          burnin_max_learning_rate, name='burnin_max_learning_rate')
      self._max_learning_rate = tf.convert_to_tensor(
          max_learning_rate, name='max_learning_rate')
      self._use_single_learning_rate = use_single_learning_rate

      with tf.variable_scope(self._variable_scope):
        self._counter = tf.get_variable(
            'counter', initializer=0, trainable=False)

      self._preconditioner_decay_rate = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._preconditioner_decay_rate,
              message='`preconditioner_decay_rate` must be non-negative'),
          tf.assert_less_equal(
              self._preconditioner_decay_rate,
              1.,
              message='`preconditioner_decay_rate` must be at most 1.'),
      ], self._preconditioner_decay_rate)

      self._batch_size = control_flow_ops.with_dependencies([
          tf.assert_greater(
              self._batch_size,
              0,
              message='`batch_size` must be greater than zero')
      ], self._batch_size)

      self._total_num_examples = control_flow_ops.with_dependencies([
          tf.assert_greater(
              self._total_num_examples,
              0,
              message='`total_num_examples` must be greater than zero')
      ], self._total_num_examples)

      self._burnin = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._burnin, message='`burnin` must be non-negative'),
          tf.assert_integer(
              self._burnin, message='`burnin` must be an integer')
      ], self._burnin)

      self._burnin_max_learning_rate = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._burnin_max_learning_rate,
              message='`burnin_max_learning_rate` must be non-negative')
      ], self._burnin_max_learning_rate)

      self._max_learning_rate = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._max_learning_rate,
              message='`max_learning_rate` must be non-negative')
      ], self._max_learning_rate)

      super(VariationalSGD, self).__init__(
          use_locking=False, name=name or default_name)
def model_fn(features, labels, mode, params, config):
  """Build the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.
  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  encoder = make_encoder(params["activation"],
                         params["num_topics"],
                         params["layer_sizes"])
  decoder, topics_words = make_decoder(params["num_topics"],
                                       features.shape[1])
  prior, prior_variables = make_prior(params["num_topics"],
                                      params["prior_initial_value"])

  topics_prior = prior()
  alpha = topics_prior.concentration

  topics_posterior = encoder(features)
  topics = topics_posterior.sample()
  random_reconstruction = decoder(topics)

  reconstruction = random_reconstruction.log_prob(features)
  tf.summary.scalar("reconstruction", tf.reduce_mean(reconstruction))

  # Compute the KL-divergence between two Dirichlets analytically.
  # The sampled KL does not work well for "sparse" distributions
  # (see Appendix D of [2]).
  kl = tfd.kl_divergence(topics_posterior, topics_prior)
  tf.summary.scalar("kl", tf.reduce_mean(kl))

  # Ensure that the KL is non-negative (up to a very small slack).
  # Negative KL can happen due to numerical instability.
  with tf.control_dependencies([tf.assert_greater(kl, -1e-3, message="kl")]):
    kl = tf.identity(kl)

  elbo = reconstruction - kl
  avg_elbo = tf.reduce_mean(elbo)
  tf.summary.scalar("elbo", avg_elbo)
  loss = -avg_elbo

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  optimizer = tf.train.AdamOptimizer(params["learning_rate"])

  # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
  # For the first prior_burn_in_steps steps they are fixed, and then trained
  # jointly with the other parameters.
  grads_and_vars = optimizer.compute_gradients(loss)
  grads_and_vars_except_prior = [
      x for x in grads_and_vars if x[1] not in prior_variables]

  def train_op_except_prior():
    return optimizer.apply_gradients(
        grads_and_vars_except_prior,
        global_step=global_step)

  def train_op_all():
    return optimizer.apply_gradients(
        grads_and_vars,
        global_step=global_step)

  train_op = tf.cond(
      global_step < params["prior_burn_in_steps"],
      true_fn=train_op_except_prior,
      false_fn=train_op_all)

  # The perplexity is an exponent of the average negative ELBO per word.
  words_per_document = tf.reduce_sum(features, axis=1)
  log_perplexity = -elbo / words_per_document
  tf.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity)))
  (log_perplexity_tensor, log_perplexity_update) = tf.metrics.mean(
      log_perplexity)
  perplexity_tensor = tf.exp(log_perplexity_tensor)

  # Obtain the topics summary. Implemented as a py_func for simplicity.
  topics = tf.py_func(
      functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
      [topics_words, alpha], tf.string, stateful=False)
  tf.summary.text("topics", topics)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo": tf.metrics.mean(elbo),
          "reconstruction": tf.metrics.mean(reconstruction),
          "kl": tf.metrics.mean(kl),
          "perplexity": (perplexity_tensor, log_perplexity_update),
          "topics": (topics, tf.no_op()),
      },
  )
Example #51
0
  def __init__(self,
               learning_rate,
               preconditioner_decay_rate=0.95,
               data_size=1,
               burnin=25,
               diagonal_bias=1e-8,
               name=None,
               parallel_iterations=10,
               variable_scope=None):
    default_name = 'StochasticGradientLangevinDynamics'
    with tf.name_scope(name, default_name, [
        learning_rate, preconditioner_decay_rate, data_size, burnin,
        diagonal_bias
    ]):
      if tf.executing_eagerly():
        raise NotImplementedError('Eager execution currently not supported for '
                                  ' SGLD optimizer.')
      if variable_scope is None:
        var_scope_name = tf.get_default_graph().unique_name(
            name or default_name)
        with tf.variable_scope(var_scope_name) as scope:
          self._variable_scope = scope
      else:
        self._variable_scope = variable_scope

      self._preconditioner_decay_rate = tf.convert_to_tensor(
          preconditioner_decay_rate, name='preconditioner_decay_rate')
      self._data_size = tf.convert_to_tensor(
          data_size, name='data_size')
      self._burnin = tf.convert_to_tensor(burnin, name='burnin')
      self._diagonal_bias = tf.convert_to_tensor(
          diagonal_bias, name='diagonal_bias')
      self._learning_rate = tf.convert_to_tensor(
          learning_rate, name='learning_rate')
      self._parallel_iterations = parallel_iterations

      with tf.variable_scope(self._variable_scope):
        self._counter = tf.get_variable(
            'counter', initializer=0, trainable=False)

      self._preconditioner_decay_rate = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._preconditioner_decay_rate,
              message='`preconditioner_decay_rate` must be non-negative'),
          tf.assert_less_equal(
              self._preconditioner_decay_rate,
              1.,
              message='`preconditioner_decay_rate` must be at most 1.'),
      ], self._preconditioner_decay_rate)

      self._data_size = control_flow_ops.with_dependencies([
          tf.assert_greater(
              self._data_size,
              0,
              message='`data_size` must be greater than zero')
      ], self._data_size)

      self._burnin = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._burnin, message='`burnin` must be non-negative'),
          tf.assert_integer(
              self._burnin, message='`burnin` must be an integer')
      ], self._burnin)

      self._diagonal_bias = control_flow_ops.with_dependencies([
          tf.assert_non_negative(
              self._diagonal_bias,
              message='`diagonal_bias` must be non-negative')
      ], self._diagonal_bias)

      super(StochasticGradientLangevinDynamics, self).__init__(
          use_locking=False, name=name or default_name)
  def __init__(self,
               mean_direction,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='VonMisesFisher'):
    """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name, values=[mean_direction, concentration]) as name:
      dtype = dtype_util.common_dtype([mean_direction, concentration],
                                      tf.float32)
      mean_direction = tf.convert_to_tensor(
          mean_direction, name='mean_direction', dtype=dtype)
      concentration = tf.convert_to_tensor(
          concentration, name='concentration', dtype=dtype)
      assertions = [
          tf.assert_non_negative(
              concentration, message='`concentration` must be non-negative'),
          tf.assert_greater(
              tf.shape(mean_direction)[-1], 1,
              message='`mean_direction` may not have scalar event shape'),
          tf.assert_near(
              1., tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length')
      ] if validate_args else []
      if mean_direction.shape.with_rank_at_least(1)[-1].value is not None:
        if mean_direction.shape.with_rank_at_least(1)[-1].value > 5:
          raise ValueError('vMF ndims > 5 is not currently supported')
      elif validate_args:
        assertions += [tf.assert_less_equal(
            tf.shape(mean_direction)[-1], 5,
            message='vMF ndims > 5 is not currently supported')]
      with tf.control_dependencies(assertions):
        self._mean_direction = tf.identity(mean_direction)
        self._concentration = tf.identity(concentration)
      tf.assert_same_float_dtype([self._mean_direction, self._concentration])
      # mean_direction is always reparameterized.
      # concentration is only for event_dim==3, via an inversion sampler.
      reparameterization_type = (
          reparameterization.FULLY_REPARAMETERIZED
          if mean_direction.shape.with_rank_at_least(1)[-1].value == 3 else
          reparameterization.NOT_REPARAMETERIZED)
      super(VonMisesFisher, self).__init__(
          dtype=self._concentration.dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          graph_parents=[self._mean_direction, self._concentration],
          name=name)