def create_experiment(output_dir, data_dir, model_name, train_steps, eval_steps): """Create Experiment.""" hparams = create_hparams(FLAGS.hparams_set, data_dir) estimator, input_fns = create_experiment_components( hparams=hparams, output_dir=output_dir, data_dir=data_dir, model_name=model_name) eval_metrics = metrics.create_evaluation_metrics( zip(FLAGS.problems.split("-"), hparams.problem_instances)) if (hasattr(FLAGS, "autotune") and FLAGS.autotune and FLAGS.objective not in eval_metrics): raise ValueError("Tuning objective %s not among evaluation metrics %s" % (FLAGS.objective, eval_metrics.keys())) train_monitors = [] eval_hooks = [] if FLAGS.tfdbg: hook = debug.LocalCLIDebugHook() train_monitors.append(hook) eval_hooks.append(hook) return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_fns[tf.contrib.learn.ModeKeys.TRAIN], eval_input_fn=input_fns[tf.contrib.learn.ModeKeys.EVAL], eval_metrics=eval_metrics, train_steps=train_steps, eval_steps=eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, train_monitors=train_monitors, eval_hooks=eval_hooks)
def estimator_spec_eval(self, features, logits, labels, loss, problem, hparams, use_tpu=False): """Construct EstimatorSpec for EVAL mode.""" if use_tpu: eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) _remove_summaries() return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), loss=loss) else: eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss)
def estimator_spec_eval(self, features, logits, labels, loss): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams if not hasattr(hparams, "problem_instances"): raise NotImplementedError(_no_problem_err("estimator_spec_eval")) problem = hparams.problem_instances[0] if common_layers.is_on_tpu(): eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) _remove_summaries() return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), loss=loss) else: eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss)
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams if not hasattr(hparams, "problem_instances"): raise NotImplementedError(_no_problem_err("estimator_spec_eval")) problem = hparams.problem_instances[0] if common_layers.is_on_tpu(): eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) _remove_summaries() if isinstance(logits, dict): # For TPU, logits dict will be passed as keyword arguments to # eval_metrics_fn. Here we add the labels to those arguments. logits.update({"labels": labels}) return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, logits), loss=loss) else: return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), loss=loss) else: eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss)
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams if not hasattr(hparams, "problem_instances"): raise NotImplementedError(_no_problem_err("estimator_spec_eval")) problem = hparams.problem_instances[0] if common_layers.is_on_tpu(): _remove_summaries() if isinstance(logits, dict): eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) # For TPU, logits dict will be passed as keyword arguments to # eval_metrics_fn. Here we add the labels to those arguments. logits.update({"labels": labels}) return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, logits), loss=loss) else: eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), loss=loss) else: eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): if isinstance(logits, dict): # the key is located in the center of metric_name: "metrics-%s/%s/%s" k = metric_name.split("/")[1] eval_metrics[metric_name] = metric_fn(logits[k], features) else: eval_metrics[metric_name] = metric_fn(logits, features) if isinstance(logits, dict): predictions = logits else: predictions = {"predictions": logits} return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions=predictions, eval_metric_ops=eval_metrics, loss=loss)
def estimator_spec_eval(self, features, logits, labels, loss, restore_hook, use_tpu): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams problem = hparams.problem if logits.get_shape().ndims == 3: logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # Support for multiproblem task_list = [problem] if hasattr(problem, "task_list"): task_list = problem.task_list eval_metrics_fns = metrics.create_evaluation_metrics( task_list, hparams) if use_tpu: def metric_fn(tf_logits, labels): with tf.device("cpu:0"), mtf.utils.outside_all_rewrites(): eval_metrics = {} for metric_name, metric_fn in six.iteritems( eval_metrics_fns): if metric_name.split( "/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST: eval_metrics[metric_name] = metric_fn( tf_logits, None, tf.identity(labels)) return eval_metrics return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=loss, eval_metrics=(metric_fn, [logits, labels])) else: eval_metrics = {} predictions = {"predictions": logits} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.EVAL, predictions=predictions, eval_metric_ops=eval_metrics, evaluation_hooks=[restore_hook], loss=loss)
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for EVAL mode.""" del losses_dict hparams = self.hparams problem = hparams.problem if common_layers.is_xla_compiled(): raise NotImplementedError("TPU usage is not supported") outputs = tf.contrib.framework.nest.map_structure( lambda x: tf.squeeze(tf.argmax(x, axis=-1), axis=[2, 3]), logits) if hasattr(problem, "compute_predictions"): predictions = problem.compute_predictions(outputs, features, hparams, decode=False) else: predictions = outputs problem_metrics = problem.eval_metrics() if isinstance(problem_metrics, list): eval_metrics = metrics.create_evaluation_metrics([problem], hparams) for metric_name, metric_fn in eval_metrics.items(): eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) else: eval_metrics = {} for metric_key, metric_fn in problem_metrics.items(): metric_name = "metrics-%s/%s" % (problem.name, metric_key) first, second = metric_fn(predictions, labels, features) if isinstance(second, tf.Tensor): scores, weights = first, second eval_metrics[metric_name] = tf.metrics.mean( scores, weights) else: value, update_op = first, second eval_metrics[metric_name] = (value, update_op) return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.EVAL, eval_metric_ops=eval_metrics, loss=loss)
def _gpu_estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for GPU EVAL mode.""" hparams = self.hparams if not hasattr(hparams, "problem"): raise NotImplementedError( "hparams is missing attribute `problem`. NasSeq2Seq must " "be used with a problem.") # TPU is not supported. eval_metrics_fns = metrics.create_evaluation_metrics([hparams.problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): if "rouge" not in metric_name and "bleu" not in metric_name: eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.EVAL, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss)
def estimator_spec_eval( self, features, logits, labels, loss, restore_hook, use_tpu): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams problem = hparams.problem if logits.get_shape().ndims == 3: logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) if use_tpu: def metric_fn(tf_logits, labels): with tf.device("cpu:0"), mtf.utils.outside_all_rewrites(): eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): if metric_name.split("/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST: eval_metrics[metric_name] = metric_fn( tf_logits, None, tf.identity(labels)) return eval_metrics return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=loss, eval_metrics=(metric_fn, [logits, labels])) else: eval_metrics = {} predictions = {"predictions": logits} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions=predictions, eval_metric_ops=eval_metrics, evaluation_hooks=[restore_hook], loss=loss)
def model_fn(model, features, mode, hparams, problem_names, train_steps=100000, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op * EVAL: Constructs the loss and eval metrics * PREDICT: Constructs the predictions Args: model: str, name of model. features: dict<feature name, Tensor>. Expected to have keys {inputs, targets, problem_choice}. mode: tf.estimator.ModeKeys. hparams: model HParams. problem_names: list of str, names of the problems. train_steps: int, total number of training steps. Used to compute learning rate decay. worker_id: int, id of this worker. worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. Returns: tf.estimator.EstimatorSpec """ assert len(problem_names) == len(hparams.problem_instances) decode_hp = decode_hparams # TODO(rsepassi): This still depends on FLAGS. Rm eventually. dp = devices.data_parallelism() tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams)) is_training = mode == tf.estimator.ModeKeys.TRAIN # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) nonpadding_tokens = tf.reduce_sum(nonpadding) if k == "targets": targets_nonpadding_tokens = nonpadding_tokens tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( hparams, mode, hparams.problems[n], n, dp, devices.ps_devices(all_workers=True)) if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer( features, beam_size=decode_hp.beam_size, top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1), last_position_only=decode_hp.use_last_position_only, alpha=decode_hp.alpha, decode_length=decode_hp.extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. skipping_is_on = hparams.problem_choice == "distributed" and is_training problem_worker_id = worker_id % len(hparams.problems) skip_this_one = n != 0 and n % worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. # TODO(lukaszkaiser): why is this hack needed for variables init? Repair. skip_this_one = skip_this_one and (worker_id != 0 or n > 1) if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: sharded_logits, losses_dict = model_class.eval_autoregressive( features) else: sharded_logits, losses_dict = model_class.model_fn( features, skip=(skipping_is_on and skip_this_one)) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] for loss_key, loss_value in six.iteritems(losses_dict): loss_name = "problem_%d/%s_loss" % (n, loss_key) loss_moving_avg = tf.get_variable(loss_name, initializer=100.0, trainable=False) loss_variable_names.append(loss_name) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value try: # Total loss avg might be reused or not, we try both. with tf.variable_scope(tf.get_variable_scope(), reuse=True): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) except ValueError: loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. problem_steps = tf.get_variable("problem_%d_steps" % n, initializer=0, trainable=False) ops.append(problem_steps.assign_add(1)) with tf.control_dependencies(ops): # Make sure the ops run. # Ensure the loss is a scalar here. total_loss = tf.reshape(total_loss, [], name="total_loss_control_id") return [total_loss, tf.concat(sharded_logits, 0)] model_output = input_fn_builder.cond_on_index( nth_model, index_tensor=features["problem_choice"], max_idx=len(hparams.problems) - 1) if mode == tf.estimator.ModeKeys.PREDICT: # If beam searching, model_output will be a dict with keys "outputs" and # "scores". if isinstance(model_output, dict): outputs = model_output["outputs"] scores = model_output["scores"] else: outputs = model_output scores = None batched_problem_choice = (features["problem_choice"] * tf.ones( (tf.shape(features["inputs"])[0], ), dtype=tf.int32)) predictions = { "outputs": outputs, "scores": scores, "inputs": features.get("inputs", None), "targets": features.get("infer_targets", None), "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) export_out = {"outputs": predictions["outputs"]} if "scores" in predictions: export_out["scores"] = predictions["scores"] return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(export_out) }) total_loss, logits = model_output if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( zip(problem_names, hparams.problem_instances), hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec(mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=total_loss) assert mode == tf.estimator.ModeKeys.TRAIN # Set learning rate learning_rate = hparams.learning_rate * learning_rate_decay( hparams, num_worker_replicas=worker_replicas, num_train_steps=train_steps) learning_rate /= math.sqrt(float(worker_replicas)) # Get global step global_step = tf.train.get_or_create_global_step() # Some training statistics. with tf.name_scope("training_stats"): tf.summary.scalar("learning_rate", learning_rate) for n in xrange(len(hparams.problems)): names_and_vars = [] with tf.variable_scope("losses_avg", reuse=True): total_loss_var = tf.get_variable("problem_%d/total_loss" % n) names_and_vars.append(("total_loss", total_loss_var)) with tf.variable_scope("losses_avg", reuse=True): for loss_name in loss_variable_names: if loss_name.startswith("problem_%d/" % n): loss_var = tf.get_variable(loss_name) loss_suffix = loss_name[loss_name.index("/") + 1:] names_and_vars.append((loss_suffix, loss_var)) for (loss_name, loss_var) in names_and_vars: tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var) with tf.variable_scope("train_stats", reuse=True): nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32) tf.summary.scalar( "problem_%d_frequency" % n, tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0)) # Add weight decay and noise. total_size, weight_decay_loss = 0, 0.0 all_weights = {v.name: v for v in tf.trainable_variables()} for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) total_size += v_size if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). with tf.device(v._ref().device): # pylint: disable=protected-access v_loss = tf.nn.l2_loss(v) / v_size weight_decay_loss += v_loss is_body = len(v_name) > 5 and v_name[:5] == "body/" if hparams.weight_noise > 0.0 and is_body: # Add weight noise if set in hparams. with tf.device(v._ref().device): # pylint: disable=protected-access scale = learning_rate * 0.001 noise = tf.truncated_normal( v.shape) * hparams.weight_noise * scale noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) if hparams.weight_decay > 0.0: total_loss += weight_decay_loss * hparams.weight_decay # The new data reader occasionally emits very small batches, which # cause the examples in those batches to be grossly overweighted. # We decrease the loss proportionally to the ratio of the size of this # batch to the size of the largest training batch ever. # TODO(noam): to be more sophisticated, we could keep separate # maxima based on problem choice. max_nonpadding_var = tf.get_variable("max_nonpadding", shape=[], initializer=tf.ones_initializer(), trainable=False) max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) with tf.control_dependencies( [tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) total_loss *= small_batch_multiplier # Log variable sizes _log_variable_sizes(tf.trainable_variables(), "Trainable Variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] _log_variable_sizes(diet_vars, "Diet Variables") # Optimize total_loss = tf.identity(total_loss, name="total_loss") opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams) opt_summaries = ["learning_rate", "loss"] if hparams.summarize_grads: opt_summaries.extend(["gradients", "gradient_norm"]) tf.logging.info("Computing gradients for global model_fn.") train_op = tf.contrib.layers.optimize_loss( name="training", loss=total_loss, global_step=global_step, learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True) # Remove summaries that will fail to run because they are in conditionals. # TODO(cwhipkey): Test with this code removed, later in 2017. summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) for i in reversed(range(len(summaries))): if summaries[i].name.startswith("cond_"): del summaries[i] tf.logging.info("Global model_fn finished.") return tf.estimator.EstimatorSpec( mode, predictions={"problem_choice": features["problem_choice"]}, loss=total_loss, train_op=train_op)
def model_fn(features, labels, mode, params): """Creates the prediction, loss, and train ops. Args: features: A dictionary of tensors keyed by the feature name. labels: A tensor representing the labels. mode: The execution mode, as defined in tf.estimator.ModeKeys. params: model HParams. Returns: An EstimatorSpec. """ hparams = params # Deep-copy the model hparams between modes to eliminate # side-effects caused by abuse of the linked problem_hparams # objects which are used to share modality objects between # problems. We do not want to share the modality objects between # modes, since the modality objects may decide to do something # mode-specific. A better fix would be to stop abusing the # hparams in this way and instead use a separate dictionary to # share the modality objects between problems. This dictionary # could be created once per mode and passed to the constructor of # t2t_model. my_hp = copy.deepcopy(hparams) def initializer(): if hparams.initializer == "orthogonal": return tf.orthogonal_initializer(gain=hparams.initializer_gain) elif hparams.initializer == "uniform": max_val = 0.1 * hparams.initializer_gain return tf.random_uniform_initializer(-max_val, max_val) elif hparams.initializer == "normal_unit_scaling": return init_ops.variance_scaling_initializer( hparams.initializer_gain, mode="fan_avg", distribution="normal") elif hparams.initializer == "uniform_unit_scaling": return init_ops.variance_scaling_initializer( hparams.initializer_gain, mode="fan_avg", distribution="uniform") else: raise ValueError("Unrecognized initializer: %s" % hparams.initializer) def learning_rate_decay(): """Inverse-decay learning rate until warmup_steps, then decay.""" warmup_steps = tf.to_float(hparams.learning_rate_warmup_steps * FLAGS.worker_replicas) step = tf.to_float(tf.contrib.framework.get_global_step()) if hparams.learning_rate_decay_scheme == "noam": return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum( (step + 1) * warmup_steps**-1.5, (step + 1)**-0.5) elif hparams.learning_rate_decay_scheme == "exp100k": return 0.94**(step // 100000) elif hparams.learning_rate_decay_scheme == "cosine": cycle_steps = hparams.learning_rate_cosine_cycle_steps return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps)) elif hparams.learning_rate_decay_scheme == "cyclelinear10x": # Cycle the rate linearly by 10x every warmup_steps, up and down. cycle_steps = hparams.learning_rate_warmup_steps cycle_position = step % (2 * cycle_steps) cycle_position = tf.to_float( # Normalize to the interval [-1, 1]. cycle_position - cycle_steps) / float(cycle_steps) cycle_position = 1.0 - tf.abs( cycle_position) # 0 to 1 and back to 0. return (cycle_position + 0.1) * 3.0 # 10x difference each cycle (0.3-3). inv_base = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = inv_base**(warmup_steps - step) if hparams.learning_rate_decay_scheme == "sqrt": decay = _sqrt_decay(step - warmup_steps) elif hparams.learning_rate_decay_scheme == "exp10k": decay = _exp_decay_after( step - warmup_steps, 0.9995, FLAGS.train_steps - warmup_steps - 10000) elif hparams.learning_rate_decay_scheme == "exp50k": decay = _exp_decay_after( step - warmup_steps, 0.99995, FLAGS.train_steps - warmup_steps - 50000) elif hparams.learning_rate_decay_scheme == "exp500k": decay = _exp_decay_after( step - warmup_steps, 0.9999955, FLAGS.train_steps - warmup_steps - 500000) elif hparams.learning_rate_decay_scheme == "none": decay = tf.constant(1.0) else: raise ValueError( "Unrecognized learning rate decay scheme: %s" % hparams.learning_rate_decay_scheme) return tf.cond(step < warmup_steps, lambda: inv_decay, lambda: decay, name="learning_rate_decay_warump_cond") if labels is not None: features["targets"] = labels dp = devices.data_parallelism() tf.get_variable_scope().set_initializer(initializer()) is_training = mode == tf.estimator.ModeKeys.TRAIN # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) nonpadding_tokens = tf.reduce_sum(nonpadding) if k == "targets": targets_nonpadding_tokens = nonpadding_tokens tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) if is_training: # The new data reader occasionally emits very small batches, which # cause the examples in those batches to be grossly overweighted. # We decrease the loss proportionally to the ratio of the size of this # batch to the size of the largest training batch ever. # TODO(noam): to be more sophisticated, we could keep separate # maxima based on problem choice. max_nonpadding_var = tf.get_variable( "max_nonpadding", shape=[], initializer=tf.ones_initializer(), trainable=False) max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) with tf.control_dependencies( [tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( my_hp, mode, my_hp.problems[n], n, dp, devices.ps_devices(all_workers=True)) if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer( features, beam_size=FLAGS.decode_beam_size, top_beams=(FLAGS.decode_beam_size if FLAGS.decode_return_beams else 1), last_position_only=FLAGS.decode_use_last_position_only, alpha=FLAGS.decode_alpha, decode_length=FLAGS.decode_extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. skipping_is_on = my_hp.problem_choice == "distributed" and is_training problem_worker_id = FLAGS.worker_id % len(my_hp.problems) skip_this_one = n != 0 and n % FLAGS.worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. # TODO(lukaszkaiser): why is this hack needed for variables init? Repair. skip_this_one = skip_this_one and (FLAGS.worker_id != 0 or n > 1) if (FLAGS.eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL): sharded_logits, losses_dict = model_class.eval_autoregressive( features) else: sharded_logits, losses_dict = model_class.model_fn( features, skip=(skipping_is_on and skip_this_one)) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] for loss_key, loss_value in six.iteritems(losses_dict): loss_name = "problem_%d/%s_loss" % (n, loss_key) loss_moving_avg = tf.get_variable(loss_name, initializer=100.0, trainable=False) loss_variable_names.append(loss_name) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value try: # Total loss avg might be reused or not, we try both. with tf.variable_scope(tf.get_variable_scope(), reuse=True): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable( "problem_%d/total_loss" % n) except ValueError: loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope( "train_stats"): # Count steps for this problem. problem_steps = tf.get_variable("problem_%d_steps" % n, initializer=0, trainable=False) ops.append(problem_steps.assign_add(1)) with tf.control_dependencies(ops): # Make sure the ops run. # Ensure the loss is a scalar here. total_loss = tf.reshape(total_loss, [], name="total_loss_control_id") return [total_loss ] + sharded_logits # Need to flatten for cond later. result_list = input_fn_builder.cond_on_index( nth_model, features["problem_choice"], 0, len(my_hp.problems) - 1) if mode == tf.estimator.ModeKeys.PREDICT: # Beam search in sequence model returns both decodes withe key "outputs" # and scores with they key "scores". If return list is a dict, we expect # that it will have keys "outputs", a tensor of int32 and scores, a # tensor of floats. This is useful if we want to return scores from # estimator.predict if not isinstance(result_list, dict): predictions = {"outputs": result_list} else: predictions = { "outputs": result_list["outputs"], "scores": result_list["scores"] } if "inputs" in features: predictions["inputs"] = features["inputs"] if "infer_targets" in features: predictions["targets"] = features["infer_targets"] predictions["problem_choice"] = ( features["problem_choice"] * tf.ones( (tf.shape(features["inputs"])[0], ), dtype=tf.int32)) return tf.estimator.EstimatorSpec(mode, predictions=predictions) sharded_logits, total_loss = result_list[1:], result_list[0] if mode == tf.estimator.ModeKeys.EVAL: # For evaluation, return the logits layer as our predictions. logits = tf.concat(sharded_logits, 0) eval_metrics_fns = metrics.create_evaluation_metrics( zip(FLAGS.problems.split("-"), hparams.problem_instances), hparams) _check_autotune_metrics(eval_metrics_fns) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn( logits, labels, features["problem_choice"]) return tf.estimator.EstimatorSpec( mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=total_loss) assert mode == tf.estimator.ModeKeys.TRAIN # Some training statistics. with tf.name_scope("training_stats"): learning_rate = my_hp.learning_rate * learning_rate_decay() learning_rate /= math.sqrt(float(FLAGS.worker_replicas)) tf.summary.scalar("learning_rate", learning_rate) global_step = tf.to_float(tf.contrib.framework.get_global_step()) for n in xrange(len(my_hp.problems)): names_and_vars = [] with tf.variable_scope("losses_avg", reuse=True): total_loss_var = tf.get_variable("problem_%d/total_loss" % n) names_and_vars.append(("total_loss", total_loss_var)) with tf.variable_scope("losses_avg", reuse=True): for loss_name in loss_variable_names: if loss_name.startswith("problem_%d/" % n): loss_var = tf.get_variable(loss_name) loss_suffix = loss_name[loss_name.index("/") + 1:] names_and_vars.append((loss_suffix, loss_var)) for (loss_name, loss_var) in names_and_vars: tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var) with tf.variable_scope("train_stats", reuse=True): nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32) tf.summary.scalar("problem_%d_frequency" % n, tf.to_float(nth_steps) / (global_step + 1.0)) # Log trainable weights and add decay. total_size, weight_decay_loss = 0, 0.0 all_weights = {v.name: v for v in tf.trainable_variables()} for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) total_size += v_size if my_hp.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). with tf.device(v._ref().device): # pylint: disable=protected-access v_loss = tf.nn.l2_loss(v) / v_size weight_decay_loss += v_loss is_body = len(v_name) > 5 and v_name[:5] == "body/" if my_hp.weight_noise > 0.0 and is_body: # Add weight noise if set in my_hp. with tf.device(v._ref().device): # pylint: disable=protected-access scale = learning_rate * 0.001 noise = tf.truncated_normal( v.shape) * my_hp.weight_noise * scale noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) if my_hp.weight_decay > 0.0: total_loss += weight_decay_loss * my_hp.weight_decay if is_training: total_loss *= small_batch_multiplier total_loss = tf.identity(total_loss, name="total_loss") log_variable_sizes(tf.trainable_variables(), "Trainable Variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes(diet_vars, "Diet Varaibles") # Define the train_op for the TRAIN mode. opt = _ConditionalOptimizer(my_hp.optimizer, learning_rate, my_hp) tf.logging.info("Computing gradients for global model_fn.") opt_summaries = ["learning_rate", "loss"] if hparams.summarize_grads: opt_summaries.extend(["gradients", "gradient_norm"]) train_op = tf.contrib.layers.optimize_loss( name="training", loss=total_loss, global_step=tf.train.get_global_step(), learning_rate=learning_rate, clip_gradients=my_hp.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True) # Remove summaries that will fail to run because they are in conditionals. # TODO(cwhipkey): Test with this code removed, later in 2017. summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) for i in range(len(summaries) - 1, -1, -1): if summaries[i].name.startswith("cond_"): del summaries[i] tf.logging.info("Global model_fn finished.") return tf.estimator.EstimatorSpec( mode, predictions={"problem_choice": features["problem_choice"]}, loss=total_loss, train_op=train_op)
def model_fn(model, features, mode, hparams, problem_names, train_steps=100000, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op * EVAL: Constructs the loss and eval metrics * PREDICT: Constructs the predictions Args: model: str, name of model. features: dict<feature name, Tensor>. Expected to have keys {inputs, targets, problem_choice}. mode: tf.estimator.ModeKeys. hparams: model HParams. problem_names: list of str, names of the problems. train_steps: int, total number of training steps. Used to compute learning rate decay. worker_id: int, id of this worker. worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. Returns: tf.estimator.EstimatorSpec """ assert len(problem_names) == len(hparams.problem_instances) decode_hp = decode_hparams # TODO(rsepassi): This still depends on FLAGS. Rm eventually. dp = devices.data_parallelism() tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams)) # set the initializer functions is_training = mode == tf.estimator.ModeKeys.TRAIN # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) nonpadding_tokens = tf.reduce_sum( nonpadding) # non zeros tokens if k == "targets": targets_nonpadding_tokens = nonpadding_tokens tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( hparams, mode, hparams.problems[n], n, dp, devices.ps_devices(all_workers=True), decode_hparams=decode_hparams ) # initialize transformer model class: hparams, modalities if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer(features, beam_size=decode_hp.beam_size, top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1), alpha=decode_hp.alpha, decode_length=decode_hp.extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. skipping_is_on = hparams.problem_choice == "distributed" and is_training problem_worker_id = worker_id % len(hparams.problems) skip_this_one = n != 0 and n % worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. # TODO(lukaszkaiser): why is this hack needed for variables init? Repair. skip_this_one = skip_this_one and (worker_id != 0 or n > 1) mrt_samples = getattr(hparams, 'mrt_samples', None) if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: # evaluation mode sharded_logits, losses_dict = model_class.eval_autoregressive( features) else: # training mode if hparams.rl: # generate sample data, it will automatically sharded, samples shape [batch, time, 1, 1] if model_class._num_datashards == 1: # work on single GPU cards, fast sample print("###Work on Single GPU card, Use Fast Decode.###") train_beam = getattr(hparams, 'train_beam', None) if mrt_samples: samples, _ = model_class._fast_decode( features, decode_length=50, beam_size=mrt_samples, top_beams=mrt_samples) inputs = tf.squeeze(tf.squeeze(features["inputs"], axis=-1), axis=-1) targets = tf.squeeze(tf.squeeze(features["targets"], axis=-1), axis=-1) batch_size = tf.shape(inputs)[0] inputs_len = tf.shape(inputs)[1] targets_len = tf.shape(targets)[1] inputs_tile = tf.tile(inputs, [1, mrt_samples]) targets_tile = tf.tile(targets, [1, mrt_samples]) inputs_reshape = tf.reshape( inputs_tile, [batch_size * mrt_samples, inputs_len]) targets_reshape = tf.reshape( targets_tile, [batch_size * mrt_samples, targets_len]) inputs_feed = tf.expand_dims(tf.expand_dims( inputs_reshape, axis=-1), axis=-1) targets_feed = tf.expand_dims(tf.expand_dims( targets_reshape, axis=-1), axis=-1) features["inputs"] = inputs_feed features["targets"] = targets_feed elif train_beam and train_beam != 1: # beam search with hparams.train_beam size and return the top 1 sample samples, _ = model_class._fast_decode( features, decode_length=50, beam_size=hparams.train_beam) else: targets_beam = getattr(hparams, 'targets_beam', None) if targets_beam: targets_samples, _ = model_class._fast_decode( features, decode_length=50, beam_size=4, sampling_method='argmax') targets_samples = tf.reshape( targets_samples, [ tf.shape(targets_samples)[0], tf.shape(targets_samples)[1], 1, 1 ]) features["targets"] = targets_samples samples, _ = model_class._fast_decode(features, decode_length=50) samples = tf.expand_dims(samples, axis=-1) samples = tf.expand_dims( samples, axis=-1 ) # add two additional dimensions to make it compatible. else: # work on multi GPU cards, only support slow sample print("###Work on Multi GPU cards, Use Slow Decode.###") samples, _, _ = model_class._slow_greedy_infer( features, decode_length=50) # default decode_length = 50 samples = tf.stop_gradient(samples) # calculate bleu score use metric_fn # train_metric_fn = "approx_bleu_train_score" train_metric_fn = metrics.METRICS_FNS[ metrics.Metrics.APPROX_BLEU_TRAIN] labels = features.get("targets", None) samples.set_shape([None, None, 1, 1]) # haprams.delta_reward = True for delta reward; False for total reward metric_value = train_metric_fn( samples, labels, delat_reward=hparams.delta_reward) metric_value = tf.stop_gradient( metric_value) # to be more strict of the gradient metric_value.set_shape([None, None, 1, 1]) """Accodring to the metrics.py: The tf.metrics.mean function assures correct aggregation.""" # metric_value is total_reward: scalar features["samples"] = samples features["values"] = metric_value # del samples # del labels sharded_logits, losses_dict = model_class.model_fn( features, skip=(skipping_is_on and skip_this_one), mrt=mrt_samples) # if hparams.rl: # training_loss = losses_dict["training"] * metric_value # losses_dict["training"]: [batch, timesteps] # training_loss_sum = tf.reduce_sum(training_loss) # sum the training_loss # losses_dict["training"] = training_loss_sum # log_prob * r (current r is total_reward) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] for loss_key, loss_value in six.iteritems(losses_dict): if hparams.rl: baseline_loss_weight = getattr(hparams, 'baseline_loss_weight', 1.0) training_loss_weight = getattr(hparams, 'training_loss_weight', 1.0) mle_training_loss_weight = getattr( hparams, 'mle_training_loss_weight', 0.3) if loss_key == "training": loss_value = loss_value * training_loss_weight elif loss_key == "training_baseline": loss_value = loss_value * baseline_loss_weight elif loss_key == "mle_training": loss_value = loss_value * mle_training_loss_weight loss_name = "problem_%d/%s_loss" % (n, loss_key) loss_moving_avg = tf.get_variable(loss_name, initializer=100.0, trainable=False) loss_variable_names.append(loss_name) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value try: # Total loss avg might be reused or not, we try both. with tf.variable_scope(tf.get_variable_scope(), reuse=True): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) except ValueError: loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. problem_steps = tf.get_variable("problem_%d_steps" % n, initializer=0, trainable=False) ops.append(problem_steps.assign_add(1)) with tf.control_dependencies(ops): # Make sure the ops run. # Ensure the loss is a scalar here. total_loss = tf.reshape(total_loss, [], name="total_loss_control_id") return [total_loss, tf.concat(sharded_logits, 0)] model_output = input_fn_builder.cond_on_index( nth_model, index_tensor=features["problem_choice"], max_idx=len(hparams.problems) - 1) # total_loss and shared_logits if mode == tf.estimator.ModeKeys.PREDICT: # If beam searching, model_output will be a dict with keys "outputs" and # "scores". if isinstance(model_output, dict): # beam search outputs = model_output["outputs"] scores = model_output["scores"] else: outputs = model_output scores = None batched_problem_choice = (features["problem_choice"] * tf.ones( (tf.shape(features["inputs"])[0], ), dtype=tf.int32)) predictions = { "outputs": outputs, "scores": scores, "inputs": features.get("inputs", None), "targets": features.get("infer_targets", None), "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) # delete the empty ones in predictions export_out = {"outputs": predictions["outputs"]} if "scores" in predictions: export_out["scores"] = predictions["scores"] return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(export_out) }) total_loss, logits = model_output if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( hparams.problem_instances, hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec(mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=total_loss) assert mode == tf.estimator.ModeKeys.TRAIN # Set learning rate learning_rate = hparams.learning_rate * optimize.learning_rate_decay( hparams, num_worker_replicas=worker_replicas, num_train_steps=train_steps) learning_rate /= math.sqrt(float(worker_replicas)) # Get global step global_step = tf.train.get_or_create_global_step() # Some training statistics. with tf.name_scope("training_stats"): tf.summary.scalar("learning_rate", learning_rate) for n in xrange(len(hparams.problems)): names_and_vars = [] with tf.variable_scope("losses_avg", reuse=True): total_loss_var = tf.get_variable("problem_%d/total_loss" % n) names_and_vars.append(("total_loss", total_loss_var)) with tf.variable_scope("losses_avg", reuse=True): for loss_name in loss_variable_names: if loss_name.startswith("problem_%d/" % n): loss_var = tf.get_variable(loss_name) loss_suffix = loss_name[loss_name.index("/") + 1:] names_and_vars.append((loss_suffix, loss_var)) for (loss_name, loss_var) in names_and_vars: tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var) with tf.variable_scope("train_stats", reuse=True): nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32) tf.summary.scalar( "problem_%d_frequency" % n, tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0)) # Add weight decay and noise. total_size, weight_decay_loss = 0, 0.0 all_weights = {v.name: v for v in tf.trainable_variables()} for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) total_size += v_size if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). with tf.device(v._ref().device): # pylint: disable=protected-access v_loss = tf.nn.l2_loss(v) / v_size weight_decay_loss += v_loss is_body = len(v_name) > 5 and v_name[:5] == "body/" if hparams.weight_noise > 0.0 and is_body: # Add weight noise if set in hparams. with tf.device(v._ref().device): # pylint: disable=protected-access scale = learning_rate * 0.001 noise = tf.truncated_normal( v.shape) * hparams.weight_noise * scale noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) if hparams.weight_decay > 0.0: total_loss += weight_decay_loss * hparams.weight_decay # The new data reader occasionally emits very small batches, which # cause the examples in those batches to be grossly overweighted. # We decrease the loss proportionally to the ratio of the size of this # batch to the size of the largest training batch ever. # TODO(noam): to be more sophisticated, we could keep separate # maxima based on problem choice. max_nonpadding_var = tf.get_variable("max_nonpadding", shape=[], initializer=tf.ones_initializer(), trainable=False) max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) with tf.control_dependencies( [tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) total_loss *= small_batch_multiplier # Log variable sizes _log_variable_sizes(tf.trainable_variables(), "Trainable Variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] _log_variable_sizes(diet_vars, "Diet Variables") # Optimize train_op = optimize.optimize(total_loss, learning_rate, hparams) # Remove summaries that will fail to run because they are in conditionals. # TODO(cwhipkey): Test with this code removed, later in 2017. summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) for i in reversed(range(len(summaries))): if summaries[i].name.startswith("cond_"): del summaries[i] tf.logging.info("Global model_fn finished.") return tf.estimator.EstimatorSpec( mode, predictions={"problem_choice": features["problem_choice"]}, loss=total_loss, train_op=train_op)
def model_fn(features, labels, mode, params, config): """Model fn.""" del params del config create_dummy_vars() hparams = copy.deepcopy(hp) problem_hp = hparams.problems[0] orig_features = features # Instantiate model and retrieve modalities. Note that autoregressive models # have no input modality. model_class = registry.model(model_name)(hparams, mode, problem_hp) input_modality = problem_hp.input_modality.get("inputs") target_modality = problem_hp.target_modality # Transform features transformed_features = {} if input_modality is not None: with tf.variable_scope(input_modality.name): transformed_features["inputs"] = input_modality.bottom( features["inputs"]) with tf.variable_scope(target_modality.name): transformed_features["targets"] = target_modality.targets_bottom( features["targets"]) transformed_features["problem_choice"] = tf.constant(0) transformed_features["input_space_id"] = tf.constant( problem_hp.input_space_id) transformed_features["target_space_id"] = tf.constant( problem_hp.target_space_id) # Model construction with tf.variable_scope("body"): outputs = model_class.model_fn_body(transformed_features) with tf.variable_scope(target_modality.name): logits = target_modality.top(outputs, labels) # If the length dim is unknown fix it to max_length if use_tpu and logits.get_shape().as_list()[1] is None: shape = logits.get_shape().as_list() shape[1] = hparams.max_length logits.set_shape(shape) # Loss loss_num, loss_den = target_modality.loss(logits, labels) loss = loss_num / tf.maximum(1.0, loss_den) if mode == tf.estimator.ModeKeys.EVAL: problem = hp.problem_instances[0] if use_tpu: eval_metrics_fn = create_eval_metrics_fn(problem) _remove_summaries() return tf.contrib.tpu.TPUEstimatorSpec( mode, eval_metrics=(eval_metrics_fn, [logits, orig_features["targets"]]), loss=loss) else: eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss) assert mode == tf.estimator.ModeKeys.TRAIN # Learning rate lr = hparams.learning_rate * optimize.learning_rate_decay(hparams) # Optimizer opt = optimize.ConditionalOptimizer(hparams.optimizer, lr, hparams) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) # Optimize gradients = opt.compute_gradients(loss, tf.trainable_variables()) if hparams.clip_grad_norm: gradients = _clip_gradients_by_norm(gradients, hparams.clip_grad_norm) train_op = opt.apply_gradients( gradients, global_step=tf.train.get_or_create_global_step()) with tf.control_dependencies([train_op]): train_op = tf.identity(loss) _remove_summaries() if use_tpu: return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) else: return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def model_fn(model, features, mode, hparams, problem_names, train_steps=100000, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op * EVAL: Constructs the loss and eval metrics * PREDICT: Constructs the predictions Args: model: str, name of model. features: dict<feature name, Tensor>. Expected to have keys {inputs, targets, problem_choice}. mode: tf.estimator.ModeKeys. hparams: model HParams. problem_names: list of str, names of the problems. train_steps: int, total number of training steps. Used to compute learning rate decay. worker_id: int, id of this worker. worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. Returns: tf.estimator.EstimatorSpec """ assert len(problem_names) == len(hparams.problem_instances) decode_hp = decode_hparams # TODO(rsepassi): This still depends on FLAGS. Rm eventually. dp = devices.data_parallelism(hparams) tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams)) is_training = mode == tf.estimator.ModeKeys.TRAIN # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) nonpadding_tokens = tf.reduce_sum(nonpadding) if k == "targets": targets_nonpadding_tokens = nonpadding_tokens tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( hparams, mode, hparams.problems[n], n, dp, devices.ps_devices(all_workers=True), decode_hparams=decode_hparams) if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer( features, beam_size=decode_hp.beam_size, top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1), alpha=decode_hp.alpha, decode_length=decode_hp.extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. skipping_is_on = hparams.problem_choice == "distributed" and is_training problem_worker_id = worker_id % len(hparams.problems) skip_this_one = n != 0 and n % worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. # TODO(lukaszkaiser): why is this hack needed for variables init? Repair. skip_this_one = skip_this_one and (worker_id != 0 or n > 1) if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: logits, losses_dict = model_class.eval_autoregressive(features) else: logits, losses_dict = model_class( features, skip=(skipping_is_on and skip_this_one)) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] for loss_key, loss_value in six.iteritems(losses_dict): loss_name = "problem_%d/%s_loss" % (n, loss_key) loss_moving_avg = tf.get_variable( loss_name, initializer=100.0, trainable=False) loss_variable_names.append(loss_name) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value try: # Total loss avg might be reused or not, we try both. with tf.variable_scope(tf.get_variable_scope(), reuse=True): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) except ValueError: loss_moving_avg = tf.get_variable( "problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. problem_steps = tf.get_variable( "problem_%d_steps" % n, initializer=0, trainable=False) ops.append(problem_steps.assign_add(1)) with tf.control_dependencies(ops): # Make sure the ops run. # Ensure the loss is a scalar here. total_loss = tf.reshape(total_loss, [], name="total_loss_control_id") return [total_loss, logits] model_output = input_fn_builder.cond_on_index( nth_model, index_tensor=features["problem_choice"], max_idx=len(hparams.problems) - 1) if mode == tf.estimator.ModeKeys.PREDICT: # If beam searching, model_output will be a dict with keys "outputs" and # "scores". if isinstance(model_output, dict): outputs = model_output["outputs"] scores = model_output["scores"] else: outputs = model_output scores = None batched_problem_choice = ( features["problem_choice"] * tf.ones( (tf.shape(features["inputs"])[0],), dtype=tf.int32)) predictions = { "outputs": outputs, "scores": scores, "inputs": features.get("inputs", None), "targets": features.get("infer_targets", None), "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) export_out = {"outputs": predictions["outputs"]} if "scores" in predictions: export_out["scores"] = predictions["scores"] return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(export_out) }) total_loss, logits = model_output if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( hparams.problem_instances, hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=total_loss) assert mode == tf.estimator.ModeKeys.TRAIN # Set learning rate learning_rate = hparams.learning_rate * optimize.learning_rate_decay( hparams, num_worker_replicas=worker_replicas, num_train_steps=train_steps) learning_rate /= math.sqrt(float(worker_replicas)) # Get global step global_step = tf.train.get_or_create_global_step() # Some training statistics. with tf.name_scope("training_stats"): tf.summary.scalar("learning_rate", learning_rate) for n in xrange(len(hparams.problems)): names_and_vars = [] with tf.variable_scope("losses_avg", reuse=True): total_loss_var = tf.get_variable("problem_%d/total_loss" % n) names_and_vars.append(("total_loss", total_loss_var)) with tf.variable_scope("losses_avg", reuse=True): for loss_name in loss_variable_names: if loss_name.startswith("problem_%d/" % n): loss_var = tf.get_variable(loss_name) loss_suffix = loss_name[loss_name.index("/") + 1:] names_and_vars.append((loss_suffix, loss_var)) for (loss_name, loss_var) in names_and_vars: tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var) with tf.variable_scope("train_stats", reuse=True): nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32) tf.summary.scalar("problem_%d_frequency" % n, tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0)) # Add weight decay and noise. total_size, weight_decay_loss = 0, 0.0 all_weights = {v.name: v for v in tf.trainable_variables()} for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) total_size += v_size if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). with tf.device(v._ref().device): # pylint: disable=protected-access v_loss = tf.nn.l2_loss(v) / v_size weight_decay_loss += v_loss is_body = len(v_name) > 5 and v_name[:5] == "body/" if hparams.weight_noise > 0.0 and is_body: # Add weight noise if set in hparams. with tf.device(v._ref().device): # pylint: disable=protected-access scale = learning_rate * 0.001 noise = tf.truncated_normal(v.shape) * hparams.weight_noise * scale noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) if hparams.weight_decay > 0.0: total_loss += weight_decay_loss * hparams.weight_decay # The new data reader occasionally emits very small batches, which # cause the examples in those batches to be grossly overweighted. # We decrease the loss proportionally to the ratio of the size of this # batch to the size of the largest training batch ever. # TODO(noam): to be more sophisticated, we could keep separate # maxima based on problem choice. max_nonpadding_var = tf.get_variable( "max_nonpadding", shape=[], initializer=tf.ones_initializer(), trainable=False) max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) with tf.control_dependencies([tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) total_loss *= small_batch_multiplier # Log variable sizes _log_variable_sizes(tf.trainable_variables(), "Trainable Variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] _log_variable_sizes(diet_vars, "Diet Variables") # Optimize train_op = optimize.optimize(total_loss, learning_rate, hparams) # Remove summaries that will fail to run because they are in conditionals. # TODO(cwhipkey): Test with this code removed, later in 2017. summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) for i in reversed(range(len(summaries))): if summaries[i].name.startswith("cond_"): del summaries[i] tf.logging.info("Global model_fn finished.") return tf.estimator.EstimatorSpec( mode, predictions={"problem_choice": features["problem_choice"]}, loss=total_loss, train_op=train_op)