示例#1
0
 def var_avg(self, global_step=None):
   ''' average model variables, add average_op to UPDATES_OPS'''
   model_avg_conf = self.config['solver']['model_average']
   var_avg_model = model_avg_conf['enable']
   if var_avg_model:
     var_avg_decay = model_avg_conf['var_avg_decay']
     variable_averages = self.get_var_avg_ema(var_avg_decay, global_step)
     apply_op = variable_averages.apply(tf.trainable_variables())
     tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, apply_op)
     utils.log_vars('Avg Trainable Vars', tf.trainable_variables())
示例#2
0
  def get_train_op(self, loss, global_step=None):
    """Get the training operator."""
    apply_gradient_op = self.get_apply_gradients_op(loss, global_step)

    # model average
    self.var_avg(global_step)

    # model average after apply gradients
    with tf.control_dependencies([apply_gradient_op]):
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      train_op = tf.group(*update_ops)

    utils.log_vars('moving vars', tf.moving_average_variables())
    return train_op
示例#3
0
  def get_train_op(self, loss, multitask, global_step=None):
    """Get the training operator."""
    # quantize training
    quantconf = self.config['solver']['quantization']
    quantization = quantconf['enable']
    if quantization:
      quant_delay = quantconf['quant_delay']
      logging.info('Quantization training with {} delay'.format(quant_delay))
      tf.contrib.quantize.create_training_graph(quant_delay=quant_delay)

    apply_gradient_op = self.get_apply_gradients_op(loss, multitask,
                                                    global_step)

    # model average
    self.var_avg(global_step)

    # model average after apply gradients
    with tf.control_dependencies([apply_gradient_op]):
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      train_op = tf.group(*update_ops)

    utils.log_vars('moving vars', tf.moving_average_variables())
    return train_op
示例#4
0
        def _model_fn(features, labels, mode, params):
            del params
            is_train = mode == utils.TRAIN

            # Supports both dict output and legacy single logits output.
            model_outputs = model_class(features, training=is_train)
            if isinstance(model_outputs, dict):
                logits = model_outputs['logits']
                extra_outputs = dict(model_outputs)
                extra_outputs.pop('logits')
            else:
                logits = model_outputs
                extra_outputs = None

            alignment = model_class.alphas if hasattr(model_class,
                                                      'alphas') else None

            if mode == utils.INFER:
                softmax = tf.nn.softmax(logits, name='softmax_output')
                predictions = self.get_infer_predictions(
                    features,
                    softmax,
                    alpha=alignment,
                    extra_outputs=extra_outputs)
                return tf.estimator.EstimatorSpec( #pylint: disable=no-member
                    mode=mode,
                    predictions=predictions,
                    scaffold=self.get_scaffold(mode),
                    export_outputs={
                        'predictions': tf.estimator.export.PredictOutput(predictions) #pylint: disable=no-member
                    })

            if 'soft_labels' in features.keys():
                soft_labels = features['soft_labels']
            else:
                soft_labels = None

            loss = self.get_loss_fn()(
                labels=labels,
                logits=logits,
                soft_labels=soft_labels,
                name='x_loss',
            )

            if mode == utils.TRAIN:  #pylint: disable=no-else-return
                if self.config['solver']['adversarial']['enable']:
                    x = features['inputs']  #pylint: disable=invalid-name
                    grad, = tf.gradients(loss, x)
                    x_adv = x + self.config['solver']['adversarial'][
                        'adv_epslion'] * tf.sign(grad)
                    x_adv = tf.stop_gradient(x_adv)
                    features_adv = {'inputs': x_adv, 'texts': features['text']}
                    logits_adv = model_class(features_adv)

                    loss_adv = self.get_loss_fn()(
                        labels=labels,
                        logits=logits_adv,
                        soft_labels=soft_labels,
                        name='x_adv_loss',
                    )
                    adv_alpha = self.config['solver']['adversarial'][
                        'adv_alpha']
                    loss_all = (1 - adv_alpha) * loss + adv_alpha * loss_adv
                else:
                    loss_all = loss

                # L2 loss
                loss_all += self.l2_loss()

                train_op = self.get_train_op(loss_all)
                train_hooks = self.get_train_hooks(labels,
                                                   logits,
                                                   alpha=alignment)

                utils.log_vars('Global Vars', tf.global_variables())
                return tf.estimator.EstimatorSpec(  #pylint: disable=no-member
                    mode=mode,
                    loss=loss_all,
                    train_op=train_op,
                    training_chief_hooks=train_hooks,
                    training_hooks=None,
                    scaffold=None,
                )
            else:  # eval
                loss_all = loss
                eval_hooks, eval_metrics_ops = self.get_eval_hooks(
                    labels, logits)
                return tf.estimator.EstimatorSpec(  #pylint: disable=no-member
                    mode=mode,
                    loss=loss_all,
                    eval_metric_ops=eval_metrics_ops,
                    evaluation_hooks=eval_hooks,
                    scaffold=self.get_scaffold(mode),
                )