Example #1
0
    def create_train_op(self,
                        loss,
                        optimizer,
                        update_ops=None,
                        train_outputs=None):
        """Create the train_op of from the loss obtained from model_train_fn.

    Args:
      loss: The loss we compute within model_train_fn.
      optimizer: An instance of `tf.train.Optimizer`.
      update_ops: List of update ops to execute alongside the training op.
      train_outputs: (Optional) A dict with additional tensors the training
        model generates.

    Returns:
      train_op: Op for the training step.
    """
        summarize_gradients = self._summarize_gradients
        if self.is_device_tpu:
            # TPUs don't support summaries up until now. Hence, we overwrite the user
            # provided summarize_gradients option to False.
            if self._summarize_gradients:
                logging.info('We cannot use summarize_gradients on TPUs.')
            summarize_gradients = False
        return contrib_training.create_train_op(
            loss,
            optimizer,
            summarize_gradients=summarize_gradients,
            update_ops=update_ops)
Example #2
0
def Encode(source, ckpt_prefix, hparams):
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
  tf.reset_default_graph()

  g = tf.Graph()
  session = tf.Session(graph=g)

  with g.as_default(), session.as_default():
    A = nx.adjacency_matrix(source, weight=None)

    x = tf.one_hot(
        list(source.nodes()), source.number_of_nodes(), dtype=tf.float64)
    y = tf.convert_to_tensor(A.todense(), dtype=tf.float64)

    layer = tf.layers.dense(x, hparams.embedding_size, use_bias=False)
    for _ in range(hparams.num_dnn_layers):
      layer = tf.layers.dense(
          layer, hparams.embedding_size * 4, activation=tf.nn.tanh)
    logits = tf.layers.dense(
        layer, source.number_of_nodes(), activation=tf.nn.tanh)

    loss = AdjMatrixLoss(logits, y)

    train_op = contrib_training.create_train_op(
        loss,
        tf.train.AdamOptimizer(hparams.learning_rate),
        summarize_gradients=False)

    session.run(tf.global_variables_initializer())

    for _ in range(hparams.train_num_epochs):
      session.run(train_op)

    tf.train.Saver(tf.trainable_variables()).save(session, ckpt_prefix)
Example #3
0
    def MakeEncoder(self, source, ckpt_prefix, hparams):
        #tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
        tf.reset_default_graph()

        g = tf.Graph()
        session = tf.Session(graph=g)

        with g.as_default(), session.as_default():
            ##################### 定義
            ######### データを整形
            A = nx.adjacency_matrix(
                source, weight='weight')  # 重み込みの隣接行列 i, j要素がiからjへのエッジ

            x = tf.one_hot(list(source.nodes()),
                           source.number_of_nodes(),
                           dtype=tf.float64)
            y = tf.convert_to_tensor(A.todense(), dtype=tf.float64)

            ########### モデルを定義
            layer = tf.layers.dense(x, hparams.embedding_size, use_bias=False)

            for _ in range(hparams.num_dnn_layers):
                layer = tf.layers.dense(layer,
                                        hparams.embedding_size * 4,
                                        activation=tf.nn.tanh)

            logits = tf.layers.dense(layer,
                                     source.number_of_nodes(),
                                     activation=tf.nn.tanh)

            ############# ロス定義
            loss = self.AdjMatrixLoss(logits, y)

            ############# BackPropagation定義
            train_op = contrib_training.create_train_op(
                loss,
                tf.train.AdamOptimizer(hparams.learning_rate),
                summarize_gradients=False)

            ###################### 実行
            session.run(tf.global_variables_initializer())

            for _ in range(hparams.train_num_epochs):
                session.run(train_op)

            ##################### 保存
            tf.train.Saver(tf.trainable_variables()).save(session, ckpt_prefix)
    def create_train_op(
            self,
            loss,
            optimizer,
            update_ops=None,
            train_outputs=None,
            filter_trainables_fn=None):  # pylint: disable=line-too-long
        """Create the train_op of from the loss obtained from model_train_fn.

    Args:
      loss: The loss we compute within model_train_fn.
      optimizer: An instance of `tf.train.Optimizer`.
      update_ops: List of update ops to execute alongside the training op.
      train_outputs: (Optional) A dict with additional tensors the training
        model generates.
      filter_trainables_fn: (Optional) A function that takes a trainable
        TensorFlow variable and returns whether it should be updated or not.
        By default, all trainable variables are updated.

    Returns:
      train_op: Op for the training step.
    """
        summarize_gradients = self._summarize_gradients
        if self.is_device_tpu:
            # TPUs don't support summaries up until now. Hence, we overwrite the user
            # provided summarize_gradients option to False.
            if self._summarize_gradients:
                logging.info('We cannot use summarize_gradients on TPUs.')
            summarize_gradients = False
        variables_to_train = None
        if filter_trainables_fn is not None:
            logging.info('Filtering trainable variables')
            variables_to_train = [
                var for var in tf.trainable_variables()
                if filter_trainables_fn(var)
            ]
            logging.info('Only updating the following trainables:')
            for var in variables_to_train:
                logging.info('  %s', var.name)
        return contrib_training.create_train_op(
            loss,
            optimizer,
            summarize_gradients=summarize_gradients,
            update_ops=update_ops,
            variables_to_train=variables_to_train)
Example #5
0
    def create_train_op(self,
                        loss,
                        optimizer,
                        update_ops=None,
                        train_outputs=None):
        """Create meta-training op.

    MAMLModel has a configurable var_scope used to select which variables to
    train on. Note that MAMLInnerLoopGradientDescent also has such a parameter
    to decide which variables to update in the *inner* loop. If you don't want
    to update a set of variables in both the inner and outer loop, you'll need
    to configure var_scope for both MAMLModel *and*
    MAMLInnerLoopGradientDescent.

    Args:
      loss: The loss we compute within model_train_fn.
      optimizer: An instance of `tf.train.Optimizer`.
      update_ops: List of update ops to execute alongside the training op.
      train_outputs: (Optional) A dict with additional tensors the training
        model generates.

    Returns:
      train_op: Op for the training step.
    """
        vars_to_train = tf.trainable_variables()
        if self._var_scope is not None:
            vars_to_train = [
                v for v in vars_to_train
                if v.op.name.startswith(self._var_scope)
            ]
        summarize_gradients = self._summarize_gradients
        if self.is_device_tpu:
            # TPUs don't support summaries up until now. Hence, we overwrite the user
            # provided summarize_gradients option to False.
            if self._summarize_gradients:
                logging.info('We cannot use summarize_gradients on TPUs.')
            summarize_gradients = False
        return contrib_training.create_train_op(
            loss,
            optimizer,
            variables_to_train=vars_to_train,
            summarize_gradients=summarize_gradients,
            update_ops=update_ops)
Example #6
0
def set_train_op(model_config):
  """Sets the train op for a single weights update."""
  with tf.name_scope('train-op-creation'):
    if model_config.hparams.opt == 'adm':
      optimizer = tf.train.AdamOptimizer(
          learning_rate=model_config.hparams.lr, beta1=0.9, beta2=0.999)
    elif model_config.hparams.opt == 'sgd':
      optimizer = tf.train.GradientDescentOptimizer(
          learning_rate=model_config.hparams.lr)
    elif model_config.hparams.opt == 'mtm':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate=model_config.hparams.lr, momentum=0.9)
    # In TPU training this wraps the optimizer in a CrossShardOptimizer for
    #   you.
    if model_config.hparams.sync == 't':
      assert model_config.hparams.gpuc > 0
      assert model_config.hparams.vbs > 0
      optimizer = tf.train.SyncReplicasOptimizer(
          optimizer,
          replicas_to_aggregate=model_config.hparams.vbs,
          total_num_replicas=model_config.hparams.gpuc)
    else:
      optimizer = model_config.wrap_optimizer(optimizer)
    variables_to_train = tf.compat.v1.trainable_variables()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if model_config.hparams.ob == 'f':
      variables_to_train = contrib_framework.filter_variables(
          variables_to_train, exclude_patterns=['explicit_embedding_cnn'])
      update_ops = contrib_framework.filter_variables(
          update_ops, exclude_patterns=['explicit_embedding_cnn'])

    model_config.train_op = contrib_training.create_train_op(
        model_config.loss,
        optimizer=optimizer,
        update_ops=update_ops,
        variables_to_train=variables_to_train,
        # transform_grads_fn=opt_util.clip_by_global_norm,
        summarize_gradients=False,
        colocate_gradients_with_ops=False)
Example #7
0
def train_ddpg(dataset,
               policy,
               actor_optimizer=None,
               critic_optimizer=None,
               pack_transition_fn=None,
               ddpg_graph_fn=None,
               log_dir=None,
               master='local',
               task=0,
               training_steps=None,
               max_training_steps=100000,
               reuse=False,
               init_checkpoint=None,
               update_target_every_n_steps=50,
               log_every_n_steps=None,
               save_checkpoint_steps=500,
               save_summaries_steps=500):
  """Self-contained learning loop for offline Q-learning.

  Code inspired by OpenAI Baselines' deepq.build_train. This function is
  compatible with discrete Q-learning graphs, continuous Q learning graphs, and
  SARSA.

  Args:
    dataset: tf.data.Dataset providing transitions.
    policy: Instance of TFDQNPolicy class that provides functor for building the
      critic function.
    actor_optimizer: Optional instance of an optimizer for the actor network.
      If not specified, creates an AdamOptimizer using the default constructor.
    critic_optimizer: Optional instance of an optimizer for the critic network.
      If not specified, creates an AdamOptimizer using the default constructor.
    pack_transition_fn: Optional function that performs additional processing
      of the transition. This is a convenience method for ad-hoc manipulation of
      transition data passed to the learning function after parsing.
    ddpg_graph_fn: Function used to construct training objectives w.r.t. critic
      outputs.
    log_dir: Where to save model checkpoints and tensorboard summaries.
    master: Optional address of master worker. Specify this when doing
      distributed training.
    task: Optional worker task for distributed training. Defaults to solo master
      task on a single machine.
    training_steps: Optional number of steps to run training before terminating
      early. Max_training_steps remains unchanged - training will terminate
      after max_training_steps whether or not training_steps is specified.
    max_training_steps: maximum number of training iters.
    reuse: If True, reuse existing variables for all declared variables by this
      function.
    init_checkpoint: Optional checkpoint to restore prior to training. If not
      provided, variables are initialized using global_variables_initializer().
    update_target_every_n_steps: How many global steps (training) between
      copying the Q network weights (scope='q_func') to target network
      (scope='target_q_func').
    log_every_n_steps: How many global steps between logging loss tensors.
    save_checkpoint_steps: How many global steps between saving TF variables
      to a checkpoint file.
    save_summaries_steps: How many global steps between saving TF summaries.

  Returns:
    (int) Current `global_step` reached after training for training_steps, or
    `max_training_steps` if `global_step` has reached `max_training_steps`.

  """
  data_iterator = dataset.make_one_shot_iterator()

  transition = data_iterator.get_next()
  if pack_transition_fn:
    transition = pack_transition_fn(transition)

  if actor_optimizer is None:
    actor_optimizer = tf.train.AdamOptimizer()
  if critic_optimizer is None:
    critic_optimizer = tf.train.AdamOptimizer()

  a_func = policy.get_a_func(is_training=True, reuse=reuse)
  q_func = policy.get_q_func(is_training=True, reuse=reuse)
  actor_loss, critic_loss, all_summaries = ddpg_graph_fn(
      a_func, q_func, transition)

  a_func_vars = contrib_framework.get_trainable_variables(scope='a_func')
  q_func_vars = contrib_framework.get_trainable_variables(scope='q_func')
  target_q_func_vars = contrib_framework.get_trainable_variables(
      scope='target_q_func')

  # with tf.variable_scope('ddpg', use_resource=True):
  global_step = tf.train.get_or_create_global_step()

  # CRITIC OPTIMIZATION
  # Only optimize q_func and update its batchnorm params.
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func')
  critic_train_op = contrib_training.create_train_op(
      critic_loss,
      critic_optimizer,
      global_step=global_step,
      update_ops=update_ops,
      summarize_gradients=True,
      variables_to_train=q_func_vars,
  )

  # ACTOR OPTIMIZATION
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='a_func')
  actor_train_op = contrib_training.create_train_op(
      actor_loss,
      actor_optimizer,
      global_step=None,
      summarize_gradients=True,
      variables_to_train=a_func_vars,
  )
  # Combine losses to train both actor and critic simultaneously.
  train_op = critic_train_op + actor_train_op

  chief_hooks = []
  hooks = []
  # Save summaries periodically.
  if save_summaries_steps is not None:
    chief_hooks.append(tf.train.SummarySaverHook(
        save_steps=save_summaries_steps,
        output_dir=log_dir, summary_op=all_summaries))

  # Stop after training_steps
  if max_training_steps:
    hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps))

  # Report if loss tensor is NaN.
  hooks.append(tf.train.NanTensorHook(actor_loss))
  hooks.append(tf.train.NanTensorHook(critic_loss))

  if log_every_n_steps is not None:
    tensor_dict = {
        'global_step': global_step,
        'actor loss': actor_loss,
        'critic_loss': critic_loss
    }
    chief_hooks.append(
        tf.train.LoggingTensorHook(tensor_dict, every_n_iter=log_every_n_steps))

    # Measure how fast we are training per sec and save to summary.
    chief_hooks.append(tf.train.StepCounterHook(
        every_n_steps=log_every_n_steps, output_dir=log_dir))

  # If target network exists, periodically update target Q network with new
  # weights (frozen target network). We hack this by
  # abusing a LoggingTensorHook for this.
  if target_q_func_vars and update_target_every_n_steps is not None:
    update_target_expr = []
    for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name),
                          sorted(target_q_func_vars, key=lambda v: v.name)):
      update_target_expr.append(var_t.assign(var))
    update_target_expr = tf.group(*update_target_expr)

    with tf.control_dependencies([update_target_expr]):
      update_target = tf.constant(0)
    chief_hooks.append(
        tf.train.LoggingTensorHook({'update_target': update_target},
                                   every_n_iter=update_target_every_n_steps))

  # Save checkpoints periodically, save all of them.
  saver = tf.train.Saver(max_to_keep=None)
  chief_hooks.append(tf.train.CheckpointSaverHook(
      log_dir, save_steps=save_checkpoint_steps, saver=saver,
      checkpoint_basename='model.ckpt'))

  # Save our experiment params to checkpoint dir.
  chief_hooks.append(gin.tf.GinConfigSaverHook(log_dir, summarize_config=True))

  session_config = tf.ConfigProto(log_device_placement=True)

  init_fn = None
  if init_checkpoint:
    assign_fn = contrib_framework.assign_from_checkpoint_fn(
        init_checkpoint, contrib_framework.get_model_variables())
    init_fn = lambda _, sess: assign_fn(sess)
  scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)
  with tf.train.MonitoredTrainingSession(
      master=master,
      is_chief=(task == 0),
      config=session_config,
      checkpoint_dir=log_dir,
      scaffold=scaffold,
      hooks=hooks,
      chief_only_hooks=chief_hooks) as sess:
    np_step = 0
    while not sess.should_stop():
      np_step, _ = sess.run([global_step, train_op])
      if training_steps and np_step % training_steps == 0:
        break
    done = np_step >= max_training_steps
  return np_step, done
Example #8
0
def Score(source, target, ckpt_prefix, hparams):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
    tf.reset_default_graph()

    g = tf.Graph()
    session = tf.Session(graph=g)

    with g.as_default(), session.as_default():
        A = nx.adjacency_matrix(target, weight=None)

        x = tf.one_hot(list(target.nodes()),
                       target.number_of_nodes(),
                       dtype=tf.float64)
        y = tf.convert_to_tensor(A.todense(), dtype=tf.float64)

        with tf.variable_scope('attention'):
            attention = tf.layers.dense(x,
                                        source.number_of_nodes(),
                                        use_bias=False)
            source_node_prob = tf.nn.softmax(attention)

        layer = tf.layers.dense(source_node_prob,
                                hparams.embedding_size,
                                use_bias=False)
        for _ in range(hparams.num_dnn_layers):
            layer = tf.layers.dense(layer,
                                    hparams.embedding_size * 4,
                                    activation=tf.nn.tanh)
        logits = tf.layers.dense(layer,
                                 source.number_of_nodes(),
                                 activation=tf.nn.tanh)

        with tf.variable_scope('attention_reverse'):
            attention_reverse = tf.layers.dense(logits,
                                                target.number_of_nodes())
            target_neighbors_pred = tf.nn.sigmoid(attention_reverse)
            target_neighbors_prob = ProbFromCounts(target_neighbors_pred)

        loss = AdjMatrixLoss(attention_reverse, y)

        if hparams.get('node_label_loss_coefficient', None):
            label_loss = NodeLabelLoss(source, source_node_prob, target,
                                       hparams.num_node_labels)
            label_loss += NeighborNodesLabelLoss(target_neighbors_prob, target,
                                                 hparams.num_node_labels)
            loss += label_loss * hparams.node_label_loss_coefficient

        if hparams.get('incident_label_loss_coefficient', None):
            edge_loss = EdgeLabelLoss(source, source_node_prob, target,
                                      hparams.num_edge_labels)
            edge_loss += NeighborEdgesLabelsLoss(target_neighbors_prob, target,
                                                 hparams.num_edge_labels)
            loss += edge_loss * hparams.incident_label_loss_coefficient

        vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            scope='(?!attention)')
        vars_to_train = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='attention')

        train_op = contrib_training.create_train_op(
            loss,
            tf.train.AdamOptimizer(hparams.learning_rate),
            variables_to_train=vars_to_train,
            summarize_gradients=False)

        session.run(tf.global_variables_initializer())

        tf.train.Saver(vars_to_restore).restore(session, ckpt_prefix)

        losses = []

        for _ in range(hparams.score_num_epochs):
            losses.append(session.run([train_op, loss])[1])

    return losses[-hparams.score_window:]
def define_estimator(mode, features, labels, model_fn, config, params):
    """Add documentation... More information at tf.Estimator class.
  Assumptions:
    features: a dict containing rawfeatures and profeatures
      both: Nb x hf x wf x 3, tf.float32, in [0,1]
    labels: a dict containing rawfeatures and profeatures
      both: Nb x hf x wf, tf.int32, in [0,Nc-1]
  Args:
    features: First item returned by input_fn passed to train, evaluate, and predict.
    labels: Second item returned by input_fn passed to train, evaluate, and predict.
    mode: one of tf.estimator.ModeKeys.
    config: a tf.estimator.RunConfig object...
    parameters: a tf.train.HParams object...
    ...
  """

    assert mode in _ALLOWED_MODES, (
        'mode should be TRAIN, EVAL or PREDICT from tf.estimator.ModeKeys.')
    assert params.name_feature_extractor in {
        'resnet_v1_50', 'resnet_v1_101'
    }, ('params must have name_feature_extractor attribute in resnet_v1_{50,101}.'
        )
    if params.name_feature_extractor == 'resnet_v1_101':
        raise NotImplementedError(
            'Use of resnet_v1_101 as base feature extractor is not yet implemented.'
        )

    # unpack features
    rawimages = features['rawimages'] if 'rawimages' in features.keys(
    ) else None
    rawimagespaths = features[
        'rawimagespaths'] if 'rawimagespaths' in features.keys() else None
    proimages = features['proimages']
    prolabels = labels if labels else None

    ## build a fully convolutional model for semantic segmentation
    # predictions refer to the training class ids
    # for plotting of results (inference) or assessment, predictions should be transformed
    #   using `{inference, evaluation}_problem_def`s
    _, _, predictions = model_fn(mode, proimages, prolabels, config, params)

    # TODO(panos): assert that proimages and predictions have same spatial size

    if mode == tf.estimator.ModeKeys.TRAIN:

        # global step
        global_step = tf.train.get_or_create_global_step()

        # losses
        with tf.variable_scope('losses'):
            losses = define_losses(mode, predictions, prolabels, config,
                                   params)

        # exponential moving averages
        # creates variables in checkpoint with name: 'emas/' + <variable_name> +
        #   {'ExponentialMovingAverage,Momentum}
        # ex.: for 'classifier/logits/Conv/biases' it saves also
        #          'emas/classifier/logits/Conv/biases/ExponentialMovingAverage'
        #      and 'emas/classifier/logits/Conv/biases/Momentum'
        # create_train_op guarantees to run GraphKeys.UPDATE_OPS collection
        #   before total_loss in every step, but doesn't give any guarantee
        #   for running after some other op, and since ema need to be run
        #   after applying the gradients maybe this code needs checking
        if params.ema_decay > 0:
            with tf.variable_scope('exponential_moving_averages'):
                #for mv in slim.get_model_variables():
                #  print('slim.model_vars:', mv.op.name)
                ema = tf.train.ExponentialMovingAverage(
                    params.ema_decay,
                    num_updates=global_step,
                    zero_debias=True)
                variables_to_ema = []
                for mv in tf.model_variables():
                    if 'BatchNorm/moving' not in mv.name:
                        variables_to_ema.append(mv)
                print(
                    f"\nFound {len(tf.model_variables())} variables, saving exponential "
                    f"moving averages for {len(variables_to_ema)} of them.\n")
                maintain_ema_op = ema.apply(var_list=variables_to_ema)
                tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, maintain_ema_op)

        # create training operation
        with tf.variable_scope('train_ops'):

            # optimizer
            optimizer = define_optimizer(global_step, params)

            # training op
            train_op = create_train_op(
                losses['total'],
                optimizer,
                global_step=global_step,
                # update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS),
                summarize_gradients=False,
                # transform_grads_fn=,
                # gradient_multipliers=gradient_multipliers,
                check_numerics=False,
            )

        # TODO: maybe parameterize it
        training_hooks = [
            _RunMetadataHook(params.log_dir,
                             every_n_iter=max(params.num_training_steps // 50,
                                              params.save_checkpoints_steps))
        ]

        # next two lines were added for distributed debugging
        if params.distribute:
            tower_context = tf.contrib.distribute.get_tower_context()
            assert tower_context
            print(
                f"Tower {tower_context.tower_id}: _RunMetadataHook is not supported "
                "yet for distributed training.")
            training_hooks = []

        replace_initializers(config, params)

        summaries_data = {
            'features': features,
            'labels': labels,
            'predictions': predictions,
            'losses': losses,
            'learning_rate': optimizer._learning_rate
        }  #pylint: disable=protected-access

        scaffold = _define_scaffold(mode, config, params, summaries_data)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            loss=losses['total'],
            train_op=train_op,
            training_hooks=training_hooks,
            scaffold=scaffold)

    if mode == tf.estimator.ModeKeys.EVAL:
        with tf.variable_scope('losses'):
            losses = define_losses(mode, predictions, prolabels, config,
                                   params)

        # returns (variable, update_op)
        # TF internal error/problem: _streaming_confusion_matrix internally casts
        # labels and predictions to int64, and since we feed a dictionary, tensors are
        # passed by reference leading them to change type, thus we send an identity
        # confusion_matrix = metrics_impl._streaming_confusion_matrix(  # pylint: disable=protected-access
        #     tf.identity(prolabels),
        #     tf.identity(predictions['decisions']),
        #     params.output_Nclasses)
        # l1_probs, decs = itemgetter('l1_probabilities', 'decisions')(predictions)
        # create a new dict with the supported keys only
        predictions = _map_predictions_to_new_cids(
            predictions, params.training_cids2evaluation_cids)
        if params.replace_voids:
            predictions = _replace_voids(predictions, params)
        # TODO(panos): confusion matrix expects prolabels and predictions to have the same shape
        #   this may not the case when preserve_aspect_ratio is set and this will give an error
        if hasattr(params, 'preserve_aspect_ratio'):
            if params.preserve_aspect_ratio:
                raise NotImplementedError(
                    'evaluation with preserving aspect ratio is not implemented.'
                )
        predictions = _resize_predictions(predictions,
                                          tf.shape(labels['prolabels'])[1:3],
                                          params)
        tcids2ecids = _replacevoids(params.training_cids2evaluation_cids)
        confusion_matrix = metrics_impl._streaming_confusion_matrix(  # pylint: disable=protected-access
            labels['prolabels'],
            predictions['decisions'],
            # +1 due to convention of starting counting at 0
            max(tcids2ecids) + 1)

        # dict of metrics keyed by name with values tuples of (metric_tensor, update_op)
        # TODO: add more semantic segmentation metrics
        eval_metric_ops = {
            'confusion_matrix':
            (tf.to_int32(confusion_matrix[0]), confusion_matrix[1])
        }

        scaffold = _define_scaffold(mode, config, params)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            loss=losses['total'],
            eval_metric_ops=eval_metric_ops,
            scaffold=scaffold)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # create a new dict with the supported keys only
        l1_probs, l2_vehicle_probs, l2_human_probs, decs = itemgetter(
            'l1_probabilities', 'l2_vehicle_probabilities',
            'l2_human_probabilities', 'decisions')(predictions)
        predictions = {
            'l1_probabilities': l1_probs,
            'l2_vehicle_probabilities': l2_vehicle_probs,
            'l2_human_probabilities': l2_human_probs,
            'decisions': decs
        }
        # workaround for connecting input pipeline outputs to system output
        # TODO(panos): maybe from a system perspective makes more sense to have mapping and
        #   resizing in the system_factory
        # since these are functions of the system and not the network/estimator
        # new size defaults to provided values
        # if at least one is None then new size is the arbitrary size of rawimage in each step
        new_size = (params.height_system, params.width_system)
        is_arbitrary = not all(new_size)
        if is_arbitrary:
            if rawimages is not None:
                predictions['rawimages'] = rawimages
            if rawimagespaths is not None:
                predictions['rawimagespaths'] = rawimagespaths
            new_size = tf.shape(predictions['rawimages'])[1:3]
        predictions = _resize_predictions(predictions, new_size, params)
        tf.logging.warn(
            'Mapping of predictions to new cids is not implemented for now.')
        # predictions = _map_predictions_to_new_cids(predictions, params.training_cids2inference_cids)
        if params.replace_voids:
            predictions = _replace_voids(predictions, params)

        scaffold = _define_scaffold(mode, config, params)
        estimator_spec = tf.estimator.EstimatorSpec(mode,
                                                    predictions=predictions,
                                                    scaffold=scaffold)

    return estimator_spec
    def model_fn(features, labels, mode):
      """Creates the prediction, loss, and train ops.

      Args:
        features: A dictionary of tensors keyed by the feature name.
        labels: A dictionary of label tensors keyed by the label key.
        mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.

      Returns:
        EstimatorSpec with the mode, prediction, loss, train_op and
        output_alternatives a dictionary specifying the output for a
        servo request during serving.
      """
      # 1. Construct input to RNN
      sequence_feature_map = {
          k: features[input_fn.SEQUENCE_KEY_PREFIX + k]
          for k in hparams.sequence_features
      }
      sequence_length = tf.squeeze(
          features[input_fn.CONTEXT_KEY_PREFIX + 'sequenceLength'],
          axis=1,
          name='sq_seq_len')
      tf.summary.scalar('sequence_length', tf.reduce_mean(sequence_length))
      diff_delta_time, obs_values, indicator = construct_input(
          sequence_feature_map, hparams.categorical_values,
          hparams.categorical_seq_feature, hparams.feature_value, mode,
          hparams.normalize, hparams.momentum, hparams.min_value,
          hparams.max_value, hparams.input_keep_prob)

      seq_mask = tf.expand_dims(
          tf.sequence_mask(sequence_length, dtype=tf.float32), axis=2)
      logits, weights = construct_logits(
          diff_delta_time,
          obs_values,
          indicator,
          sequence_length,
          seq_mask,
          hparams,
          reuse=False)

      all_attribution_dict = {}
      if mode == tf.estimator.ModeKeys.TRAIN:
        if hparams.sequence_prediction:
          assert not hparams.use_rnn_attention
          # If we train a sequence_prediction we repeat the labels over time.
          label_tensor = labels[hparams.label_key]
          labels[hparams.label_key] = tf.tile(
              tf.expand_dims(label_tensor, 2),
              multiples=[1, tf.shape(logits)[1], 1])
          if hparams.volatility_loss_factor > 0.0:
            volatility = tf.reduce_sum(
                tf.square(seq_mask *
                          compute_prediction_diff_attribution(logits)))
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                 volatility * hparams.volatility_loss_factor)
        elif not hparams.use_rnn_attention:
          logits = rnn_common.select_last_activations(
              logits, tf.to_int32(sequence_length))
      else:
        if hparams.sequence_prediction:
          last_logits = rnn_common.select_last_activations(
              logits, tf.to_int32(sequence_length))
        else:
          last_logits = logits
        if mode == tf.estimator.ModeKeys.PREDICT:
          delta_time = sequence_feature_map['deltaTime']
          all_attributions = {}
          if hparams.include_gradients_attribution:
            all_attributions['gradient_last'] = compute_gradient_attribution(
                last_logits, obs_values, indicator)
          if hparams.include_gradients_sum_time_attribution:
            assert not hparams.use_rnn_attention
            all_attributions['gradient_sum'] = compute_gradient_attribution(
                _predictions_for_gradients(
                    logits, seq_mask, delta_time,
                    hparams.attribution_max_delta_time, averaged=False),
                obs_values, indicator)
          if hparams.include_gradients_avg_time_attribution:
            assert not hparams.use_rnn_attention
            all_attributions['gradient_avg'] = compute_gradient_attribution(
                _predictions_for_gradients(
                    logits, seq_mask, delta_time,
                    hparams.attribution_max_delta_time, averaged=True),
                obs_values, indicator)
          if hparams.include_path_integrated_gradients_attribution:
            all_attributions['integrated_gradient'] = (
                compute_path_integrated_gradient_attribution(
                    obs_values, indicator, diff_delta_time, delta_time,
                    sequence_length, seq_mask, hparams))
          if hparams.use_rnn_attention:
            all_attributions['rnn_attention'] = weights
          if hparams.include_diff_sequence_prediction_attribution:
            all_attributions['diff_sequence'] = (
                compute_prediction_diff_attribution(logits))

          all_attribution_dict = {}
          for attribution_name, attribution in all_attributions.items():
            attribution_dict = convert_attribution(
                attribution,
                sequence_feature_map,
                seq_mask,
                delta_time,
                hparams.attribution_threshold,
                hparams.attribution_max_delta_time,
                prefix=attribution_name + '-')
            all_attribution_dict.update(attribution_dict)
          if hparams.include_sequence_prediction:
            # Add the predictions at each time step to the attention dictionary.
            attribution_indices = tf.where(seq_mask > 0.5)
            all_attribution_dict['predictions'] = tf.sparse.expand_dims(
                tf.SparseTensor(
                    indices=attribution_indices,
                    values=tf.gather_nd(
                        tf.sigmoid(logits), attribution_indices),
                    dense_shape=tf.to_int64(tf.shape(delta_time))),
                axis=1)
        # At test/inference time we only make a single prediction even if we did
        # sequence_prediction during training.
        logits = last_logits
        seq_mask = None

      probabilities = tf.sigmoid(logits)
      classes = probabilities > 0.5
      predictions = {
          PredictionKeys.LOGITS: logits,
          PredictionKeys.PROBABILITIES: probabilities,
          PredictionKeys.CLASSES: classes
      }
      # Calculate the loss for TRAIN and EVAL, but not PREDICT.
      if mode == tf.estimator.ModeKeys.PREDICT:
        loss = None
      else:
        loss = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels[hparams.label_key],
            logits=predictions[PredictionKeys.LOGITS])
        if hparams.sequence_prediction:
          loss *= seq_mask
        loss = tf.reduce_mean(loss)
        regularization_losses = tf.losses.get_regularization_losses()
        if regularization_losses:
          tf.summary.scalar('loss/prior_regularization', loss)
          regularization_loss = tf.add_n(regularization_losses)
          tf.summary.scalar('loss/regularization_loss', regularization_loss)
          loss += regularization_loss
        tf.summary.scalar('loss', loss)

      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(
            learning_rate=hparams.learning_rate, beta1=0.9, beta2=0.999,
            epsilon=1e-8)
        optimizer = contrib_estimator.clip_gradients_by_norm(optimizer, 6.0)
        train_op = contrib_training.create_train_op(
            total_loss=loss, optimizer=optimizer, summarize_gradients=False)
      if mode != tf.estimator.ModeKeys.TRAIN:
        for k, v in all_attribution_dict.items():
          if not isinstance(v, tf.SparseTensor):
            raise ValueError('Expect attributions to be in SparseTensor, '
                             'getting %s for feature %s' %
                             (v.__class__.__name__, k))
          predictions['attention_attribution,%s,indices' % k] = v.indices
          predictions['attention_attribution,%s,values' % k] = v.values
          predictions['attention_attribution,%s,shape' % k] = v.dense_shape

      eval_metric_ops = {}
      if mode == tf.estimator.ModeKeys.EVAL:
        auc = tf.metrics.auc
        prob_k = PredictionKeys.PROBABILITIES
        class_k = PredictionKeys.CLASSES
        m = 'careful_interpolation'
        metric_fn_dict = {
            'auc-roc':
                lambda l, p: auc(l, p[prob_k], curve='ROC', summation_method=m),
            'auc-pr':
                lambda l, p: auc(l, p[prob_k], curve='PR', summation_method=m),
            'accuracy':
                lambda l, p: tf.metrics.accuracy(l, p[class_k]),
        }
        for (k, f) in metric_fn_dict.items():
          eval_metric_ops[k] = f(label_tensor, predictions)
      # Define the output for serving.
      export_outputs = {}
      if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = {
            'mortality': tf.estimator.export.PredictOutput(predictions)
        }

      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)
Example #11
0
    def Measure(self, source, target, ckpt_prefix, hparams):
        #tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
        tf.reset_default_graph()

        g = tf.Graph()
        session = tf.Session(graph=g)

        with g.as_default(), session.as_default():
            ########################## モデル定義
            A = nx.adjacency_matrix(target, weight=None)

            x = tf.one_hot(list(target.nodes()),
                           target.number_of_nodes(),
                           dtype=tf.float64)
            y = tf.convert_to_tensor(A.todense(), dtype=tf.float64)

            with tf.variable_scope('attention'):
                attention = tf.layers.dense(x,
                                            source.number_of_nodes(),
                                            use_bias=False)
                source_node_prob = tf.nn.softmax(attention)

            layer = tf.layers.dense(source_node_prob,
                                    hparams.embedding_size,
                                    use_bias=False)
            for _ in range(hparams.num_dnn_layers):
                layer = tf.layers.dense(layer,
                                        hparams.embedding_size * 4,
                                        activation=tf.nn.tanh)
            logits = tf.layers.dense(layer,
                                     source.number_of_nodes(),
                                     activation=tf.nn.tanh)

            with tf.variable_scope('attention_reverse'):
                attention_reverse = tf.layers.dense(logits,
                                                    target.number_of_nodes())
                # target_neighbors_pred = tf.nn.sigmoid(attention_reverse)
                # target_neighbors_prob = ProbFromCounts(target_neighbors_pred) # ラベルロスの計算に使用

            ########################### ロス定義
            loss = self.AdjMatrixLoss(attention_reverse, y)

            ########################### 訓練定義
            ## attention以外のパラメタはrestore
            vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                scope='(?!attention)')
            ## attention, reverse_attentionのパラメタはtrain
            vars_to_train = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='attention')

            train_op = contrib_training.create_train_op(
                loss,
                tf.train.AdamOptimizer(hparams.learning_rate),
                variables_to_train=vars_to_train,
                summarize_gradients=False)

            ############################ 実行
            session.run(tf.global_variables_initializer())

            tf.train.Saver(vars_to_restore).restore(session, ckpt_prefix)
            losses = []

            for _ in range(hparams.score_num_epochs):
                losses.append(session.run([train_op, loss])[1])

            return losses[-hparams.score_window:]
Example #12
0
def define_estimator(mode, features, labels, model_fn, config, params):
    """
  Assumptions:
    features: a dict containing rawfeatures and profeatures
      both: Nb x hf x wf x 3, tf.float32, in [0,1]
    labels: a dict containing rawfeatures and profeatures
      both: Nb x hf x wf, tf.int32, in [0,Nc-1]
  Args:
    features: First item returned by input_fn passed to train, evaluate, and predict.
    labels: Second item returned by input_fn passed to train, evaluate, and predict.
    mode: one of tf.estimator.ModeKeys.
  """

    assert mode in _ALLOWED_MODES, (
        'mode should be TRAIN, EVAL or PREDICT from tf.estimator.ModeKeys.')
    assert params.name_feature_extractor in {
        'resnet_v1_50', 'resnet_v1_101'
    }, ('params must have name_feature_extractor attribute in resnet_v1_{50,101}.'
        )
    if params.name_feature_extractor == 'resnet_v1_101':
        raise NotImplementedError(
            'Use of resnet_v1_101 as base feature extractor is not yet implemented.'
        )

    # unpack features
    rawimages = features['rawimages']
    proimages = features['proimages']
    # TODO: fix this temporary workaround for labels
    # rawlabels = labels['rawlabels'] if mode != tf.estimator.ModeKeys.PREDICT else None
    prolabels = labels[
        'prolabels'] if mode != tf.estimator.ModeKeys.PREDICT else None

    print('debug:rawimages:', rawimages)
    print('debug:proimages:', proimages)
    print('debug:prolabels:', prolabels)

    ## build a fully convolutional model for semantic segmentation
    _, _, predictions = model_fn(mode, proimages, prolabels, config, params)

    # print('debug: predictions:', predictions)
    ## create training ops and exponential moving averages
    if mode == tf.estimator.ModeKeys.TRAIN:

        # global step
        global_step = tf.train.get_or_create_global_step()

        # losses
        with tf.variable_scope('losses'):
            losses = define_losses(mode, config, params, predictions,
                                   prolabels)

        # exponential moving averages
        # creates variables in checkpoint with name: 'emas/' + <variable_name> +
        #   {'ExponentialMovingAverage,Momentum}
        # ex.: for 'classifier/logits/Conv/biases' it saves also
        #          'emas/classifier/logits/Conv/biases/ExponentialMovingAverage'
        #      and 'emas/classifier/logits/Conv/biases/Momentum'
        # create_train_op guarantees to run GraphKeys.UPDATE_OPS collection
        #   before total_loss in every step, but doesn't give any guarantee
        #   for running after some other op, and since ema need to be run
        #   after applying the gradients maybe this code needs checking
        if params.ema_decay > 0:
            with tf.name_scope('exponential_moving_averages'):
                #for mv in slim.get_model_variables():
                #  print('slim.model_vars:', mv.op.name)
                ema = tf.train.ExponentialMovingAverage(
                    params.ema_decay,
                    num_updates=global_step,
                    zero_debias=True)
                maintain_ema_op = ema.apply(var_list=tf.model_variables())
                tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, maintain_ema_op)

        # create training operation
        with tf.variable_scope('train_ops'):
            learning_rate = tf.train.piecewise_constant(
                global_step, params.lr_boundaries, params.lr_values)
            # optimizer
            if params.optimizer == 'SGDM':
                optimizer = tf.train.MomentumOptimizer(
                    learning_rate,
                    params.momentum,
                    use_nesterov=params.use_nesterov)
            elif params.optimizer == 'SGD':
                optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            # training op
            train_op = create_train_op(
                losses['total'],
                optimizer,
                global_step=global_step,
                #update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS),
                # summarize_gradients=True,
                # #clip_gradient_norm=params.clip_grad_norm,
                # #gradient_multipliers=gradient_multipliers,
                check_numerics=False,
            )

        # TODO: maybe parameterize it
        training_hooks = [
            _RunMetadataHook(params.log_dir,
                             every_n_iter=max(params.num_training_steps // 50,
                                              params.save_checkpoints_steps))
        ]

        summaries_data = {
            'features': features,
            'labels': labels,
            'predictions': predictions,
            'losses': losses,
            'learning_rate': learning_rate
        }

    # flatten and concatenate decisions
    if mode in [tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT]:
        # don't forget to change confusion matrix outputs
        # C: 28, M: 66, G: 44, E: 71
        flatten_decs = _flatten_all_decs(predictions['decisions'])
        # flatten_decs = _flatten_for_cityscapes_val(predictions['decisions'])
        # flatten_decs = _flatten_for_mapillary_val(predictions['decisions'])
        # flatten_decs = _flatten_for_cityscapes_extended_val(predictions['decisions'])
        # flatten_decs = _flatten_for_gtsdb_val(predictions['decisions'])

    if mode == tf.estimator.ModeKeys.EVAL:
        with tf.variable_scope('losses'):
            losses = define_losses(mode, config, params, predictions,
                                   prolabels)

        # returns (variable, update_op)
        # TF internal error/problem: _streaming_confusion_matrix internally casts
        # labels and predictions to int64, and since we feed a dictionary, tensors are
        # passed by reference leading them to change type, thus we send an identity
        # confusion_matrix = metrics_impl._streaming_confusion_matrix(  # pylint: disable=protected-access
        #     tf.identity(prolabels),
        #     tf.identity(predictions['decisions']),
        #     params.training_Nclasses)
        confusion_matrix = metrics_impl._streaming_confusion_matrix(  # pylint: disable=protected-access
            prolabels, flatten_decs, 44)

        # dict of metrics keyed by name with values tuples of (metric_tensor, update_op)
        # TODO: add more semantic segmentation metrics
        eval_metric_ops = {
            'confusion_matrix':
            (tf.to_int32(confusion_matrix[0]), confusion_matrix[1])
        }

    ## create EstimatorSpec according to mode
    # unpack predictions
    if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
        predictions = None
    else:
        # redefine predictions according to estimator requirements
        predictions = {
            'logits': predictions['logits'][0][0],
            'probabilities': predictions['probabilities'][0][0],
            # 'decisions': predictions['decisions'][0],
            'decisions': flatten_decs,
        }

    if mode == tf.estimator.ModeKeys.TRAIN:
        scaffold = _define_scaffold(mode, config, params, summaries_data)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            loss=losses['total'],
            train_op=train_op,
            training_hooks=training_hooks,
            scaffold=scaffold)
    elif mode == tf.estimator.ModeKeys.EVAL:
        scaffold = _define_scaffold(mode, config, params)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            loss=losses['total'],
            eval_metric_ops=eval_metric_ops,
            scaffold=scaffold)
    elif mode == tf.estimator.ModeKeys.PREDICT:
        scaffold = _define_scaffold(mode, config, params)
        # workaround for connecting input pipeline outputs to system output
        # TODO: make it more clear
        predictions['rawimages'] = rawimages
        predictions['rawimagespaths'] = features['rawimagespaths']
        # the expected predictions.keys() in this point is:
        # dict_keys(['logits', 'probabilities', 'decisions', 'rawimages', 'rawimagespaths'])
        estimator_spec = tf.estimator.EstimatorSpec(mode,
                                                    predictions=predictions,
                                                    scaffold=scaffold)

    return estimator_spec
Example #13
0
def create_alternating_train_op(losses, optimizer, global_step, params):
    """
    This op control the alternating of training ops

    Every `switch_period`, another train_op will be used with its respective trainable variables

    :param losses: Dictionary of losses in the network
    :param optimizer: optimizer to supply to `create_train_op`
    :param global_step: global step tensor
    :param params:
    :return:
    """
    def clip_function(gv_list):
        if params.gradient_clip_norm > 0.0:
            grad_list, var_list = zip(*gv_list)

            def clipping(x):
                return clip_by_value(x,
                                     clip_value_min=-params.gradient_clip_norm,
                                     clip_value_max=params.gradient_clip_norm)

            grad_list = map(clipping, grad_list)
            # grad_list, _ = clip_by_global_norm(list(grad_list), clip_norm=params.gradient_clip_norm)
            return zip(grad_list, var_list)
        else:
            return gv_list

    # We only alternate the training op if it is set to True AND when there's actually two domains
    switch_train_op = params.switch_train_op and len(params.tfrecords_list) > 1

    if switch_train_op:
        switch_period = params.switch_period

        variables_sem_seg = tf.trainable_variables()
        variables_dom_class = tf.trainable_variables('domain_classifier')
        for var in variables_dom_class:
            variables_sem_seg.remove(var)

        log.debug(
            'We are switching the train ops between Sem Seg %i tensors and Dom Class %i tensors'
            % (len(variables_sem_seg), len(variables_dom_class)))
        condition = tf.greater_equal(tf.mod(global_step, 2 * switch_period),
                                     switch_period)
        train_op = tf.cond(condition,
                           true_fn=lambda: create_train_op(
                               losses['total'],
                               optimizer,
                               variables_to_train=variables_sem_seg,
                               global_step=global_step,
                               check_numerics=False,
                               transform_grads_fn=clip_function),
                           false_fn=lambda: create_train_op(
                               losses['domain'],
                               optimizer,
                               variables_to_train=variables_dom_class,
                               global_step=global_step,
                               check_numerics=False,
                               transform_grads_fn=clip_function))
        tf.summary.scalar('Switch_condition',
                          tf.cast(condition, tf.int16),
                          family='optimizer')
    else:
        train_op = create_train_op(
            losses['total'],
            optimizer,
            global_step=global_step,
            check_numerics=False,
        )
    return train_op
Example #14
0
    def resnet_model_fn(features, labels, mode, params):
        """Returns the model function."""
        global_step = tf.train.get_global_step()

        feature = features['feature']
        labels = labels['label']
        one_hot_labels = model_utils.get_label(labels,
                                               params,
                                               bird_num_classes,
                                               batch_size=params['batch_size'])

        def get_logits():
            """Return the logits."""
            end_points, aux_logits = None, None
            if FLAGS.model_type == 'resnet':
                avg_pool = model.resnet_v1_model(feature, labels, mode, params)
            else:
                assert False
            name = 'final_dense_dst'
            with tf.variable_scope('target_CLS'):
                logits = tf.layers.dense(
                    inputs=avg_pool,
                    units=bird_num_classes,
                    kernel_initializer=tf.random_normal_initializer(
                        stddev=.01),
                    name=name)
                if end_points is not None:
                    aux_pool = end_points['AuxLogits_Pool']
                    aux_logits = tf.layers.dense(
                        inputs=aux_pool,
                        units=bird_num_classes,
                        kernel_initializer=tf.random_normal_initializer(
                            stddev=.001),
                        name='Aux{}'.format(name))
            return logits, aux_logits, end_points

        logits, _, _ = get_logits()
        logits = tf.cast(logits, tf.float32)

        if FLAGS.model_type == 'resnet':
            dst_loss = tf.losses.softmax_cross_entropy(
                logits=logits,
                weights=1.,
                onehot_labels=one_hot_labels,
                label_smoothing=params['label_smoothing'])
            dst_l2_loss = FLAGS.weight_decay * tf.add_n([
                tf.nn.l2_loss(v) for v in tf.trainable_variables()
                if 'batch_normalization' not in v.name
            ])
            loss = dst_loss + dst_l2_loss

        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            cur_finetune_step = tf.train.get_global_step()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if FLAGS.model_type == 'resnet':
                    finetune_learning_rate = rampcosine()
                else:
                    finetune_learning_rate = rampcosine()
                if FLAGS.optimizer == 'momentum':
                    optimizer = tf.train.MomentumOptimizer(
                        learning_rate=finetune_learning_rate,
                        momentum=params['momentum'],
                        use_nesterov=True)
                elif FLAGS.optimizer == 'RMS':
                    optimizer = tf.train.RMSPropOptimizer(
                        finetune_learning_rate,
                        RMSPROP_DECAY,
                        momentum=RMSPROP_MOMENTUM,
                        epsilon=RMSPROP_EPSILON)
                elif FLAGS.optimizer == 'adam':
                    optimizer = tf.train.AdamOptimizer(finetune_learning_rate)

                optimizer = tf.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=FLAGS.sync_replicas,
                    total_num_replicas=run_config.num_worker_replicas)
                train_op = contrib_training.create_train_op(loss, optimizer)
                with tf.variable_scope('finetune'):
                    train_op = optimizer.minimize(loss, cur_finetune_step)
                if FLAGS.moving_average:
                    ema = tf.train.ExponentialMovingAverage(
                        decay=MOVING_AVERAGE_DECAY, num_updates=global_step)
                    variables_to_average = (tf.trainable_variables() +
                                            tf.moving_average_variables())
                    with tf.control_dependencies([train_op]):
                        with tf.name_scope('moving_average'):
                            train_op = ema.apply(variables_to_average)
        else:
            train_op = None

        batch_size = params['batch_size']  # pylint: disable=unused-variable
        eval_metrics = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = model_utils.metric_fn(labels, logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.control_dependencies([train_op]):
                tf.summary.scalar('classifier/finetune_loss', loss)
                tf.summary.scalar('classifier/finetune_lr',
                                  finetune_learning_rate)
        else:
            train_op = None

        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metrics,
        )