def _build_bidi_rnn_fused(self, inputs, sequence_length, hparams, dtype): mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=hparams.dropout) if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): inputs = tf.nn.dropout(inputs, keep_prob=1 - hparams.dropout) fwd_cell = block_lstm.LSTMBlockFusedCell(hparams.num_units, hparams.forget_bias, dtype=dtype) fwd_encoder_outputs, (fwd_final_c, fwd_final_h) = fwd_cell( inputs, dtype=dtype, sequence_length=sequence_length) inputs_r = tf.reverse_sequence(inputs, sequence_length, batch_axis=1, seq_axis=0) bak_cell = block_lstm.LSTMBlockFusedCell(hparams.num_units, hparams.forget_bias, dtype=dtype) bak_encoder_outputs, (bak_final_c, bak_final_h) = bak_cell( inputs_r, dtype=dtype, sequence_length=sequence_length) bak_encoder_outputs = tf.reverse_sequence(bak_encoder_outputs, sequence_length, batch_axis=1, seq_axis=0) bi_encoder_outputs = tf.concat( [fwd_encoder_outputs, bak_encoder_outputs], axis=-1) fwd_state = tf.nn.rnn_cell.LSTMStateTuple(fwd_final_c, fwd_final_h) bak_state = tf.nn.rnn_cell.LSTMStateTuple(bak_final_c, bak_final_h) bi_encoder_state = (fwd_state, bak_state) # mask aren't applied on outputs, but final states are post-masking. return bi_encoder_outputs, bi_encoder_state
def __init__(self, model, criterion, opt_config, print_freq=10, save_freq=1000, grad_clip=float('inf'), batch_first=False, save_info={}, save_path='.', checkpoint_filename='checkpoint%s.pth', keep_checkpoints=5, math='fp32', cuda=True, distributed=False, verbose=False): super(Seq2SeqTrainer, self).__init__() self.model = model self.criterion = criterion self.epoch = 0 self.save_info = save_info self.save_path = save_path self.save_freq = save_freq self.save_counter = 0 self.checkpoint_filename = checkpoint_filename self.checkpoint_counter = cycle(range(keep_checkpoints)) self.opt_config = opt_config self.cuda = cuda self.distributed = distributed self.print_freq = print_freq self.batch_first = batch_first self.verbose = verbose self.loss = None if cuda: self.model = self.model.cuda() self.criterion = self.criterion.cuda() if distributed: self.model = DDP(self.model) if math == 'fp16': self.model = self.model.half() self.fp_optimizer = Fp16Optimizer(self.model, grad_clip) params = self.fp_optimizer.fp32_params elif math == 'fp32': self.fp_optimizer = Fp32Optimizer(self.model, grad_clip) params = self.model.parameters() opt_name = opt_config['optimizer'] lr = opt_config['lr'] self.optimizer = torch.optim.__dict__[opt_name](params, lr=lr) mlperf_log.gnmt_print(key=mlperf_log.OPT_NAME, value=opt_name) mlperf_log.gnmt_print(key=mlperf_log.OPT_LR, value=lr) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=self.optimizer.defaults['betas'][0]) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=self.optimizer.defaults['betas'][1]) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=self.optimizer.defaults['eps'])
def _length_penalty(sequence_lengths, penalty_factor, dtype): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. Returns the length penalty tensor: ``` [(5+sequence_lengths)/6]**penalty_factor ``` where all operations are performed element-wise. Args: sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. dtype: dtype of result. Returns: If the penalty is `0`, returns the scalar `1.0`. Otherwise returns the length penalty factor, a tensor with the same shape as `sequence_lengths`. """ penalty_factor = tf.convert_to_tensor(penalty_factor, name="penalty_factor", dtype=dtype) penalty_factor.set_shape(()) # penalty should be a scalar. static_penalty = tf.contrib.util.constant_value(penalty_factor) if static_penalty is not None and static_penalty == 0: return 1.0 length_penalty_const = 5.0 mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_CONST, value=length_penalty_const) return tf.div((length_penalty_const + tf.cast(sequence_lengths, dtype))**penalty_factor, (length_penalty_const + 1.)**penalty_factor)
def _compute_loss(self, logits, label_smoothing): """Compute optimization loss.""" mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, value=label_smoothing) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value="Cross Entropy with label smoothing") target_output = self.features["target_output"] if self.time_major: target_output = tf.transpose(target_output) max_time = self.get_max_time(target_output) self.batch_seq_len = max_time crossent = self._softmax_cross_entropy_loss(logits, target_output, label_smoothing) assert crossent.dtype == tf.float32 target_weights = tf.sequence_mask( self.features["target_sequence_length"], max_time, dtype=crossent.dtype) if self.time_major: # [time, batch] if time_major, since the crossent is [time, batch] in this # case. target_weights = tf.transpose(target_weights) loss = tf.reduce_sum(crossent * target_weights) / tf.to_float( self.batch_size) return loss
def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False, len_norm_factor=0.6, len_norm_const=5, cov_penalty_factor=0.1): self.model = model self.cuda = cuda self.beam_size = beam_size self.max_seq_len = max_seq_len self.len_norm_factor = len_norm_factor self.len_norm_const = len_norm_const self.cov_penalty_factor = cov_penalty_factor self.batch_first = self.model.batch_first mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_BEAM_SIZE, value=self.beam_size) mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_MAX_SEQ_LEN, value=self.max_seq_len) mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_CONST, value=self.len_norm_const) mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_FACTOR, value=self.len_norm_factor) mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_COV_PENALTY_FACTOR, value=self.cov_penalty_factor)
def get_infer_iterator(src_dataset, src_vocab_table, batch_size, eos, src_max_len=None, use_char_encode=False): """Get dataset for inference.""" # Totol number of examples in src_dataset # (3003 examples + 69 padding examples). mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, value=3003) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL) if use_char_encode: src_eos_id = vocab_utils.EOS_CHAR_ID else: src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values) if use_char_encode: # Convert the word strings to character ids src_dataset = src_dataset.map( lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1])) else: # Convert the word strings to ids src_dataset = src_dataset.map( lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) # Add in the word counts. if use_char_encode: src_dataset = src_dataset.map( lambda src: (src, tf.to_int32( tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN))) else: src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) def batching_func(x): return x.padded_batch( batch_size, # The entry is the source line rows; # this has unknown-length vectors. The last entry is # the source row size; this is a scalar. padded_shapes=( tf.TensorShape([src_max_len]), # src tf.TensorShape([])), # src_len # Pad the source sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=( src_eos_id, # src 0), drop_remainder=True) # src_len -- unused batched_dataset = batching_func(src_dataset) batched_dataset = batched_dataset.map( lambda src_ids, src_seq_len: ( {"source": src_ids, "source_sequence_length": src_seq_len})) return batched_dataset
def train_fn(hparams, num_workers): """Copy of train function from estimator.py.""" # TODO: Merge improvements into the original. # pylint: disable=protected-access hparams.tgt_sos_id, hparams.tgt_eos_id = nmt_estimator._get_tgt_sos_eos_id( hparams) model_fn = nmt_estimator.make_model_fn(hparams) def print_log(): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=0) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) if hparams.use_tpu_low_level_api: runner = create_train_runner(hparams, num_workers) mlperf_log.gnmt_print(key=mlperf_log.RUN_START) input_fn = DistributedPipeline(hparams, num_workers) runner.initialize(input_fn, {}) runner.build_model(model_fn, {}) print_log() runner.train(0, hparams.num_train_steps) return 0.0 # cluster = tf.contrib.cluster_resolver.TPUClusterResolver(hparams.tpu_name) # cluster_spec = cluster.cluster_spec() # print('cluster_spec: %s' % cluster_spec) # num_workers = cluster_spec.num_tasks('tpu_worker') # print('num_workers: %s' % num_workers) pipeline = DistributedPipeline(hparams, num_workers) print_log() if hparams.use_tpu: run_config = nmt_estimator._get_tpu_run_config(hparams, True) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, config=run_config, use_tpu=hparams.use_tpu, train_batch_size=hparams.batch_size, eval_batch_size=hparams.batch_size, predict_batch_size=hparams.infer_batch_size) else: raise ValueError("Distributed input pipeline only supported on TPUs.") hooks = [pipeline] if hparams.use_async_checkpoint: hooks.append( async_checkpoint.AsyncCheckpointSaverHook( checkpoint_dir=hparams.out_dir, save_steps=int(hparams.num_examples_per_epoch / hparams.batch_size))) estimator.train(input_fn=pipeline, max_steps=hparams.num_train_steps, hooks=hooks) # Return value is not used return 0.0
def _input_fn(params): """Input function.""" if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: if "context" in params: batch_size = params["batch_size"] num_hosts = params["context"].num_hosts # TODO(dehao): update to use current_host once available in API. current_host = params["context"].current_input_fn_deployment( )[1] else: num_hosts = 1 current_host = 0 batch_size = hparams.batch_size mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=batch_size) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=hparams.src_max_len) return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=num_hosts, shard_index=current_host, reshuffle_each_iteration=True, use_char_encode=hparams.use_char_encode, filter_oversized_sequences=True) else: return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=hparams.infer_batch_size, eos=hparams.eos, src_max_len=hparams.src_max_len, use_char_encode=hparams.use_char_encode)
def __iter__(self): mlperf_log.gnmt_print(key=mlperf_log.INPUT_ORDER) # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) # generate permutation indices = torch.randperm(self.data_len, generator=g) # make indices evenly divisible by (batch_size * world_size) indices = indices[:self.num_samples] if self.bucket: # begin shards batches_in_shard = 80 shard_size = self.global_batch_size * batches_in_shard nshards = (self.num_samples + shard_size - 1) // shard_size lengths = self.dataset.lengths[indices] shards = [ indices[i * shard_size:(i + 1) * shard_size] for i in range(nshards) ] len_shards = [ lengths[i * shard_size:(i + 1) * shard_size] for i in range(nshards) ] indices = [] for len_shard in len_shards: _, ind = len_shard.sort() indices.append(ind) output = tuple(shard[idx] for shard, idx in zip(shards, indices)) indices = torch.cat(output) # global reshuffle indices = indices.view(-1, self.global_batch_size) order = torch.randperm(indices.shape[0], generator=g) indices = indices[order, :] indices = indices.view(-1) # end shards assert len(indices) == self.num_samples # build indices for each individual worker # ranks are getting consecutive batches, # default pytorch DistributedSampler assigns strided batches # with offset = length / world_size indices = indices.view(-1, self.batch_size) indices = indices[self.rank::self.world_size].contiguous() indices = indices.view(-1) indices = indices.tolist() assert len(indices) == self.num_samples // self.world_size return iter(indices)
def get_metric(hparams, predictions, current_step): """Run inference and compute metric.""" predicted_ids = [] for prediction in predictions: predicted_ids.append(prediction["predictions"]) mlperf_log.gnmt_print(key=mlperf_log.EVAL_SIZE, value=hparams.examples_to_infer) if hparams.examples_to_infer < len(predicted_ids): predicted_ids = predicted_ids[0:hparams.examples_to_infer] translations = _convert_ids_to_strings(hparams.tgt_vocab_file, predicted_ids) trans_file = os.path.join( hparams.out_dir, "newstest2014_out_{}.tok.de".format(current_step)) trans_dir = os.path.dirname(trans_file) if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) tf.logging.info("Writing to file %s" % trans_file) with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file, mode="wb")) as trans_f: trans_f.write("") # Write empty string to ensure file is created. for translation in translations: sentence = nmt_utils.get_translation( translation, tgt_eos=hparams.eos, subword_option=hparams.subword_option) trans_f.write((sentence + b"\n").decode("utf-8")) # Evaluation output_dir = os.path.join(hparams.out_dir, "eval_{}".format(hparams.test_year)) tf.gfile.MakeDirs(output_dir) summary_writer = tf.summary.FileWriter(output_dir) ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) metric = "bleu" if hparams.use_borg: score = evaluation_utils.evaluate(ref_file, trans_file, metric, hparams.subword_option) else: score = get_sacrebleu(trans_file, hparams.detokenizer_file, hparams.test_year) with tf.Graph().as_default(): summaries = [] summaries.append(tf.Summary.Value(tag=metric, simple_value=score)) tf_summary = tf.Summary(value=list(summaries)) summary_writer.add_summary(tf_summary, current_step) with tf.gfile.Open(os.path.join(output_dir, 'bleu'), 'w') as f: f.write('{}\n'.format(score)) misc_utils.print_out(" %s: %.1f" % (metric, score)) summary_writer.close() return score
def create_train_runner_and_build_graph(hparams, model_fn): runner = create_train_runner(hparams) mlperf_log.gnmt_print(key=mlperf_log.RUN_START) input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) params = { "batch_size": int(hparams.batch_size / hparams.num_shards), } runner.initialize(input_fn, params) runner.build_model(model_fn, params) return runner
def create_train_runner(hparams, num_workers): params = {} steps_per_epoch = int(hparams.num_examples_per_epoch / hparams.batch_size) return low_level_runner.TrainLowLevelRunner(iterations=steps_per_epoch, hparams=hparams, per_host_v1=True) input_fn = DistributedPipeline(hparams, num_workers) runner.initialize(input_fn, params) mlperf_log.gnmt_print(key=mlperf_log.RUN_START) runner.build_model(model_fn, params) return runner
def _set_train_or_infer(self, res, hparams): """Set up training.""" if self.mode == tf.contrib.learn.ModeKeys.INFER: self.predicted_ids = res[1] params = tf.trainable_variables() # Gradients and SGD update operation for training the model. # Arrange for the embedding vars to appear at the beginning. if self.mode == tf.contrib.learn.ModeKeys.TRAIN: loss = res[0] self.loss = loss mlperf_log.gnmt_print(key=mlperf_log.OPT_LR, value=hparams.learning_rate) if hparams.lottery_force_learning_rate is not None: self.learning_rate = lottery.get_lr_tensor(hparams.values()) else: self.learning_rate = tf.constant(hparams.learning_rate) # warm-up self.learning_rate = self._get_learning_rate_warmup(hparams) # decay self.learning_rate = self._get_learning_rate_decay(hparams) # Optimizer mlperf_log.gnmt_print(key=mlperf_log.OPT_NAME, value=hparams.optimizer) if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif hparams.optimizer == "adam": mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=0.9) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=0.999) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=1e-8) opt = tf.train.AdamOptimizer(self.learning_rate) else: raise ValueError("Unknown optimizer type %s" % hparams.optimizer) if hparams.use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) # Gradients gradients = tf.gradients( loss, params, colocate_gradients_with_ops=hparams.colocate_gradients_with_ops) clipped_grads, grad_norm = model_helper.gradient_clip(gradients, max_gradient_norm=hparams.max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_grads, params), global_step=self.global_step) # Print trainable variables utils.print_out("# Trainable variables") utils.print_out("Format: <name>, <shape>, <(soft) device placement>") for param in params: utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), param.op.device))
def gnmt_print(*args, **kwargs): """ Wrapper for MLPerf compliance logging calls. All arguments but 'sync' are passed to mlperf_log.gnmt_print function. If 'sync' is set to True then the wrapper will synchronize all distributed workers. 'sync' should be set to True for all compliance tags that require accurate timing (RUN_START, RUN_STOP etc.) """ if kwargs.pop('sync'): barrier() if get_rank() == 0: kwargs['stack_offset'] = 2 mlperf_log.gnmt_print(*args, **kwargs)
def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, dtype=None): """Build a multi-layer RNN cell that can be used by encoder.""" mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=hparams.dropout) return model_helper.create_rnn_cell( unit_type=hparams.unit_type, num_units=self.num_units, num_layers=num_layers, num_residual_layers=num_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, mode=self.mode, dtype=dtype, single_cell_fn=self.single_cell_fn, use_block_lstm=hparams.use_block_lstm)
def _compute_loss(self, theta, _, inputs): logits = tf.cast( tf.matmul(tf.slice(inputs, [0, 0], [512, self.num_units]), theta), tf.float32) target = tf.cast( tf.reshape(tf.slice(inputs, [0, self.num_units], [512, 1]), [-1]), tf.int32) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, value=self.label_smoothing) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value="Cross Entropy with label smoothing") crossent = tf.losses.softmax_cross_entropy( tf.one_hot(target, self.tgt_vocab_size, dtype=logits.dtype), logits, label_smoothing=self.label_smoothing, reduction=tf.losses.Reduction.NONE) crossent = tf.where(target == self.eos_id, tf.zeros_like(crossent), crossent) return tf.reshape(crossent, [-1]), []
def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, fast_reverse=False, reverse=False): """Build a multi-layer RNN cell that can be used by encoder.""" mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=hparams.dropout) return model_helper.create_rnn_cell( unit_type=hparams.unit_type, num_units=self.num_units, num_layers=num_layers, num_residual_layers=num_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, mode=self.mode, single_cell_fn=self.single_cell_fn, global_step=self.global_step, fast_reverse=fast_reverse, seq_len=self.features["source_sequence_length"] if reverse else None)
def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True, dropout=0.2, batch_first=False, math='fp32', share_embedding=False): super(GNMT, self).__init__(batch_first=batch_first) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, value=num_layers) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, value=hidden_size) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=dropout) if share_embedding: embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) else: embedder = None self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, embedder) self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, math, embedder)
def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True, dropout=0.2, batch_first=False, math='fp32', share_embedding=False): super(GNMT, self).__init__(batch_first=batch_first) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, value=num_layers) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, value=hidden_size) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=dropout) if share_embedding: embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) else: embedder = None #SSY 1 seq2seq/models/encoder.py only nn.LSTM at /opt/conda/lib/python3.6/site-packages/torch/nn/modules/rnn.py by bf16cut self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, embedder) # SSY 2 seq2seq/models/decoder.py torch.bmm nn.LSTM # nn.Linear /opt/conda/lib/python3.6/site-packages/torch/nn/modules/linear.py by bf16cut # torch.bmm /opt/conda/lib/python3.6/site-packages/torch/onnx/symbolic.py self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, math, embedder)
def get_optimizer(hparams, learning_rate): """docstring.""" mlperf_log.gnmt_print(key=mlperf_log.OPT_NAME, value=hparams.optimizer) if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(learning_rate) elif hparams.optimizer == "adam": mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=0.9) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=0.999) mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=1e-8) opt = tf.train.AdamOptimizer(learning_rate) else: raise ValueError("Unknown optimizer type %s" % hparams.optimizer) return opt
def build_criterion(vocab_size, padding_idx, smoothing): if smoothing == 0.: logging.info(f'building CrossEntropyLoss') loss_weight = torch.ones(vocab_size) loss_weight[padding_idx] = 0 criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value='Cross Entropy') else: logging.info(f'building SmoothingLoss (smoothing: {smoothing})') criterion = LabelSmoothing(padding_idx, smoothing) mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value='Cross Entropy with label smoothing') mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, value=smoothing) return criterion
def train_fn(hparams): """Train function.""" hparams.tgt_sos_id, hparams.tgt_eos_id = _get_tgt_sos_eos_id(hparams) model_fn = make_model_fn(hparams) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=0) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) if hparams.use_tpu_low_level_api: runner = create_train_runner_and_build_graph(hparams, model_fn) runner.train(0, hparams.num_train_steps) return 0.0 input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) if hparams.use_tpu: run_config = _get_tpu_run_config(hparams, True) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, config=run_config, use_tpu=hparams.use_tpu, train_batch_size=hparams.batch_size, eval_batch_size=hparams.batch_size, predict_batch_size=hparams.infer_batch_size) else: distribution_strategy = get_distribution_strategy(hparams.num_gpus) estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=hparams.out_dir, config=tf.estimator.RunConfig( train_distribute=distribution_strategy)) hooks = [] if hparams.use_async_checkpoint: hooks.append( async_checkpoint.AsyncCheckpointSaverHook( checkpoint_dir=hparams.out_dir, save_steps=int(hparams.num_examples_per_epoch / hparams.batch_size))) estimator.train(input_fn=input_fn, max_steps=hparams.num_train_steps, hooks=hooks) # Return value is not used return 0.0
def train_and_eval_fn(hparams, num_workers): """Train and evaluation function.""" # pylint: disable=protected-access hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2 model_fn = nmt_estimator.make_model_fn(hparams) pipeline = DistributedPipeline(hparams, num_workers) run_config = nmt_estimator._get_tpu_run_config(hparams) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, config=run_config, use_tpu=hparams.use_tpu, train_batch_size=hparams.batch_size, eval_batch_size=hparams.batch_size, predict_batch_size=hparams.infer_batch_size) score = 0.0 mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, value=hparams.target_bleu) for i in range(hparams.max_train_epochs): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i) tf.logging.info("Start training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) steps_per_epoch = int(hparams.num_examples_per_epoch / hparams.batch_size) max_steps = steps_per_epoch * (i + 1) estimator.train(input_fn=pipeline, max_steps=max_steps, hooks=[pipeline]) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT, value=("Under " + hparams.out_dir)) tf.logging.info("End training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.EVAL_START) score = nmt_estimator.get_metric(hparams, estimator) tf.logging.info("Score after epoch %d: %f", i, score) mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value=score) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i) if score >= hparams.target_bleu: mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True}) return score mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False}) return score
def main(): mlperf_log.ROOT_DIR_GNMT = os.path.dirname(os.path.abspath(__file__)) mlperf_log.LOGGER.propagate = False mlperf_log.gnmt_print(key=mlperf_log.RUN_START) args = exp.get_arguments(parse_args(), show=True) device = exp.get_device() chrono = exp.chrono() if not args.cudnn: torch.backends.cudnn.enabled = False # initialize distributed backend distributed = args.world_size > 1 if distributed: backend = 'nccl' if args.cuda else 'gloo' dist.init_process_group(backend=backend, rank=args.rank, init_method=args.dist_url, world_size=args.world_size) # create directory for results save_path = os.environ.get('OUTPUT_DIRECTORY') if save_path is None: save_path = '/tmp' if args.save is not None: save_path = os.path.join(args.results_dir, args.save) os.makedirs(save_path, exist_ok=True) # setup logging log_filename = f'log_gpu_{args.rank}.log' setup_logging(os.path.join(save_path, log_filename)) if args.cuda: torch.cuda.set_device(args.rank) # build tokenizer tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME)) train_data = ParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=args.max_size) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, value=len(train_data)) vocab_size = tokenizer.vocab_size mlperf_log.gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value=vocab_size) # build GNMT model model_config = dict(vocab_size=vocab_size, math=args.math, **literal_eval(args.model_config)) model = models.GNMT(**model_config) logging.info(model) batch_first = model.batch_first # define loss function (criterion) and optimizer criterion = build_criterion(vocab_size, config.PAD, args.smoothing) opt_config = literal_eval(args.optimization_config) # create trainer trainer_options = dict(criterion=criterion, grad_clip=args.grad_clip, save_path=save_path, save_freq=args.save_freq, save_info={ 'config': args, 'tokenizer': tokenizer }, opt_config=opt_config, batch_first=batch_first, keep_checkpoints=args.keep_checkpoints, math=args.math, print_freq=args.print_freq, cuda=args.cuda, distributed=distributed) trainer_options['model'] = model trainer = trainers.Seq2SeqTrainer(**trainer_options, number=args.number) translator = Translator(model, tokenizer, beam_size=args.beam_size, max_seq_len=args.max_length_val, len_norm_factor=args.len_norm_factor, len_norm_const=args.len_norm_const, cov_penalty_factor=args.cov_penalty_factor, cuda=args.cuda) num_parameters = sum([l.nelement() for l in model.parameters()]) # get data loaders train_loader = train_data.get_loader(batch_size=args.batch_size, batch_first=batch_first, shuffle=True, bucket=args.bucketing, num_workers=args.workers, drop_last=True, distributed=distributed) mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size * args.world_size) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=train_loader.sampler.num_samples) # training loop best_loss = float('inf') mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) for epoch in range(0, args.repeat): with chrono.time('train') as t: if distributed: train_loader.sampler.set_epoch(epoch) trainer.epoch = epoch train_loss = trainer.optimize(train_loader) exp.log_epoch_loss(train_loss) exp.show_eta(epoch, t) exp.report()
def _make_distributed_pipeline(hparams, num_hosts): """Makes the distributed input pipeline. make_distributed_pipeline must be used in the PER_HOST_V1 configuration. Note: we return both the input function and the hook because MultiDeviceIterator is not compatible with Estimator / TPUEstimator. Args: hparams: The hyperparameters to use. num_hosts: The number of hosts we're running across. Returns: A MultiDeviceIterator. """ # TODO: Merge with the original copy in iterator_utils.py. # pylint: disable=g-long-lambda,line-too-long global_batch_size = hparams.batch_size if global_batch_size % num_hosts != 0: raise ValueError( "global_batch_size (%s) must be a multiple of num_hosts (%s)" % (global_batch_size, num_hosts)) # Optionally choose from `choose_buckets` buckets simultaneously. if hparams.choose_buckets: window_batch_size = int(global_batch_size / hparams.choose_buckets) else: window_batch_size = global_batch_size per_host_batch_size = global_batch_size / num_hosts output_buffer_size = global_batch_size * 100 with tf.device("/job:worker/replica:0/task:0/device:CPU:0"): # From estimator.py src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file).prefetch( output_buffer_size) tgt_dataset = tf.data.TextLineDataset(tgt_file).prefetch( output_buffer_size) mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=global_batch_size) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=hparams.src_max_len) # Define local variables that are parameters in iterator_utils.make_input_fn sos = hparams.sos eos = hparams.eos random_seed = hparams.random_seed num_buckets = hparams.num_buckets src_max_len = hparams.src_max_len tgt_max_len = hparams.tgt_max_len num_parallel_calls = 100 # constant in iterator_utils.py skip_count = None # constant in estimator.py reshuffle_each_iteration = True # constant in estimator.py use_char_encode = hparams.use_char_encode filter_oversized_sequences = True # constant in estimator.py # From iterator_utils.py if use_char_encode: src_eos_id = vocab_utils.EOS_CHAR_ID else: src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SHARD, value=1) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) def map_fn_1(src, tgt): src = tf.string_split([src]).values tgt = tf.string_split([tgt]).values src_size = tf.size(src) tgt_size = tf.size(tgt) size_ok_bool = tf.logical_and(src_size > 0, tgt_size > 0) if filter_oversized_sequences: oversized = tf.logical_and(src_size < src_max_len, tgt_size < tgt_max_len) size_ok_bool = tf.logical_and(size_ok_bool, oversized) if src_max_len: src = src[:src_max_len] if tgt_max_len: tgt = tgt[:tgt_max_len] return (src, tgt, size_ok_bool) src_tgt_bool_dataset = src_tgt_dataset.map( map_fn_1, num_parallel_calls=num_parallel_calls) src_tgt_bool_dataset = src_tgt_bool_dataset.filter( lambda src, tgt, filter_bool: filter_bool) def map_fn_2(src, tgt, unused_filter_bool): if use_char_encode: src = tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]) tgt = tf.cast(tgt_vocab_table.lookup(tgt), tf.int32) else: src = tf.cast(src_vocab_table.lookup(src), tf.int32) tgt = tf.cast(tgt_vocab_table.lookup(tgt), tf.int32) # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>. tgt_in = tf.concat(([tgt_sos_id], tgt), 0) tgt_out = tf.concat((tgt, [tgt_eos_id]), 0) # Add in sequence lengths. if use_char_encode: src_len = tf.to_int32( tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN) else: src_len = tf.size(src) tgt_len = tf.size(tgt_in) return src, tgt_in, tgt_out, src_len, tgt_len # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING) src_tgt_dataset = src_tgt_bool_dataset.map( map_fn_2, num_parallel_calls=num_parallel_calls) src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) src_tgt_dataset = src_tgt_dataset.cache() src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration).repeat() # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) def batching_func(x): return x.padded_batch( window_batch_size, # The first three entries are the source and target line rows; # these have unknown-length vectors. The last two entries are # the source and target row sizes; these are scalars. padded_shapes=( tf.TensorShape([src_max_len]), # src tf.TensorShape([tgt_max_len]), # tgt_input tf.TensorShape([tgt_max_len]), # tgt_output tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len # Pad the source and target sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=( src_eos_id, # src tgt_eos_id, # tgt_input tgt_eos_id, # tgt_output 0, # src_len -- unused 0), # For TPU, must set drop_remainder to True or batch size will be None drop_remainder=True) # tgt_len -- unused def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): """Calculate bucket_width by maximum source sequence length.""" # Pairs with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) * bucket_width) words all go into the last bucket. if src_max_len: bucket_width = (src_max_len + num_buckets - 1) // num_buckets else: bucket_width = 10 # Bucket sentence pairs by the length of their source sentence and target # sentence. bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(num_buckets, bucket_id)) def reduce_func(unused_key, windowed_data): return batching_func(windowed_data) if num_buckets > 1: batched_dataset = src_tgt_dataset.apply( tf.contrib.data.group_by_window(key_func=key_func, reduce_func=reduce_func, window_size=window_batch_size)) else: batched_dataset = batching_func(src_tgt_dataset) batched_dataset = batched_dataset.map( lambda src, tgt_in, tgt_out, source_size, tgt_in_size: ({ "source": src, "target_input": tgt_in, "target_output": tgt_out, "source_sequence_length": source_size, "target_sequence_length": tgt_in_size })) re_batched_dataset = batched_dataset.apply( tf.contrib.data.unbatch()).batch(int(per_host_batch_size), drop_remainder=True) output_devices = [ "/job:worker/replica:0/task:%d/device:CPU:0" % i for i in range(num_hosts) ] options = tf.data.Options() options.experimental_numa_aware = True options.experimental_filter_fusion = True options.experimental_map_and_filter_fusion = True re_batched_dataset = re_batched_dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset=re_batched_dataset, devices=output_devices, max_buffer_size=10, prefetch_buffer_size=10, source_device="/job:worker/replica:0/task:0/device:CPU:0") return multi_device_iterator
bleu_score = estimator.eval_fn(infer_hparams) eval_end = time.time() utils.print_out("eval time for epoch %d: %.2f mins" % (epochs, (eval_end - eval_start) / 60.), f=sys.stderr) mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={"epoch": epochs, "value": bleu_score}) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=epochs) if FLAGS.debug or bleu_score > FLAGS.target_bleu: should_stop = True utils.print_out( "Stop job since target bleu is reached at epoch %d ." % epochs, f=sys.stderr) mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True}) if epochs >= FLAGS.max_train_epochs: should_stop = True utils.print_out("Stop job since max_train_epochs is reached.", f=sys.stderr) mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False}) epochs += 1 mlperf_log.gnmt_print(key=mlperf_log.RUN_FINAL) if __name__ == "__main__": nmt_parser = argparse.ArgumentParser() add_arguments(nmt_parser) FLAGS, unparsed = nmt_parser.parse_known_args() mlperf_log.gnmt_print(key=mlperf_log.RUN_START) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.use_fp16 and FLAGS.use_dist_strategy: raise ValueError("use_fp16 and use_dist_strategy aren't compatible") # Set up hacky envvars. # Hack that affects Defun in attention_wrapper.py active_xla_option_nums = np.sum([FLAGS.use_xla, FLAGS.use_autojit_xla, FLAGS.xla_compile]) if active_xla_option_nums > 1: raise ValueError( "Only one of use_xla, xla_compile, use_autojit_xla can be set") os.environ["use_xla"] = str(FLAGS.use_xla).lower() if FLAGS.use_xla: os.environ["use_defun"] = str(True).lower() else: os.environ["use_defun"] = str(FLAGS.use_defun).lower() utils.print_out("use_defun is %s for attention" % os.environ["use_defun"]) # TODO(jamesqin): retire this config after Cuda9.1 os.environ["use_fp32_batch_matmul"] = ("true" if FLAGS.use_fp32_batch_matmul else "false") os.environ["xla_compile"] = "true" if FLAGS.xla_compile else "false" os.environ["force_inputs_padding"] = ( "true" if FLAGS.force_inputs_padding else "false") if FLAGS.mode == "train": utils.print_out("Running training mode.") FLAGS.num_buckets = 5 default_hparams = create_hparams(FLAGS) run_main(FLAGS, default_hparams, estimator.train_fn) elif FLAGS.mode == "infer": utils.print_out("Running inference mode.") # Random random_seed = FLAGS.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed) np.random.seed(random_seed) tf.set_random_seed(random_seed) # Model output directory output_dir = FLAGS.output_dir if output_dir and not tf.gfile.Exists(output_dir): utils.print_out("# Creating output directory %s ..." % output_dir) tf.gfile.MakeDirs(output_dir) # Load hparams. default_hparams = create_hparams(FLAGS) default_hparams.num_buckets = 1 # The estimator model_fn is written in a way allowing train hparams to be # passed in infer mode. hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path) utils.print_out("infer_hparams:") utils.print_hparams(hparams) # Run evaluation when there's a new checkpoint for i, ckpt in enumerate( evaluation_utils.get_all_checkpoints(FLAGS.output_dir)): tf.logging.info("Starting to evaluate...") eval_start = time.time() bleu_score = estimator.eval_fn(hparams, ckpt) eval_end = time.time() utils.print_out("eval time for %d th ckpt: %.2f mins" % (i, (eval_end - eval_start) / 60.), f=sys.stderr) else: assert FLAGS.mode == "train_and_eval" utils.print_out("Running train and eval mode.") # Random random_seed = FLAGS.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed) np.random.seed(random_seed) tf.set_random_seed(random_seed) # Model output directory output_dir = FLAGS.output_dir if output_dir and not tf.gfile.Exists(output_dir): utils.print_out("# Creating output directory %s ..." % output_dir) tf.gfile.MakeDirs(output_dir) # Load hparams. default_hparams = create_hparams(FLAGS) default_hparams.num_buckets = 5 hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path) utils.print_out("training hparams:") utils.print_hparams(hparams) with tf.gfile.GFile(os.path.join(output_dir, "train_hparams.txt"), "w") as f: f.write(utils.serialize_hparams(hparams) + "\n") # The estimator model_fn is written in a way allowing train hparams to be # passed in infer mode. infer_hparams = tf.contrib.training.HParams(**hparams.values()) infer_hparams.num_buckets = 1 utils.print_out("infer_hparams:") utils.print_hparams(infer_hparams) with tf.gfile.GFile(os.path.join(output_dir, "infer_hparams.txt"), "w") as f: f.write(utils.serialize_hparams(infer_hparams) + "\n") epochs = 0 should_stop = epochs >= FLAGS.max_train_epochs mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, value=hparams.target_bleu) while not should_stop: utils.print_out("Starting epoch %d" % epochs) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=epochs) mlperf_log.gnmt_print( key=mlperf_log.INPUT_SIZE, value=iterator_utils.get_effective_train_epoch_size(hparams)) mlperf_log.gnmt_print( key=mlperf_log.TRAIN_CHECKPOINT, value=("Under " + hparams.output_dir)) try: train_start = time.time() estimator.train_fn(hparams) except tf.errors.OutOfRangeError: utils.print_out("training hits OutOfRangeError", f=sys.stderr) train_end = time.time() utils.print_out("training time for epoch %d: %.2f mins" % (epochs, (train_end - train_start) / 60.), f=sys.stderr) # This is probably sub-optimal, doing eval per-epoch mlperf_log.gnmt_print(key=mlperf_log.EVAL_START) eval_start = time.time() bleu_score = estimator.eval_fn(infer_hparams) eval_end = time.time() utils.print_out("eval time for epoch %d: %.2f mins" % (epochs, (eval_end - eval_start) / 60.), f=sys.stderr) mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={"epoch": epochs, "value": bleu_score}) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=epochs) if FLAGS.debug or bleu_score > FLAGS.target_bleu: should_stop = True utils.print_out( "Stop job since target bleu is reached at epoch %d ." % epochs, f=sys.stderr) mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True}) if epochs >= FLAGS.max_train_epochs: should_stop = True utils.print_out("Stop job since max_train_epochs is reached.", f=sys.stderr) mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False}) epochs += 1 mlperf_log.gnmt_print(key=mlperf_log.RUN_FINAL)
def gnmt_print(*args, **kwargs): barrier() if get_rank() == 0: kwargs['stack_offset'] = 2 mlperf_log.gnmt_print(*args, **kwargs)
def main(): mlperf_log.ROOT_DIR_GNMT = os.path.dirname(os.path.abspath(__file__)) mlperf_log.LOGGER.propagate = False mlperf_log.gnmt_print(key=mlperf_log.RUN_START) args = parse_args() print(args) if not args.cudnn: torch.backends.cudnn.enabled = False mlperf_log.gnmt_print(key=mlperf_log.RUN_SET_RANDOM_SEED) if args.seed: torch.manual_seed(args.seed + args.rank) # initialize distributed backend distributed = args.world_size > 1 if distributed: backend = 'nccl' if args.cuda else 'gloo' dist.init_process_group(backend=backend, rank=args.rank, init_method=args.dist_url, world_size=args.world_size) # create directory for results save_path = os.path.join(args.results_dir, args.save) os.makedirs(save_path, exist_ok=True) # setup logging log_filename = f'log_gpu_{args.rank}.log' setup_logging(os.path.join(save_path, log_filename)) logging.info(f'Saving results to: {save_path}') logging.info(f'Run arguments: {args}') if args.cuda: torch.cuda.set_device(args.rank) # build tokenizer tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME)) # build datasets mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=args.max_length_train) train_data = ParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=args.max_size) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, value=len(train_data)) val_data = ParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_VAL_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_VAL_FNAME), tokenizer=tokenizer, min_len=args.min_length_val, max_len=args.max_length_val, sort=True) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL) test_data = ParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TEST_FNAME), tokenizer=tokenizer, min_len=args.min_length_val, max_len=args.max_length_val, sort=False) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, value=len(test_data)) vocab_size = tokenizer.vocab_size mlperf_log.gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value=vocab_size) # build GNMT model model_config = dict(vocab_size=vocab_size, math=args.math, **literal_eval(args.model_config)) # SSY the real model # seq2seq/models/gnmt.py model = models.GNMT(**model_config) logging.info(model) batch_first = model.batch_first # define loss function (criterion) and optimizer criterion = build_criterion(vocab_size, config.PAD, args.smoothing) opt_config = literal_eval(args.optimization_config) logging.info(f'Training optimizer: {opt_config}') # create trainer trainer_options = dict( criterion=criterion, grad_clip=args.grad_clip, save_path=save_path, save_freq=args.save_freq, save_info={'config': args, 'tokenizer': tokenizer}, opt_config=opt_config, batch_first=batch_first, keep_checkpoints=args.keep_checkpoints, math=args.math, print_freq=args.print_freq, cuda=args.cuda, distributed=distributed) trainer_options['model'] = model # SSY only the trainer seq2seq/train/trainer.py # not the models trainer = trainers.Seq2SeqTrainer(**trainer_options) translator = Translator(model, tokenizer, beam_size=args.beam_size, max_seq_len=args.max_length_val, len_norm_factor=args.len_norm_factor, len_norm_const=args.len_norm_const, cov_penalty_factor=args.cov_penalty_factor, cuda=args.cuda) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info(f'Number of parameters: {num_parameters}') # optionally resume from a checkpoint if args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): checkpoint_file = os.path.join( checkpoint_file, 'model_best.pth') if os.path.isfile(checkpoint_file): trainer.load(checkpoint_file) else: logging.error(f'No checkpoint found at {args.resume}') # get data loaders train_loader = train_data.get_loader(batch_size=args.batch_size, batch_first=batch_first, shuffle=True, bucket=args.bucketing, num_workers=args.workers, drop_last=True, distributed=distributed) mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size * args.world_size) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=train_loader.sampler.num_samples) val_loader = val_data.get_loader(batch_size=args.eval_batch_size, batch_first=batch_first, shuffle=False, num_workers=args.workers, drop_last=False, distributed=False) test_loader = test_data.get_loader(batch_size=args.eval_batch_size, batch_first=batch_first, shuffle=False, num_workers=0, drop_last=False, distributed=False) mlperf_log.gnmt_print(key=mlperf_log.EVAL_SIZE, value=len(test_loader.sampler)) # training loop best_loss = float('inf') mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) for epoch in range(args.start_epoch, args.epochs): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=epoch) logging.info(f'Starting epoch {epoch}') if distributed: train_loader.sampler.set_epoch(epoch) trainer.epoch = epoch train_loss = trainer.optimize(train_loader) # evaluate on validation set if args.rank == 0 and not args.disable_eval: logging.info(f'Running validation on dev set') val_loss = trainer.evaluate(val_loader) # remember best prec@1 and save checkpoint is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT) trainer.save(save_all=args.save_all, is_best=is_best) logging.info(f'Epoch: {epoch}\t' f'Training Loss {train_loss:.4f}\t' f'Validation Loss {val_loss:.4f}') else: logging.info(f'Epoch: {epoch}\t' f'Training Loss {train_loss:.4f}') if args.cuda: break_training = torch.cuda.LongTensor([0]) else: break_training = torch.LongTensor([0]) if args.rank == 0 and not args.disable_eval: logging.info(f'Running evaluation on test set') mlperf_log.gnmt_print(key=mlperf_log.EVAL_START, value=epoch) model.eval() torch.cuda.empty_cache() eval_path = os.path.join(save_path, f'eval_epoch_{epoch}') eval_file = open(eval_path, 'w') for i, (src, tgt, indices) in enumerate(test_loader): src, src_length = src if translator.batch_first: batch_size = src.size(0) else: batch_size = src.size(1) beam_size = args.beam_size bos = [translator.insert_target_start] * (batch_size * beam_size) bos = torch.LongTensor(bos) if translator.batch_first: bos = bos.view(-1, 1) else: bos = bos.view(1, -1) src_length = torch.LongTensor(src_length) if args.cuda: src = src.cuda() src_length = src_length.cuda() bos = bos.cuda() with torch.no_grad(): context = translator.model.encode(src, src_length) context = [context, src_length, None] if beam_size == 1: generator = translator.generator.greedy_search else: generator = translator.generator.beam_search preds, lengths, counter = generator(batch_size, bos, context) preds = preds.cpu() lengths = lengths.cpu() output = [] for idx, pred in enumerate(preds): end = lengths[idx] - 1 pred = pred[1: end] pred = pred.tolist() out = translator.tok.detokenize(pred) output.append(out) output = [output[indices.index(i)] for i in range(len(output))] for line in output: eval_file.write(line) eval_file.write('\n') eval_file.close() # run moses detokenizer detok_path = os.path.join(args.dataset_dir, config.DETOKENIZER) detok_eval_path = eval_path + '.detok' with open(detok_eval_path, 'w') as detok_eval_file, \ open(eval_path, 'r') as eval_file: subprocess.run(['perl', f'{detok_path}'], stdin=eval_file, stdout=detok_eval_file, stderr=subprocess.DEVNULL) # run sacrebleu reference_path = os.path.join(args.dataset_dir, config.TGT_TEST_TARGET_FNAME) sacrebleu = subprocess.run([f'sacrebleu --input {detok_eval_path} \ {reference_path} --score-only -lc --tokenize intl'], stdout=subprocess.PIPE, shell=True) bleu = float(sacrebleu.stdout.strip()) logging.info(f'Finished evaluation on test set') logging.info(f'BLEU on test dataset: {bleu}') if args.target_bleu: if bleu >= args.target_bleu: logging.info(f'Target accuracy reached') break_training[0] = 1 torch.cuda.empty_cache() mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={"epoch": epoch, "value": bleu}) mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, value=args.target_bleu) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP) if distributed: dist.broadcast(break_training, 0) logging.info(f'Finished epoch {epoch}') if break_training: break mlperf_log.gnmt_print(key=mlperf_log.RUN_STOP, value={"success": bool(break_training)}) mlperf_log.gnmt_print(key=mlperf_log.RUN_FINAL)
def extend_hparams(hparams): """Add new arguments to hparams.""" # Sanity checks if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0: raise ValueError("For bi, num_encoder_layers %d should be even" % hparams.num_encoder_layers) if (hparams.attention_architecture in ["gnmt"] and hparams.num_encoder_layers < 2): raise ValueError("For gnmt attention architecture, " "num_encoder_layers %d should be >= 2" % hparams.num_encoder_layers) if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: raise ValueError("subword option must be either spm, or bpe") if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0: raise ValueError("beam_width must greater than 0 when using beam_search" "decoder.") if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0: raise ValueError("sampling_temperature must greater than 0.0 when using" "sample decoder.") # Different number of encoder / decoder layers assert hparams.num_encoder_layers and hparams.num_decoder_layers if hparams.num_encoder_layers != hparams.num_decoder_layers: hparams.pass_hidden_state = False utils.print_out("Num encoder layer %d is different from num decoder layer" " %d, so set pass_hidden_state to False" % ( hparams.num_encoder_layers, hparams.num_decoder_layers)) # Set residual layers num_encoder_residual_layers = 0 num_decoder_residual_layers = 0 if hparams.residual: if hparams.num_encoder_layers > 1: num_encoder_residual_layers = hparams.num_encoder_layers - 1 if hparams.num_decoder_layers > 1: num_decoder_residual_layers = hparams.num_decoder_layers - 1 if hparams.encoder_type == "gnmt": # The first unidirectional layer (after the bi-directional layer) in # the GNMT encoder can't have residual connection due to the input is # the concatenation of fw_cell and bw_cell's outputs. num_encoder_residual_layers = hparams.num_encoder_layers - 2 # Compatible for GNMT models if hparams.num_encoder_layers == hparams.num_decoder_layers: num_decoder_residual_layers = num_encoder_residual_layers _add_argument(hparams, "num_encoder_residual_layers", num_encoder_residual_layers) _add_argument(hparams, "num_decoder_residual_layers", num_decoder_residual_layers) # Language modeling if hparams.language_model: hparams.attention = "" hparams.attention_architecture = "" hparams.pass_hidden_state = False hparams.share_vocab = True hparams.src = hparams.tgt utils.print_out("For language modeling, we turn off attention and " "pass_hidden_state; turn on share_vocab; set src to tgt.") ## Vocab # Get vocab file names first if hparams.vocab_prefix: src_vocab_file = hparams.vocab_prefix + "." + hparams.src tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt else: raise ValueError("hparams.vocab_prefix must be provided.") # Source vocab src_vocab_size, src_vocab_file = vocab_utils.check_vocab( src_vocab_file, hparams.output_dir, check_special_token=hparams.check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) # Target vocab if hparams.share_vocab: utils.print_out(" using source vocab for target") tgt_vocab_file = src_vocab_file tgt_vocab_size = src_vocab_size else: tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( tgt_vocab_file, hparams.output_dir, check_special_token=hparams.check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) mlperf_log.gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value={"src": src_vocab_size, "tgt": tgt_vocab_size}) _add_argument(hparams, "src_vocab_size", src_vocab_size) _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size) _add_argument(hparams, "src_vocab_file", src_vocab_file) _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file) # Num embedding partitions _add_argument( hparams, "num_enc_emb_partitions", hparams.num_embeddings_partitions) _add_argument( hparams, "num_dec_emb_partitions", hparams.num_embeddings_partitions) # Pretrained Embeddings _add_argument(hparams, "src_embed_file", "") _add_argument(hparams, "tgt_embed_file", "") if hparams.embed_prefix: src_embed_file = hparams.embed_prefix + "." + hparams.src tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt if tf.gfile.Exists(src_embed_file): utils.print_out(" src_embed_file %s exist" % src_embed_file) hparams.src_embed_file = src_embed_file utils.print_out( "For pretrained embeddings, set num_enc_emb_partitions to 1") hparams.num_enc_emb_partitions = 1 else: utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file) if tf.gfile.Exists(tgt_embed_file): utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file) hparams.tgt_embed_file = tgt_embed_file utils.print_out( "For pretrained embeddings, set num_dec_emb_partitions to 1") hparams.num_dec_emb_partitions = 1 else: utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file) # Evaluation metric = "bleu" best_metric_dir = os.path.join(hparams.output_dir, "best_" + metric) tf.gfile.MakeDirs(best_metric_dir) _add_argument(hparams, "best_" + metric, 0, update=False) _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir) if hparams.avg_ckpts: best_metric_dir = os.path.join(hparams.output_dir, "avg_best_" + metric) tf.gfile.MakeDirs(best_metric_dir) _add_argument(hparams, "avg_best_" + metric, 0, update=False) _add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir) return hparams