Ejemplo n.º 1
0
def flip_randomly_points_and_normals_motions(points, normals, motions,
                                             is_training):
    """Flip points and normals against x or/and y axis.

  Args:
    points: A tf.float32 tensor of size [N, 3] containing points.
    normals: A tf.float32 tensor of size [N, 3] containing points or None.
    motions: A tf.float32 tensor of size [N, 3] containing motion vectors or
      None.
    is_training: True if in training stage. Random flipping only takes place
      during training.

  Returns:
    flipped_points: Flipped points. A tf.float32 tensor of size [N, 3].
    flipped_normals: Flipped normals. A tf.float32 tensor of size [N, 3]. It
      will be None of the normals is None.
  """
    if is_training:
        x_cond = tf.greater(
            tf.random.uniform([], minval=0.0, maxval=1.0, dtype=tf.float32),
            0.5)
        x_rotate = tf.cond(x_cond, lambda: tf.constant(1.0, dtype=tf.float32),
                           lambda: tf.constant(-1.0, dtype=tf.float32))
        y_cond = tf.greater(
            tf.random.uniform([], minval=0.0, maxval=1.0, dtype=tf.float32),
            0.5)
        y_rotate = tf.cond(y_cond, lambda: tf.constant(1.0, dtype=tf.float32),
                           lambda: tf.constant(-1.0, dtype=tf.float32))
        (points, normals,
         motions) = flip_points_and_normals_motions(points=points,
                                                    normals=normals,
                                                    motions=motions,
                                                    x_rotate=x_rotate,
                                                    y_rotate=y_rotate)
    return points, normals, motions
Ejemplo n.º 2
0
def select_slate_greedy(slate_size, s_no_click, s, q):
    """Selects the slate using the adaptive greedy algorithm.

  This algorithm corresponds to the method "GS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """
    def argmax(v, mask):
        return tf.argmax((v - tf.reduce_min(v) + 1) * mask, axis=0)

    numerator = tf.constant(0.)
    denominator = tf.constant(0.) + s_no_click
    mask = tf.ones(tf.shape(q)[0])

    def set_element(v, i, x):
        mask = tf.one_hot(i, tf.shape(v)[0])
        v_new = tf.ones_like(v) * x
        return tf.where(tf.equal(mask, 1), v_new, v)

    for _ in range(slate_size):
        k = argmax((numerator + s * q) / (denominator + s), mask)
        mask = set_element(mask, k, 0)
        numerator = numerator + tf.gather(s * q, k)
        denominator = denominator + tf.gather(s, k)

    output_slate = tf.where(tf.equal(mask, 0))
    return output_slate
Ejemplo n.º 3
0
 def get_metric_dictionary(self):
   metrics_dict = {}
   class_recall_list = []  # used for calculating mean pixel accuracy.
   class_iou_list = []     # used for calculating mean iou.
   for c in self.class_range:
     tp = self.true_positive_metrics[c].result()
     fp = self.false_positive_metrics[c].result()
     fn = self.false_negative_metrics[c].result()
     class_recall = tp / (tp + fn)
     class_precision = tf.where(
         tf.greater(tp + fn, 0.0), _safe_div(tp, (tp + fp)),
         tf.constant(np.NaN))
     class_iou = tf.where(
         tf.greater(tp + fn, 0.0), tp / (tp + fn + fp), tf.constant(np.NaN))
     class_recall_list.append(class_recall)
     class_iou_list.append(class_iou)
     class_name = _get_class_name(class_id=c, label_map=self.label_map)
     metrics_dict[self.eval_prefix +
                  '_recall/{}'.format(class_name)] = class_recall
     metrics_dict[self.eval_prefix +
                  '_precision/{}'.format(class_name)] = class_precision
     metrics_dict[self.eval_prefix + '_iou/{}'.format(class_name)] = class_iou
   mean_pixel_accuracy = _non_nan_mean(class_recall_list)
   mean_iou = _non_nan_mean(class_iou_list)
   metrics_dict[self.eval_prefix +
                '_avg/mean_pixel_accuracy'] = mean_pixel_accuracy
   metrics_dict[self.eval_prefix + '_avg/mean_iou'] = mean_iou
   return metrics_dict
def state_rewards(states,
                  actions,
                  rewards,
                  next_states,
                  contexts,
                  weight_index=None,
                  state_indices=None,
                  weight_vector=1.0,
                  offset_vector=0.0,
                  summarize=False):
  """Returns the rewards that are linear mapping of next_states.

  Args:
    states: A [batch_size, num_state_dims] Tensor representing a batch
        of states.
    actions: A [batch_size, num_action_dims] Tensor representing a batch
      of actions.
    rewards: A [batch_size] Tensor representing a batch of rewards.
    next_states: A [batch_size, num_state_dims] Tensor representing a batch
      of next states.
    contexts: A list of [batch_size, num_context_dims] Tensor representing
      a batch of contexts.
    weight_index: (integer) Index of contexts lists that specify weighting.
    state_indices: (a list of Numpy integer array) Indices of states dimensions
      to be mapped.
    weight_vector: (a number or a list or Numpy array) The weighting vector,
      broadcastable to `next_states`.
    offset_vector: (a number or a list of Numpy array) The off vector.
    summarize: (boolean) enable summary ops.

  Returns:
    A new tf.float32 [batch_size] rewards Tensor, and
      tf.float32 [batch_size] discounts tensor.
  """
  del states, actions, rewards  # unused args
  stats = {}
  record_tensor(next_states, state_indices, stats)
  next_states = index_states(next_states, state_indices)
  weight = tf.constant(
      weight_vector, dtype=next_states.dtype, shape=next_states[0].shape)
  weights = tf.expand_dims(weight, 0)
  offset = tf.constant(
      offset_vector, dtype=next_states.dtype, shape=next_states[0].shape)
  offsets = tf.expand_dims(offset, 0)
  if weight_index is not None:
    weights *= contexts[weight_index]
  rewards = tf.to_float(tf.reduce_sum(weights * (next_states+offsets), axis=1))
  if summarize:
    with tf.name_scope('RewardFn/'):
      summarize_stats(stats)
  return rewards, tf.ones_like(rewards)
Ejemplo n.º 5
0
 def __init__(self, ckpt_dir, save_epoch_freq=1, max_to_keep=3):
     self._ckpt_saved_epoch = tf.Variable(initial_value=tf.constant(
         -1, dtype=tf.dtypes.int64),
                                          name='ckpt_saved_epoch')
     self.ckpt_dir = ckpt_dir
     self.max_to_keep = max_to_keep
     self.save_epoch_freq = save_epoch_freq
Ejemplo n.º 6
0
def score_documents_tf(user_obs,
                       doc_obs,
                       no_click_mass=1.0,
                       is_mnl=False,
                       min_normalizer=-1.0):
    """Computes unnormalized scores given both user and document observations.

  This implements both multinomial proportional model and multinormial logit
    model given some parameters. We also assume scores are based on inner
    products of user_obs and doc_obs.

  Args:
    user_obs: An instance of AbstractUserState.
    doc_obs: A numpy array that represents the observation of all documents in
      the candidate set.
    no_click_mass: a float indicating the mass given to a no click option
    is_mnl: whether to use a multinomial logit model instead of a multinomial
      proportional model.
    min_normalizer: A float (<= 0) used to offset the scores to be positive when
      using multinomial proportional model.

  Returns:
    A float tensor that stores unnormalzied scores of documents and a float
      tensor that represents the score for the action of picking no document.
  """
    user_obs = tf.reshape(user_obs, [1, -1])
    scores = tf.reduce_sum(input_tensor=tf.multiply(user_obs, doc_obs), axis=1)
    all_scores = tf.concat([scores, tf.constant([no_click_mass])], axis=0)
    if is_mnl:
        all_scores = tf.nn.softmax(all_scores)
    else:
        all_scores = all_scores - min_normalizer
    return all_scores[:-1], all_scores[-1]
Ejemplo n.º 7
0
    def _build_select_slate_op(self):
        p_no_click = self._prob_no_click_ph
        p = self._doc_affinity_scores_ph
        q = self._net_outputs.q_values[0]
        with tf.name_scope('select_slate'):
            self._output_slate = self._select_slate_fn(self._slate_size,
                                                       p_no_click, p, q)

        self._output_slate = tf.Print(
            self._output_slate,
            [tf.constant('cp 1'), self._output_slate, p, q],
            summarize=10000)
        self._output_slate = tf.reshape(self._output_slate,
                                        (self._slate_size, ))

        self._action_counts = tf.get_variable(
            'action_counts',
            shape=[self._num_candidates],
            initializer=tf.zeros_initializer())
        output_slate = tf.reshape(self._output_slate, [-1])
        output_one_hot = tf.one_hot(output_slate, self._num_candidates)
        update_ops = []
        for i in range(self._slate_size):
            update_ops.append(
                tf.assign_add(self._action_counts, output_one_hot[i]))
        self._select_action_update_op = tf.group(*update_ops)
Ejemplo n.º 8
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the image, and ignores the label. We assume that the
    image has three channels.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      image_rescaled: the image, resized to `image_size x image_size` and
      rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
      to go beyond this range.
    """
        image_decoded = read_example_and_parse_image(example_string)['image']
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image_resized = tf.cast(image_resized, tf.float32)
        image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(tf.shape(
                    image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(
                    image, [self.image_size, self.image_size, 3])

        return image
def _box_rotation_regression_loss(loss_type, is_balanced,
                                  input_boxes_rotation_matrix,
                                  input_boxes_instance_id,
                                  output_boxes_rotation_matrix, delta):
  """Computes regression loss on object rotations."""

  def fn():
    """Loss function for when number of input and output boxes is positive."""
    if is_balanced:
      weights = loss_utils.get_balanced_loss_weights_multiclass(
          labels=input_boxes_instance_id)
    else:
      weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                        dtype=tf.float32)
    gt_rotation_matrix = tf.reshape(input_boxes_rotation_matrix, [-1, 9])
    predicted_rotation_matrix = tf.reshape(output_boxes_rotation_matrix,
                                           [-1, 9])
    if loss_type == 'huber':
      loss_fn = tf.keras.losses.Huber(
          delta=delta, reduction=tf.keras.losses.Reduction.NONE)
    elif loss_type == 'absolute_difference':
      loss_fn = tf.keras.losses.MeanAbsoluteError(
          reduction=tf.keras.losses.Reduction.NONE)
    else:
      raise ValueError(('Unknown loss type %s.' % loss_type))
    rotation_losses = loss_fn(
        y_true=gt_rotation_matrix, y_pred=predicted_rotation_matrix)
    return tf.reduce_mean(rotation_losses * tf.reshape(weights, [-1]))

  cond_input = tf.greater(tf.shape(input_boxes_rotation_matrix)[0], 0)
  cond_output = tf.greater(tf.shape(output_boxes_rotation_matrix)[0], 0)
  cond = tf.logical_and(cond_input, cond_output)
  return tf.cond(cond, fn, lambda: tf.constant(0.0, dtype=tf.float32))
def _box_center_distance_loss_on_voxel_tensors_unbatched(
    inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate):
  """Computes huber loss on predicted object centers for each voxel."""
  inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs(
      inputs_1=inputs_1, outputs_1=outputs_1)

  def loss_fn_unbatched():
    """Loss function."""
    if is_intermediate:
      output_boxes_center = outputs_1[standard_fields.DetectionResultFields
                                      .intermediate_object_center_voxels]
    else:
      output_boxes_center = outputs_1[
          standard_fields.DetectionResultFields.object_center_voxels]
    return _box_center_distance_loss(
        loss_type=loss_type,
        is_balanced=is_balanced,
        input_boxes_center=inputs_1[
            standard_fields.InputDataFields.object_center_voxels],
        input_boxes_instance_id=inputs_1[
            standard_fields.InputDataFields.object_instance_id_voxels],
        output_boxes_center=output_boxes_center,
        delta=delta)

  return tf.cond(
      tf.reduce_any(valid_mask),
      loss_fn_unbatched, lambda: tf.constant(0.0, dtype=tf.float32))
Ejemplo n.º 11
0
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments,
                                        num_samples_per_voxel):
    """Samples features from the points within each voxel.

  Args:
    data: A tf.float32 tensor of size [N, F].
    segment_ids: A tf.int32 tensor of size [N].
    num_segments: Number of segments.
    num_samples_per_voxel: Number of features to sample per voxel. If the voxel
      has less number of points in it, the point features will be padded by 0.

  Returns:
    A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F].
    A tf.int32 indices of size [N, num_samples_per_voxel].
  """
    num_channels = data.get_shape().as_list()[1]
    if num_channels is None:
        raise ValueError('num_channels is None.')
    n = tf.shape(segment_ids)[0]

    def _body_fn(i, indices_range, indices):
        """Computes the indices of the i-th point feature in each segment."""
        indices_i = tf.math.unsorted_segment_max(data=indices_range,
                                                 segment_ids=segment_ids,
                                                 num_segments=num_segments)
        indices_i_positive_mask = tf.greater(indices_i, 0)
        indices_i_positive = tf.boolean_mask(indices_i,
                                             indices_i_positive_mask)
        boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims(
            indices_i_positive - 1, axis=1),
                                                     dtype=tf.int64),
                                     updates=tf.ones_like(indices_i_positive,
                                                          dtype=tf.int32),
                                     shape=(n, ))
        indices_range *= (1 - boolean_mask)
        indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
        indices_i = tf.pad(tf.expand_dims(indices_i, axis=1),
                           paddings=[[0, 0],
                                     [i, num_samples_per_voxel - i - 1]])
        indices += indices_i
        i = i + 1
        return i, indices_range, indices

    cond = lambda i, indices_range, indices: i < num_samples_per_voxel

    (_, _, indices) = tf.while_loop(
        cond=cond,
        body=_body_fn,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1,
                   tf.zeros([num_segments, num_samples_per_voxel],
                            dtype=tf.int32)))

    data = tf.pad(data, paddings=[[1, 0], [0, 0]])
    voxel_features = tf.gather(data, tf.reshape(indices, [-1]))
    return tf.reshape(voxel_features,
                      [num_segments, num_samples_per_voxel, num_channels])
Ejemplo n.º 12
0
    def __init__(self,
                 env_spec,
                 callbacks=None,
                 model_class=FullyConvModel,
                 optimizer=tf.train.AdamOptimizer,
                 learning_rate=0.0001,
                 discount=0.99,
                 trajectory_length=16,
                 batch_size=32,
                 max_grads_norm=100,
                 policy_factor=1,
                 entropy_factor=0.0001,
                 value_factor=0.5):
        self.callbacks = callbacks
        self.discount = discount
        self.policy_factor = policy_factor
        self.entropy_factor = entropy_factor
        self.value_factor = value_factor

        self.input_observations = {
            name: Input(shape=spec.shape, name='input_{}'.format(name))
            for name, spec in env_spec.observation_spec.items()
        }
        self.input_actions = {
            name: Input(shape=(),
                        name='input_arg_{}_value'.format(name),
                        dtype='int32')
            for name in env_spec.action_spec
        }
        self.input_returns = Input(shape=(), name='input_returns')

        self.function_args_mask = tf.constant(
            env_spec.action_spec['function_id'].args_mask,
            dtype=tf.float32,
            name='function_args_mask')

        self.model = model_class(self.input_observations, env_spec)

        self.loss = self.build_loss()

        self.optimizer = optimizer(learning_rate=learning_rate)
        grads, vars = zip(*self.optimizer.compute_gradients(self.loss))
        grads_norm = tf.global_norm(grads)
        if max_grads_norm > 0:
            grads, _ = tf.clip_by_global_norm(grads, max_grads_norm,
                                              grads_norm)
        self.train_op = self.optimizer.apply_gradients(
            zip(grads, vars), global_step=tf.train.get_or_create_global_step())

        self.history = History(trajectory_length, batch_size, env_spec)

        tf.summary.scalar('learning_rate', learning_rate)
        tf.summary.scalar('total_loss', self.loss, family='losses')
        tf.summary.scalar('grads_norm', grads_norm)
def reset_rewards(states,
                  actions,
                  rewards,
                  next_states,
                  contexts,
                  reset_index=0,
                  reset_state=None,
                  reset_reward_function=None,
                  include_forward_rewards=True,
                  include_reset_rewards=True):
  """Returns the rewards for a forward/reset agent.

  Args:
    states: A [batch_size, num_state_dims] Tensor representing a batch
        of states.
    actions: A [batch_size, num_action_dims] Tensor representing a batch
      of actions.
    rewards: A [batch_size] Tensor representing a batch of rewards.
    next_states: A [batch_size, num_state_dims] Tensor representing a batch
      of next states.
    contexts: A list of [batch_size, num_context_dims] Tensor representing
      a batch of contexts.
    reset_index: (integer) The context list index that specifies reset.
    reset_state: Reset state.
    reset_reward_function: Reward function for reset step.
    include_forward_rewards: Include the rewards from the forward pass.
    include_reset_rewards: Include the rewards from the reset pass.

  Returns:
    A new tf.float32 [batch_size] rewards Tensor, and
      tf.float32 [batch_size] discounts tensor.
  """
  reset_state = tf.constant(
      reset_state, dtype=next_states.dtype, shape=next_states.shape)
  reset_states = tf.expand_dims(reset_state, 0)

  def true_fn():
    if include_reset_rewards:
      return reset_reward_function(states, actions, rewards, next_states,
                                   [reset_states] + contexts[1:])
    else:
      return tf.zeros_like(rewards), tf.ones_like(rewards)

  def false_fn():
    if include_forward_rewards:
      return plain_rewards(states, actions, rewards, next_states, contexts)
    else:
      return tf.zeros_like(rewards), tf.ones_like(rewards)

  rewards, discounts = tf.cond(
      tf.cast(contexts[reset_index][0, 0], dtype=tf.bool), true_fn, false_fn)
  return rewards, discounts
Ejemplo n.º 14
0
def box_corner_distance_loss_on_object_tensors(inputs,
                                               outputs,
                                               loss_type,
                                               delta=1.0,
                                               is_balanced=False):
    """Computes regression loss on object corner locations using object tensors.

  Args:
    inputs: A dictionary of tf.Tensors with our input data.
    outputs: A dictionary of tf.Tensors with the network output.
    loss_type: Loss type.
    delta: float, the voxel where the huber loss function changes from a
      quadratic to linear.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for each object instance.

  Returns:
    localization_loss: A tf.float32 scalar corresponding to localization loss.
  """
    def fn(inputs_1, outputs_1):
        return _box_corner_distance_loss_on_object_tensors(
            inputs=inputs_1,
            outputs=outputs_1,
            loss_type=loss_type,
            delta=delta,
            is_balanced=is_balanced)

    batch_size = len(inputs[standard_fields.InputDataFields.objects_length])
    losses = []
    for b in range(batch_size):
        inputs_1 = batch_utils.get_batch_size_1_input_objects(inputs=inputs,
                                                              b=b)
        outputs_1 = batch_utils.get_batch_size_1_output_objects(
            outputs=outputs, b=b)
        cond_input = tf.greater(
            tf.shape(
                inputs_1[standard_fields.InputDataFields.objects_length])[0],
            0)
        cond_output = tf.greater(
            tf.shape(outputs_1[
                standard_fields.DetectionResultFields.objects_length])[0], 0)
        cond = tf.logical_and(cond_input, cond_output)
        # pylint: disable=cell-var-from-loop
        loss = tf.cond(cond,
                       lambda: fn(inputs_1=inputs_1, outputs_1=outputs_1),
                       lambda: tf.constant(0.0, dtype=tf.float32))
        # pylint: enable=cell-var-from-loop
        losses.append(loss)
    return tf.reduce_mean(tf.stack(losses))
def index_states(states, indices):
  """Return indexed states.

  Args:
    states: A [batch_size, num_state_dims] Tensor representing a batch
        of states.
    indices: (a list of Numpy integer array) Indices of states dimensions
      to be mapped.
  Returns:
    A [batch_size, num_indices] Tensor representing the batch of indexed states.
  """
  if indices is None:
    return states
  indices = tf.constant(indices, dtype=tf.int32)
  return tf.gather(states, indices=indices, axis=1)
Ejemplo n.º 16
0
def _box_corner_distance_loss_on_object_tensors(inputs, outputs, loss_type,
                                                delta, is_balanced):
    """Computes huber loss on object corner locations."""
    valid_mask_class = tf.greater(
        tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                   [-1]), 0)
    valid_mask_instance = tf.greater(
        tf.reshape(inputs[standard_fields.InputDataFields.objects_instance_id],
                   [-1]), 0)
    valid_mask = tf.logical_and(valid_mask_class, valid_mask_instance)

    def fn():
        for field in standard_fields.get_input_object_fields():
            if field in inputs:
                inputs[field] = tf.boolean_mask(inputs[field], valid_mask)
        for field in standard_fields.get_output_object_fields():
            if field in outputs:
                outputs[field] = tf.boolean_mask(outputs[field], valid_mask)
        return _box_corner_distance_loss(
            loss_type=loss_type,
            is_balanced=is_balanced,
            input_boxes_length=inputs[
                standard_fields.InputDataFields.objects_length],
            input_boxes_height=inputs[
                standard_fields.InputDataFields.objects_height],
            input_boxes_width=inputs[
                standard_fields.InputDataFields.objects_width],
            input_boxes_center=inputs[
                standard_fields.InputDataFields.objects_center],
            input_boxes_rotation_matrix=inputs[
                standard_fields.InputDataFields.objects_rotation_matrix],
            input_boxes_instance_id=inputs[
                standard_fields.InputDataFields.objects_instance_id],
            output_boxes_length=outputs[
                standard_fields.DetectionResultFields.objects_length],
            output_boxes_height=outputs[
                standard_fields.DetectionResultFields.objects_height],
            output_boxes_width=outputs[
                standard_fields.DetectionResultFields.objects_width],
            output_boxes_center=outputs[
                standard_fields.DetectionResultFields.objects_center],
            output_boxes_rotation_matrix=outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            delta=delta)

    return tf.cond(tf.reduce_any(valid_mask), fn,
                   lambda: tf.constant(0.0, dtype=tf.float32))
Ejemplo n.º 17
0
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      An op performing one step of training from replay data.
    """
        # click_indicator: [B, S]
        # q_values: [B, A]
        # actions: [B, S]
        # slate_q_values: [B, S]
        # replay_click_q: [B]
        click_indicator = self._replay.rewards[:, :,
                                               self._click_response_index]
        slate_q_values = tf.compat.v1.batch_gather(
            self._replay_net_outputs.q_values,
            tf.cast(self._replay.actions, dtype=tf.int32))
        # Only get the Q from the clicked document.
        replay_click_q = tf.reduce_sum(input_tensor=slate_q_values *
                                       click_indicator,
                                       axis=1,
                                       name='replay_click_q')

        target = tf.stop_gradient(self._build_target_q_op())

        clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
        clicked_indices = tf.squeeze(tf.compat.v1.where(tf.equal(clicked, 1)),
                                     axis=1)
        # clicked_indices is a vector and tf.gather selects the batch dimension.
        q_clicked = tf.gather(replay_click_q, clicked_indices)
        target_clicked = tf.gather(target, clicked_indices)

        def get_train_op():
            loss = tf.reduce_mean(input_tensor=tf.square(q_clicked -
                                                         target_clicked))
            if self.summary_writer is not None:
                with tf.compat.v1.variable_scope('Losses'):
                    tf.compat.v1.summary.scalar('Loss', loss)

            return loss

        loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
                       true_fn=get_train_op,
                       false_fn=lambda: tf.constant(0.),
                       name='')

        return self.optimizer.minimize(loss)
Ejemplo n.º 18
0
def process_example(example_string, image_size, data_augmentation=None):
    """Processes a single example string.

  Extracts and processes the image, and ignores the label. We assume that the
  image has three channels.

  Args:
    example_string: str, an Example protocol buffer.
    image_size: int, desired image size. The extracted image will be resized to
      `[image_size, image_size]`.
    data_augmentation: A DataAugmentation object with parameters for perturbing
      the images.

  Returns:
    image_rescaled: the image, resized to `image_size x image_size` and rescaled
      to [-1, 1]. Note that Gaussian data augmentation may cause values to
      go beyond this range.
  """
    image_string = tf.parse_single_example(example_string,
                                           features={
                                               'image':
                                               tf.FixedLenFeature(
                                                   [], dtype=tf.string),
                                               'label':
                                               tf.FixedLenFeature([], tf.int64)
                                           })['image']
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize_images(
        image_decoded, [image_size, image_size],
        method=tf.image.ResizeMethod.BILINEAR,
        align_corners=True)
    image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

    if data_augmentation is not None:
        if data_augmentation.enable_gaussian_noise:
            image = image + tf.random_normal(
                tf.shape(image)) * data_augmentation.gaussian_noise_std

        if data_augmentation.enable_jitter:
            j = data_augmentation.jitter_amount
            paddings = tf.constant([[j, j], [j, j], [0, 0]])
            image = tf.pad(image, paddings, 'REFLECT')
            image = tf.image.random_crop(image, [image_size, image_size, 3])

    return image
Ejemplo n.º 19
0
def _box_size_regression_loss_on_voxel_tensors_unbatched(
        inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate):
    """Computes regression loss on predicted object size for each voxel."""
    inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs(
        inputs_1=inputs_1, outputs_1=outputs_1)

    def loss_fn_unbatched():
        """Loss function."""
        if is_intermediate:
            output_boxes_length = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_length_voxels]
            output_boxes_height = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_height_voxels]
            output_boxes_width = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_width_voxels]
        else:
            output_boxes_length = outputs_1[
                standard_fields.DetectionResultFields.object_length_voxels]
            output_boxes_height = outputs_1[
                standard_fields.DetectionResultFields.object_height_voxels]
            output_boxes_width = outputs_1[
                standard_fields.DetectionResultFields.object_width_voxels]
        return _box_size_regression_loss(
            loss_type=loss_type,
            is_balanced=is_balanced,
            input_boxes_length=inputs_1[
                standard_fields.InputDataFields.object_length_voxels],
            input_boxes_height=inputs_1[
                standard_fields.InputDataFields.object_height_voxels],
            input_boxes_width=inputs_1[
                standard_fields.InputDataFields.object_width_voxels],
            input_boxes_instance_id=inputs_1[
                standard_fields.InputDataFields.object_instance_id_voxels],
            output_boxes_length=output_boxes_length,
            output_boxes_height=output_boxes_height,
            output_boxes_width=output_boxes_width,
            delta=delta)

    return tf.cond(tf.reduce_any(valid_mask), loss_fn_unbatched,
                   lambda: tf.constant(0.0, dtype=tf.float32))
def _box_size_regression_loss(loss_type, is_balanced, input_boxes_length,
                              input_boxes_height, input_boxes_width,
                              input_boxes_instance_id, output_boxes_length,
                              output_boxes_height, output_boxes_width, delta):
  """Computes regression loss on object sizes."""

  def fn():
    """Loss function for when number of input and output boxes is positive."""
    if is_balanced:
      weights = loss_utils.get_balanced_loss_weights_multiclass(
          labels=input_boxes_instance_id)
    else:
      weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                        dtype=tf.float32)
    gt_length = tf.reshape(input_boxes_length, [-1, 1])
    gt_height = tf.reshape(input_boxes_height, [-1, 1])
    gt_width = tf.reshape(input_boxes_width, [-1, 1])
    predicted_length = tf.reshape(output_boxes_length, [-1, 1])
    predicted_height = tf.reshape(output_boxes_height, [-1, 1])
    predicted_width = tf.reshape(output_boxes_width, [-1, 1])
    predicted_length /= gt_length
    predicted_height /= gt_height
    predicted_width /= gt_width
    predicted_size = tf.concat(
        [predicted_length, predicted_height, predicted_width], axis=1)
    gt_size = tf.ones_like(predicted_size)
    if loss_type == 'huber':
      loss_fn = tf.keras.losses.Huber(
          delta=delta, reduction=tf.keras.losses.Reduction.NONE)
    elif loss_type == 'absolute_difference':
      loss_fn = tf.keras.losses.MeanAbsoluteError(
          reduction=tf.keras.losses.Reduction.NONE)
    else:
      raise ValueError(('Unknown loss type %s.' % loss_type))
    size_losses = loss_fn(y_true=gt_size, y_pred=predicted_size)
    return tf.reduce_mean(size_losses * tf.reshape(weights, [-1]))

  cond_input = tf.greater(tf.shape(input_boxes_length)[0], 0)
  cond_output = tf.greater(tf.shape(output_boxes_length)[0], 0)
  cond = tf.logical_and(cond_input, cond_output)
  return tf.cond(cond, fn, lambda: tf.constant(0.0, dtype=tf.float32))
def _voxel_hard_negative_classification_loss_unbatched(inputs_1, outputs_1,
                                                       is_intermediate, gamma):
    """Loss function for input and outputs of batch size 1."""
    inputs_1, outputs_1 = _get_voxels_valid_inputs_outputs(inputs_1=inputs_1,
                                                           outputs_1=outputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    labels = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1])
    background_mask = tf.equal(labels, 0)
    num_background_points = tf.reduce_sum(
        tf.cast(background_mask, dtype=tf.int32))

    def loss_fn():
        """Loss function."""
        num_classes = logits.get_shape().as_list()[-1]
        if num_classes is None:
            raise ValueError('Number of classes is unknown.')
        masked_logits = tf.boolean_mask(logits, background_mask)
        masked_weights = tf.pow(
            1.0 - tf.reshape(tf.nn.softmax(masked_logits)[:, 0], [-1, 1]),
            gamma)
        num_points = tf.shape(masked_logits)[0]
        masked_weights = masked_weights * tf.cast(
            num_points, dtype=tf.float32) / tf.reduce_sum(masked_weights)
        masked_labels_one_hot = tf.one_hot(indices=tf.boolean_mask(
            labels, background_mask),
                                           depth=num_classes)
        loss = classification_loss_fn(logits=masked_logits,
                                      labels=masked_labels_one_hot,
                                      weights=masked_weights)
        return loss

    cond = tf.logical_and(tf.greater(num_background_points, 0),
                          tf.greater(tf.shape(labels)[0], 0))
    return tf.cond(cond, loss_fn, lambda: tf.constant(0.0, dtype=tf.float32))
Ejemplo n.º 22
0
def train_q(dataset,
            policy,
            optimizer=None,
            pack_transition_fn=None,
            q_graph_fn=None,
            log_dir=None,
            master='',
            task=0,
            training_steps=None,
            max_training_steps=100000,
            reuse=False,
            init_checkpoint=None,
            update_target_every_n_steps=50,
            log_every_n_steps=None,
            save_checkpoint_steps=500,
            save_summaries_steps=500):
    """Self-contained learning loop for offline Q-learning.

  Code inspired by OpenAI Baselines' deepq.build_train. This function is
  compatible with discrete Q-learning graphs, continuous Q learning graphs, and
  SARSA.

  Args:
    dataset: tf.data.Dataset providing transitions.
    policy: Instance of TFDQNPolicy class that provides functor for building the
      critic function.
    optimizer: Optional instance of an optimizer. If not specified, creates an
      AdamOptimizer using the default constructor.
    pack_transition_fn: Optional function that performs additional processing
      of the transition. This is a convenience method for ad-hoc manipulation of
      transition data passed to the learning function after parsing.
    q_graph_fn: Function used to construct training objectives w.r.t. critic
      outputs.
    log_dir: Where to save model checkpoints and tensorboard summaries.
    master: Optional address of master worker. Specify this when doing
      distributed training.
    task: Optional worker task for distributed training. Defaults to solo master
      task on a single machine.
    training_steps: Optional number of steps to run training before terminating
      early. Max_training_steps remains unchanged - training will terminate
      after max_training_steps whether or not training_steps is specified.
    max_training_steps: maximum number of training iters.
    reuse: If True, reuse existing variables for all declared variables by this
      function.
    init_checkpoint: Optional checkpoint to restore prior to training. If not
      provided, variables are initialized using global_variables_initializer().
    update_target_every_n_steps: How many global steps (training) between
      copying the Q network weights (scope='q_func') to target network
      (scope='target_q_func').
    log_every_n_steps: How many global steps between logging loss tensors.
    save_checkpoint_steps: How many global steps between saving TF variables
      to a checkpoint file.
    save_summaries_steps: How many global steps between saving TF summaries.

  Returns:
    (int) Current `global_step` reached after training for training_steps, or
    `max_training_steps` if `global_step` has reached `max_training_steps`.

  Raises:
    ValueError: If a batch of transitions is empty or the zeroth element is
      empty, when it's supposed to be of length batch_size.
  """
    data_iterator = dataset.make_one_shot_iterator()

    transition = data_iterator.get_next()
    if pack_transition_fn:
        transition = pack_transition_fn(transition)

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer()

    q_func = policy.get_q_func(is_training=True, reuse=reuse)
    loss, all_summaries = q_graph_fn(q_func, transition)

    q_func_vars = contrib_framework.get_trainable_variables(scope='q_func')
    target_q_func_vars = contrib_framework.get_trainable_variables(
        scope='target_q_func')
    global_step = tf.train.get_or_create_global_step()

    # Only optimize q_func and update its batchnorm params.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func')
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss,
                                      global_step=global_step,
                                      var_list=q_func_vars)

    chief_hooks = []
    hooks = []
    # Save summaries periodically.
    if save_summaries_steps is not None:
        chief_hooks.append(
            tf.train.SummarySaverHook(save_steps=save_summaries_steps,
                                      output_dir=log_dir,
                                      summary_op=all_summaries))

    # Stop after training_steps
    if max_training_steps:
        hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps))

    # Report if loss tensor is NaN.
    hooks.append(tf.train.NanTensorHook(loss))

    if log_every_n_steps is not None:
        tensor_dict = {'global_step': global_step, 'train loss': loss}
        chief_hooks.append(
            tf.train.LoggingTensorHook(tensor_dict,
                                       every_n_iter=log_every_n_steps))

        # Measure how fast we are training per sec and save to summary.
        chief_hooks.append(
            tf.train.StepCounterHook(every_n_steps=log_every_n_steps,
                                     output_dir=log_dir))

    # If target network exists, periodically update target Q network with new
    # weights (frozen target network). We hack this by
    # abusing a LoggingTensorHook for this.
    if target_q_func_vars and update_target_every_n_steps is not None:
        update_target_expr = []
        for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name),
                              sorted(target_q_func_vars,
                                     key=lambda v: v.name)):
            update_target_expr.append(var_t.assign(var))
        update_target_expr = tf.group(*update_target_expr)

        with tf.control_dependencies([update_target_expr]):
            update_target = tf.constant(0)
        chief_hooks.append(
            tf.train.LoggingTensorHook(
                {'update_target': update_target},
                every_n_iter=update_target_every_n_steps))

    # Save checkpoints periodically, save all of them.
    saver = tf.train.Saver(max_to_keep=None)
    chief_hooks.append(
        tf.train.CheckpointSaverHook(log_dir,
                                     save_steps=save_checkpoint_steps,
                                     saver=saver,
                                     checkpoint_basename='model.ckpt'))

    # Save our experiment params to checkpoint dir.
    chief_hooks.append(
        gin.tf.GinConfigSaverHook(log_dir, summarize_config=True))

    session_config = tf.ConfigProto(log_device_placement=False)

    init_fn = None
    if init_checkpoint:
        assign_fn = contrib_framework.assign_from_checkpoint_fn(
            init_checkpoint, contrib_framework.get_model_variables())
        init_fn = lambda _, sess: assign_fn(sess)
    scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)
    with tf.train.MonitoredTrainingSession(
            master=master,
            is_chief=(task == 0),
            config=session_config,
            checkpoint_dir=log_dir,
            scaffold=scaffold,
            hooks=hooks,
            chief_only_hooks=chief_hooks) as sess:
        np_step = 0
        while not sess.should_stop():
            np_step, _ = sess.run([global_step, train_op])
            if training_steps and np_step % training_steps == 0:
                break
        done = np_step >= max_training_steps
    return np_step, done
Ejemplo n.º 23
0
 def filter_fn(e):
   return tf.math.reduce_all(
       tf.math.not_equal(e['tfds_id'], tf.constant(valid_tfds_ids_np)))
def _box_classification_using_center_distance_loss_unbatched(
        inputs_1, outputs_1, is_intermediate, is_balanced,
        max_positive_normalized_distance):
    """Loss function for input and outputs of batch size 1."""
    inputs_1, outputs_1 = _get_voxels_valid_inputs_outputs(inputs_1=inputs_1,
                                                           outputs_1=outputs_1)
    if is_intermediate:
        output_object_centers = outputs_1[standard_fields.DetectionResultFields
                                          .intermediate_object_center_voxels]
        output_object_length = outputs_1[standard_fields.DetectionResultFields.
                                         intermediate_object_length_voxels]
        output_object_height = outputs_1[standard_fields.DetectionResultFields.
                                         intermediate_object_height_voxels]
        output_object_width = outputs_1[standard_fields.DetectionResultFields.
                                        intermediate_object_width_voxels]
        output_object_rotation_matrix = outputs_1[
            standard_fields.DetectionResultFields.
            intermediate_object_rotation_matrix_voxels]
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        output_object_centers = outputs_1[
            standard_fields.DetectionResultFields.object_center_voxels]
        output_object_length = outputs_1[
            standard_fields.DetectionResultFields.object_length_voxels]
        output_object_height = outputs_1[
            standard_fields.DetectionResultFields.object_height_voxels]
        output_object_width = outputs_1[
            standard_fields.DetectionResultFields.object_width_voxels]
        output_object_rotation_matrix = outputs_1[
            standard_fields.DetectionResultFields.
            object_rotation_matrix_voxels]
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    normalized_center_distance = loss_utils.get_normalized_corner_distances(
        predicted_boxes_center=output_object_centers,
        predicted_boxes_length=output_object_length,
        predicted_boxes_height=output_object_height,
        predicted_boxes_width=output_object_width,
        predicted_boxes_rotation_matrix=output_object_rotation_matrix,
        gt_boxes_center=inputs_1[
            standard_fields.InputDataFields.object_center_voxels],
        gt_boxes_length=inputs_1[
            standard_fields.InputDataFields.object_length_voxels],
        gt_boxes_height=inputs_1[
            standard_fields.InputDataFields.object_height_voxels],
        gt_boxes_width=inputs_1[
            standard_fields.InputDataFields.object_width_voxels],
        gt_boxes_rotation_matrix=inputs_1[
            standard_fields.InputDataFields.object_rotation_matrix_voxels])
    labels = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1])
    instances = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_instance_id_voxels],
        [-1])
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights

    def loss_fn():
        """Loss function."""
        num_classes = logits.get_shape().as_list()[-1]
        if num_classes is None:
            raise ValueError('Number of classes is unknown.')
        labels_one_hot = tf.one_hot(indices=(labels - 1),
                                    depth=(num_classes - 1))
        inverse_distance_coef = tf.maximum(
            tf.minimum(
                1.0 -
                normalized_center_distance / max_positive_normalized_distance,
                1.0), 0.0)
        labels_one_hot = tf.reshape(inverse_distance_coef,
                                    [-1, 1]) * labels_one_hot
        background_label = 1.0 - tf.math.reduce_sum(
            labels_one_hot, axis=1, keepdims=True)
        labels_one_hot = tf.concat([background_label, labels_one_hot], axis=1)
        loss = classification_loss_fn(logits=logits,
                                      labels=labels_one_hot,
                                      **params)
        return loss

    return tf.cond(tf.greater(tf.shape(labels)[0], 0), loss_fn,
                   lambda: tf.constant(0.0, dtype=tf.float32))
Ejemplo n.º 25
0
def npair_loss_func(embeddings,
                    instance_ids,
                    num_samples,
                    valid_mask=None,
                    max_instance_id=None,
                    similarity_strategy='dotproduct',
                    loss_strategy='softmax'):
  """N-pair metric learning loss for learning feature embeddings.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    num_samples: An int determinig the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  batch_size = embeddings.get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('Unknown batch size at graph construction time.')
  if max_instance_id is None:
    max_instance_id = tf.reduce_max(instance_ids)
  sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample(
      features=embeddings,
      instance_ids=instance_ids,
      num_samples=num_samples,
      valid_mask=valid_mask,
      max_instance_id=max_instance_id)
  losses = []
  for i in range(batch_size):
    sampled_instance_ids_i = sampled_instance_ids[i, :]
    sampled_embeddings_i = sampled_embeddings[i, :, :]
    min_ids_i = tf.math.reduce_min(sampled_instance_ids_i)
    max_ids_i = tf.math.reduce_max(sampled_instance_ids_i)
    target_i = tf.one_hot(
        sampled_instance_ids_i,
        depth=(max_instance_id + 1),
        dtype=tf.float32)

    # pylint: disable=cell-var-from-loop
    def npair_loss_i():
      return metric_learning_losses.npair_loss(
          embedding=sampled_embeddings_i,
          target=target_i,
          similarity_strategy=similarity_strategy,
          loss_strategy=loss_strategy)
# pylint: enable=cell-var-from-loop

    loss_i = tf.cond(
        max_ids_i > min_ids_i, npair_loss_i,
        lambda: tf.constant(0.0, dtype=tf.float32))
    losses.append(loss_i)
  return tf.math.reduce_mean(losses)
Ejemplo n.º 26
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Ejemplo n.º 27
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the image, and ignores the label. We assume that the
    image has three channels.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      image_rescaled: the image, resized to `image_size x image_size` and
      rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
      to go beyond this range.
    """
        image_string = tf.parse_single_example(
            example_string,
            features={
                'image': tf.FixedLenFeature([], dtype=tf.string),
                'label': tf.FixedLenFeature([], tf.int64)
            })['image']
        image_decoded = tf.image.decode_image(image_string, channels=3)
        image_decoded.set_shape([None, None, 3])
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image = tf.cast(image_resized, tf.float32)

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_random_brightness:
                delta = self.data_augmentation.random_brightness_delta
                image = tf.image.random_brightness(image, delta)

            if self.data_augmentation.enable_random_saturation:
                delta = self.data_augmentation.random_saturation_delta
                image = tf.image.random_saturation(image, 1 - delta, 1 + delta)

            if self.data_augmentation.enable_random_contrast:
                delta = self.data_augmentation.random_contrast_delta
                image = tf.image.random_contrast(image, 1 - delta, 1 + delta)

            if self.data_augmentation.enable_random_hue:
                delta = self.data_augmentation.random_hue_delta
                image = tf.image.random_hue(image, delta)

            if self.data_augmentation.enable_random_flip:
                image = tf.image.random_flip_left_right(image)

        image = 2 * (image / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(tf.shape(
                    image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(
                    image, [self.image_size, self.image_size, 3])

        return image
def _box_corner_distance_loss(
    loss_type, is_balanced, input_boxes_length, input_boxes_height,
    input_boxes_width, input_boxes_center, input_boxes_rotation_matrix,
    input_boxes_instance_id, output_boxes_length, output_boxes_height,
    output_boxes_width, output_boxes_center, output_boxes_rotation_matrix,
    delta):
  """Computes regression loss on object corner locations."""

  def fn():
    """Loss function for when number of input and output boxes is positive."""
    if is_balanced:
      weights = loss_utils.get_balanced_loss_weights_multiclass(
          labels=input_boxes_instance_id)
    else:
      weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                        dtype=tf.float32)
    normalized_box_size = 5.0
    predicted_boxes_length = output_boxes_length
    predicted_boxes_height = output_boxes_height
    predicted_boxes_width = output_boxes_width
    predicted_boxes_center = output_boxes_center
    predicted_boxes_rotation_matrix = output_boxes_rotation_matrix
    gt_boxes_length = input_boxes_length
    gt_boxes_height = input_boxes_height
    gt_boxes_width = input_boxes_width
    gt_boxes_center = input_boxes_center
    gt_boxes_rotation_matrix = input_boxes_rotation_matrix
    if loss_type in ['normalized_huber', 'normalized_euclidean']:
      predicted_boxes_length /= (gt_boxes_length / normalized_box_size)
      predicted_boxes_height /= (gt_boxes_height / normalized_box_size)
      predicted_boxes_width /= (gt_boxes_width / normalized_box_size)
      gt_boxes_length = tf.ones_like(
          gt_boxes_length, dtype=tf.float32) * normalized_box_size
      gt_boxes_height = tf.ones_like(
          gt_boxes_height, dtype=tf.float32) * normalized_box_size
      gt_boxes_width = tf.ones_like(
          gt_boxes_width, dtype=tf.float32) * normalized_box_size
    gt_box_corners = box_utils.get_box_corners_3d(
        boxes_length=gt_boxes_length,
        boxes_height=gt_boxes_height,
        boxes_width=gt_boxes_width,
        boxes_rotation_matrix=gt_boxes_rotation_matrix,
        boxes_center=gt_boxes_center)
    predicted_box_corners = box_utils.get_box_corners_3d(
        boxes_length=predicted_boxes_length,
        boxes_height=predicted_boxes_height,
        boxes_width=predicted_boxes_width,
        boxes_rotation_matrix=predicted_boxes_rotation_matrix,
        boxes_center=predicted_boxes_center)
    corner_weights = tf.tile(weights, [1, 8])
    if loss_type in ['huber', 'normalized_huber']:
      loss_fn = tf.keras.losses.Huber(
          delta=delta, reduction=tf.keras.losses.Reduction.NONE)
    elif loss_type in ['normalized_absolute_difference', 'absolute_difference']:
      loss_fn = tf.keras.losses.MeanAbsoluteError(
          reduction=tf.keras.losses.Reduction.NONE)
    else:
      raise ValueError(('Unknown loss type %s.' % loss_type))
    box_corner_losses = loss_fn(
        y_true=tf.reshape(gt_box_corners, [-1, 3]),
        y_pred=tf.reshape(predicted_box_corners, [-1, 3]))
    return tf.reduce_mean(box_corner_losses * tf.reshape(corner_weights, [-1]))

  cond_input = tf.greater(tf.shape(input_boxes_length)[0], 0)
  cond_output = tf.greater(tf.shape(output_boxes_length)[0], 0)
  cond = tf.logical_and(cond_input, cond_output)
  return tf.cond(cond, fn, lambda: tf.constant(0.0, dtype=tf.float32))