def body(self, features): # Remove dropout if not training hparams = self._hparams ps_devices = self._ps_devices assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ ps_devices[i // shards_per_device] for i in xrange(hparams.num_model_shards) ] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = features["targets_raw"] targets = tf.squeeze(targets, 3) targets = tf.squeeze(targets, 2) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. decoder_input = mp(_embedding, shifted_targets, vocab_size, hparams.hidden_size) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) else: targets_position = None if hparams.pos == "timing": if targets_position is None: decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) else: decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) decoder_input = mp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output, extra_loss = _super_stack(decoder_input, decoder_self_attention_bias, hparams, mp) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") logits = common_layers.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n**-0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. logits_shard_0 = logits[0] logits_shard_0 = tf.expand_dims(logits_shard_0, 2) logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy(logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) # override training loss so that it is not computed externally. losses = {"training": tf.add_n(num) / tf.add_n(denom)} if extra_loss is not None: losses["extra"] = extra_loss return logits_shard_0, losses
def data_parallelism(all_workers=False): """Over which devices do we split each training batch. In old-fashioned async mode, we split the batch over all GPUs on the current worker. In sync mode, we split the batch over all the parameter server GPUs. This function returns an expert_utils.Parallelism object, which can be used to build the model. It is configured in a way that any variables created by `tf.get_variable` will be assigned to the parameter servers and shared between datashards. Args: all_workers: whether the devices are all async workers or just this one. Returns: a expert_utils.Parallelism. """ def _replica_device_setter(worker_device): if FLAGS.ps_replicas == 0: return worker_device return tf.train.replica_device_setter( worker_device=worker_device, ps_tasks=FLAGS.ps_replicas, ps_device=FLAGS.ps_job + "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job) if FLAGS.schedule == "train_and_evaluate": assert not FLAGS.sync datashard_devices = [ "gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu) ] if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: datashard_devices += ["cpu:0"] caching_devices = None elif FLAGS.sync: assert FLAGS.ps_replicas > 0 datashard_devices = [ _replica_device_setter(d) for d in ps_devices(all_workers=all_workers) ] if FLAGS.ps_gpu > 0 and FLAGS.ps_replicas > 1: caching_devices = [ FLAGS.ps_job + "/task:%d/cpu:0" % d for (d, _) in _ps_gpus(all_workers=all_workers) ] else: caching_devices = None else: # old fashioned async - compute on worker if FLAGS.worker_gpu > 1: datashard_devices = [ _replica_device_setter(FLAGS.worker_job + "/GPU:%d" % d) for d in _gpu_order(FLAGS.worker_gpu) ] caching_devices = [FLAGS.worker_job + "/GPU:0"] * FLAGS.worker_gpu else: datashard_devices = [_replica_device_setter(FLAGS.worker_job)] caching_devices = None tf.logging.info("datashard_devices: %s", datashard_devices) tf.logging.info("caching_devices: %s", caching_devices) return eu.Parallelism(datashard_devices, reuse=True, caching_devices=caching_devices, daisy_chain_variables=FLAGS.daisy_chain_variables)
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards) ] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary[ "targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary[ "inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp(tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack(mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp(tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def data_parallelism(hparams, all_workers=False): """Over which devices do we split each training batch. In old-fashioned async mode, we split the batch over all GPUs on the current worker. In sync mode, we split the batch over all the parameter server GPUs. This function returns an expert_utils.Parallelism object, which can be used to build the model. It is configured in a way that any variables created by `tf.get_variable` will be assigned to the parameter servers and shared between datashards. Args: hparams: model hyperparameters (an HParams object). all_workers: whether the devices are all async workers or just this one. Returns: a expert_utils.Parallelism. """ def _replica_device_setter(worker_device): if FLAGS.ps_replicas == 0: return worker_device return tf.train.replica_device_setter( worker_device=worker_device, ps_tasks=FLAGS.ps_replicas, ps_device=FLAGS.ps_job + "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job) if FLAGS.schedule in ["train_and_evaluate", "continuous_train_and_eval"]: assert not FLAGS.sync tf.logging.warn( "Schedule=%s. Assuming that training is running on a single machine.", FLAGS.schedule) datashard_devices = [ "gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu) ] if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: datashard_devices += ["cpu:0"] caching_devices = None elif FLAGS.sync and FLAGS.ps_replicas > 0: # compute on ps datashard_devices = [ _replica_device_setter(d) for d in ps_devices(all_workers=all_workers) ] if FLAGS.ps_gpu > 0 and FLAGS.ps_replicas > 1: caching_devices = [ FLAGS.ps_job + "/task:%d/cpu:0" % d for (d, _) in _ps_gpus(all_workers=all_workers) ] else: caching_devices = None else: # compute on worker - this is either a single-worker setup or asynchronous # with parameter servers. if FLAGS.worker_gpu > 1: datashard_devices = [ _replica_device_setter(FLAGS.worker_job + "/GPU:%d" % d) for d in _gpu_order(FLAGS.worker_gpu) ] caching_devices = [FLAGS.worker_job + "/GPU:0"] * FLAGS.worker_gpu else: datashard_devices = [_replica_device_setter(FLAGS.worker_job)] caching_devices = None tf.logging.info("datashard_devices: %s", datashard_devices) tf.logging.info("caching_devices: %s", caching_devices) return eu.Parallelism(datashard_devices, caching_devices=caching_devices, daisy_chain_variables=hparams.daisy_chain_variables)
def data_parallelism(daisy_chain_variables=True, all_workers=False, ps_replicas=0, ps_job="/job:ps", ps_gpu=0, schedule="continuous_train_and_eval", sync=False, worker_gpu=1, worker_replicas=1, worker_id=0, gpu_order="", locally_shard_to_cpu=False, worker_job="/job:localhost", no_data_parallelism=False): """See data_parallelism_from_flags.""" tf.logging.info("schedule=%s" % schedule) tf.logging.info("worker_gpu=%s" % worker_gpu) tf.logging.info("sync=%s" % sync) def _ps_replicas(all_workers=False): if all_workers: return list(range(ps_replicas)) # Worker K will be using replicas {0,...n-1} + K*n if we have n replicas. num_replicas = ps_replicas // worker_replicas return [d + worker_id * num_replicas for d in range(num_replicas)] def _gpu_order(num_gpus): if gpu_order: ret = [int(s) for s in gpu_order.split(" ")] if len(ret) == num_gpus: return ret return list(range(num_gpus)) def _ps_gpus(all_workers=False): ps_gpus = [] for d in _ps_replicas(all_workers=all_workers): ps_gpus.extend([(d, gpu) for gpu in _gpu_order(ps_gpu)]) return ps_gpus def ps_devices(all_workers=False): """List of ps devices (where to put the experts). Args: all_workers: whether the list is for all async workers or just this one. Returns: a list of device names """ if ps_replicas > 0: if ps_gpu > 0: return [ ps_job + "/task:%d/GPU:%d" % (d, gpu) for (d, gpu) in _ps_gpus(all_workers=all_workers) ] else: return [ ps_job + "/task:%d" % d for d in _ps_replicas(all_workers=all_workers) ] else: if worker_gpu > 0: return ["gpu:%d" % d for d in _gpu_order(worker_gpu)] else: return [""] def _replica_device_setter(worker_device): if ps_replicas == 0: return worker_device return tf.train.replica_device_setter( worker_device=worker_device, ps_tasks=ps_replicas, ps_device=ps_job + "/GPU:0" if ps_gpu > 0 else ps_job) is_single_machine = ps_replicas == 0 and worker_replicas == 1 if no_data_parallelism: datashard_devices = [""] caching_devices = None elif is_single_machine: assert not sync tf.logging.warn( "Schedule=%s. Assuming that training is running on a single machine.", schedule) datashard_devices = ["gpu:%d" % d for d in _gpu_order(worker_gpu)] if locally_shard_to_cpu or worker_gpu < 1: datashard_devices += ["cpu:0"] caching_devices = None elif sync and ps_replicas > 0: # compute on ps datashard_devices = [ _replica_device_setter(d) for d in ps_devices(all_workers=all_workers) ] if ps_gpu > 0 and ps_replicas > 1: caching_devices = [ ps_job + "/task:%d/cpu:0" % d for (d, _) in _ps_gpus(all_workers=all_workers) ] else: caching_devices = None else: # compute on worker - this is either a single-worker setup or asynchronous # with parameter servers. if worker_gpu > 1: datashard_devices = [ _replica_device_setter(worker_job + "/GPU:%d" % d) for d in _gpu_order(worker_gpu) ] caching_devices = None else: datashard_devices = [_replica_device_setter(worker_job)] caching_devices = None tf.logging.info("datashard_devices: %s", datashard_devices) tf.logging.info("caching_devices: %s", caching_devices) tf.logging.info("ps_devices: %s", ps_devices(all_workers=all_workers)) return eu.Parallelism(datashard_devices, caching_devices=caching_devices, daisy_chain_variables=daisy_chain_variables, ps_devices=ps_devices(all_workers=all_workers))
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=True): """Model fn for Estimator. Args: hparams: HParams, model hyperparameters features: dict<str name, Tensor feature> labels: Tensor mode: tf.estimator.ModeKeys config: RunConfig; if passed, should have t2t_device_info dict params: dict, may include batch_size decode_hparams: HParams, used when mode == PREDICT. use_tpu: bool, whether using TPU Returns: TPUEstimatorSpec if use tpu else EstimatorSpec """ tf.logging.warning( "T2TModel.estimator_model_fn implements a subset of " "model_builder.model_fn and is currently only used " "in tpu_trainer.") _create_dummy_vars() hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu problem = hparams.problem_instances[0] # Instantiate model data_parallelism = (eu.Parallelism([""]) if use_tpu else _create_data_parallelism(**config.t2t_device_info)) model = cls(hparams, mode, data_parallelism=data_parallelism) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: assert not use_tpu assert decode_hparams is not None return model.estimator_spec_predict(features, decode_hparams) # TRAIN and EVAL modes logits, losses_dict = model(features) # pylint: disable=not-callable # Set known shapes # TODO(rsepassi): Add support for variable lengths and batch sizes shape = logits.get_shape().as_list() if shape[0] is None: shape[0] = _get_batch_size(params, hparams, config) if shape[1] is None: shape[1] = hparams.max_length logits.set_shape(shape) # Accumulate losses assert "training" in losses_dict loss = sum(losses_dict.values()) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: return model.estimator_spec_eval(features, logits, labels, loss, problem, hparams, use_tpu=use_tpu) # TRAIN mode assert mode == tf.estimator.ModeKeys.TRAIN return model.estimator_spec_train(loss, use_tpu=use_tpu)