def var_avg(self, global_step=None): ''' average model variables, add average_op to UPDATES_OPS''' model_avg_conf = self.config['solver']['model_average'] var_avg_model = model_avg_conf['enable'] if var_avg_model: var_avg_decay = model_avg_conf['var_avg_decay'] variable_averages = self.get_var_avg_ema(var_avg_decay, global_step) apply_op = variable_averages.apply(tf.trainable_variables()) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, apply_op) utils.log_vars('Avg Trainable Vars', tf.trainable_variables())
def get_train_op(self, loss, global_step=None): """Get the training operator.""" apply_gradient_op = self.get_apply_gradients_op(loss, global_step) # model average self.var_avg(global_step) # model average after apply gradients with tf.control_dependencies([apply_gradient_op]): update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(*update_ops) utils.log_vars('moving vars', tf.moving_average_variables()) return train_op
def get_train_op(self, loss, multitask, global_step=None): """Get the training operator.""" # quantize training quantconf = self.config['solver']['quantization'] quantization = quantconf['enable'] if quantization: quant_delay = quantconf['quant_delay'] logging.info('Quantization training with {} delay'.format(quant_delay)) tf.contrib.quantize.create_training_graph(quant_delay=quant_delay) apply_gradient_op = self.get_apply_gradients_op(loss, multitask, global_step) # model average self.var_avg(global_step) # model average after apply gradients with tf.control_dependencies([apply_gradient_op]): update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(*update_ops) utils.log_vars('moving vars', tf.moving_average_variables()) return train_op
def _model_fn(features, labels, mode, params): del params is_train = mode == utils.TRAIN # Supports both dict output and legacy single logits output. model_outputs = model_class(features, training=is_train) if isinstance(model_outputs, dict): logits = model_outputs['logits'] extra_outputs = dict(model_outputs) extra_outputs.pop('logits') else: logits = model_outputs extra_outputs = None alignment = model_class.alphas if hasattr(model_class, 'alphas') else None if mode == utils.INFER: softmax = tf.nn.softmax(logits, name='softmax_output') predictions = self.get_infer_predictions( features, softmax, alpha=alignment, extra_outputs=extra_outputs) return tf.estimator.EstimatorSpec( #pylint: disable=no-member mode=mode, predictions=predictions, scaffold=self.get_scaffold(mode), export_outputs={ 'predictions': tf.estimator.export.PredictOutput(predictions) #pylint: disable=no-member }) if 'soft_labels' in features.keys(): soft_labels = features['soft_labels'] else: soft_labels = None loss = self.get_loss_fn()( labels=labels, logits=logits, soft_labels=soft_labels, name='x_loss', ) if mode == utils.TRAIN: #pylint: disable=no-else-return if self.config['solver']['adversarial']['enable']: x = features['inputs'] #pylint: disable=invalid-name grad, = tf.gradients(loss, x) x_adv = x + self.config['solver']['adversarial'][ 'adv_epslion'] * tf.sign(grad) x_adv = tf.stop_gradient(x_adv) features_adv = {'inputs': x_adv, 'texts': features['text']} logits_adv = model_class(features_adv) loss_adv = self.get_loss_fn()( labels=labels, logits=logits_adv, soft_labels=soft_labels, name='x_adv_loss', ) adv_alpha = self.config['solver']['adversarial'][ 'adv_alpha'] loss_all = (1 - adv_alpha) * loss + adv_alpha * loss_adv else: loss_all = loss # L2 loss loss_all += self.l2_loss() train_op = self.get_train_op(loss_all) train_hooks = self.get_train_hooks(labels, logits, alpha=alignment) utils.log_vars('Global Vars', tf.global_variables()) return tf.estimator.EstimatorSpec( #pylint: disable=no-member mode=mode, loss=loss_all, train_op=train_op, training_chief_hooks=train_hooks, training_hooks=None, scaffold=None, ) else: # eval loss_all = loss eval_hooks, eval_metrics_ops = self.get_eval_hooks( labels, logits) return tf.estimator.EstimatorSpec( #pylint: disable=no-member mode=mode, loss=loss_all, eval_metric_ops=eval_metrics_ops, evaluation_hooks=eval_hooks, scaffold=self.get_scaffold(mode), )