Exemple #1
0
 def build_task_parameters(self):
   """Assign to attributes the meta parameters."""
   self.locs = [
       tf.Variable(tf.zeros((self.num_dims)), name='loc_{}'.format(i))
       for i in range(self.num_components)
   ]
   self.log_scales = [
       tf.Variable(tf.zeros((self.num_dims)), name='log_scale_{}'.format(i))
       for i in range(self.num_components)
   ]
Exemple #2
0
 def build_meta_parameters(self):
   """Assign to attributes the task parameters."""
   self.meta_loc = tf.Variable(
       self.loc_initializer([self.num_modes, self.num_dims]),
       trainable=self.trainable_loc,
       name='meta_loc')
   self.meta_log_scale = tf.Variable(
       self.log_scale_initializer([self.num_modes, self.num_dims]),
       trainable=self.trainable_scale,
       name='meta_log_scale')
   self.meta_logits = tf.Variable(
       self.logits_initializer([self.num_modes]),
       trainable=self.trainable_logits,
       name='meta_logits')
Exemple #3
0
def em_loop(
    num_updates,
    e_step,
    m_step,
    variables,
):
    """Expectation-maximization of objective_fn wrt variables for num_updates."""
    def _body(step, preupdate_vars):
        train_predictions_, responsibilities_ = e_step(preupdate_vars)
        updated_vars = m_step(preupdate_vars, train_predictions_,
                              responsibilities_)
        return step + 1, updated_vars

    def _cond(step, *args):
        del args
        return step < num_updates

    step = tf.Variable(0, trainable=False, name='inner_step_counter')
    loop_vars = (step, variables)
    step, updated_vars = tf.while_loop(cond=_cond,
                                       body=_body,
                                       loop_vars=loop_vars,
                                       swap_memory=True)

    return updated_vars
 def __init__(self, ckpt_dir, save_epoch_freq=1, max_to_keep=3):
     self._ckpt_saved_epoch = tf.Variable(initial_value=tf.constant(
         -1, dtype=tf.dtypes.int64),
                                          name='ckpt_saved_epoch')
     self.ckpt_dir = ckpt_dir
     self.max_to_keep = max_to_keep
     self.save_epoch_freq = save_epoch_freq
Exemple #5
0
    def _construct_variables():
      """Construct an initialization for task parameters."""

      def _split_mode_params(params):
        return [
            tf.squeeze(p) for p in tf.split(
                params, axis=0, num_or_size_splits=self.num_modes)
        ]

      locs = _split_mode_params(tf.zeros_like(self.meta_loc))
      log_scales = _split_mode_params(tf.zeros_like(self.meta_log_scale))
      logits = tf.zeros_like(self.meta_logits)

      return (
          [tf.Variable(loc, 'loc') for loc in locs],
          [tf.Variable(log_scale, 'log_scale') for log_scale in log_scales],
          tf.Variable(logits, 'logits'),
      )
Exemple #6
0
 def _build_sync_op(self):
   """Build the sync op."""
   sync_count = tf.Variable(0, trainable=False)
   sync_ops = [tf.assign_add(sync_count, 1)]
   trainables_online = tf.get_collection(
       tf.GraphKeys.TRAINABLE_VARIABLES, scope='Online')
   trainables_target = tf.get_collection(
       tf.GraphKeys.TRAINABLE_VARIABLES, scope='Target')
   for (w_online, w_target) in zip(trainables_online, trainables_target):
     sync_ops.append(w_target.assign(w_online, use_locking=True))
   tf.summary.scalar('Learning/SyncCount', sync_count)
   return sync_ops
Exemple #7
0
def get_fc_vars_copy_ops(fc_weights, fc_bias, make_copies):
  """Gets copies of the classifier layer variables or returns those variables.

  At meta-test time, a copy is created for the given Variables, and these copies
  copies will be used in place of the original ones.

  Args:
    fc_weights: A Variable for the weights of the fc layer.
    fc_bias: A Variable for the bias of the fc layer.
    make_copies: A bool. Whether to copy the given variables. If not, those
      variables themselves are returned.

  Returns:
    fc_weights: A Variable for the weights of the fc layer. Might be the same as
      the input fc_weights or a copy of it.
    fc_bias: Analogously, a Variable for the bias of the fc layer.
    fc_vars_copy_ops: A (possibly empty) list of operations for assigning the
      value of each of fc_weights and fc_bias to a respective copy variable.
  """
  fc_vars_copy_ops = []
  if make_copies:
    with tf.variable_scope('weight_copy'):
      # fc_weights copy
      fc_weights_copy = tf.Variable(
          tf.zeros(fc_weights.shape.as_list()),
          collections=[tf.GraphKeys.LOCAL_VARIABLES])
      fc_weights_copy_op = tf.assign(fc_weights_copy, fc_weights)
      fc_vars_copy_ops.append(fc_weights_copy_op)

      # fc_bias copy
      fc_bias_copy = tf.Variable(
          tf.zeros(fc_bias.shape.as_list()),
          collections=[tf.GraphKeys.LOCAL_VARIABLES])
      fc_bias_copy_op = tf.assign(fc_bias_copy, fc_bias)
      fc_vars_copy_ops.append(fc_bias_copy_op)

      fc_weights = fc_weights_copy
      fc_bias = fc_bias_copy
  return fc_weights, fc_bias, fc_vars_copy_ops
Exemple #8
0
 def build(self, *args, **kwargs):
     self._global_step = tf.Variable(initial_value=0,
                                     dtype=tf.int32,
                                     name="global_step",
                                     trainable=False)
     self._ph_op()
     self._graph_op(*args, **kwargs)
     self._predict_op()
     self._vars = tf.trainable_variables()
     self._loss_op()
     self._train_op()
     self._summary_op()
     self._built = True
     tf.logging.log(logging.INFO,
                    "Built model with scope {}".format(self._scope))
 def __init__(self, embeddings_config: EmbeddingsConfig, model_config: TransformerSoftmaxModelConfig = gin.REQUIRED):
     super().__init__(embeddings_config)
     self.model_config = model_config
     self.pre_normalization_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6)
     self.pre_dropout_layer = tf.keras.layers.Dropout(rate=self.model_config.pre_dropout_rate)
     self.transformer_layer = StackedTransformerEncodersLayer()
     self.post_hidden_layer = tf.keras.layers.Dense(
         units=self.embeddings_layer.config.embeddings_dimension,
         kernel_initializer=parameters_factory.get_parameters_initializer(),
     )
     self.post_normalization_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6)
     self.projection_bias = tf.Variable(
         initial_value=tf.zeros_initializer()(shape=(tf.shape(self.get_similarity_matrix())[0], )),
         trainable=True,
     )
Exemple #10
0
def optimizer_loop(
    num_updates,
    objective_fn,
    update_fn,
    variables,
    first_order,
    clip_grad_norm,
):
    """Optimization of `objective_fn` for `num_updates` of `variables`."""

    # Optimizer specifics.
    init, update, get_params = update_fn()

    def _body(step, preupdate_vars):
        """Optimization loop body."""
        updated_vars = optimizer_update(
            iterate_collection=preupdate_vars,
            iteration_idx=step,
            objective_fn=objective_fn,
            update_fn=update,
            get_params_fn=get_params,
            first_order=first_order,
            clip_grad_norm=clip_grad_norm,
        )

        return step + 1, updated_vars

    def _cond(step, *args):
        """Optimization truncation condition."""
        del args
        return step < num_updates

    step = tf.Variable(0, trainable=False, name='inner_step_counter')
    loop_vars = (step, [init(var) for var in variables])
    step, updated_vars = tf.while_loop(cond=_cond,
                                       body=_body,
                                       loop_vars=loop_vars,
                                       swap_memory=True)

    return [get_params(v) for v in updated_vars]
Exemple #11
0
def get_embeddings_vars_copy_ops(embedding_vars_dict, make_copies):
    """Gets copies of the embedding variables or returns those variables.

  This is useful at meta-test time for MAML and the finetuning baseline. In
  particular, at meta-test time, we don't want to make permanent updates to
  the model's variables, but only modifications that persist in the given
  episode. This can be achieved by creating copies of each variable and
  modifying and using these copies instead of the variables themselves.

  Args:
    embedding_vars_dict: A dict mapping each variable name to the corresponding
      Variable.
    make_copies: A bool. Whether to copy the given variables. If not, those
      variables themselves will be returned. Typically, this is True at meta-
      test time and False at meta-training time.

  Returns:
    embedding_vars_keys: A list of variable names.
    embeddings_vars: A corresponding list of Variables.
    embedding_vars_copy_ops: A (possibly empty) list of operations, each of
      which assigns the value of one of the provided Variables to a new
      Variable which is its copy.
  """
    embedding_vars_keys = []
    embedding_vars = []
    embedding_vars_copy_ops = []
    for name, var in six.iteritems(embedding_vars_dict):
        embedding_vars_keys.append(name)
        if make_copies:
            with tf.variable_scope('weight_copy'):
                shape = var.shape.as_list()
                var_copy = tf.Variable(
                    tf.zeros(shape),
                    collections=[tf.GraphKeys.LOCAL_VARIABLES])
                var_copy_op = tf.assign(var_copy, var)
                embedding_vars_copy_ops.append(var_copy_op)
            embedding_vars.append(var_copy)
        else:
            embedding_vars.append(var)
    return embedding_vars_keys, embedding_vars, embedding_vars_copy_ops
def train(hparams, num_epoch, tuning):

    log_dir = './results/'
    test_batch_size = 8
    # Load dataset
    training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'],
                                           file_name='train_tf_record',
                                           split=True)
    test_set = make_dataset(BATCH_SIZE=test_batch_size,
                            file_name='test_tf_record',
                            split=False)
    class_names = ['NRDR', 'RDR']

    # Model
    model = ResNet()

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR'])
    # set metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.Accuracy()
    valid_con_mat = ConfusionMatrix(num_class=2)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=2)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time + '/train'
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    @tf.function
    def train_step(train_img, train_label):
        # Optimize the model
        loss_value, grads = grad(model, train_img, train_label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred, _ = model(train_img)
        train_label = tf.expand_dims(train_label, axis=1)
        train_accuracy.update_state(train_label, train_pred)

    for epoch in range(num_epoch):

        begin = time()

        # Training loop
        for train_img, train_label, train_name in training_set:
            train_img = data_augmentation(train_img)
            train_step(train_img, train_label)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)

        for valid_img, valid_label, _ in valid_set:
            valid_img = tf.cast(valid_img, tf.float32)
            valid_img = valid_img / 255.0
            valid_pred, _ = model(valid_img, training=False)
            valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64)
            valid_con_mat.update_state(valid_label, valid_pred)
            valid_accuracy.update_state(valid_label, valid_pred)

        # Log the confusion matrix as an image summary
        cm_valid = valid_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Valid Accuracy',
                              valid_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)

        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(),
                    valid_accuracy.result(), (end - begin)))
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_con_mat.reset_states()
        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    for test_img, test_label, _ in test_set:
        test_img = tf.cast(test_img, tf.float32)
        test_img = test_img / 255.0
        test_pred, _ = model(test_img, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64)
        test_accuracy.update_state(test_label, test_pred)
        test_con_mat.update_state(test_label, test_pred)

    cm_test = test_con_mat.result()
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)
    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    # Visualization
    if not tuning:
        for vis_img, vis_label, vis_name in test_set:
            vis_label = vis_label[0]
            vis_name = vis_name[0]
            vis_img = tf.cast(vis_img[0], tf.float32)
            vis_img = tf.expand_dims(vis_img, axis=0)
            vis_img = vis_img / 255.0
            with tf.GradientTape() as tape:
                vis_pred, conv_output = model(vis_img, training=False)
                pred_label = tf.argmax(vis_pred, axis=-1)
                vis_pred = tf.reduce_max(vis_pred, axis=-1)
                grad_1 = tape.gradient(vis_pred, conv_output)
                weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1]
                act_map0 = tf.nn.relu(
                    tf.reduce_sum(weight * conv_output, axis=-1))
                act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0,
                                                                     axis=-1),
                                                      (256, 256),
                                                      antialias=True),
                                      axis=-1)
                plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label,
                         vis_name)
            break

    return test_accuracy.result()
Exemple #13
0
  def learn_metric(self, verbose=False):
    """Approximate the bisimulation metric by learning.

    Args:
      verbose: bool, whether to print verbose messages.
    """
    summary_writer = tf.summary.FileWriter(self.base_dir)
    global_step = tf.Variable(0, trainable=False)
    inc_global_step_op = tf.assign_add(global_step, 1)
    bisim_horizon = 0.0
    bisim_horizon_discount_value = 1.0
    if self.use_decayed_learning_rate:
      learning_rate = tf.train.exponential_decay(self.starting_learning_rate,
                                                 global_step,
                                                 self.num_iterations,
                                                 self.learning_rate_decay,
                                                 staircase=self.staircase)
    else:
      learning_rate = self.starting_learning_rate
    tf.summary.scalar('Learning/LearningRate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                       epsilon=self.epsilon)
    train_op = self._build_train_op(optimizer)
    sync_op = self._build_sync_op()
    eval_op = tf.stop_gradient(self._build_eval_metric())
    eval_states = []
    # Build the evaluation tensor.
    for state in range(self.num_states):
      row, col = self.inverse_index_states[state]
      # We make the evaluation states at the center of each grid cell.
      eval_states.append([row + 0.5, col + 0.5])
    eval_states = np.array(eval_states, dtype=np.float64)
    normalized_bisim_metric = (
        self.bisim_metric / np.linalg.norm(self.bisim_metric))
    metric_errors = []
    average_metric_errors = []
    normalized_metric_errors = []
    average_normalized_metric_errors = []
    saver = tf.train.Saver(max_to_keep=3)
    with tf.Session() as sess:
      summary_writer.add_graph(graph=tf.get_default_graph())
      sess.run(tf.global_variables_initializer())
      merged_summaries = tf.summary.merge_all()
      for i in range(self.num_iterations):
        sampled_states = np.random.randint(self.num_states,
                                           size=(self.batch_size,))
        sampled_actions = np.random.randint(4,
                                            size=(self.batch_size,))
        if self.add_noise:
          sampled_noise = np.clip(
              np.random.normal(0, 0.1, size=(self.batch_size, 2)),
              -0.3, 0.3)
        sampled_action_names = [self.actions[x] for x in sampled_actions]
        next_states = [self.next_states[a][s]
                       for s, a in zip(sampled_states, sampled_action_names)]
        rewards = np.array([self.rewards[a][s]
                            for s, a in zip(sampled_states,
                                            sampled_action_names)])
        states = np.array(
            [self.inverse_index_states[x] for x in sampled_states])
        next_states = np.array([self.inverse_index_states[x]
                                for x in next_states])
        states = states.astype(np.float64)
        states += 0.5  # Place points in center of grid.
        next_states = next_states.astype(np.float64)
        next_states += 0.5
        if self.add_noise:
          states += sampled_noise
          next_states += sampled_noise

        _, summary = sess.run(
            [train_op, merged_summaries],
            feed_dict={self.s1_ph: states,
                       self.s2_ph: next_states,
                       self.action_ph: sampled_actions,
                       self.rewards_ph: rewards,
                       self.bisim_horizon_ph: bisim_horizon,
                       self.eval_states_ph: eval_states})
        summary_writer.add_summary(summary, i)
        if self.double_period_halfway and i > self.num_iterations / 2.:
          self.target_update_period *= 2
          self.double_period_halfway = False
        if i % self.target_update_period == 0:
          bisim_horizon = 1.0 - bisim_horizon_discount_value
          bisim_horizon_discount_value *= self.bisim_horizon_discount
          sess.run(sync_op)
        # Now compute difference with exact metric.
        self.learned_distance = sess.run(
            eval_op, feed_dict={self.eval_states_ph: eval_states})
        self.learned_distance = np.reshape(self.learned_distance,
                                           (self.num_states, self.num_states))
        metric_difference = np.max(
            abs(self.learned_distance - self.bisim_metric))
        average_metric_difference = np.mean(
            abs(self.learned_distance - self.bisim_metric))
        normalized_learned_distance = (
            self.learned_distance / np.linalg.norm(self.learned_distance))
        normalized_metric_difference = np.max(
            abs(normalized_learned_distance - normalized_bisim_metric))
        average_normalized_metric_difference = np.mean(
            abs(normalized_learned_distance - normalized_bisim_metric))
        error_summary = tf.Summary(value=[
            tf.Summary.Value(tag='Approx/Error',
                             simple_value=metric_difference),
            tf.Summary.Value(tag='Approx/AvgError',
                             simple_value=average_metric_difference),
            tf.Summary.Value(tag='Approx/NormalizedError',
                             simple_value=normalized_metric_difference),
            tf.Summary.Value(tag='Approx/AvgNormalizedError',
                             simple_value=average_normalized_metric_difference),
        ])
        summary_writer.add_summary(error_summary, i)
        sess.run(inc_global_step_op)
        if i % 100 == 0:
          # Collect statistics every 100 steps.
          metric_errors.append(metric_difference)
          average_metric_errors.append(average_metric_difference)
          normalized_metric_errors.append(normalized_metric_difference)
          average_normalized_metric_errors.append(
              average_normalized_metric_difference)
          saver.save(sess, os.path.join(self.base_dir, 'tf_ckpt'),
                     global_step=i)
        if self.debug and i % 100 == 0:
          self.pretty_print_metric(metric_type='learned')
          print('Iteration: {}'.format(i))
          print('Metric difference: {}'.format(metric_difference))
          print('Normalized metric difference: {}'.format(
              normalized_metric_difference))
      if self.add_noise:
        # Finally, if we have noise, we draw a bunch of samples to get estimates
        # of the distances between states.
        sampled_distances = {}
        for _ in range(self.total_final_samples):
          eval_states = []
          for state in range(self.num_states):
            row, col = self.inverse_index_states[state]
            # We make the evaluation states at the center of each grid cell.
            eval_states.append([row + 0.5, col + 0.5])
          eval_states = np.array(eval_states, dtype=np.float64)
          eval_noise = np.clip(
              np.random.normal(0, 0.1, size=(self.num_states, 2)),
              -0.3, 0.3)
          eval_states += eval_noise
          distance_samples = sess.run(
              eval_op, feed_dict={self.eval_states_ph: eval_states})
          distance_samples = np.reshape(distance_samples,
                                        (self.num_states, self.num_states))
          for s1 in range(self.num_states):
            for s2 in range(self.num_states):
              sampled_distances[(tuple(eval_states[s1]),
                                 tuple(eval_states[s2]))] = (
                                     distance_samples[s1, s2])
      else:
        # Otherwise we just use the last evaluation metric.
        sampled_distances = self.learned_distance
    learned_statistics = {
        'num_iterations': self.num_iterations,
        'metric_errors': metric_errors,
        'average_metric_errors': average_metric_errors,
        'normalized_metric_errors': normalized_metric_errors,
        'average_normalized_metric_errors': average_normalized_metric_errors,
        'learned_distances': sampled_distances,
    }
    self.statistics['learned'] = learned_statistics
    if verbose:
      self.pretty_print_metric(metric_type='learned')
Exemple #14
0
  def forward_pass(self, data):
    """Computes the test logits of MAML.

    Args:
      data: A `meta_dataset.providers.Episode` containing the data for the
        episode.

    Returns:
      The output logits for the query data in this episode.
    """
    # Have to use one-hot labels since sparse softmax doesn't allow
    # second derivatives.
    support_embeddings_ = self.embedding_fn(
        data.support_images, self.is_training, reuse=tf.AUTO_REUSE)
    support_embeddings = support_embeddings_['embeddings']
    embedding_vars_dict = support_embeddings_['params']

    # TODO(eringrant): Refactor to make use of
    # `functional_backbones.linear_classifier`, which allows Gin-configuration.
    with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE):
      embedding_depth = support_embeddings.shape.as_list()[-1]
      fc_weights = functional_backbones.weight_variable(
          [embedding_depth, self.logit_dim],
          weight_decay=self.classifier_weight_decay)
      fc_bias = functional_backbones.bias_variable([self.logit_dim])

    # A list of variable names, a list of corresponding Variables, and a list
    # of operations (possibly empty) that creates a copy of each Variable.
    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         embedding_vars_dict, make_copies=not self.is_training)

    # A Variable for the weights of the fc layer, a Variable for the bias of the
    # fc layer, and a list of operations (possibly empty) that copies them.
    (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops(
        fc_weights, fc_bias, make_copies=not self.is_training)

    fc_vars = [fc_weights, fc_bias]
    num_embedding_vars = len(embedding_vars)
    num_fc_vars = len(fc_vars)

    def _cond(step, *args):
      del args
      num_steps = self.num_update_steps
      if not self.is_training:
        num_steps += self.additional_evaluation_update_steps
      return step < num_steps

    def _body(step, *args):
      """The inner update loop body."""
      updated_embedding_vars = args[0:num_embedding_vars]
      updated_fc_vars = args[num_embedding_vars:num_embedding_vars +
                             num_fc_vars]
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['embeddings']

      updated_fc_weights, updated_fc_bias = updated_fc_vars
      support_logits = tf.matmul(support_embeddings,
                                 updated_fc_weights) + updated_fc_bias

      support_logits = support_logits[:, 0:data.way]
      loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels,
                                             support_logits)

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss])

      with tf.control_dependencies([print_op]):
        updated_embedding_vars = gradient_descent_step(
            loss, updated_embedding_vars, self.first_order,
            self.adapt_batch_norm, self.alpha, False)['updated_vars']
        updated_fc_vars = gradient_descent_step(loss, updated_fc_vars,
                                                self.first_order,
                                                self.adapt_batch_norm,
                                                self.alpha,
                                                False)['updated_vars']

        step = step + 1
      return tuple([step] + list(updated_embedding_vars) +
                   list(updated_fc_vars))

    # MAML meta updates using query set examples from an episode.
    if self.zero_fc_layer:
      # To account for variable class sizes, we initialize the output
      # weights to zero. See if truncated normal initialization will help.
      zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights))
      zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias))
      fc_vars_init_ops = [zero_weights_op, zero_bias_op]
    else:
      fc_vars_init_ops = fc_vars_copy_ops

    if self.proto_maml_fc_layer_init:
      support_embeddings = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']

      prototypes = metric_learners.compute_prototypes(
          support_embeddings, data.onehot_support_labels)
      pmaml_fc_weights = self.proto_maml_fc_weights(
          prototypes, zero_pad_to_max_way=True)
      pmaml_fc_bias = self.proto_maml_fc_bias(
          prototypes, zero_pad_to_max_way=True)
      fc_vars = [pmaml_fc_weights, pmaml_fc_bias]

    # These control dependencies assign the value of each variable to a new copy
    # variable that corresponds to it. This is required at test time for
    # initilizing the copies as they are used in place of the original vars.
    with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops):
      # Make step a local variable as we don't want to save and restore it.
      step = tf.Variable(
          0,
          trainable=False,
          name='inner_step_counter',
          collections=[tf.GraphKeys.LOCAL_VARIABLES])
      loop_vars = [step] + embedding_vars + fc_vars
      step_and_all_updated_vars = tf.while_loop(
          _cond, _body, loop_vars, swap_memory=True)
      step = step_and_all_updated_vars[0]
      all_updated_vars = step_and_all_updated_vars[1:]
      updated_embedding_vars = all_updated_vars[0:num_embedding_vars]
      updated_fc_weights, updated_fc_bias = all_updated_vars[
          num_embedding_vars:num_embedding_vars + num_fc_vars]

    # Forward pass the training images with the updated weights in order to
    # compute the means and variances, to use for the query's batch norm.
    support_set_moments = None
    if not self.transductive_batch_norm:
      support_set_moments = self.embedding_fn(
          data.support_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, updated_embedding_vars)),
          reuse=True)['moments']

    query_embeddings = self.embedding_fn(
        data.query_images,
        self.is_training,
        params=collections.OrderedDict(
            zip(embedding_vars_keys, updated_embedding_vars)),
        moments=support_set_moments,  # Use support set stats for batch norm.
        reuse=True,
        backprop_through_moments=self.backprop_through_moments)['embeddings']

    query_logits = (tf.matmul(query_embeddings, updated_fc_weights) +
                    updated_fc_bias)[:, 0:data.way]

    return query_logits
Exemple #15
0
def evaluation(model_class=None,
               input_fn=None,
               num_quantitative_examples=1000,
               num_qualitative_examples=50):
    """A function that build the model and eval quali."""

    tensorboard_callback = callback_utils.CustomTensorBoard(
        log_dir=FLAGS.eval_dir,
        batch_update_freq=1,
        split=FLAGS.split,
        num_qualitative_examples=num_qualitative_examples,
        num_steps_per_epoch=FLAGS.num_steps_per_epoch)
    model = model_class()
    checkpoint = tf.train.Checkpoint(model=model,
                                     ckpt_saved_epoch=tf.Variable(
                                         initial_value=-1, dtype=tf.int64))
    val_inputs = input_fn(is_training=False, batch_size=1)
    num_evauated_epoch = -1

    while True:
        ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir)
        if ckpt_path:
            ckpt_num_of_epoch = int(ckpt_path.split('/')[-1].split('-')[-1])
            if num_evauated_epoch == ckpt_num_of_epoch:
                logging.info(
                    'Found old epoch %d ckpt, skip and will check later.',
                    num_evauated_epoch)
                time.sleep(30)
                continue
            try:
                logging.info('Restoring new checkpoint[epoch:%d] at %s',
                             ckpt_num_of_epoch, ckpt_path)
                checkpoint.restore(ckpt_path)
            except tf.errors.NotFoundError:
                logging.info(
                    'Restoring from checkpoint has failed. Maybe file missing.'
                    'Try again now.')
                time.sleep(3)
                continue
        else:
            logging.info(
                'No checkpoint found at %s, will check again 10 s later..',
                FLAGS.ckpt_dir)
            time.sleep(10)
            continue

        tensorboard_callback.set_epoch_number(ckpt_num_of_epoch)
        logging.info('Start qualitative eval for %d steps...',
                     num_quantitative_examples)
        try:
            # TODO(huangrui): there is still possibility of crash due to
            # not found ckpt files.
            model._predict_counter.assign(0)  # pylint: disable=protected-access
            tensorboard_callback.set_model(model)
            tensorboard_callback.on_predict_begin()
            for i, inputs in enumerate(
                    val_inputs.take(num_quantitative_examples), start=1):
                tensorboard_callback.on_predict_batch_begin(batch=i)
                outputs = model(inputs, training=False)
                model._predict_counter.assign_add(1)  # pylint: disable=protected-access
                tensorboard_callback.on_predict_batch_end(batch=i,
                                                          logs={
                                                              'outputs':
                                                              outputs,
                                                              'inputs': inputs
                                                          })
                if i % FLAGS.num_steps_per_log == 0:
                    logging.info('eval progress %d / %d...', i,
                                 num_quantitative_examples)
            tensorboard_callback.on_predict_end()

            num_evauated_epoch = ckpt_num_of_epoch
            logging.info('Finished eval for epoch %d, sleeping for :%d s...',
                         num_evauated_epoch, 100)
            time.sleep(100)
        except tf.errors.NotFoundError:
            logging.info(
                'Restoring from checkpoint has failed. Maybe file missing.'
                'Try again now.')
            continue