Exemple #1
0
def _decode_input_tensor_to_features_dict(feature_map, hparams):
  """Convert the interactive input format (see above) to a dictionary.

  Args:
    feature_map: a dictionary with keys `problem_choice` and `input` containing
      Tensors.
    hparams: model hyperparameters

  Returns:
    a features dictionary, as expected by the decoder.
  """
  inputs = tf.convert_to_tensor(feature_map["inputs"])
  input_is_image = False

  def input_fn(problem_choice, x=inputs):  # pylint: disable=missing-docstring
    p_hparams = hparams.problems[problem_choice]
    # Add a third empty dimension dimension
    x = tf.expand_dims(x, axis=[2])
    x = tf.to_int32(x)
    return (tf.constant(p_hparams.input_space_id), tf.constant(
        p_hparams.target_space_id), x)

  input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
      input_fn, feature_map["problem_choice"], len(hparams.problems) - 1)

  features = {}
  features["problem_choice"] = feature_map["problem_choice"]
  features["input_space_id"] = input_space_id
  features["target_space_id"] = target_space_id
  features["decode_length"] = (
      IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
  features["inputs"] = x
  return features
Exemple #2
0
def _decode_input_tensor_to_features_dict(feature_map, hparams):
  """Convert the interactive input format (see above) to a dictionary.

  Args:
    feature_map: a dictionary with keys `problem_choice` and `input` containing
      Tensors.
    hparams: model hyperparameters

  Returns:
    a features dictionary, as expected by the decoder.
  """
  inputs = tf.convert_to_tensor(feature_map["inputs"])
  input_is_image = False

  def input_fn(problem_choice, x=inputs):  # pylint: disable=missing-docstring
    p_hparams = hparams.problems[problem_choice]
    # Add a third empty dimension dimension
    x = tf.expand_dims(x, axis=[2])
    x = tf.to_int32(x)
    return (tf.constant(p_hparams.input_space_id), tf.constant(
        p_hparams.target_space_id), x)

  input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
      input_fn, feature_map["problem_choice"], len(hparams.problems) - 1)

  features = {}
  features["problem_choice"] = feature_map["problem_choice"]
  features["input_space_id"] = input_space_id
  features["target_space_id"] = target_space_id
  features["decode_length"] = (
      IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
  features["inputs"] = x
  return features
  def testCondOnIndex(self):
    """Smoke tests of cond_on_index()."""

    z = tf.constant(1., dtype=tf.float32)
    def f(n):
      return {
          "a": z * n,
          "b": z * n * n
      }

    index = tf.placeholder(shape=[], dtype=tf.int32)
    out = input_fn_builder.cond_on_index(f, index, 3, 0)

    with self.test_session() as sess:
      # Check dispatching to the correct branch
      result = sess.run(out, feed_dict={
          index: 2
      })

      self.assertAllClose(result["a"], 2.)
      self.assertAllClose(result["b"], 4.)

      result = sess.run(out, feed_dict={
          index: 3
      })

      self.assertAllClose(result["a"], 3.)
      self.assertAllClose(result["b"], 9.)
Exemple #4
0
def _interactive_input_tensor_to_features_dict(feature_map, hparams):
    """Convert the interactive input format (see above) to a dictionary.

  Args:
    feature_map: a dictionary with keys `problem_choice` and `input` containing
      Tensors.
    hparams: model hyperparameters

  Returns:
    a features dictionary, as expected by the decoder.
  """
    inputs = tf.convert_to_tensor(feature_map["inputs"])
    input_is_image = False if len(inputs.get_shape()) < 3 else True

    def input_fn(problem_choice, x=inputs):  # pylint: disable=missing-docstring
        if input_is_image:
            x = tf.image.resize_images(x, [299, 299])
            x = tf.reshape(x, [1, 299, 299, -1])
            x = tf.to_int32(x)
        else:
            # Remove the batch dimension.
            num_samples = x[0]
            length = x[2]
            x = tf.slice(x, [3], tf.to_int32([length]))
            x = tf.reshape(x, [1, -1, 1, 1])
            # Transform into a batch of size num_samples to get that many random
            # decodes.
            x = tf.tile(x, tf.to_int32([num_samples, 1, 1, 1]))

        p_hparams = hparams.problems[problem_choice]
        return (tf.constant(p_hparams.input_space_id),
                tf.constant(p_hparams.target_space_id), x)

    input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
        input_fn, feature_map["problem_choice"],
        len(hparams.problems) - 1)

    features = {}
    features["problem_choice"] = tf.convert_to_tensor(
        feature_map["problem_choice"])
    features["input_space_id"] = input_space_id
    features["target_space_id"] = target_space_id
    features["decode_length"] = (IMAGE_DECODE_LENGTH
                                 if input_is_image else inputs[1])
    features["inputs"] = x
    return features
Exemple #5
0
def _interactive_input_tensor_to_features_dict(feature_map, hparams):
  """Convert the interactive input format (see above) to a dictionary.

  Args:
    feature_map: a dictionary with keys `problem_choice` and `input` containing
      Tensors.
    hparams: model hyperparameters

  Returns:
    a features dictionary, as expected by the decoder.
  """
  inputs = tf.convert_to_tensor(feature_map["inputs"])
  input_is_image = False if len(inputs.get_shape()) < 3 else True

  def input_fn(problem_choice, x=inputs):  # pylint: disable=missing-docstring
    if input_is_image:
      x = tf.image.resize_images(x, [299, 299])
      x = tf.reshape(x, [1, 299, 299, -1])
      x = tf.to_int32(x)
    else:
      # Remove the batch dimension.
      num_samples = x[0]
      length = x[2]
      x = tf.slice(x, [3], tf.to_int32([length]))
      x = tf.reshape(x, [1, -1, 1, 1])
      # Transform into a batch of size num_samples to get that many random
      # decodes.
      x = tf.tile(x, tf.to_int32([num_samples, 1, 1, 1]))

    p_hparams = hparams.problems[problem_choice]
    return (tf.constant(p_hparams.input_space_id), tf.constant(
        p_hparams.target_space_id), x)

  input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
      input_fn, feature_map["problem_choice"], len(hparams.problems) - 1)

  features = {}
  features["problem_choice"] = tf.convert_to_tensor(
      feature_map["problem_choice"])
  features["input_space_id"] = input_space_id
  features["target_space_id"] = target_space_id
  features["decode_length"] = (
      IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
  features["inputs"] = x
  return features
    def testCondOnIndex(self):
        """Smoke tests of cond_on_index()."""

        z = tf.constant(1., dtype=tf.float32)

        def f(n):
            return {"a": z * n, "b": z * n * n}

        index = tf.placeholder(shape=[], dtype=tf.int32)
        out = input_fn_builder.cond_on_index(f, index, 3, 0)

        with self.test_session() as sess:
            # Check dispatching to the correct branch
            result = sess.run(out, feed_dict={index: 2})

            self.assertAllClose(result["a"], 2.)
            self.assertAllClose(result["b"], 4.)

            result = sess.run(out, feed_dict={index: 3})

            self.assertAllClose(result["a"], 3.)
            self.assertAllClose(result["b"], 9.)
Exemple #7
0
def model_fn(model,
             features,
             mode,
             hparams,
             problem_names,
             train_steps=100000,
             worker_id=0,
             worker_replicas=1,
             eval_run_autoregressive=False,
             decode_hparams=None):
    """Builds the model for all modes.

  * TRAIN: Constructs loss and train_op
  * EVAL: Constructs the loss and eval metrics
  * PREDICT: Constructs the predictions

  Args:
    model: str, name of model.
    features: dict<feature name, Tensor>. Expected to have keys
      {inputs, targets, problem_choice}.
    mode: tf.estimator.ModeKeys.
    hparams: model HParams.
    problem_names: list of str, names of the problems.
    train_steps: int, total number of training steps. Used to compute learning
      rate decay.
    worker_id: int, id of this worker.
    worker_replicas: int, number of workers.
    eval_run_autoregressive: bool, whether to run evaluation autoregressively.
    decode_hparams: HParams for decode settings. Used when mode == PREDICT.

  Returns:
    tf.estimator.EstimatorSpec
  """
    assert len(problem_names) == len(hparams.problem_instances)
    decode_hp = decode_hparams

    # TODO(rsepassi): This still depends on FLAGS. Rm eventually.
    dp = devices.data_parallelism()

    tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams))
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Add input statistics for incoming features.
    with tf.name_scope("input_stats"):
        for (k, v) in six.iteritems(features):
            if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
                tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
                tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
                nonpadding = tf.to_float(tf.not_equal(v, 0))
                nonpadding_tokens = tf.reduce_sum(nonpadding)
                if k == "targets":
                    targets_nonpadding_tokens = nonpadding_tokens
                tf.summary.scalar("%s_nonpadding_tokens" % k,
                                  nonpadding_tokens)
                tf.summary.scalar("%s_nonpadding_fraction" % k,
                                  tf.reduce_mean(nonpadding))

    # Get multi-problem logits and loss based on features["problem_choice"].
    loss_variable_names = []

    def nth_model(n):
        """Build the model for the n-th problem, plus some added variables."""
        model_class = registry.model(model)(
            hparams, mode, hparams.problems[n], n, dp,
            devices.ps_devices(all_workers=True))
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model_class.infer(
                features,
                beam_size=decode_hp.beam_size,
                top_beams=(decode_hp.beam_size
                           if decode_hp.return_beams else 1),
                last_position_only=decode_hp.use_last_position_only,
                alpha=decode_hp.alpha,
                decode_length=decode_hp.extra_length)
        # In distributed mode, we build graph for problem=0 and problem=worker_id.
        skipping_is_on = hparams.problem_choice == "distributed" and is_training
        problem_worker_id = worker_id % len(hparams.problems)
        skip_this_one = n != 0 and n % worker_replicas != problem_worker_id
        # On worker 0 also build graph for problems <= 1.
        # TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
        skip_this_one = skip_this_one and (worker_id != 0 or n > 1)
        if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL:
            sharded_logits, losses_dict = model_class.eval_autoregressive(
                features)
        else:
            sharded_logits, losses_dict = model_class.model_fn(
                features, skip=(skipping_is_on and skip_this_one))
        with tf.variable_scope("losses_avg"):
            total_loss, ops = 0.0, []
            for loss_key, loss_value in six.iteritems(losses_dict):
                loss_name = "problem_%d/%s_loss" % (n, loss_key)
                loss_moving_avg = tf.get_variable(loss_name,
                                                  initializer=100.0,
                                                  trainable=False)
                loss_variable_names.append(loss_name)
                ops.append(
                    loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                           loss_value * 0.1))
                total_loss += loss_value
            try:  # Total loss avg might be reused or not, we try both.
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    # Total loss was already constructed on input.
                    loss_moving_avg = tf.get_variable("problem_%d/total_loss" %
                                                      n)
            except ValueError:
                loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n,
                                                  initializer=100.0,
                                                  trainable=False)
            ops.append(
                loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                       total_loss * 0.1))
        with tf.variable_scope("train_stats"):  # Count steps for this problem.
            problem_steps = tf.get_variable("problem_%d_steps" % n,
                                            initializer=0,
                                            trainable=False)
            ops.append(problem_steps.assign_add(1))
        with tf.control_dependencies(ops):  # Make sure the ops run.
            # Ensure the loss is a scalar here.
            total_loss = tf.reshape(total_loss, [],
                                    name="total_loss_control_id")
        return [total_loss, tf.concat(sharded_logits, 0)]

    model_output = input_fn_builder.cond_on_index(
        nth_model,
        index_tensor=features["problem_choice"],
        max_idx=len(hparams.problems) - 1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # If beam searching, model_output will be a dict with keys "outputs" and
        # "scores".
        if isinstance(model_output, dict):
            outputs = model_output["outputs"]
            scores = model_output["scores"]
        else:
            outputs = model_output
            scores = None

        batched_problem_choice = (features["problem_choice"] * tf.ones(
            (tf.shape(features["inputs"])[0], ), dtype=tf.int32))
        predictions = {
            "outputs": outputs,
            "scores": scores,
            "inputs": features.get("inputs", None),
            "targets": features.get("infer_targets", None),
            "problem_choice": batched_problem_choice,
        }
        _del_dict_nones(predictions)

        export_out = {"outputs": predictions["outputs"]}
        if "scores" in predictions:
            export_out["scores"] = predictions["scores"]

        return tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            export_outputs={
                "output": tf.estimator.export.PredictOutput(export_out)
            })

    total_loss, logits = model_output

    if mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics_fns = metrics.create_evaluation_metrics(
            zip(problem_names, hparams.problem_instances), hparams)

        eval_metrics = {}
        for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
            eval_metrics[metric_name] = metric_fn(logits, features)

        return tf.estimator.EstimatorSpec(mode,
                                          predictions={"predictions": logits},
                                          eval_metric_ops=eval_metrics,
                                          loss=total_loss)

    assert mode == tf.estimator.ModeKeys.TRAIN

    # Set learning rate
    learning_rate = hparams.learning_rate * learning_rate_decay(
        hparams,
        num_worker_replicas=worker_replicas,
        num_train_steps=train_steps)
    learning_rate /= math.sqrt(float(worker_replicas))

    # Get global step
    global_step = tf.train.get_or_create_global_step()

    # Some training statistics.
    with tf.name_scope("training_stats"):
        tf.summary.scalar("learning_rate", learning_rate)
        for n in xrange(len(hparams.problems)):
            names_and_vars = []
            with tf.variable_scope("losses_avg", reuse=True):
                total_loss_var = tf.get_variable("problem_%d/total_loss" % n)
                names_and_vars.append(("total_loss", total_loss_var))
            with tf.variable_scope("losses_avg", reuse=True):
                for loss_name in loss_variable_names:
                    if loss_name.startswith("problem_%d/" % n):
                        loss_var = tf.get_variable(loss_name)
                        loss_suffix = loss_name[loss_name.index("/") + 1:]
                        names_and_vars.append((loss_suffix, loss_var))
            for (loss_name, loss_var) in names_and_vars:
                tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var)
            with tf.variable_scope("train_stats", reuse=True):
                nth_steps = tf.get_variable("problem_%d_steps" % n,
                                            dtype=tf.int32)
            tf.summary.scalar(
                "problem_%d_frequency" % n,
                tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0))

    # Add weight decay and noise.
    total_size, weight_decay_loss = 0, 0.0
    all_weights = {v.name: v for v in tf.trainable_variables()}
    for v_name in sorted(list(all_weights)):
        v = all_weights[v_name]
        v_size = int(np.prod(np.array(v.shape.as_list())))
        total_size += v_size
        if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
            # Add weight regularization if set and the weight is not a bias (dim>1).
            with tf.device(v._ref().device):  # pylint: disable=protected-access
                v_loss = tf.nn.l2_loss(v) / v_size
            weight_decay_loss += v_loss
        is_body = len(v_name) > 5 and v_name[:5] == "body/"
        if hparams.weight_noise > 0.0 and is_body:
            # Add weight noise if set in hparams.
            with tf.device(v._ref().device):  # pylint: disable=protected-access
                scale = learning_rate * 0.001
                noise = tf.truncated_normal(
                    v.shape) * hparams.weight_noise * scale
                noise_op = v.assign_add(noise)
            with tf.control_dependencies([noise_op]):
                total_loss = tf.identity(total_loss)
    if hparams.weight_decay > 0.0:
        total_loss += weight_decay_loss * hparams.weight_decay

    # The new data reader occasionally emits very small batches, which
    # cause the examples in those batches to be grossly overweighted.
    # We decrease the loss proportionally to the ratio of the size of this
    # batch to the size of the largest training batch ever.
    # TODO(noam): to be more sophisticated, we could keep separate
    # maxima based on problem choice.
    max_nonpadding_var = tf.get_variable("max_nonpadding",
                                         shape=[],
                                         initializer=tf.ones_initializer(),
                                         trainable=False)
    max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens)
    with tf.control_dependencies(
        [tf.assign(max_nonpadding_var, max_nonpadding)]):
        small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
    tf.summary.scalar("small_batch_multiplier", small_batch_multiplier)
    total_loss *= small_batch_multiplier

    # Log variable sizes
    _log_variable_sizes(tf.trainable_variables(), "Trainable Variables")
    diet_vars = [
        v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
    ]
    _log_variable_sizes(diet_vars, "Diet Variables")

    # Optimize
    total_loss = tf.identity(total_loss, name="total_loss")
    opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams)
    opt_summaries = ["learning_rate", "loss"]
    if hparams.summarize_grads:
        opt_summaries.extend(["gradients", "gradient_norm"])
    tf.logging.info("Computing gradients for global model_fn.")
    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=total_loss,
        global_step=global_step,
        learning_rate=learning_rate,
        clip_gradients=hparams.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=opt,
        summaries=opt_summaries,
        colocate_gradients_with_ops=True)

    # Remove summaries that will fail to run because they are in conditionals.
    # TODO(cwhipkey): Test with this code removed, later in 2017.
    summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
    for i in reversed(range(len(summaries))):
        if summaries[i].name.startswith("cond_"):
            del summaries[i]

    tf.logging.info("Global model_fn finished.")
    return tf.estimator.EstimatorSpec(
        mode,
        predictions={"problem_choice": features["problem_choice"]},
        loss=total_loss,
        train_op=train_op)
    def model_fn(features, labels, mode, params):
        """Creates the prediction, loss, and train ops.

    Args:
      features: A dictionary of tensors keyed by the feature name.
      labels: A tensor representing the labels.
      mode: The execution mode, as defined in tf.estimator.ModeKeys.
      params: model HParams.

    Returns:
      An EstimatorSpec.
    """
        hparams = params
        # Deep-copy the model hparams between modes to eliminate
        # side-effects caused by abuse of the linked problem_hparams
        # objects which are used to share modality objects between
        # problems.  We do not want to share the modality objects between
        # modes, since the modality objects may decide to do something
        # mode-specific.  A better fix would be to stop abusing the
        # hparams in this way and instead use a separate dictionary to
        # share the modality objects between problems.  This dictionary
        # could be created once per mode and passed to the constructor of
        # t2t_model.
        my_hp = copy.deepcopy(hparams)

        def initializer():
            if hparams.initializer == "orthogonal":
                return tf.orthogonal_initializer(gain=hparams.initializer_gain)
            elif hparams.initializer == "uniform":
                max_val = 0.1 * hparams.initializer_gain
                return tf.random_uniform_initializer(-max_val, max_val)
            elif hparams.initializer == "normal_unit_scaling":
                return init_ops.variance_scaling_initializer(
                    hparams.initializer_gain,
                    mode="fan_avg",
                    distribution="normal")
            elif hparams.initializer == "uniform_unit_scaling":
                return init_ops.variance_scaling_initializer(
                    hparams.initializer_gain,
                    mode="fan_avg",
                    distribution="uniform")
            else:
                raise ValueError("Unrecognized initializer: %s" %
                                 hparams.initializer)

        def learning_rate_decay():
            """Inverse-decay learning rate until warmup_steps, then decay."""
            warmup_steps = tf.to_float(hparams.learning_rate_warmup_steps *
                                       FLAGS.worker_replicas)
            step = tf.to_float(tf.contrib.framework.get_global_step())
            if hparams.learning_rate_decay_scheme == "noam":
                return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum(
                    (step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
            elif hparams.learning_rate_decay_scheme == "exp100k":
                return 0.94**(step // 100000)
            elif hparams.learning_rate_decay_scheme == "cosine":
                cycle_steps = hparams.learning_rate_cosine_cycle_steps
                return 0.5 * (1 + tf.cos(np.pi *
                                         (step % cycle_steps) / cycle_steps))
            elif hparams.learning_rate_decay_scheme == "cyclelinear10x":
                # Cycle the rate linearly by 10x every warmup_steps, up and down.
                cycle_steps = hparams.learning_rate_warmup_steps
                cycle_position = step % (2 * cycle_steps)
                cycle_position = tf.to_float(  # Normalize to the interval [-1, 1].
                    cycle_position - cycle_steps) / float(cycle_steps)
                cycle_position = 1.0 - tf.abs(
                    cycle_position)  # 0 to 1 and back to 0.
                return (cycle_position +
                        0.1) * 3.0  # 10x difference each cycle (0.3-3).

            inv_base = tf.exp(tf.log(0.01) / warmup_steps)
            inv_decay = inv_base**(warmup_steps - step)
            if hparams.learning_rate_decay_scheme == "sqrt":
                decay = _sqrt_decay(step - warmup_steps)
            elif hparams.learning_rate_decay_scheme == "exp10k":
                decay = _exp_decay_after(
                    step - warmup_steps, 0.9995,
                    FLAGS.train_steps - warmup_steps - 10000)
            elif hparams.learning_rate_decay_scheme == "exp50k":
                decay = _exp_decay_after(
                    step - warmup_steps, 0.99995,
                    FLAGS.train_steps - warmup_steps - 50000)
            elif hparams.learning_rate_decay_scheme == "exp500k":
                decay = _exp_decay_after(
                    step - warmup_steps, 0.9999955,
                    FLAGS.train_steps - warmup_steps - 500000)
            elif hparams.learning_rate_decay_scheme == "none":
                decay = tf.constant(1.0)
            else:
                raise ValueError(
                    "Unrecognized learning rate decay scheme: %s" %
                    hparams.learning_rate_decay_scheme)
            return tf.cond(step < warmup_steps,
                           lambda: inv_decay,
                           lambda: decay,
                           name="learning_rate_decay_warump_cond")

        if labels is not None:
            features["targets"] = labels

        dp = devices.data_parallelism()

        tf.get_variable_scope().set_initializer(initializer())
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Add input statistics for incoming features.
        with tf.name_scope("input_stats"):
            for (k, v) in six.iteritems(features):
                if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
                    tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
                    tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
                    nonpadding = tf.to_float(tf.not_equal(v, 0))
                    nonpadding_tokens = tf.reduce_sum(nonpadding)
                    if k == "targets":
                        targets_nonpadding_tokens = nonpadding_tokens
                    tf.summary.scalar("%s_nonpadding_tokens" % k,
                                      nonpadding_tokens)
                    tf.summary.scalar("%s_nonpadding_fraction" % k,
                                      tf.reduce_mean(nonpadding))

            if is_training:
                # The new data reader occasionally emits very small batches, which
                # cause the examples in those batches to be grossly overweighted.
                # We decrease the loss proportionally to the ratio of the size of this
                # batch to the size of the largest training batch ever.
                # TODO(noam): to be more sophisticated, we could keep separate
                # maxima based on problem choice.
                max_nonpadding_var = tf.get_variable(
                    "max_nonpadding",
                    shape=[],
                    initializer=tf.ones_initializer(),
                    trainable=False)
                max_nonpadding = tf.maximum(max_nonpadding_var,
                                            targets_nonpadding_tokens)
                with tf.control_dependencies(
                    [tf.assign(max_nonpadding_var, max_nonpadding)]):
                    small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
                tf.summary.scalar("small_batch_multiplier",
                                  small_batch_multiplier)

        # Get multi-problem logits and loss based on features["problem_choice"].
        loss_variable_names = []

        def nth_model(n):
            """Build the model for the n-th problem, plus some added variables."""
            model_class = registry.model(model)(
                my_hp, mode, my_hp.problems[n], n, dp,
                devices.ps_devices(all_workers=True))
            if mode == tf.estimator.ModeKeys.PREDICT:
                return model_class.infer(
                    features,
                    beam_size=FLAGS.decode_beam_size,
                    top_beams=(FLAGS.decode_beam_size
                               if FLAGS.decode_return_beams else 1),
                    last_position_only=FLAGS.decode_use_last_position_only,
                    alpha=FLAGS.decode_alpha,
                    decode_length=FLAGS.decode_extra_length)
            # In distributed mode, we build graph for problem=0 and problem=worker_id.
            skipping_is_on = my_hp.problem_choice == "distributed" and is_training
            problem_worker_id = FLAGS.worker_id % len(my_hp.problems)
            skip_this_one = n != 0 and n % FLAGS.worker_replicas != problem_worker_id
            # On worker 0 also build graph for problems <= 1.
            # TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
            skip_this_one = skip_this_one and (FLAGS.worker_id != 0 or n > 1)
            if (FLAGS.eval_run_autoregressive
                    and mode == tf.estimator.ModeKeys.EVAL):
                sharded_logits, losses_dict = model_class.eval_autoregressive(
                    features)
            else:
                sharded_logits, losses_dict = model_class.model_fn(
                    features, skip=(skipping_is_on and skip_this_one))
            with tf.variable_scope("losses_avg"):
                total_loss, ops = 0.0, []
                for loss_key, loss_value in six.iteritems(losses_dict):
                    loss_name = "problem_%d/%s_loss" % (n, loss_key)
                    loss_moving_avg = tf.get_variable(loss_name,
                                                      initializer=100.0,
                                                      trainable=False)
                    loss_variable_names.append(loss_name)
                    ops.append(
                        loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                               loss_value * 0.1))
                    total_loss += loss_value
                try:  # Total loss avg might be reused or not, we try both.
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        # Total loss was already constructed on input.
                        loss_moving_avg = tf.get_variable(
                            "problem_%d/total_loss" % n)
                except ValueError:
                    loss_moving_avg = tf.get_variable("problem_%d/total_loss" %
                                                      n,
                                                      initializer=100.0,
                                                      trainable=False)
                ops.append(
                    loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                           total_loss * 0.1))
            with tf.variable_scope(
                    "train_stats"):  # Count steps for this problem.
                problem_steps = tf.get_variable("problem_%d_steps" % n,
                                                initializer=0,
                                                trainable=False)
                ops.append(problem_steps.assign_add(1))
            with tf.control_dependencies(ops):  # Make sure the ops run.
                # Ensure the loss is a scalar here.
                total_loss = tf.reshape(total_loss, [],
                                        name="total_loss_control_id")
            return [total_loss
                    ] + sharded_logits  # Need to flatten for cond later.

        result_list = input_fn_builder.cond_on_index(
            nth_model, features["problem_choice"], 0,
            len(my_hp.problems) - 1)

        if mode == tf.estimator.ModeKeys.PREDICT:
            # Beam search in sequence model returns both decodes withe key "outputs"
            # and scores with they key "scores". If return list is a dict, we expect
            # that it will have keys "outputs", a tensor of int32 and scores, a
            # tensor of floats. This is useful if we want to return scores from
            # estimator.predict
            if not isinstance(result_list, dict):
                predictions = {"outputs": result_list}
            else:
                predictions = {
                    "outputs": result_list["outputs"],
                    "scores": result_list["scores"]
                }

            if "inputs" in features:
                predictions["inputs"] = features["inputs"]
            if "infer_targets" in features:
                predictions["targets"] = features["infer_targets"]
            predictions["problem_choice"] = (
                features["problem_choice"] * tf.ones(
                    (tf.shape(features["inputs"])[0], ), dtype=tf.int32))

            return tf.estimator.EstimatorSpec(mode, predictions=predictions)

        sharded_logits, total_loss = result_list[1:], result_list[0]
        if mode == tf.estimator.ModeKeys.EVAL:
            # For evaluation, return the logits layer as our predictions.
            logits = tf.concat(sharded_logits, 0)

            eval_metrics_fns = metrics.create_evaluation_metrics(
                zip(FLAGS.problems.split("-"), hparams.problem_instances),
                hparams)
            _check_autotune_metrics(eval_metrics_fns)

            eval_metrics = {}
            for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
                eval_metrics[metric_name] = metric_fn(
                    logits, labels, features["problem_choice"])

            return tf.estimator.EstimatorSpec(
                mode,
                predictions={"predictions": logits},
                eval_metric_ops=eval_metrics,
                loss=total_loss)

        assert mode == tf.estimator.ModeKeys.TRAIN

        # Some training statistics.
        with tf.name_scope("training_stats"):
            learning_rate = my_hp.learning_rate * learning_rate_decay()
            learning_rate /= math.sqrt(float(FLAGS.worker_replicas))
            tf.summary.scalar("learning_rate", learning_rate)
            global_step = tf.to_float(tf.contrib.framework.get_global_step())
            for n in xrange(len(my_hp.problems)):
                names_and_vars = []
                with tf.variable_scope("losses_avg", reuse=True):
                    total_loss_var = tf.get_variable("problem_%d/total_loss" %
                                                     n)
                    names_and_vars.append(("total_loss", total_loss_var))
                with tf.variable_scope("losses_avg", reuse=True):
                    for loss_name in loss_variable_names:
                        if loss_name.startswith("problem_%d/" % n):
                            loss_var = tf.get_variable(loss_name)
                            loss_suffix = loss_name[loss_name.index("/") + 1:]
                            names_and_vars.append((loss_suffix, loss_var))
                for (loss_name, loss_var) in names_and_vars:
                    tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name),
                                      loss_var)
                with tf.variable_scope("train_stats", reuse=True):
                    nth_steps = tf.get_variable("problem_%d_steps" % n,
                                                dtype=tf.int32)
                tf.summary.scalar("problem_%d_frequency" % n,
                                  tf.to_float(nth_steps) / (global_step + 1.0))

        # Log trainable weights and add decay.
        total_size, weight_decay_loss = 0, 0.0
        all_weights = {v.name: v for v in tf.trainable_variables()}
        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            v_size = int(np.prod(np.array(v.shape.as_list())))
            total_size += v_size
            if my_hp.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
                # Add weight regularization if set and the weight is not a bias (dim>1).
                with tf.device(v._ref().device):  # pylint: disable=protected-access
                    v_loss = tf.nn.l2_loss(v) / v_size
                weight_decay_loss += v_loss
            is_body = len(v_name) > 5 and v_name[:5] == "body/"
            if my_hp.weight_noise > 0.0 and is_body:
                # Add weight noise if set in my_hp.
                with tf.device(v._ref().device):  # pylint: disable=protected-access
                    scale = learning_rate * 0.001
                    noise = tf.truncated_normal(
                        v.shape) * my_hp.weight_noise * scale
                    noise_op = v.assign_add(noise)
                with tf.control_dependencies([noise_op]):
                    total_loss = tf.identity(total_loss)
        if my_hp.weight_decay > 0.0:
            total_loss += weight_decay_loss * my_hp.weight_decay
        if is_training:
            total_loss *= small_batch_multiplier
        total_loss = tf.identity(total_loss, name="total_loss")
        log_variable_sizes(tf.trainable_variables(), "Trainable Variables")
        diet_vars = [
            v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
        ]
        log_variable_sizes(diet_vars, "Diet Varaibles")
        # Define the train_op for the TRAIN mode.
        opt = _ConditionalOptimizer(my_hp.optimizer, learning_rate, my_hp)
        tf.logging.info("Computing gradients for global model_fn.")
        opt_summaries = ["learning_rate", "loss"]
        if hparams.summarize_grads:
            opt_summaries.extend(["gradients", "gradient_norm"])
        train_op = tf.contrib.layers.optimize_loss(
            name="training",
            loss=total_loss,
            global_step=tf.train.get_global_step(),
            learning_rate=learning_rate,
            clip_gradients=my_hp.clip_grad_norm or None,
            gradient_noise_scale=hparams.grad_noise_scale or None,
            optimizer=opt,
            summaries=opt_summaries,
            colocate_gradients_with_ops=True)

        # Remove summaries that will fail to run because they are in conditionals.
        # TODO(cwhipkey): Test with this code removed, later in 2017.
        summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
        for i in range(len(summaries) - 1, -1, -1):
            if summaries[i].name.startswith("cond_"):
                del summaries[i]

        tf.logging.info("Global model_fn finished.")
        return tf.estimator.EstimatorSpec(
            mode,
            predictions={"problem_choice": features["problem_choice"]},
            loss=total_loss,
            train_op=train_op)
  def model_fn(features, targets, mode):
    """Creates the prediction, loss, and train ops.

    Args:
      features: A dictionary of tensors keyed by the feature name.
      targets: A tensor representing the labels (targets).
      mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.

    Returns:
      A tuple consisting of the prediction, loss, and train_op.
    """
    # Deep-copy the model hparams between modes to eliminate
    # side-effects caused by abuse of the linked problem_hparams
    # objects which are used to share modality objects between
    # problems.  We do not want to share the modality objects between
    # modes, since the modality objects may decide to do something
    # mode-specific.  A better fix would be to stop abusing the
    # hparams in this way and instead use a separate dictionary to
    # share the modality objects between problems.  This dictionary
    # could be created once per mode and passed to the constructor of
    # t2t_model.
    my_hp = copy.deepcopy(hparams)
    if mode == tf.contrib.learn.ModeKeys.INFER:
      if FLAGS.decode_interactive:
        features = _interactive_input_tensor_to_features_dict(features, my_hp)
      elif FLAGS.decode_from_file:
        features = _decode_input_tensor_to_features_dict(features, my_hp)
    # A dictionary containing:
    #  - problem_choice: A Tensor containing an integer indicating which problem
    #                    was selected for this run.
    #  - predictions: A Tensor containing the model's output predictions.
    run_info = dict()
    run_info["problem_choice"] = features["problem_choice"]

    if targets is not None:
      features["targets"] = targets

    dp = devices.data_parallelism()

    tf.get_variable_scope().set_initializer(initializer())
    is_training = mode == tf.contrib.learn.ModeKeys.TRAIN

    # Add input statistics for incoming features.
    with tf.name_scope("input_stats"):
      for (k, v) in six.iteritems(features):
        if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
          tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
          tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
          nonpadding = tf.to_float(tf.not_equal(v, 0))
          nonpadding_tokens = tf.reduce_sum(nonpadding)
          if k == "targets":
            targets_nonpadding_tokens = nonpadding_tokens
          tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens)
          tf.summary.scalar("%s_nonpadding_fraction" % k,
                            tf.reduce_mean(nonpadding))

      # The new data reader occasionally emits very small batches, which
      # cause the examples in those batches to be grossly overweighted.
      # We decrease the loss proportionally to the ratio of the size of this
      # batch to the size of the largest training batch ever.
      # TODO(noam): to be more sophisticated, we could keep separate
      # maxima based on problem choice.
      max_nonpadding_var = tf.get_variable(
          "max_nonpadding", shape=[],
          initializer=tf.ones_initializer(), trainable=False)
      max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens)
      if is_training:
        with tf.control_dependencies(
            [tf.assign(max_nonpadding_var, max_nonpadding)]):
          small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
        tf.summary.scalar("small_batch_multiplier", small_batch_multiplier)

    # Get multi-problem logits and loss based on features["problem_choice"].
    loss_variable_names = []
    def nth_model(n):
      """Build the model for the n-th problem, plus some added variables."""
      model_class = registry.model(model)(
          my_hp,
          mode,
          my_hp.problems[n],
          n,
          dp,
          devices.ps_devices(all_workers=True))
      if mode == tf.contrib.learn.ModeKeys.INFER:
        return model_class.infer(
            features,
            beam_size=FLAGS.decode_beam_size,
            top_beams=(FLAGS.decode_beam_size
                       if FLAGS.decode_return_beams else 1),
            last_position_only=FLAGS.decode_use_last_position_only,
            alpha=FLAGS.decode_alpha,
            decode_length=FLAGS.decode_extra_length)
      # In distributed mode, we build graph for problem=0 and problem=worker_id.
      skipping_is_on = my_hp.problem_choice == "distributed" and is_training
      problem_worker_id = FLAGS.worker_id % len(my_hp.problems)
      skip_this_one = n != 0 and n % FLAGS.worker_replicas != problem_worker_id
      # On worker 0 also build graph for problems <= 1.
      # TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
      skip_this_one = skip_this_one and (FLAGS.worker_id != 0 or n > 1)
      if (FLAGS.eval_run_autoregressive and
          mode == tf.contrib.learn.ModeKeys.EVAL):
        sharded_logits, losses_dict = model_class.eval_autoregressive(features)
      else:
        sharded_logits, losses_dict = model_class.model_fn(
            features, skip=(skipping_is_on and skip_this_one))
      with tf.variable_scope("losses_avg"):
        total_loss, ops = 0.0, []
        for loss_key, loss_value in six.iteritems(losses_dict):
          loss_name = "problem_%d/%s_loss" % (n, loss_key)
          loss_moving_avg = tf.get_variable(
              loss_name, initializer=100.0, trainable=False)
          loss_variable_names.append(loss_name)
          ops.append(
              loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1))
          total_loss += loss_value
        try:  # Total loss avg might be reused or not, we try both.
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            # Total loss was already constructed on input.
            loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n)
        except ValueError:
          loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n,
                                            initializer=100.0, trainable=False)
        ops.append(
            loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1))
      with tf.variable_scope("train_stats"):  # Count steps for this problem.
        problem_steps = tf.get_variable(
            "problem_%d_steps" % n, initializer=0, trainable=False)
        ops.append(problem_steps.assign_add(1))
      with tf.control_dependencies(ops):  # Make sure the ops run.
        # Ensure the loss is a scalar here.
        total_loss = tf.reshape(total_loss, [], name="total_loss_control_id")
      return [total_loss] + sharded_logits  # Need to flatten for cond later.

    result_list = input_fn_builder.cond_on_index(nth_model,
                                                 features["problem_choice"], 0,
                                                 len(my_hp.problems) - 1)

    if mode == tf.contrib.learn.ModeKeys.INFER:
      # Beam search in sequence model returns both decodes withe key "outputs"
      # and scores with they key "scores". If return list is a dict, we expect
      # that it will have keys "outputs", a tensor of int32 and scores, a
      # tensor of floats. This is useful if we want to return scores from
      # estimator.predict
      if not isinstance(result_list, dict):
        ret = {"outputs": result_list}, None, None
      else:
        ret = {
            "outputs": result_list["outputs"],
            "scores": result_list["scores"]
        }, None, None
      if "inputs" in features:
        ret[0]["inputs"] = features["inputs"]
      if "infer_targets" in features:
        ret[0]["targets"] = features["infer_targets"]
      return ret

    sharded_logits, total_loss = result_list[1:], result_list[0]
    if mode == tf.contrib.learn.ModeKeys.EVAL:
      logits = tf.concat(sharded_logits, 0)
      # For evaluation, return the logits layer as our predictions.
      run_info["predictions"] = logits
      train_op = None
      return run_info, total_loss, None

    assert mode == tf.contrib.learn.ModeKeys.TRAIN

    # Some training statistics.
    with tf.name_scope("training_stats"):
      learning_rate = my_hp.learning_rate * learning_rate_decay()
      learning_rate /= math.sqrt(float(FLAGS.worker_replicas))
      tf.summary.scalar("learning_rate", learning_rate)
      global_step = tf.to_float(tf.contrib.framework.get_global_step())
      for n in xrange(len(my_hp.problems)):
        names_and_vars = []
        with tf.variable_scope("losses_avg", reuse=True):
          total_loss_var = tf.get_variable("problem_%d/total_loss" % n)
          names_and_vars.append(("total_loss", total_loss_var))
        with tf.variable_scope("losses_avg", reuse=True):
          for loss_name in loss_variable_names:
            if loss_name.startswith("problem_%d/" % n):
              loss_var = tf.get_variable(loss_name)
              loss_suffix = loss_name[loss_name.index("/") + 1:]
              names_and_vars.append((loss_suffix, loss_var))
        for (loss_name, loss_var) in names_and_vars:
          tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var)
        with tf.variable_scope("train_stats", reuse=True):
          nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32)
        tf.summary.scalar("problem_%d_frequency" % n,
                          tf.to_float(nth_steps) / (global_step + 1.0))

    # Log trainable weights and add decay.
    total_size, weight_decay_loss = 0, 0.0
    all_weights = {v.name: v for v in tf.trainable_variables()}
    for v_name in sorted(list(all_weights)):
      v = all_weights[v_name]
      v_size = int(np.prod(np.array(v.shape.as_list())))
      total_size += v_size
      if my_hp.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
        # Add weight regularization if set and the weight is not a bias (dim>1).
        with tf.device(v._ref().device):  # pylint: disable=protected-access
          v_loss = tf.nn.l2_loss(v) / v_size
        weight_decay_loss += v_loss
      is_body = len(v_name) > 5 and v_name[:5] == "body/"
      if my_hp.weight_noise > 0.0 and is_body:
        # Add weight noise if set in my_hp.
        with tf.device(v._ref().device):  # pylint: disable=protected-access
          scale = learning_rate * 0.001
          noise = tf.truncated_normal(v.shape) * my_hp.weight_noise * scale
          noise_op = v.assign_add(noise)
        with tf.control_dependencies([noise_op]):
          total_loss = tf.identity(total_loss)
    if my_hp.weight_decay > 0.0:
      total_loss += weight_decay_loss * my_hp.weight_decay
    if is_training:
      total_loss *= small_batch_multiplier
    total_loss = tf.identity(total_loss, name="total_loss")
    log_variable_sizes(tf.trainable_variables(), "Trainable Variables")
    diet_vars = [v for v in tf.global_variables() if hasattr(v, "optimizer")]
    log_variable_sizes(diet_vars, "Diet Varaibles")
    # Define the train_op for the TRAIN mode.
    opt = _ConditionalOptimizer(my_hp.optimizer, learning_rate, my_hp)
    tf.logging.info("Computing gradients for global model_fn.")
    opt_summaries = ["learning_rate", "loss"]
    if hparams.summarize_grads:
      opt_summaries.extend(["gradients", "gradient_norm"])
    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=total_loss,
        global_step=tf.train.get_global_step(),
        learning_rate=learning_rate,
        clip_gradients=my_hp.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=opt,
        summaries=opt_summaries,
        colocate_gradients_with_ops=True)

    # Remove summaries that will fail to run because they are in conditionals.
    # TODO(cwhipkey): Test with this code removed, later in 2017.
    summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
    for i in range(len(summaries) - 1, -1, -1):
      if summaries[i].name.startswith("cond_"):
        del summaries[i]

    tf.logging.info("Global model_fn finished.")
    return run_info, total_loss, train_op
Exemple #10
0
def model_fn(model,
             features,
             mode,
             hparams,
             problem_names,
             train_steps=100000,
             worker_id=0,
             worker_replicas=1,
             eval_run_autoregressive=False,
             decode_hparams=None):
    """Builds the model for all modes.

  * TRAIN: Constructs loss and train_op
  * EVAL: Constructs the loss and eval metrics
  * PREDICT: Constructs the predictions

  Args:
    model: str, name of model.
    features: dict<feature name, Tensor>. Expected to have keys
      {inputs, targets, problem_choice}.
    mode: tf.estimator.ModeKeys.
    hparams: model HParams.
    problem_names: list of str, names of the problems.
    train_steps: int, total number of training steps. Used to compute learning
      rate decay.
    worker_id: int, id of this worker.
    worker_replicas: int, number of workers.
    eval_run_autoregressive: bool, whether to run evaluation autoregressively.
    decode_hparams: HParams for decode settings. Used when mode == PREDICT.

  Returns:
    tf.estimator.EstimatorSpec
  """
    assert len(problem_names) == len(hparams.problem_instances)
    decode_hp = decode_hparams

    # TODO(rsepassi): This still depends on FLAGS. Rm eventually.
    dp = devices.data_parallelism()

    tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams))
    # set the initializer functions
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Add input statistics for incoming features.
    with tf.name_scope("input_stats"):
        for (k, v) in six.iteritems(features):
            if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
                tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
                tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
                nonpadding = tf.to_float(tf.not_equal(v, 0))
                nonpadding_tokens = tf.reduce_sum(
                    nonpadding)  # non zeros tokens
                if k == "targets":
                    targets_nonpadding_tokens = nonpadding_tokens
                tf.summary.scalar("%s_nonpadding_tokens" % k,
                                  nonpadding_tokens)
                tf.summary.scalar("%s_nonpadding_fraction" % k,
                                  tf.reduce_mean(nonpadding))

    # Get multi-problem logits and loss based on features["problem_choice"].
    loss_variable_names = []

    def nth_model(n):
        """Build the model for the n-th problem, plus some added variables."""
        model_class = registry.model(model)(
            hparams,
            mode,
            hparams.problems[n],
            n,
            dp,
            devices.ps_devices(all_workers=True),
            decode_hparams=decode_hparams
        )  # initialize transformer model class: hparams, modalities
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model_class.infer(features,
                                     beam_size=decode_hp.beam_size,
                                     top_beams=(decode_hp.beam_size if
                                                decode_hp.return_beams else 1),
                                     alpha=decode_hp.alpha,
                                     decode_length=decode_hp.extra_length)
        # In distributed mode, we build graph for problem=0 and problem=worker_id.
        skipping_is_on = hparams.problem_choice == "distributed" and is_training
        problem_worker_id = worker_id % len(hparams.problems)
        skip_this_one = n != 0 and n % worker_replicas != problem_worker_id
        # On worker 0 also build graph for problems <= 1.
        # TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
        skip_this_one = skip_this_one and (worker_id != 0 or n > 1)
        mrt_samples = getattr(hparams, 'mrt_samples', None)
        if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL:  # evaluation mode
            sharded_logits, losses_dict = model_class.eval_autoregressive(
                features)
        else:  # training mode
            if hparams.rl:
                # generate sample data, it will automatically sharded, samples shape [batch, time, 1, 1]
                if model_class._num_datashards == 1:  # work on single GPU cards, fast sample
                    print("###Work on Single GPU card, Use Fast Decode.###")
                    train_beam = getattr(hparams, 'train_beam', None)
                    if mrt_samples:
                        samples, _ = model_class._fast_decode(
                            features,
                            decode_length=50,
                            beam_size=mrt_samples,
                            top_beams=mrt_samples)
                        inputs = tf.squeeze(tf.squeeze(features["inputs"],
                                                       axis=-1),
                                            axis=-1)
                        targets = tf.squeeze(tf.squeeze(features["targets"],
                                                        axis=-1),
                                             axis=-1)
                        batch_size = tf.shape(inputs)[0]
                        inputs_len = tf.shape(inputs)[1]
                        targets_len = tf.shape(targets)[1]
                        inputs_tile = tf.tile(inputs, [1, mrt_samples])
                        targets_tile = tf.tile(targets, [1, mrt_samples])
                        inputs_reshape = tf.reshape(
                            inputs_tile,
                            [batch_size * mrt_samples, inputs_len])
                        targets_reshape = tf.reshape(
                            targets_tile,
                            [batch_size * mrt_samples, targets_len])
                        inputs_feed = tf.expand_dims(tf.expand_dims(
                            inputs_reshape, axis=-1),
                                                     axis=-1)
                        targets_feed = tf.expand_dims(tf.expand_dims(
                            targets_reshape, axis=-1),
                                                      axis=-1)
                        features["inputs"] = inputs_feed
                        features["targets"] = targets_feed
                    elif train_beam and train_beam != 1:  # beam search with hparams.train_beam size and return the top 1 sample
                        samples, _ = model_class._fast_decode(
                            features,
                            decode_length=50,
                            beam_size=hparams.train_beam)
                    else:
                        targets_beam = getattr(hparams, 'targets_beam', None)
                        if targets_beam:
                            targets_samples, _ = model_class._fast_decode(
                                features,
                                decode_length=50,
                                beam_size=4,
                                sampling_method='argmax')
                            targets_samples = tf.reshape(
                                targets_samples, [
                                    tf.shape(targets_samples)[0],
                                    tf.shape(targets_samples)[1], 1, 1
                                ])
                            features["targets"] = targets_samples
                        samples, _ = model_class._fast_decode(features,
                                                              decode_length=50)
                    samples = tf.expand_dims(samples, axis=-1)
                    samples = tf.expand_dims(
                        samples, axis=-1
                    )  # add two additional dimensions to make it compatible.
                else:  # work on multi GPU cards, only support slow sample
                    print("###Work on Multi GPU cards, Use Slow Decode.###")
                    samples, _, _ = model_class._slow_greedy_infer(
                        features,
                        decode_length=50)  # default decode_length = 50
                samples = tf.stop_gradient(samples)
                # calculate bleu score use metric_fn
                # train_metric_fn = "approx_bleu_train_score"
                train_metric_fn = metrics.METRICS_FNS[
                    metrics.Metrics.APPROX_BLEU_TRAIN]
                labels = features.get("targets", None)
                samples.set_shape([None, None, 1, 1])
                # haprams.delta_reward = True for delta reward; False for total reward
                metric_value = train_metric_fn(
                    samples, labels, delat_reward=hparams.delta_reward)
                metric_value = tf.stop_gradient(
                    metric_value)  # to be more strict of the gradient
                metric_value.set_shape([None, None, 1, 1])
                """Accodring to the metrics.py: The tf.metrics.mean function assures correct aggregation."""
                # metric_value is total_reward: scalar
                features["samples"] = samples
                features["values"] = metric_value
                # del samples
                # del labels
            sharded_logits, losses_dict = model_class.model_fn(
                features,
                skip=(skipping_is_on and skip_this_one),
                mrt=mrt_samples)
            # if hparams.rl:
            #     training_loss = losses_dict["training"] * metric_value  # losses_dict["training"]: [batch, timesteps]
            #     training_loss_sum = tf.reduce_sum(training_loss)  # sum the training_loss
            #     losses_dict["training"] = training_loss_sum  # log_prob * r (current r is total_reward)
        with tf.variable_scope("losses_avg"):
            total_loss, ops = 0.0, []
            for loss_key, loss_value in six.iteritems(losses_dict):
                if hparams.rl:
                    baseline_loss_weight = getattr(hparams,
                                                   'baseline_loss_weight', 1.0)
                    training_loss_weight = getattr(hparams,
                                                   'training_loss_weight', 1.0)
                    mle_training_loss_weight = getattr(
                        hparams, 'mle_training_loss_weight', 0.3)
                    if loss_key == "training":
                        loss_value = loss_value * training_loss_weight
                    elif loss_key == "training_baseline":
                        loss_value = loss_value * baseline_loss_weight
                    elif loss_key == "mle_training":
                        loss_value = loss_value * mle_training_loss_weight
                loss_name = "problem_%d/%s_loss" % (n, loss_key)
                loss_moving_avg = tf.get_variable(loss_name,
                                                  initializer=100.0,
                                                  trainable=False)
                loss_variable_names.append(loss_name)
                ops.append(
                    loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                           loss_value * 0.1))
                total_loss += loss_value
            try:  # Total loss avg might be reused or not, we try both.
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    # Total loss was already constructed on input.
                    loss_moving_avg = tf.get_variable("problem_%d/total_loss" %
                                                      n)
            except ValueError:
                loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n,
                                                  initializer=100.0,
                                                  trainable=False)
            ops.append(
                loss_moving_avg.assign(loss_moving_avg * 0.9 +
                                       total_loss * 0.1))
        with tf.variable_scope("train_stats"):  # Count steps for this problem.
            problem_steps = tf.get_variable("problem_%d_steps" % n,
                                            initializer=0,
                                            trainable=False)
            ops.append(problem_steps.assign_add(1))
        with tf.control_dependencies(ops):  # Make sure the ops run.
            # Ensure the loss is a scalar here.
            total_loss = tf.reshape(total_loss, [],
                                    name="total_loss_control_id")
        return [total_loss, tf.concat(sharded_logits, 0)]

    model_output = input_fn_builder.cond_on_index(
        nth_model,
        index_tensor=features["problem_choice"],
        max_idx=len(hparams.problems) - 1)  # total_loss and shared_logits

    if mode == tf.estimator.ModeKeys.PREDICT:
        # If beam searching, model_output will be a dict with keys "outputs" and
        # "scores".
        if isinstance(model_output, dict):  # beam search
            outputs = model_output["outputs"]
            scores = model_output["scores"]
        else:
            outputs = model_output
            scores = None

        batched_problem_choice = (features["problem_choice"] * tf.ones(
            (tf.shape(features["inputs"])[0], ), dtype=tf.int32))
        predictions = {
            "outputs": outputs,
            "scores": scores,
            "inputs": features.get("inputs", None),
            "targets": features.get("infer_targets", None),
            "problem_choice": batched_problem_choice,
        }
        _del_dict_nones(predictions)  # delete the empty ones in predictions

        export_out = {"outputs": predictions["outputs"]}
        if "scores" in predictions:
            export_out["scores"] = predictions["scores"]

        return tf.estimator.EstimatorSpec(
            mode,
            predictions=predictions,
            export_outputs={
                "output": tf.estimator.export.PredictOutput(export_out)
            })

    total_loss, logits = model_output

    if mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics_fns = metrics.create_evaluation_metrics(
            hparams.problem_instances, hparams)

        eval_metrics = {}
        for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
            eval_metrics[metric_name] = metric_fn(logits, features)

        return tf.estimator.EstimatorSpec(mode,
                                          predictions={"predictions": logits},
                                          eval_metric_ops=eval_metrics,
                                          loss=total_loss)

    assert mode == tf.estimator.ModeKeys.TRAIN

    # Set learning rate
    learning_rate = hparams.learning_rate * optimize.learning_rate_decay(
        hparams,
        num_worker_replicas=worker_replicas,
        num_train_steps=train_steps)
    learning_rate /= math.sqrt(float(worker_replicas))

    # Get global step
    global_step = tf.train.get_or_create_global_step()

    # Some training statistics.
    with tf.name_scope("training_stats"):
        tf.summary.scalar("learning_rate", learning_rate)
        for n in xrange(len(hparams.problems)):
            names_and_vars = []
            with tf.variable_scope("losses_avg", reuse=True):
                total_loss_var = tf.get_variable("problem_%d/total_loss" % n)
                names_and_vars.append(("total_loss", total_loss_var))
            with tf.variable_scope("losses_avg", reuse=True):
                for loss_name in loss_variable_names:
                    if loss_name.startswith("problem_%d/" % n):
                        loss_var = tf.get_variable(loss_name)
                        loss_suffix = loss_name[loss_name.index("/") + 1:]
                        names_and_vars.append((loss_suffix, loss_var))
            for (loss_name, loss_var) in names_and_vars:
                tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var)
            with tf.variable_scope("train_stats", reuse=True):
                nth_steps = tf.get_variable("problem_%d_steps" % n,
                                            dtype=tf.int32)
            tf.summary.scalar(
                "problem_%d_frequency" % n,
                tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0))

    # Add weight decay and noise.
    total_size, weight_decay_loss = 0, 0.0
    all_weights = {v.name: v for v in tf.trainable_variables()}
    for v_name in sorted(list(all_weights)):
        v = all_weights[v_name]
        v_size = int(np.prod(np.array(v.shape.as_list())))
        total_size += v_size
        if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
            # Add weight regularization if set and the weight is not a bias (dim>1).
            with tf.device(v._ref().device):  # pylint: disable=protected-access
                v_loss = tf.nn.l2_loss(v) / v_size
            weight_decay_loss += v_loss
        is_body = len(v_name) > 5 and v_name[:5] == "body/"
        if hparams.weight_noise > 0.0 and is_body:
            # Add weight noise if set in hparams.
            with tf.device(v._ref().device):  # pylint: disable=protected-access
                scale = learning_rate * 0.001
                noise = tf.truncated_normal(
                    v.shape) * hparams.weight_noise * scale
                noise_op = v.assign_add(noise)
            with tf.control_dependencies([noise_op]):
                total_loss = tf.identity(total_loss)
    if hparams.weight_decay > 0.0:
        total_loss += weight_decay_loss * hparams.weight_decay

    # The new data reader occasionally emits very small batches, which
    # cause the examples in those batches to be grossly overweighted.
    # We decrease the loss proportionally to the ratio of the size of this
    # batch to the size of the largest training batch ever.
    # TODO(noam): to be more sophisticated, we could keep separate
    # maxima based on problem choice.
    max_nonpadding_var = tf.get_variable("max_nonpadding",
                                         shape=[],
                                         initializer=tf.ones_initializer(),
                                         trainable=False)
    max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens)
    with tf.control_dependencies(
        [tf.assign(max_nonpadding_var, max_nonpadding)]):
        small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
    tf.summary.scalar("small_batch_multiplier", small_batch_multiplier)
    total_loss *= small_batch_multiplier

    # Log variable sizes
    _log_variable_sizes(tf.trainable_variables(), "Trainable Variables")
    diet_vars = [
        v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
    ]
    _log_variable_sizes(diet_vars, "Diet Variables")

    # Optimize
    train_op = optimize.optimize(total_loss, learning_rate, hparams)

    # Remove summaries that will fail to run because they are in conditionals.
    # TODO(cwhipkey): Test with this code removed, later in 2017.
    summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
    for i in reversed(range(len(summaries))):
        if summaries[i].name.startswith("cond_"):
            del summaries[i]

    tf.logging.info("Global model_fn finished.")
    return tf.estimator.EstimatorSpec(
        mode,
        predictions={"problem_choice": features["problem_choice"]},
        loss=total_loss,
        train_op=train_op)
Exemple #11
0
def model_fn(model,
             features,
             mode,
             hparams,
             problem_names,
             train_steps=100000,
             worker_id=0,
             worker_replicas=1,
             eval_run_autoregressive=False,
             decode_hparams=None):
  """Builds the model for all modes.

  * TRAIN: Constructs loss and train_op
  * EVAL: Constructs the loss and eval metrics
  * PREDICT: Constructs the predictions

  Args:
    model: str, name of model.
    features: dict<feature name, Tensor>. Expected to have keys
      {inputs, targets, problem_choice}.
    mode: tf.estimator.ModeKeys.
    hparams: model HParams.
    problem_names: list of str, names of the problems.
    train_steps: int, total number of training steps. Used to compute learning
      rate decay.
    worker_id: int, id of this worker.
    worker_replicas: int, number of workers.
    eval_run_autoregressive: bool, whether to run evaluation autoregressively.
    decode_hparams: HParams for decode settings. Used when mode == PREDICT.

  Returns:
    tf.estimator.EstimatorSpec
  """
  assert len(problem_names) == len(hparams.problem_instances)
  decode_hp = decode_hparams

  # TODO(rsepassi): This still depends on FLAGS. Rm eventually.
  dp = devices.data_parallelism(hparams)

  tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams))
  is_training = mode == tf.estimator.ModeKeys.TRAIN

  # Add input statistics for incoming features.
  with tf.name_scope("input_stats"):
    for (k, v) in six.iteritems(features):
      if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
        tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
        tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
        nonpadding = tf.to_float(tf.not_equal(v, 0))
        nonpadding_tokens = tf.reduce_sum(nonpadding)
        if k == "targets":
          targets_nonpadding_tokens = nonpadding_tokens
        tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens)
        tf.summary.scalar("%s_nonpadding_fraction" % k,
                          tf.reduce_mean(nonpadding))

  # Get multi-problem logits and loss based on features["problem_choice"].
  loss_variable_names = []

  def nth_model(n):
    """Build the model for the n-th problem, plus some added variables."""
    model_class = registry.model(model)(
        hparams,
        mode,
        hparams.problems[n],
        n,
        dp,
        devices.ps_devices(all_workers=True),
        decode_hparams=decode_hparams)
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model_class.infer(
          features,
          beam_size=decode_hp.beam_size,
          top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1),
          alpha=decode_hp.alpha,
          decode_length=decode_hp.extra_length)
    # In distributed mode, we build graph for problem=0 and problem=worker_id.
    skipping_is_on = hparams.problem_choice == "distributed" and is_training
    problem_worker_id = worker_id % len(hparams.problems)
    skip_this_one = n != 0 and n % worker_replicas != problem_worker_id
    # On worker 0 also build graph for problems <= 1.
    # TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
    skip_this_one = skip_this_one and (worker_id != 0 or n > 1)
    if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL:
      logits, losses_dict = model_class.eval_autoregressive(features)
    else:
      logits, losses_dict = model_class(
          features, skip=(skipping_is_on and skip_this_one))
    with tf.variable_scope("losses_avg"):
      total_loss, ops = 0.0, []
      for loss_key, loss_value in six.iteritems(losses_dict):
        loss_name = "problem_%d/%s_loss" % (n, loss_key)
        loss_moving_avg = tf.get_variable(
            loss_name, initializer=100.0, trainable=False)
        loss_variable_names.append(loss_name)
        ops.append(
            loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1))
        total_loss += loss_value
      try:  # Total loss avg might be reused or not, we try both.
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          # Total loss was already constructed on input.
          loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n)
      except ValueError:
        loss_moving_avg = tf.get_variable(
            "problem_%d/total_loss" % n, initializer=100.0, trainable=False)
      ops.append(
          loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1))
    with tf.variable_scope("train_stats"):  # Count steps for this problem.
      problem_steps = tf.get_variable(
          "problem_%d_steps" % n, initializer=0, trainable=False)
      ops.append(problem_steps.assign_add(1))
    with tf.control_dependencies(ops):  # Make sure the ops run.
      # Ensure the loss is a scalar here.
      total_loss = tf.reshape(total_loss, [], name="total_loss_control_id")
    return [total_loss, logits]

  model_output = input_fn_builder.cond_on_index(
      nth_model,
      index_tensor=features["problem_choice"],
      max_idx=len(hparams.problems) - 1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    # If beam searching, model_output will be a dict with keys "outputs" and
    # "scores".
    if isinstance(model_output, dict):
      outputs = model_output["outputs"]
      scores = model_output["scores"]
    else:
      outputs = model_output
      scores = None

    batched_problem_choice = (
        features["problem_choice"] * tf.ones(
            (tf.shape(features["inputs"])[0],), dtype=tf.int32))
    predictions = {
        "outputs": outputs,
        "scores": scores,
        "inputs": features.get("inputs", None),
        "targets": features.get("infer_targets", None),
        "problem_choice": batched_problem_choice,
    }
    _del_dict_nones(predictions)

    export_out = {"outputs": predictions["outputs"]}
    if "scores" in predictions:
      export_out["scores"] = predictions["scores"]

    return tf.estimator.EstimatorSpec(
        mode,
        predictions=predictions,
        export_outputs={
            "output": tf.estimator.export.PredictOutput(export_out)
        })

  total_loss, logits = model_output

  if mode == tf.estimator.ModeKeys.EVAL:
    eval_metrics_fns = metrics.create_evaluation_metrics(
        hparams.problem_instances, hparams)

    eval_metrics = {}
    for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
      eval_metrics[metric_name] = metric_fn(logits, features)

    return tf.estimator.EstimatorSpec(
        mode,
        predictions={"predictions": logits},
        eval_metric_ops=eval_metrics,
        loss=total_loss)

  assert mode == tf.estimator.ModeKeys.TRAIN

  # Set learning rate
  learning_rate = hparams.learning_rate * optimize.learning_rate_decay(
      hparams, num_worker_replicas=worker_replicas, num_train_steps=train_steps)
  learning_rate /= math.sqrt(float(worker_replicas))

  # Get global step
  global_step = tf.train.get_or_create_global_step()

  # Some training statistics.
  with tf.name_scope("training_stats"):
    tf.summary.scalar("learning_rate", learning_rate)
    for n in xrange(len(hparams.problems)):
      names_and_vars = []
      with tf.variable_scope("losses_avg", reuse=True):
        total_loss_var = tf.get_variable("problem_%d/total_loss" % n)
        names_and_vars.append(("total_loss", total_loss_var))
      with tf.variable_scope("losses_avg", reuse=True):
        for loss_name in loss_variable_names:
          if loss_name.startswith("problem_%d/" % n):
            loss_var = tf.get_variable(loss_name)
            loss_suffix = loss_name[loss_name.index("/") + 1:]
            names_and_vars.append((loss_suffix, loss_var))
      for (loss_name, loss_var) in names_and_vars:
        tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var)
      with tf.variable_scope("train_stats", reuse=True):
        nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32)
      tf.summary.scalar("problem_%d_frequency" % n,
                        tf.to_float(nth_steps) /
                        (tf.to_float(global_step) + 1.0))

  # Add weight decay and noise.
  total_size, weight_decay_loss = 0, 0.0
  all_weights = {v.name: v for v in tf.trainable_variables()}
  for v_name in sorted(list(all_weights)):
    v = all_weights[v_name]
    v_size = int(np.prod(np.array(v.shape.as_list())))
    total_size += v_size
    if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
      # Add weight regularization if set and the weight is not a bias (dim>1).
      with tf.device(v._ref().device):  # pylint: disable=protected-access
        v_loss = tf.nn.l2_loss(v) / v_size
      weight_decay_loss += v_loss
    is_body = len(v_name) > 5 and v_name[:5] == "body/"
    if hparams.weight_noise > 0.0 and is_body:
      # Add weight noise if set in hparams.
      with tf.device(v._ref().device):  # pylint: disable=protected-access
        scale = learning_rate * 0.001
        noise = tf.truncated_normal(v.shape) * hparams.weight_noise * scale
        noise_op = v.assign_add(noise)
      with tf.control_dependencies([noise_op]):
        total_loss = tf.identity(total_loss)
  if hparams.weight_decay > 0.0:
    total_loss += weight_decay_loss * hparams.weight_decay

  # The new data reader occasionally emits very small batches, which
  # cause the examples in those batches to be grossly overweighted.
  # We decrease the loss proportionally to the ratio of the size of this
  # batch to the size of the largest training batch ever.
  # TODO(noam): to be more sophisticated, we could keep separate
  # maxima based on problem choice.
  max_nonpadding_var = tf.get_variable(
      "max_nonpadding",
      shape=[],
      initializer=tf.ones_initializer(),
      trainable=False)
  max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens)
  with tf.control_dependencies([tf.assign(max_nonpadding_var, max_nonpadding)]):
    small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
  tf.summary.scalar("small_batch_multiplier", small_batch_multiplier)
  total_loss *= small_batch_multiplier

  # Log variable sizes
  _log_variable_sizes(tf.trainable_variables(), "Trainable Variables")
  diet_vars = [
      v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
  ]
  _log_variable_sizes(diet_vars, "Diet Variables")

  # Optimize
  train_op = optimize.optimize(total_loss, learning_rate, hparams)

  # Remove summaries that will fail to run because they are in conditionals.
  # TODO(cwhipkey): Test with this code removed, later in 2017.
  summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
  for i in reversed(range(len(summaries))):
    if summaries[i].name.startswith("cond_"):
      del summaries[i]

  tf.logging.info("Global model_fn finished.")
  return tf.estimator.EstimatorSpec(
      mode,
      predictions={"problem_choice": features["problem_choice"]},
      loss=total_loss,
      train_op=train_op)