def body(self, features):
        # Remove dropout if not training
        hparams = self._hparams
        ps_devices = self._ps_devices
        assert hparams.num_model_shards % len(ps_devices) == 0
        shards_per_device = hparams.num_model_shards // len(ps_devices)
        model_devices = [
            ps_devices[i // shards_per_device]
            for i in xrange(hparams.num_model_shards)
        ]
        print("model_devices = %s" % model_devices)
        mp = expert_utils.Parallelism(model_devices, reuse=False)
        vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
        # squeeze out channels, heights
        targets = features["targets_raw"]
        targets = tf.squeeze(targets, 3)
        targets = tf.squeeze(targets, 2)
        shifted_targets = common_layers.shift_right_2d(targets)
        # Bypass the symbol modality and use a different embedding on each shard.
        decoder_input = mp(_embedding, shifted_targets, vocab_size,
                           hparams.hidden_size)
        decoder_self_attention_bias = mp(
            common_attention.attention_bias_lower_triangle,
            tf.shape(targets)[1])
        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = features["targets_segmentation"]
            targets_position = features["targets_position"]
            decoder_self_attention_bias = mp(
                tf.add, decoder_self_attention_bias,
                mp(common_attention.attention_bias_same_segment,
                   targets_segmentation, targets_segmentation))
        else:
            targets_position = None

        if hparams.pos == "timing":
            if targets_position is None:
                decoder_input = mp(common_attention.add_timing_signal_1d,
                                   decoder_input)
            else:
                decoder_input = mp(
                    common_attention.add_timing_signal_1d_given_position,
                    decoder_input, targets_position)

        decoder_input = mp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_output, extra_loss = _super_stack(decoder_input,
                                                  decoder_self_attention_bias,
                                                  hparams, mp)
        # Bypass the symbol modality and compute logits directly.
        # We compute a different set of logits on each shard, and sum them.
        logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits")
        logits = common_layers.all_reduce_ring(logits, mp)
        logits = mp(tf.multiply, logits, mp.n**-0.5)
        # We now have identical logits on all shards.
        # Shard 0 gets returned to the estimator.
        logits_shard_0 = logits[0]
        logits_shard_0 = tf.expand_dims(logits_shard_0, 2)
        logits_shard_0 = tf.expand_dims(logits_shard_0, 3)
        # On each device, we compute the loss for a part of the batch.
        # This is faster than computing the whole loss on one shard.
        mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0])

        def _loss_for_shard(logits, targets, shard):
            if mp.n > 1:
                logits = common_layers.approximate_split(logits, mp.n,
                                                         0)[shard]
                targets = common_layers.approximate_split(targets, mp.n,
                                                          0)[shard]
            return common_layers.padded_cross_entropy(logits, targets,
                                                      hparams.label_smoothing)

        num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
        # override training loss so that it is not computed externally.
        losses = {"training": tf.add_n(num) / tf.add_n(denom)}
        if extra_loss is not None:
            losses["extra"] = extra_loss
        return logits_shard_0, losses
def data_parallelism(all_workers=False):
    """Over which devices do we split each training batch.

  In old-fashioned async mode, we split the batch over all GPUs on the
  current worker.

  In sync mode, we split the batch over all the parameter server GPUs.

  This function returns an expert_utils.Parallelism object, which can be used
  to build the model.  It is configured in a way that any variables created
  by `tf.get_variable` will be assigned to the parameter servers and shared
  between datashards.

  Args:
    all_workers: whether the devices are all async workers or just this one.

  Returns:
    a expert_utils.Parallelism.
  """
    def _replica_device_setter(worker_device):
        if FLAGS.ps_replicas == 0:
            return worker_device
        return tf.train.replica_device_setter(
            worker_device=worker_device,
            ps_tasks=FLAGS.ps_replicas,
            ps_device=FLAGS.ps_job +
            "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job)

    if FLAGS.schedule == "train_and_evaluate":
        assert not FLAGS.sync
        datashard_devices = [
            "gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)
        ]
        if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1:
            datashard_devices += ["cpu:0"]
        caching_devices = None
    elif FLAGS.sync:
        assert FLAGS.ps_replicas > 0
        datashard_devices = [
            _replica_device_setter(d)
            for d in ps_devices(all_workers=all_workers)
        ]
        if FLAGS.ps_gpu > 0 and FLAGS.ps_replicas > 1:
            caching_devices = [
                FLAGS.ps_job + "/task:%d/cpu:0" % d
                for (d, _) in _ps_gpus(all_workers=all_workers)
            ]
        else:
            caching_devices = None
    else:
        # old fashioned async - compute on worker
        if FLAGS.worker_gpu > 1:
            datashard_devices = [
                _replica_device_setter(FLAGS.worker_job + "/GPU:%d" % d)
                for d in _gpu_order(FLAGS.worker_gpu)
            ]
            caching_devices = [FLAGS.worker_job + "/GPU:0"] * FLAGS.worker_gpu
        else:
            datashard_devices = [_replica_device_setter(FLAGS.worker_job)]
            caching_devices = None
    tf.logging.info("datashard_devices: %s", datashard_devices)
    tf.logging.info("caching_devices: %s", caching_devices)
    return eu.Parallelism(datashard_devices,
                          reuse=True,
                          caching_devices=caching_devices,
                          daisy_chain_variables=FLAGS.daisy_chain_variables)
    def body(self, features):
        hparams = self._hparams
        ps_devices = self._ps_devices
        single_device = (len(ps_devices) == 1)
        assert hparams.num_model_shards % len(ps_devices) == 0
        shards_per_device = hparams.num_model_shards // len(ps_devices)
        model_devices = [
            ps_devices[i // shards_per_device]
            for i in range(hparams.num_model_shards)
        ]
        print("model_devices = %s" % model_devices)
        mp = expert_utils.Parallelism(model_devices, reuse=False)
        targets_vocab_size = self._problem_hparams.vocabulary[
            "targets"].vocab_size
        # squeeze out channels, heights
        targets = tf.squeeze(features["targets_raw"], [2, 3])
        targets_embedding_var = mp(
            tf.get_variable,
            "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n,
            initializer=tf.random_normal_initializer(
                0.0, hparams.hidden_size**-0.5))
        shifted_targets = common_layers.shift_right_2d(targets)
        # Bypass the symbol modality and use a different embedding on each shard.
        if single_device:
            targets_embedding_var_combined = tf.concat(targets_embedding_var,
                                                       1)
            decoder_input_combined = common_layers.embedding(
                shifted_targets,
                targets_vocab_size,
                hparams.hidden_size * mp.n,
                multiplier=hparams.hidden_size**0.5,
                embedding_var=targets_embedding_var_combined,
            )
            decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
        else:
            targets_embedding_var_combined = None
            decoder_input = mp(
                common_layers.embedding,
                shifted_targets,
                targets_vocab_size,
                hparams.hidden_size,
                multiplier=hparams.hidden_size**0.5,
                embedding_var=targets_embedding_var,
            )
        decoder_self_attention_bias = mp(
            common_attention.attention_bias_lower_triangle,
            tf.shape(targets)[1])
        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = features["targets_segmentation"]
            targets_position = features["targets_position"]
            decoder_self_attention_bias = mp(
                tf.add, decoder_self_attention_bias,
                mp(common_attention.attention_bias_same_segment,
                   targets_segmentation, targets_segmentation))
            decoder_input = mp(
                common_attention.add_timing_signal_1d_given_position,
                decoder_input, targets_position)
        else:
            targets_position = None
            decoder_self_attention_bias = mp(
                common_attention.attention_bias_lower_triangle,
                tf.shape(targets)[1])
            decoder_input = mp(common_attention.add_timing_signal_1d,
                               decoder_input)

        if self.has_input:
            inputs = tf.squeeze(features["inputs_raw"], [2, 3])
            inputs_vocab_size = self._problem_hparams.vocabulary[
                "inputs"].vocab_size
            # share everything for now
            share_inputs_and_targets_embedding = True
            if share_inputs_and_targets_embedding:
                assert inputs_vocab_size == targets_vocab_size
                inputs_embedding_var = targets_embedding_var
                inputs_embedding_var_combined = targets_embedding_var_combined
            if single_device:
                encoder_input_combined = common_layers.embedding(
                    inputs,
                    inputs_vocab_size,
                    hparams.hidden_size * mp.n,
                    multiplier=hparams.hidden_size**0.5,
                    embedding_var=inputs_embedding_var_combined,
                )
                encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
            else:
                encoder_input = mp(
                    common_layers.embedding,
                    inputs,
                    inputs_vocab_size,
                    hparams.hidden_size,
                    multiplier=hparams.hidden_size**0.5,
                    embedding_var=inputs_embedding_var,
                )
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = features["inputs_segmentation"]
                inputs_position = features["inputs_position"]
                encoder_self_attention_bias = mp(
                    common_attention.attention_bias_same_segment,
                    inputs_segmentation, inputs_segmentation)
                encoder_decoder_attention_bias = mp(
                    common_attention.attention_bias_same_segment,
                    targets_segmentation, inputs_segmentation)
                encoder_input = mp(
                    common_attention.add_timing_signal_1d_given_position,
                    encoder_input, inputs_position)
            else:
                encoder_padding = tf.to_float(tf.equal(inputs, 0))
                ignore_padding = common_attention.attention_bias_ignore_padding(
                    encoder_padding)
                encoder_self_attention_bias = ignore_padding
                encoder_decoder_attention_bias = ignore_padding
                inputs_position = None
                encoder_input = mp(common_attention.add_timing_signal_1d,
                                   encoder_input)

            # encoder stack here
            with tf.variable_scope("encoder"):
                encoder_input = mp(tf.nn.dropout, encoder_input,
                                   1.0 - hparams.layer_prepostprocess_dropout)
                encoder_output = _layer_stack(mp, encoder_input,
                                              encoder_self_attention_bias,
                                              hparams.encoder_layers, hparams)
        else:
            encoder_decoder_attention_bias = None
            encoder_output = None

        with tf.variable_scope("decoder"):
            decoder_input = mp(tf.nn.dropout, decoder_input,
                               1.0 - hparams.layer_prepostprocess_dropout)
            decoder_output = _layer_stack(
                mp,
                decoder_input,
                decoder_self_attention_bias,
                layers=hparams.decoder_layers,
                hparams=hparams,
                encoder_output=encoder_output,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias)

        # Bypass the symbol modality and compute logits directly.
        # We compute a different set of logits on each shard, and sum them.
        # Share the weights with the target embedding.
        output_var = targets_embedding_var
        output_var_combined = targets_embedding_var_combined
        if single_device:
            decoder_output = tf.concat(decoder_output, 2)
            logits = tf.tensordot(decoder_output, output_var_combined,
                                  [[2], [1]])
            num, denom = common_layers.padded_cross_entropy(
                logits, targets, hparams.label_smoothing)
            training_loss = num / denom
        else:
            logits = mp(tf.tensordot, decoder_output, output_var,
                        [[[2], [1]]] * mp.n)
            logits = expert_utils.all_reduce_ring(logits, mp)
            # On each device, we compute the loss for a part of the batch.
            # This is faster than computing the whole loss on one shard.
            mp, logits = expert_utils.reduce_by_device(mp, logits,
                                                       lambda l: l[0])

            def _loss_for_shard(logits, targets, shard):
                logits = common_layers.approximate_split(logits, mp.n,
                                                         0)[shard]
                targets = common_layers.approximate_split(targets, mp.n,
                                                          0)[shard]
                return common_layers.padded_cross_entropy(
                    logits, targets, hparams.label_smoothing)

            num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
            training_loss = tf.add_n(num) / tf.add_n(denom)
            logits = logits[0]
        logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
        # override training loss so that it is not computed externally.
        losses = {"training": training_loss}
        return logits, losses
Example #4
0
def data_parallelism(hparams, all_workers=False):
    """Over which devices do we split each training batch.

  In old-fashioned async mode, we split the batch over all GPUs on the
  current worker.

  In sync mode, we split the batch over all the parameter server GPUs.

  This function returns an expert_utils.Parallelism object, which can be used
  to build the model.  It is configured in a way that any variables created
  by `tf.get_variable` will be assigned to the parameter servers and shared
  between datashards.

  Args:
    hparams: model hyperparameters (an HParams object).
    all_workers: whether the devices are all async workers or just this one.

  Returns:
    a expert_utils.Parallelism.
  """
    def _replica_device_setter(worker_device):
        if FLAGS.ps_replicas == 0:
            return worker_device
        return tf.train.replica_device_setter(
            worker_device=worker_device,
            ps_tasks=FLAGS.ps_replicas,
            ps_device=FLAGS.ps_job +
            "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job)

    if FLAGS.schedule in ["train_and_evaluate", "continuous_train_and_eval"]:
        assert not FLAGS.sync
        tf.logging.warn(
            "Schedule=%s. Assuming that training is running on a single machine.",
            FLAGS.schedule)
        datashard_devices = [
            "gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)
        ]
        if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1:
            datashard_devices += ["cpu:0"]
        caching_devices = None
    elif FLAGS.sync and FLAGS.ps_replicas > 0:
        # compute on ps
        datashard_devices = [
            _replica_device_setter(d)
            for d in ps_devices(all_workers=all_workers)
        ]
        if FLAGS.ps_gpu > 0 and FLAGS.ps_replicas > 1:
            caching_devices = [
                FLAGS.ps_job + "/task:%d/cpu:0" % d
                for (d, _) in _ps_gpus(all_workers=all_workers)
            ]
        else:
            caching_devices = None
    else:
        # compute on worker - this is either a single-worker setup or asynchronous
        # with parameter servers.
        if FLAGS.worker_gpu > 1:
            datashard_devices = [
                _replica_device_setter(FLAGS.worker_job + "/GPU:%d" % d)
                for d in _gpu_order(FLAGS.worker_gpu)
            ]
            caching_devices = [FLAGS.worker_job + "/GPU:0"] * FLAGS.worker_gpu
        else:
            datashard_devices = [_replica_device_setter(FLAGS.worker_job)]
            caching_devices = None
    tf.logging.info("datashard_devices: %s", datashard_devices)
    tf.logging.info("caching_devices: %s", caching_devices)
    return eu.Parallelism(datashard_devices,
                          caching_devices=caching_devices,
                          daisy_chain_variables=hparams.daisy_chain_variables)
Example #5
0
def data_parallelism(daisy_chain_variables=True,
                     all_workers=False,
                     ps_replicas=0,
                     ps_job="/job:ps",
                     ps_gpu=0,
                     schedule="continuous_train_and_eval",
                     sync=False,
                     worker_gpu=1,
                     worker_replicas=1,
                     worker_id=0,
                     gpu_order="",
                     locally_shard_to_cpu=False,
                     worker_job="/job:localhost",
                     no_data_parallelism=False):
    """See data_parallelism_from_flags."""
    tf.logging.info("schedule=%s" % schedule)
    tf.logging.info("worker_gpu=%s" % worker_gpu)
    tf.logging.info("sync=%s" % sync)

    def _ps_replicas(all_workers=False):
        if all_workers:
            return list(range(ps_replicas))
        # Worker K will be using replicas {0,...n-1} + K*n if we have n replicas.
        num_replicas = ps_replicas // worker_replicas
        return [d + worker_id * num_replicas for d in range(num_replicas)]

    def _gpu_order(num_gpus):
        if gpu_order:
            ret = [int(s) for s in gpu_order.split(" ")]
            if len(ret) == num_gpus:
                return ret
        return list(range(num_gpus))

    def _ps_gpus(all_workers=False):
        ps_gpus = []
        for d in _ps_replicas(all_workers=all_workers):
            ps_gpus.extend([(d, gpu) for gpu in _gpu_order(ps_gpu)])
        return ps_gpus

    def ps_devices(all_workers=False):
        """List of ps devices (where to put the experts).

    Args:
      all_workers: whether the list is for all async workers or just this one.

    Returns:
      a list of device names
    """
        if ps_replicas > 0:
            if ps_gpu > 0:
                return [
                    ps_job + "/task:%d/GPU:%d" % (d, gpu)
                    for (d, gpu) in _ps_gpus(all_workers=all_workers)
                ]
            else:
                return [
                    ps_job + "/task:%d" % d
                    for d in _ps_replicas(all_workers=all_workers)
                ]
        else:
            if worker_gpu > 0:
                return ["gpu:%d" % d for d in _gpu_order(worker_gpu)]
            else:
                return [""]

    def _replica_device_setter(worker_device):
        if ps_replicas == 0:
            return worker_device
        return tf.train.replica_device_setter(
            worker_device=worker_device,
            ps_tasks=ps_replicas,
            ps_device=ps_job + "/GPU:0" if ps_gpu > 0 else ps_job)

    is_single_machine = ps_replicas == 0 and worker_replicas == 1

    if no_data_parallelism:
        datashard_devices = [""]
        caching_devices = None
    elif is_single_machine:
        assert not sync
        tf.logging.warn(
            "Schedule=%s. Assuming that training is running on a single machine.",
            schedule)
        datashard_devices = ["gpu:%d" % d for d in _gpu_order(worker_gpu)]
        if locally_shard_to_cpu or worker_gpu < 1:
            datashard_devices += ["cpu:0"]
        caching_devices = None
    elif sync and ps_replicas > 0:
        # compute on ps
        datashard_devices = [
            _replica_device_setter(d)
            for d in ps_devices(all_workers=all_workers)
        ]
        if ps_gpu > 0 and ps_replicas > 1:
            caching_devices = [
                ps_job + "/task:%d/cpu:0" % d
                for (d, _) in _ps_gpus(all_workers=all_workers)
            ]
        else:
            caching_devices = None
    else:
        # compute on worker - this is either a single-worker setup or asynchronous
        # with parameter servers.
        if worker_gpu > 1:
            datashard_devices = [
                _replica_device_setter(worker_job + "/GPU:%d" % d)
                for d in _gpu_order(worker_gpu)
            ]
            caching_devices = None
        else:
            datashard_devices = [_replica_device_setter(worker_job)]
            caching_devices = None
    tf.logging.info("datashard_devices: %s", datashard_devices)
    tf.logging.info("caching_devices: %s", caching_devices)
    tf.logging.info("ps_devices: %s", ps_devices(all_workers=all_workers))
    return eu.Parallelism(datashard_devices,
                          caching_devices=caching_devices,
                          daisy_chain_variables=daisy_chain_variables,
                          ps_devices=ps_devices(all_workers=all_workers))
Example #6
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=True):
        """Model fn for Estimator.

    Args:
      hparams: HParams, model hyperparameters
      features: dict<str name, Tensor feature>
      labels: Tensor
      mode: tf.estimator.ModeKeys
      config: RunConfig; if passed, should have t2t_device_info dict
      params: dict, may include batch_size
      decode_hparams: HParams, used when mode == PREDICT.
      use_tpu: bool, whether using TPU

    Returns:
      TPUEstimatorSpec if use tpu else EstimatorSpec
    """
        tf.logging.warning(
            "T2TModel.estimator_model_fn implements a subset of "
            "model_builder.model_fn and is currently only used "
            "in tpu_trainer.")
        _create_dummy_vars()
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        problem = hparams.problem_instances[0]

        # Instantiate model
        data_parallelism = (eu.Parallelism([""]) if use_tpu else
                            _create_data_parallelism(**config.t2t_device_info))
        model = cls(hparams, mode, data_parallelism=data_parallelism)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            assert not use_tpu
            assert decode_hparams is not None
            return model.estimator_spec_predict(features, decode_hparams)

        # TRAIN and EVAL modes
        logits, losses_dict = model(features)  # pylint: disable=not-callable

        # Set known shapes
        # TODO(rsepassi): Add support for variable lengths and batch sizes
        shape = logits.get_shape().as_list()
        if shape[0] is None:
            shape[0] = _get_batch_size(params, hparams, config)
        if shape[1] is None:
            shape[1] = hparams.max_length
        logits.set_shape(shape)

        # Accumulate losses
        assert "training" in losses_dict
        loss = sum(losses_dict.values())

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            return model.estimator_spec_eval(features,
                                             logits,
                                             labels,
                                             loss,
                                             problem,
                                             hparams,
                                             use_tpu=use_tpu)

        # TRAIN mode
        assert mode == tf.estimator.ModeKeys.TRAIN
        return model.estimator_spec_train(loss, use_tpu=use_tpu)