Ejemplo n.º 1
0
    def _build_select_slate_op(self):
        p_no_click = self._prob_no_click_ph
        p = self._doc_affinity_scores_ph
        q = self._net_outputs.q_values[0]
        with tf.name_scope('select_slate'):
            self._output_slate = self._select_slate_fn(self._slate_size,
                                                       p_no_click, p, q)

        self._output_slate = tf.Print(
            self._output_slate,
            [tf.constant('cp 1'), self._output_slate, p, q],
            summarize=10000)
        self._output_slate = tf.reshape(self._output_slate,
                                        (self._slate_size, ))

        self._action_counts = tf.get_variable(
            'action_counts',
            shape=[self._num_candidates],
            initializer=tf.zeros_initializer())
        output_slate = tf.reshape(self._output_slate, [-1])
        output_one_hot = tf.one_hot(output_slate, self._num_candidates)
        update_ops = []
        for i in range(self._slate_size):
            update_ops.append(
                tf.assign_add(self._action_counts, output_one_hot[i]))
        self._select_action_update_op = tf.group(*update_ops)
Ejemplo n.º 2
0
 def _get_finetune_op(self,
                      data,
                      embedding_vars_keys,
                      embedding_vars,
                      vars_to_finetune,
                      support_embeddings=None):
   """Returns the operation for performing a finetuning step."""
   if support_embeddings is None:
     support_embeddings = self.embedding_fn(
         data.support_images,
         self.is_training,
         params=collections.OrderedDict(
             zip(embedding_vars_keys, embedding_vars)),
         reuse=True)['embeddings']
   logits = self._fc_layer(support_embeddings)[:, 0:data.way]
   finetune_loss = self.compute_loss(
       onehot_labels=data.onehot_support_labels,
       predictions=logits,
   )
   # Perform one step of finetuning.
   if self.finetune_with_adam:
     finetune_op = self.finetune_opt.minimize(
         finetune_loss, var_list=vars_to_finetune)
   else:
     # Apply vanilla gradient descent instead of Adam.
     update_ops = gradient_descent_step(finetune_loss, vars_to_finetune, True,
                                        False, self.finetune_lr)['update_ops']
     finetune_op = tf.group(*update_ops)
   return logits, finetune_loss, finetune_op
Ejemplo n.º 3
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    # Prepare logits and labels
    logits = outputs[
        standard_fields.DetectionResultFields.object_semantic_points]
    labels = inputs[standard_fields.InputDataFields.object_class_points]
    weights = inputs[standard_fields.InputDataFields.point_loss_weights]
    num_valid_points = inputs[standard_fields.InputDataFields.num_valid_points]
    if len(logits.get_shape().as_list()) == 3:
      batch_size = logits.get_shape().as_list()[0]
      logits_list = []
      labels_list = []
      weights_list = []
      for i in range(batch_size):
        num_valid_points_i = num_valid_points[i]
        logits_list.append(logits[i, 0:num_valid_points_i, :])
        labels_list.append(labels[i, 0:num_valid_points_i, :])
        weights_list.append(weights[i, 0:num_valid_points_i, :])
      logits = tf.concat(logits_list, axis=0)
      labels = tf.concat(labels_list, axis=0)
      weights = tf.concat(weights_list, axis=0)
    if self.num_classes is None:
      num_classes = logits.get_shape().as_list()[-1]
    else:
      num_classes = self.num_classes
      if num_classes != logits.get_shape().as_list()[-1]:
        raise ValueError('num_classes do not match the logits dimensions.')

    class_labels, class_predictions = _get_class_labels_and_predictions(
        labels=labels,
        logits=logits,
        num_classes=self.num_classes,
        multi_label=self.multi_label)

    update_ops = []
    for c in self.class_range:
      update_op_tp_c = self.true_positive_metrics[c].update_state(
          y_true=class_labels[c],
          y_pred=class_predictions[c],
          sample_weight=weights)
      update_ops.append(update_op_tp_c)
      update_op_fp_c = self.false_positive_metrics[c].update_state(
          y_true=class_labels[c],
          y_pred=class_predictions[c],
          sample_weight=weights)
      update_ops.append(update_op_fp_c)
      update_op_fn_c = self.false_negative_metrics[c].update_state(
          y_true=class_labels[c],
          y_pred=class_predictions[c],
          sample_weight=weights)
      update_ops.append(update_op_fn_c)
    return tf.group(update_ops)
Ejemplo n.º 4
0
 def _do_data_dependent_init():
     """Returns ops for the data-dependent init of g and maybe b_fc."""
     w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0])
     output_init = tf.matmul(embeddings, w_fc_normalized)
     mean_init, var_init = tf.nn.moments(output_init, [0])
     # Data-dependent init values.
     g_init_value = 1. / tf.sqrt(var_init + 1e-10)
     ops = [tf.assign(g, g_init_value)]
     if not cosine_classifier:
         # Also initialize a bias in a data-dependent way.
         b_fc_init_value = -mean_init * g_init_value
         ops.append(tf.assign(b_fc, b_fc_init_value))
     # Mark that the data-dependent initialization is done to prevent it from
     # happening again in the future.
     ops.append(tf.assign(data_dependent_init_done, 1))
     return tf.group(*ops)
Ejemplo n.º 5
0
def train_q(dataset,
            policy,
            optimizer=None,
            pack_transition_fn=None,
            q_graph_fn=None,
            log_dir=None,
            master='',
            task=0,
            training_steps=None,
            max_training_steps=100000,
            reuse=False,
            init_checkpoint=None,
            update_target_every_n_steps=50,
            log_every_n_steps=None,
            save_checkpoint_steps=500,
            save_summaries_steps=500):
    """Self-contained learning loop for offline Q-learning.

  Code inspired by OpenAI Baselines' deepq.build_train. This function is
  compatible with discrete Q-learning graphs, continuous Q learning graphs, and
  SARSA.

  Args:
    dataset: tf.data.Dataset providing transitions.
    policy: Instance of TFDQNPolicy class that provides functor for building the
      critic function.
    optimizer: Optional instance of an optimizer. If not specified, creates an
      AdamOptimizer using the default constructor.
    pack_transition_fn: Optional function that performs additional processing
      of the transition. This is a convenience method for ad-hoc manipulation of
      transition data passed to the learning function after parsing.
    q_graph_fn: Function used to construct training objectives w.r.t. critic
      outputs.
    log_dir: Where to save model checkpoints and tensorboard summaries.
    master: Optional address of master worker. Specify this when doing
      distributed training.
    task: Optional worker task for distributed training. Defaults to solo master
      task on a single machine.
    training_steps: Optional number of steps to run training before terminating
      early. Max_training_steps remains unchanged - training will terminate
      after max_training_steps whether or not training_steps is specified.
    max_training_steps: maximum number of training iters.
    reuse: If True, reuse existing variables for all declared variables by this
      function.
    init_checkpoint: Optional checkpoint to restore prior to training. If not
      provided, variables are initialized using global_variables_initializer().
    update_target_every_n_steps: How many global steps (training) between
      copying the Q network weights (scope='q_func') to target network
      (scope='target_q_func').
    log_every_n_steps: How many global steps between logging loss tensors.
    save_checkpoint_steps: How many global steps between saving TF variables
      to a checkpoint file.
    save_summaries_steps: How many global steps between saving TF summaries.

  Returns:
    (int) Current `global_step` reached after training for training_steps, or
    `max_training_steps` if `global_step` has reached `max_training_steps`.

  Raises:
    ValueError: If a batch of transitions is empty or the zeroth element is
      empty, when it's supposed to be of length batch_size.
  """
    data_iterator = dataset.make_one_shot_iterator()

    transition = data_iterator.get_next()
    if pack_transition_fn:
        transition = pack_transition_fn(transition)

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer()

    q_func = policy.get_q_func(is_training=True, reuse=reuse)
    loss, all_summaries = q_graph_fn(q_func, transition)

    q_func_vars = contrib_framework.get_trainable_variables(scope='q_func')
    target_q_func_vars = contrib_framework.get_trainable_variables(
        scope='target_q_func')
    global_step = tf.train.get_or_create_global_step()

    # Only optimize q_func and update its batchnorm params.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func')
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss,
                                      global_step=global_step,
                                      var_list=q_func_vars)

    chief_hooks = []
    hooks = []
    # Save summaries periodically.
    if save_summaries_steps is not None:
        chief_hooks.append(
            tf.train.SummarySaverHook(save_steps=save_summaries_steps,
                                      output_dir=log_dir,
                                      summary_op=all_summaries))

    # Stop after training_steps
    if max_training_steps:
        hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps))

    # Report if loss tensor is NaN.
    hooks.append(tf.train.NanTensorHook(loss))

    if log_every_n_steps is not None:
        tensor_dict = {'global_step': global_step, 'train loss': loss}
        chief_hooks.append(
            tf.train.LoggingTensorHook(tensor_dict,
                                       every_n_iter=log_every_n_steps))

        # Measure how fast we are training per sec and save to summary.
        chief_hooks.append(
            tf.train.StepCounterHook(every_n_steps=log_every_n_steps,
                                     output_dir=log_dir))

    # If target network exists, periodically update target Q network with new
    # weights (frozen target network). We hack this by
    # abusing a LoggingTensorHook for this.
    if target_q_func_vars and update_target_every_n_steps is not None:
        update_target_expr = []
        for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name),
                              sorted(target_q_func_vars,
                                     key=lambda v: v.name)):
            update_target_expr.append(var_t.assign(var))
        update_target_expr = tf.group(*update_target_expr)

        with tf.control_dependencies([update_target_expr]):
            update_target = tf.constant(0)
        chief_hooks.append(
            tf.train.LoggingTensorHook(
                {'update_target': update_target},
                every_n_iter=update_target_every_n_steps))

    # Save checkpoints periodically, save all of them.
    saver = tf.train.Saver(max_to_keep=None)
    chief_hooks.append(
        tf.train.CheckpointSaverHook(log_dir,
                                     save_steps=save_checkpoint_steps,
                                     saver=saver,
                                     checkpoint_basename='model.ckpt'))

    # Save our experiment params to checkpoint dir.
    chief_hooks.append(
        gin.tf.GinConfigSaverHook(log_dir, summarize_config=True))

    session_config = tf.ConfigProto(log_device_placement=False)

    init_fn = None
    if init_checkpoint:
        assign_fn = contrib_framework.assign_from_checkpoint_fn(
            init_checkpoint, contrib_framework.get_model_variables())
        init_fn = lambda _, sess: assign_fn(sess)
    scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)
    with tf.train.MonitoredTrainingSession(
            master=master,
            is_chief=(task == 0),
            config=session_config,
            checkpoint_dir=log_dir,
            scaffold=scaffold,
            hooks=hooks,
            chief_only_hooks=chief_hooks) as sess:
        np_step = 0
        while not sess.should_stop():
            np_step, _ = sess.run([global_step, train_op])
            if training_steps and np_step % training_steps == 0:
                break
        done = np_step >= max_training_steps
    return np_step, done
Ejemplo n.º 6
0
  def compute_logits(self, data):
    """Computes the class logits for the episode.

    Args:
      data: A `meta_dataset.providers.Episode`.

    Returns:
      The query set logits as a [num_query_images, way] matrix.

    Raises:
      ValueError: Distance must be one of l2 or cosine.
    """
    # ------------------------ Finetuning -------------------------------
    # Possibly make copies of embedding variables, if they will get modified.
    # This is for making temporary-only updates to the embedding network
    # which will not persist after the end of the episode.
    make_copies = self.finetune_all_layers

    # TODO(eringrant): Reduce the number of times the embedding function graph
    # is built with the same input.
    support_embeddings_params_moments = self.embedding_fn(
        data.support_images, self.is_training)
    support_embeddings = support_embeddings_params_moments['embeddings']
    support_embeddings_var_dict = support_embeddings_params_moments['params']

    (embedding_vars_keys, embedding_vars,
     embedding_vars_copy_ops) = get_embeddings_vars_copy_ops(
         support_embeddings_var_dict, make_copies)
    embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops)

    # Compute the initial training loss (only for printing purposes). This
    # line is also needed for adding the fc variables to the graph so that the
    # tf.all_variables() line below detects them.
    logits = self._fc_layer(support_embeddings)[:, 0:data.way]
    finetune_loss = self.compute_loss(
        onehot_labels=data.onehot_support_labels,
        predictions=logits,
    )

    # Decide which variables to finetune.
    fc_vars, vars_to_finetune = [], []
    for var in tf.trainable_variables():
      if 'fc_finetune' in var.name:
        fc_vars.append(var)
        vars_to_finetune.append(var)
    if self.finetune_all_layers:
      vars_to_finetune.extend(embedding_vars)
    logging.info('Finetuning will optimize variables: %s', vars_to_finetune)

    for i in range(self.num_finetune_steps):
      if i == 0:
        # Randomly initialize the fc layer.
        fc_reset = tf.variables_initializer(var_list=fc_vars)
        # Adam related variables are created when minimize() is called.
        # We create an unused op here to put all adam varariables under
        # the 'adam_opt' namescope and create a reset op to reinitialize
        # these variables before the first finetune step.
        adam_reset = tf.no_op()
        if self.finetune_with_adam:
          with tf.variable_scope('adam_opt'):
            unused_op = self.finetune_opt.minimize(
                finetune_loss, var_list=vars_to_finetune)
          adam_reset = tf.variables_initializer(self.finetune_opt.variables())
        with tf.control_dependencies(
            [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] +
            vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:',
                finetune_loss
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

      else:
        with tf.control_dependencies([finetune_op, finetune_loss] +
                                     vars_to_finetune):
          print_op = tf.no_op()
          if self.debug_log:
            print_op = tf.print([
                'step: %d' % i,
                vars_to_finetune[0][0, 0],
                'loss:',
                finetune_loss,
                'accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_support_labels, predictions=logits),
                'query accuracy:',
                self.compute_accuracy(
                    labels=data.onehot_query_labels, predictions=query_logits),
            ])

          with tf.control_dependencies([print_op]):
            # Get the operation for finetuning.
            # (The logits and loss are returned just for printing).
            logits, finetune_loss, finetune_op = self._get_finetune_op(
                data, embedding_vars_keys, embedding_vars, vars_to_finetune,
                support_embeddings if not self.finetune_all_layers else None)

            if self.debug_log:
              # Test logits are computed only for printing logs.
              query_embeddings = self.embedding_fn(
                  data.query_images,
                  self.is_training,
                  params=collections.OrderedDict(
                      zip(embedding_vars_keys, embedding_vars)),
                  reuse=True)['embeddings']
              query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way])

    # Finetuning is now over, compute the query performance using the updated
    # fc layer, and possibly the updated embedding network.
    with tf.control_dependencies([finetune_op] + vars_to_finetune):
      query_embeddings = self.embedding_fn(
          data.query_images,
          self.is_training,
          params=collections.OrderedDict(
              zip(embedding_vars_keys, embedding_vars)),
          reuse=True)['embeddings']
      query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

      if self.debug_log:
        # The train logits are computed only for printing.
        support_embeddings = self.embedding_fn(
            data.support_images,
            self.is_training,
            params=collections.OrderedDict(
                zip(embedding_vars_keys, embedding_vars)),
            reuse=True)['embeddings']
        logits = self._fc_layer(support_embeddings)[:, 0:data.way]

      print_op = tf.no_op()
      if self.debug_log:
        print_op = tf.print([
            'accuracy:',
            self.compute_accuracy(
                labels=data.onehot_support_labels, predictions=logits),
            'query accuracy:',
            self.compute_accuracy(
                labels=data.onehot_query_labels, predictions=query_logits),
        ])
      with tf.control_dependencies([print_op]):
        query_logits = self._fc_layer(query_embeddings)[:, 0:data.way]

    return query_logits
Ejemplo n.º 7
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Ejemplo n.º 8
0
def bn(x,
       params=None,
       moments=None,
       backprop_through_moments=True,
       use_ema=False,
       is_training=True,
       ema_epsilon=.9):
  """Batch normalization.

  The usage should be as follows: If x is the support images, moments should be
  None so that they are computed from the support set examples. On the other
  hand, if x is the query images, the moments argument should be used in order
  to pass in the mean and var that were computed from the support set.

  Args:
    x: inputs.
    params: None or a dict containing the values of the offset and scale params.
    moments: None or a dict containing the values of the mean and var to use for
      batch normalization.
    backprop_through_moments: Whether to allow gradients to flow through the
      given support set moments. Only applies to non-transductive batch norm.
    use_ema: apply moving averages of batch norm statistics, or update them,
      depending on whether we are training or testing.  Note that passing
      moments will override this setting, and result in neither updating or
      using ema statistics.  This is important to make sure that episodic
      learners don't update ema statistics a second time when processing
      queries.
    is_training: if use_ema=True, this determines whether to apply the moving
      averages, or update them.
    ema_epsilon: if updating moving averages, use this value for the
      exponential moving averages.

  Returns:
    output: The result of applying batch normalization to the input.
    params: The updated params.
    moments: The updated moments.
  """
  params_keys, params_vars, moments_keys, moments_vars = [], [], [], []

  with tf.variable_scope('batch_norm'):
    scope_name = tf.get_variable_scope().name

    if use_ema:
      ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]]
      mean_ema = tf.get_variable(
          'mean_ema',
          shape=ema_shape,
          initializer=tf.initializers.zeros(),
          trainable=False)
      var_ema = tf.get_variable(
          'var_ema',
          shape=ema_shape,
          initializer=tf.initializers.ones(),
          trainable=False)

    if moments is not None:
      if backprop_through_moments:
        mean = moments[scope_name + '/mean']
        var = moments[scope_name + '/var']
      else:
        # This variant does not yield good resutls.
        mean = tf.stop_gradient(moments[scope_name + '/mean'])
        var = tf.stop_gradient(moments[scope_name + '/var'])
    elif use_ema and not is_training:
      mean = mean_ema
      var = var_ema
    else:
      # If not provided, compute the mean and var of the current batch.

      replica_ctx = tf.distribute.get_replica_context()
      if replica_ctx:
        # from third_party/tensorflow/python/keras/layers/normalization_v2.py
        axes = list(range(len(x.shape) - 1))
        local_sum = tf.reduce_sum(x, axis=axes, keepdims=True)
        local_squared_sum = tf.reduce_sum(
            tf.square(x), axis=axes, keepdims=True)
        batch_size = tf.cast(tf.shape(x)[0], tf.float32)
        x_sum, x_squared_sum, global_batch_size = (
            replica_ctx.all_reduce('sum',
                                   [local_sum, local_squared_sum, batch_size]))

        axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))]
        multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
        multiplier = multiplier * global_batch_size

        mean = x_sum / multiplier
        x_squared_mean = x_squared_sum / multiplier
        # var = E(x^2) - E(x)^2
        var = x_squared_mean - tf.square(mean)
      else:
        mean, var = tf.nn.moments(
            x, axes=list(range(len(x.shape) - 1)), keep_dims=True)

    # Only update ema's if training and we computed the moments in the current
    # call.  Note: at test time for episodic learners, ema's may be passed
    # from the support set to the query set, even if it's not really needed.
    if use_ema and is_training and moments is None:
      replica_ctx = tf.distribute.get_replica_context()
      mean_upd = tf.assign(mean_ema,
                           mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon))
      var_upd = tf.assign(var_ema,
                          var_ema * ema_epsilon + var * (1.0 - ema_epsilon))
      updates = tf.group([mean_upd, var_upd])
      if replica_ctx:
        tf.add_to_collection(
            tf.GraphKeys.UPDATE_OPS,
            tf.cond(
                tf.equal(replica_ctx.replica_id_in_sync_group, 0),
                lambda: updates, tf.no_op))
      else:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates)

    moments_keys += [scope_name + '/mean']
    moments_vars += [mean]
    moments_keys += [scope_name + '/var']
    moments_vars += [var]

    if params is None:
      offset = tf.get_variable(
          'offset',
          shape=mean.get_shape().as_list(),
          initializer=tf.initializers.zeros())
      scale = tf.get_variable(
          'scale',
          shape=var.get_shape().as_list(),
          initializer=tf.initializers.ones())
    else:
      offset = params[scope_name + '/offset']
      scale = params[scope_name + '/scale']

    params_keys += [scope_name + '/offset']
    params_vars += [offset]
    params_keys += [scope_name + '/scale']
    params_vars += [scale]

    output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001)

    params = collections.OrderedDict(zip(params_keys, params_vars))
    moments = collections.OrderedDict(zip(moments_keys, moments_vars))

    return output, params, moments
Ejemplo n.º 9
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])