def get_support_set_softmax(self, logits, class_ids):
    """Softmax normalize over the support set.

    Args:
      logits: [N_k, H*W, Q] dimensional tensor.
      class_ids: [N_k] tensor giving the support-set-id of each image.

    Returns:
      Softmax-ed x over the support set.

    softmax(x) = np.exp(x) / np.reduce_sum(np.exp(x), axis)
    """
    max_logit = tf.reduce_max(logits, axis=1, keepdims=True)
    max_logit = tf.math.unsorted_segment_max(max_logit, class_ids,
                                             tf.reduce_max(class_ids) + 1)
    max_logit = tf.gather(max_logit, class_ids)
    logits_reduc = logits - max_logit

    exp_x = tf.exp(logits_reduc)
    sum_exp_x = tf.reduce_sum(exp_x, axis=1, keepdims=True)
    sum_exp_x = tf.math.unsorted_segment_sum(sum_exp_x, class_ids,
                                             tf.reduce_max(class_ids) + 1)
    log_sum_exp_x = tf.log(sum_exp_x)
    log_sum_exp_x = tf.gather(log_sum_exp_x, class_ids)

    norm_logits = logits_reduc - log_sum_exp_x
    softmax = tf.exp(norm_logits)
    return softmax
Exemple #2
0
def randomly_crop_points(mesh_inputs,
                         view_indices_2d_inputs,
                         x_random_crop_size,
                         y_random_crop_size,
                         epsilon=1e-5):
  """Randomly crops points.

  Args:
    mesh_inputs: A dictionary containing input mesh (point) tensors.
    view_indices_2d_inputs: A dictionary containing input point to view
      correspondence tensors.
    x_random_crop_size: Size of the random crop in x dimension. If None, random
      crop will not take place on x dimension.
    y_random_crop_size: Size of the random crop in y dimension. If None, random
      crop will not take place on y dimension.
    epsilon: Epsilon (a very small value) used to add as a small margin to
      thresholds.
  """
  if x_random_crop_size is None and y_random_crop_size is None:
    return

  points = mesh_inputs[standard_fields.InputDataFields.point_positions]
  num_points = tf.shape(points)[0]
  # Pick a random point
  if x_random_crop_size is not None or y_random_crop_size is not None:
    random_index = tf.random.uniform([],
                                     minval=0,
                                     maxval=num_points,
                                     dtype=tf.int32)
    center_x = points[random_index, 0]
    center_y = points[random_index, 1]

  points_x = points[:, 0]
  points_y = points[:, 1]
  min_x = tf.reduce_min(points_x) - epsilon
  max_x = tf.reduce_max(points_x) + epsilon
  min_y = tf.reduce_min(points_y) - epsilon
  max_y = tf.reduce_max(points_y) + epsilon

  if x_random_crop_size is not None:
    min_x = center_x - x_random_crop_size / 2.0 - epsilon
    max_x = center_x + x_random_crop_size / 2.0 + epsilon

  if y_random_crop_size is not None:
    min_y = center_y - y_random_crop_size / 2.0 - epsilon
    max_y = center_y + y_random_crop_size / 2.0 + epsilon

  x_mask = tf.logical_and(tf.greater(points_x, min_x), tf.less(points_x, max_x))
  y_mask = tf.logical_and(tf.greater(points_y, min_y), tf.less(points_y, max_y))
  points_mask = tf.logical_and(x_mask, y_mask)

  for key in sorted(mesh_inputs):
    mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], points_mask)

  for key in sorted(view_indices_2d_inputs):
    view_indices_2d_inputs[key] = tf.transpose(
        tf.boolean_mask(
            tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]), points_mask),
        [1, 0, 2])
    def _build_train_op(self):
        """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """

        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_qs *
                                        replay_action_one_hot,
                                        reduction_indices=1,
                                        name='replay_chosen_q')

        target = tf.stop_gradient(self._build_target_q_op())
        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)

        update_priorities_op = self._replay.tf_set_priority(
            self._replay.indices, tf.sqrt(loss + 1e-10))

        target_priorities = self._replay.tf_get_priority(self._replay.indices)
        target_priorities = tf.math.add(target_priorities, 1e-10)
        target_priorities = 1.0 / tf.sqrt(target_priorities)
        target_priorities /= tf.reduce_max(target_priorities)

        weighted_loss = target_priorities * loss

        with tf.control_dependencies([update_priorities_op]):
            return self.optimizer.minimize(
                tf.reduce_mean(weighted_loss)), weighted_loss
Exemple #4
0
    def test_estimated_entropy(self, assume_reparametrization):
        logging.info("assume_reparametrization=%s" % assume_reparametrization)
        num_samples = 1000000
        seed_stream = tfp.distributions.SeedStream(
            seed=1, salt='test_estimated_entropy')
        batch_shape = (2, )
        loc = tf.random.normal(shape=batch_shape, seed=seed_stream())
        scale = tf.abs(tf.random.normal(shape=batch_shape, seed=seed_stream()))

        with tf.GradientTape(persistent=True) as tape:
            tape.watch(scale)
            dist = tfp.distributions.Normal(loc=loc, scale=scale)
            analytic_entropy = dist.entropy()
            est_entropy, est_entropy_for_gradient = dist_utils.estimated_entropy(
                dist=dist,
                seed=seed_stream(),
                assume_reparametrization=assume_reparametrization,
                num_samples=num_samples)

        analytic_grad = tape.gradient(analytic_entropy, scale)
        est_grad = tape.gradient(est_entropy_for_gradient, scale)
        logging.info("scale=%s" % scale)
        logging.info("analytic_entropy=%s" % analytic_entropy)
        logging.info("estimated_entropy=%s" % est_entropy)
        self.assertArrayAlmostEqual(analytic_entropy, est_entropy, 5e-2)

        logging.info("analytic_entropy_grad=%s" % analytic_grad)
        logging.info("estimated_entropy_grad=%s" % est_grad)
        self.assertArrayAlmostEqual(analytic_grad, est_grad, 5e-2)
        if not assume_reparametrization:
            est_grad_wrong = tape.gradient(est_entropy, scale)
            logging.info("estimated_entropy_grad_wrong=%s", est_grad_wrong)
            self.assertLess(tf.reduce_max(tf.abs(est_grad_wrong)), 5e-2)
Exemple #5
0
  def _build_train_op(self):
    """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """
    target_distribution = tf.stop_gradient(self._build_target_distribution())

    # size of indices: batch_size x 1.
    indices = tf.range(tf.shape(self._replay_logits)[0])[:, None]
    # size of reshaped_actions: batch_size x 2.
    reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
    # For each element of the batch, fetch the logits for its selected action.
    chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions)

    loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=target_distribution,
        logits=chosen_action_logits)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=self.learning_rate,
        epsilon=self.optimizer_epsilon)

    update_priorities_op = self._replay.tf_set_priority(
        self._replay.indices, tf.sqrt(loss + 1e-10))

    target_priorities = self._replay.tf_get_priority(self._replay.indices)
    target_priorities = tf.math.add(target_priorities, 1e-10)
    target_priorities = 1.0 / tf.sqrt(target_priorities)
    target_priorities /= tf.reduce_max(target_priorities)

    weighted_loss = target_priorities * loss

    with tf.control_dependencies([update_priorities_op]):
      return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values,
                             next_states, terminals):
    """Builds an op used as a target for the Q-value.

  This algorithm corresponds to the method "OT" in
  Ie et al. https://arxiv.org/abs/1905.12767..

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Obtain all possible slates given current docs in the candidate set.
    slate_size = next_actions.get_shape().as_list()[1]
    num_candidates = next_q_values.get_shape().as_list()[1]
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))
    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    # [num_of_slates, slate_size]
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    # [batch_size, num_of_slates, slate_size]
    next_q_values_slate = tf.gather(next_q_values, slates, axis=1)
    # [batch_size, num_of_slates, slate_size]
    scores_slate = tf.gather(scores, slates, axis=1)
    # [batch_size, num_of_slates]
    batch_size = next_states.get_shape().as_list()[0]
    score_no_click_slate = tf.reshape(
        tf.tile(score_no_click,
                tf.shape(input=slates)[:1]), [batch_size, -1])

    # [batch_size, num_of_slates]
    next_q_target_slate = tf.reduce_sum(
        input_tensor=next_q_values_slate * scores_slate,
        axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) +
                   score_no_click_slate)

    next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)

    return reward + gamma * next_q_target_max * (
        1. - tf.cast(terminals, tf.float32))
Exemple #7
0
def _variable_summaries(var):
    """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
    with tf.name_scope('summaries'):
        mean = tf.reduce_mean(var)
        tf.summary.scalar('mean', mean)
        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        tf.summary.scalar('stddev', stddev)
        tf.summary.scalar('max', tf.reduce_max(var))
        tf.summary.scalar('min', tf.reduce_min(var))
        tf.summary.histogram('histogram', var)
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
        replay_next_target_value = tf.reduce_max(
            self._replay_next_target_net_outputs.q_values, 1)
        replay_target_value = tf.reduce_max(
            self._replay_target_net_outputs.q_values, 1)

        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_net_outputs.q_values *
                                        replay_action_one_hot,
                                        axis=1,
                                        name='replay_chosen_q')
        replay_target_chosen_q = tf.reduce_sum(
            self._replay_target_net_outputs.q_values * replay_action_one_hot,
            axis=1,
            name='replay_chosen_q')

        augmented_rewards = self._replay.rewards - self.alpha * (
            replay_target_value - replay_target_chosen_q)

        target = (augmented_rewards +
                  self.cumulative_gamma * replay_next_target_value *
                  (1. - tf.cast(self._replay.terminals, tf.float32)))
        target = tf.stop_gradient(target)

        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)
        if self.summary_writer is not None:
            with tf.variable_scope('Losses'):
                tf.summary.scalar('HuberLoss', tf.reduce_mean(loss))
        return self.optimizer.minimize(tf.reduce_mean(loss))
def summarize_stats(stats):
  """Summarize a dictionary of variables.

  Args:
    stats: a dictionary of {name: tensor} to compute stats over.
  """
  for name, stat in stats.items():
    mean = tf.reduce_mean(stat)
    tf.summary.scalar('mean_%s' % name, mean)
    tf.summary.scalar('max_%s' % name, tf.reduce_max(stat))
    tf.summary.scalar('min_%s' % name, tf.reduce_min(stat))
    std = tf.sqrt(tf.reduce_mean(tf.square(stat)) - tf.square(mean) + 1e-10)
    tf.summary.scalar('std_%s' % name, std)
    tf.summary.histogram(name, stat)
  def _get_dist(self, query_queries, query_values, support_keys, support_values,
                labels):
    """Get distances between queries and query-aligned prototypes."""
    # attended_values: [N_support, n_query, h_query, w_query, C]
    attended_values = self._attend(query_queries, support_keys, support_values,
                                   labels)
    # query_aligned_prototypes: [N_classes, n_query, h_query, w_query, C]
    query_aligned_prototypes = tf.math.unsorted_segment_sum(
        attended_values, labels,
        tf.reduce_max(labels) + 1)

    # (scaled) Euclidean distance
    shp = tf.shape(query_values)
    aligned_dist = tf.square(query_values[tf.newaxis, Ellipsis] -
                             query_aligned_prototypes)

    return tf.reduce_sum(aligned_dist, [2, 3, 4]) / tf.cast(
        shp[-3] * shp[-2], aligned_dist.dtype)
    def _build_target_q_op(self):
        """Build an op to be used as a target for the Q-value.

    Returns:
      target_q_op: An op calculating the target Q-value.
    """
        # Get the max q_value across the actions dimension.
        replay_next_qt_max = tf.reduce_max(
            self._replay_next_qt + self._replay.next_legal_actions, 1)
        # Calculate the sample Bellman update.
        #   Q_t = R_t + \gamma^N * Q'_t+1
        # where,
        #   Q'_t+1 is \argmax_a Q(S_t+1, a)
        #          (or) 0 if S_t is a terminal state,
        # and
        #   N is the update horizon (by default, N=1).
        return self._replay.rewards + self.cumulative_gamma * replay_next_qt_max * (
            1. - tf.cast(self._replay.terminals, tf.float32))
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate,
                                       is_balanced, mine_hard_negatives,
                                       hard_negative_score_threshold):
    """Loss function for input and outputs of batch size 1."""
    valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    num_classes = logits.get_shape().as_list()[-1]
    if num_classes is None:
        raise ValueError('Number of classes is unknown.')
    logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask)
    labels = tf.boolean_mask(
        tf.reshape(
            inputs_1[standard_fields.InputDataFields.object_class_voxels],
            [-1, 1]), valid_mask)
    if mine_hard_negatives or is_balanced:
        instances = tf.boolean_mask(
            tf.reshape(
                inputs_1[
                    standard_fields.InputDataFields.object_instance_id_voxels],
                [-1]), valid_mask)
    params = {}
    if mine_hard_negatives:
        negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1])
        hard_negative_mask = tf.logical_and(
            tf.less(negative_scores, hard_negative_score_threshold),
            tf.equal(tf.reshape(labels, [-1]), 0))
        hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask)
        hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask)
        hard_negative_instances = tf.boolean_mask(
            tf.ones_like(instances) * (tf.reduce_max(instances) + 1),
            hard_negative_mask)
        logits = tf.concat([logits, hard_negative_logits], axis=0)
        instances = tf.concat([instances, hard_negative_instances], axis=0)
        labels = tf.concat([labels, hard_negative_labels], axis=0)
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=logits, labels=labels, **params)
    def fwd_fn(query_queries_fwd, query_values_fwd, support_keys_fwd,
               support_values_fwd, labels_fwd):
      """CrossTransformer forward, using a while loop to save memory."""
      initial = (0,
                 tf.zeros([tf.reduce_max(labels) + 1, zero_dim],
                          dtype=query_queries_fwd.dtype))

      def loop_body(idx, dist):
        dist_new = self._get_dist(query_queries_fwd[idx:idx + 1],
                                  query_values_fwd[idx:idx + 1],
                                  support_keys_fwd, support_values_fwd,
                                  labels_fwd)
        dist = tf.concat([dist, dist_new], axis=1)
        return (idx + 1, dist)

      _, res = tf.while_loop(
          lambda x, _: x < tf.shape(query_queries_fwd)[0],
          loop_body,
          initial,
          parallel_iterations=1)
      return res
Exemple #14
0
 def assertArrayAlmostEqual(self, x, y, eps):
     self.assertLess(tf.reduce_max(tf.abs(x - y)), eps)
def classification_loss_using_mask_iou_func(embeddings,
                                            logits,
                                            instance_ids,
                                            class_labels,
                                            num_samples,
                                            valid_mask=None,
                                            max_instance_id=None,
                                            similarity_strategy='dotproduct',
                                            is_balanced=True):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is
      assumed that background is class 0.
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed
      that the background voxels are assigned to class 0.
    num_samples: An int determining the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 scalar loss tensor.
  """
    batch_size = embeddings.get_shape().as_list()[0]
    if batch_size is None:
        raise ValueError('Unknown batch size at graph construction time.')
    if max_instance_id is None:
        max_instance_id = tf.reduce_max(instance_ids)
    class_labels = tf.reshape(class_labels, [batch_size, -1, 1])
    sampled_embeddings, sampled_instance_ids, sampled_indices = (
        sampling_utils.balanced_sample(features=embeddings,
                                       instance_ids=instance_ids,
                                       num_samples=num_samples,
                                       valid_mask=valid_mask,
                                       max_instance_id=max_instance_id))
    losses = []
    for i in range(batch_size):
        embeddings_i = embeddings[i, :, :]
        instance_ids_i = instance_ids[i, :]
        class_labels_i = class_labels[i, :, :]
        logits_i = logits[i, :]
        sampled_embeddings_i = sampled_embeddings[i, :, :]
        sampled_instance_ids_i = sampled_instance_ids[i, :]
        sampled_indices_i = sampled_indices[i, :]
        sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i)
        sampled_logits_i = tf.gather(logits_i, sampled_indices_i)
        if valid_mask is not None:
            valid_mask_i = valid_mask[i]
            embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i)
            instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i)
        loss_i = classification_loss_using_mask_iou_func_unbatched(
            embeddings=embeddings_i,
            instance_ids=instance_ids_i,
            sampled_embeddings=sampled_embeddings_i,
            sampled_instance_ids=sampled_instance_ids_i,
            sampled_class_labels=sampled_class_labels_i,
            sampled_logits=sampled_logits_i,
            similarity_strategy=similarity_strategy,
            is_balanced=is_balanced)
        losses.append(loss_i)
    return tf.math.reduce_mean(tf.stack(losses))
def train(hparams, num_epoch, tuning):

    log_dir = './results/'
    test_batch_size = 8
    # Load dataset
    training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'],
                                           file_name='train_tf_record',
                                           split=True)
    test_set = make_dataset(BATCH_SIZE=test_batch_size,
                            file_name='test_tf_record',
                            split=False)
    class_names = ['NRDR', 'RDR']

    # Model
    model = ResNet()

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR'])
    # set metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.Accuracy()
    valid_con_mat = ConfusionMatrix(num_class=2)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=2)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time + '/train'
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    @tf.function
    def train_step(train_img, train_label):
        # Optimize the model
        loss_value, grads = grad(model, train_img, train_label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred, _ = model(train_img)
        train_label = tf.expand_dims(train_label, axis=1)
        train_accuracy.update_state(train_label, train_pred)

    for epoch in range(num_epoch):

        begin = time()

        # Training loop
        for train_img, train_label, train_name in training_set:
            train_img = data_augmentation(train_img)
            train_step(train_img, train_label)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)

        for valid_img, valid_label, _ in valid_set:
            valid_img = tf.cast(valid_img, tf.float32)
            valid_img = valid_img / 255.0
            valid_pred, _ = model(valid_img, training=False)
            valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64)
            valid_con_mat.update_state(valid_label, valid_pred)
            valid_accuracy.update_state(valid_label, valid_pred)

        # Log the confusion matrix as an image summary
        cm_valid = valid_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Valid Accuracy',
                              valid_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)

        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(),
                    valid_accuracy.result(), (end - begin)))
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_con_mat.reset_states()
        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    for test_img, test_label, _ in test_set:
        test_img = tf.cast(test_img, tf.float32)
        test_img = test_img / 255.0
        test_pred, _ = model(test_img, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64)
        test_accuracy.update_state(test_label, test_pred)
        test_con_mat.update_state(test_label, test_pred)

    cm_test = test_con_mat.result()
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)
    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    # Visualization
    if not tuning:
        for vis_img, vis_label, vis_name in test_set:
            vis_label = vis_label[0]
            vis_name = vis_name[0]
            vis_img = tf.cast(vis_img[0], tf.float32)
            vis_img = tf.expand_dims(vis_img, axis=0)
            vis_img = vis_img / 255.0
            with tf.GradientTape() as tape:
                vis_pred, conv_output = model(vis_img, training=False)
                pred_label = tf.argmax(vis_pred, axis=-1)
                vis_pred = tf.reduce_max(vis_pred, axis=-1)
                grad_1 = tape.gradient(vis_pred, conv_output)
                weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1]
                act_map0 = tf.nn.relu(
                    tf.reduce_sum(weight * conv_output, axis=-1))
                act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0,
                                                                     axis=-1),
                                                      (256, 256),
                                                      antialias=True),
                                      axis=-1)
                plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label,
                         vis_name)
            break

    return test_accuracy.result()
Exemple #17
0
def npair_loss_func(embeddings,
                    instance_ids,
                    num_samples,
                    valid_mask=None,
                    max_instance_id=None,
                    similarity_strategy='dotproduct',
                    loss_strategy='softmax'):
  """N-pair metric learning loss for learning feature embeddings.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    num_samples: An int determinig the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  batch_size = embeddings.get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('Unknown batch size at graph construction time.')
  if max_instance_id is None:
    max_instance_id = tf.reduce_max(instance_ids)
  sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample(
      features=embeddings,
      instance_ids=instance_ids,
      num_samples=num_samples,
      valid_mask=valid_mask,
      max_instance_id=max_instance_id)
  losses = []
  for i in range(batch_size):
    sampled_instance_ids_i = sampled_instance_ids[i, :]
    sampled_embeddings_i = sampled_embeddings[i, :, :]
    min_ids_i = tf.math.reduce_min(sampled_instance_ids_i)
    max_ids_i = tf.math.reduce_max(sampled_instance_ids_i)
    target_i = tf.one_hot(
        sampled_instance_ids_i,
        depth=(max_instance_id + 1),
        dtype=tf.float32)

    # pylint: disable=cell-var-from-loop
    def npair_loss_i():
      return metric_learning_losses.npair_loss(
          embedding=sampled_embeddings_i,
          target=target_i,
          similarity_strategy=similarity_strategy,
          loss_strategy=loss_strategy)
# pylint: enable=cell-var-from-loop

    loss_i = tf.cond(
        max_ids_i > min_ids_i, npair_loss_i,
        lambda: tf.constant(0.0, dtype=tf.float32))
    losses.append(loss_i)
  return tf.math.reduce_mean(losses)