def test(self, ts, steps=0, **kwargs): """Method that evaluates on some data. There are 2 modes this can run in, `feed_dict` and `dataset` In `feed_dict` mode, the model cycles the test data batch-wise and feeds each batch in with a `feed_dict`. In `dataset` mode, the data is still passed in to this method, but it is not passed in a `feed_dict` and is mostly superfluous since the features are grafted right onto the graph. However, we do use it for supplying the ground truth, ids and text, so it is essential that the caller does not shuffle the data :param ts: The test set :param conll_output: (`str`) An optional file output :param txts: A list of text data associated with the encoded batch :param dataset: (`bool`) Is this using `tf.dataset`s :return: The metrics """ SET_TRAIN_FLAG(False) total_correct = total_sum = 0 gold_spans = [] pred_spans = [] self.cm = ConfusionMatrix(self.idx2classlabel) handle = None if kwargs.get("conll_output") is not None and kwargs.get( 'txts') is not None: handle = open(kwargs.get("conll_output"), "w") try: pg = create_progress_bar(steps) metrics = {} for (features, y), batch in pg( zip_longest(ts, kwargs.get('batches', []), fillvalue={})): correct, count, golds, guesses = self.process_batch( features, y, handle=handle, txts=kwargs.get("txts"), ids=batch.get("ids")) total_correct += correct total_sum += count gold_spans.extend(golds) pred_spans.extend(guesses) total_acc = total_correct / float(total_sum) # Only show the fscore if requested metrics['tagging_f1'] = span_f1(gold_spans, pred_spans) metrics['tagging_acc'] = total_acc metrics.update({ f"classification_{k}": v for k, v in self.cm.get_all_metrics().items() }) if self.verbose: conll_metrics = per_entity_f1(gold_spans, pred_spans) conll_metrics['acc'] = total_acc * 100 conll_metrics['tokens'] = total_sum logger.info(conlleval_output(conll_metrics)) finally: if handle is not None: handle.close() return metrics
def train(self, ts, reporting_fns): """Train by looping over the steps For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s in the model (and `dataset` is `True`). For `feed_dict`, we convert the ts samples to `feed_dict`s and hand them in one-by-one :param ts: The training set :param reporting_fns: A list of reporting hooks :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ SET_TRAIN_FLAG(True) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.perf_counter() start = time.perf_counter() @tf.function def _train_step(features, y): """Replicated training step.""" loss = self.optimizer.update(self.model, features, y) toks = self._num_toks(features['tgt_len']) report_loss = loss * tf.cast(toks, tf.float32) return report_loss, toks with autograph_options({ "function_optimization": False, "layout_optimizer": False }): for features, y in ts: features['dst'] = y[:, :-1] step_report_loss, step_toks = _train_step(features, y) epoch_loss.assign_add(step_report_loss) nstep_loss.assign_add(step_report_loss) epoch_div.assign_add(step_toks) nstep_div.assign_add(step_toks) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.perf_counter() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) self.train_epochs += 1 self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH', reporting_fns) return metrics
def _train(self, loader, steps=0, **kwargs): """Train an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. We do use a `feed_dict` for passing the `TRAIN_FLAG` in either case :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use :return: Metrics """ SET_TRAIN_FLAG(True) reporting_fns = kwargs.get('reporting_fns', []) pg = create_progress_bar(steps) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.perf_counter() @tf.function def _train_step(inputs): features, y = inputs loss = self.optimizer.update(self.model, features, y) batchsz = get_shape_as_list(y)[0] report_loss = loss * batchsz return report_loss, batchsz with autograph_options({ "function_optimization": False, "layout_optimizer": False }): for inputs in pg(loader): step_report_loss, step_batchsz = _train_step(inputs) epoch_loss.assign_add(step_report_loss) nstep_loss.assign_add(step_report_loss) epoch_div.assign_add(step_batchsz) nstep_div.assign_add(step_batchsz) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.perf_counter() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) return metrics
def get_input(self, training=False): SET_TRAIN_FLAG(training) dataset = tf.data.Dataset.from_tensor_slices((self.x, self.y)) dataset = dataset.shuffle(buffer_size=SHUF_BUF_SZ) dataset = dataset.batch(50) dataset = dataset.map(lambda x, y: ({ 'word': x, 'lengths': count_nonzero(x, axis=1) }, y)) dataset = dataset.prefetch(NUM_PREFETCH) return dataset
def test_windowed_ra(): num_heads = 4 d_model = 64 rpr_k = 1 batchsize = 2 nctx = 256 d_k = d_model // num_heads with tf.device("/cpu:0"): old = SeqScaledDotProductRelativeAttention(pdrop=0.) new = SeqScaledWindowedRelativeAttention(pdrop=0.) rpr_key_emb = tf.keras.layers.Embedding(2 * rpr_k + 1, d_k) rpr_value_emb = tf.keras.layers.Embedding(2 * rpr_k + 1, d_k) Q = tf.random.normal([batchsize, num_heads, nctx, d_k]) K = tf.random.normal([batchsize, num_heads, nctx, d_k]) V = tf.random.normal([batchsize, num_heads, nctx, d_k]) lengths = tf.random.uniform([ batchsize, ], 0, nctx, dtype=tf.int32) seq_mask = tf.sequence_mask(lengths, maxlen=nctx, dtype=tf.float32) in_mask = tf.expand_dims(tf.expand_dims(seq_mask, 1), 1) out_mask = tf.expand_dims(tf.expand_dims(seq_mask, 1), -1) # manually create a ra_mask to prevent attention beyond rpr_k ones = tf.ones([nctx, nctx]) ra_mask = tf.linalg.band_part(ones, rpr_k, rpr_k) mask = in_mask * tf.expand_dims(tf.expand_dims(ra_mask, 0), 0) rpr_key_old, rpr_value_old = make_rpr(rpr_key_emb, rpr_value_emb, rpr_k, nctx) SET_TRAIN_FLAG(False) out_old = old((Q, K, V, rpr_key_old, rpr_value_old, mask)) out_old = masked_fill(out_old, tf.equal(out_mask, 0), 1) print(out_old.shape) # using the windowed relative attention with the original sequence mask rpr_key_new, rpr_value_new = unfold_rpr(rpr_key_emb, rpr_value_emb, rpr_k) out_new = new((Q, K, V, rpr_key_new, rpr_value_new, in_mask)) out_new = masked_fill(out_new, tf.equal(out_mask, 0), 1) print(out_new.shape) if get_version(tf) < 2: with tf.compat.v1.Session() as sess: out_old, out_new = sess.run([out_old, out_new]) else: out_old, out_new = out_old.numpy(), out_new.numpy() assert np.allclose(out_old, out_new, atol=1e-6)
def test(self, vs, reporting_fns, phase='Valid', **kwargs): """Run an epoch of testing over the dataset If we are using a `tf.dataset`-based `fit_func`, we will just cycle the number of steps and let the `dataset` yield new batches. If we are using `feed_dict`s, we convert each batch from the `DataFeed` and pass that into TF as the `feed_dict` :param vs: A validation set :param reporting_fns: Reporting hooks :param phase: The phase of evaluation (`Test`, `Valid`) :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ SET_TRAIN_FLAG(False) if phase == 'Test': return self._evaluate(vs, reporting_fns, **kwargs) self.valid_epochs += 1 total_loss = 0 total_toks = 0 preds = [] golds = [] start = time.perf_counter() for features, tgt in vs: features['dst'] = tgt[:, :-1] top_preds = self.model.predict(features, beam=1, make_input=False)[0] loss_value = self.loss(self.model, features, tgt).numpy() toks = tf.cast(self._num_toks(features['tgt_len']), tf.float32).numpy() total_loss += loss_value * toks total_toks += toks preds.extend( convert_seq2seq_preds(top_preds[:, 0, :], self.tgt_rlut)) golds.extend( convert_seq2seq_golds(tgt, features['tgt_len'], self.tgt_rlut)) metrics = self.calc_metrics(total_loss, total_toks) metrics['bleu'] = bleu(preds, golds, self.bleu_n_grams)[0] self.report(self.valid_epochs, metrics, start, phase, 'EPOCH', reporting_fns) return metrics
def _test(self, loader, steps=0, **kwargs): """Test an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on :return: Metrics """ metrics = [LAS(), UAS(), LCM(), UCM()] pg = create_progress_bar(steps) SET_TRAIN_FLAG(False) for features, y in pg(loader): heads_gold, labels_gold = y greedy_heads_pred, greedy_labels_pred = self.model.decode(features) B, T = get_shape_as_list(greedy_labels_pred)[:2] labels_gold_trimmed = labels_gold[:, :T].numpy() heads_gold_trimmed = heads_gold[:, :T].numpy() for i in range(B): for m in metrics: if self.punct_eval is False: labels_gold_trimmed[i] = masked_fill( labels_gold_trimmed[i], labels_gold_trimmed[i] == self.model.punct, Offsets.PAD) m.add(greedy_heads_pred[i], heads_gold_trimmed[i], greedy_labels_pred[i], labels_gold_trimmed[i]) metrics = {m.name: m.score for m in metrics} return metrics
def test(self, vs, reporting_fns, phase): """Run an epoch of testing over the dataset If we are using a `tf.dataset`-based `fit_func`, we will just cycle the number of steps and let the `dataset` yield new batches. If we are using `feed_dict`s, we convert each batch from the `DataFeed` and pass that into TF as the `feed_dict` :param vs: A validation set :param reporting_fns: Reporting hooks :param phase: The phase of evaluation (`Test`, `Valid`) :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ total_loss = 0.0 total_toks = 0 epochs = 0 if phase == 'Valid': self.valid_epochs += 1 epochs = self.valid_epochs SET_TRAIN_FLAG(False) start = time.perf_counter() h = None for features, y in vs: if self.model.requires_state: loss_value, h = loss_with_state(self.model, h, features, y) else: loss_value = loss_without_state(self.model, features, y) loss_value = loss_value.numpy() toks = self._num_toks(y) total_loss += loss_value * tf.cast(toks, tf.float32).numpy() total_toks += toks.numpy() metrics = self.calc_metrics(total_loss, total_toks) self.report(epochs, metrics, start, phase, 'EPOCH', reporting_fns) return metrics
def train(self, ts, reporting_fns, steps=0, dataset=True): """Train by looping over the steps For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s in the model (and `dataset` is `True`). For `feed_dict`, we convert the ts samples to `feed_dict`s and hand them in one-by-one :param ts: The training set :param reporting_fns: A list of reporting hooks :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ strategy = self.strategy def _replicated_train_step_no_state(inputs): features, y = inputs per_replica_loss = self.optimizer.update(self.model, features, y) per_replica_toks = self._num_toks(y) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return per_replica_report_loss, per_replica_toks def _replicated_train_step_with_state(inputs, hidden): features, y = inputs per_replica_loss, new_hidden = self.optimizer.update_with_hidden( self.model, hidden, features, y) per_replica_toks = self._num_toks(y) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return new_hidden, per_replica_report_loss, per_replica_toks with strategy.scope(): train_iter = iter(ts) SET_TRAIN_FLAG(True) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.time() start = time.time() @tf.function def _distributed_train_no_state(inputs): per_replica_loss, per_replica_toks = strategy.experimental_run_v2( _replicated_train_step_no_state, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None), strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None) @tf.function def _distributed_train_with_state(inputs, hidden): h, per_replica_loss, per_replica_toks = strategy.experimental_run_v2( _replicated_train_step_with_state, args=( inputs, hidden, )) step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) step_toks = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None) return h, step_loss, step_toks h = None for i in range(steps): inputs = next(train_iter) if self.model.requires_state: h, step_loss, step_toks = _distributed_train_with_state( inputs, h) else: step_loss, step_toks = _distributed_train_no_state(inputs) epoch_loss.assign_add(step_loss) nstep_loss.assign_add(step_loss) epoch_div.assign_add(step_toks) nstep_div.assign_add(step_toks) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.time() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) self.train_epochs += 1 self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH', reporting_fns) return metrics
def fit_eager_distributed(model_params, ts, vs, es=None, **kwargs): """ Train an language model using TensorFlow with `tf.dataset`. This is the default behavior for training. :param model_params: The model (or parameters to create the model) to train :param ts: A training data set :param vs: A validation data set :param es: A test data set, can be None :param kwargs: See below :Keyword Arguments: * *do_early_stopping* (``bool``) -- Stop after evaluation data is no longer improving. Defaults to True * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on * *epochs* (``int``) -- how many epochs. Default to 20 * *outfile* -- Model output file, defaults to classifier-model.pyth * *patience* -- How many epochs where evaluation is no longer improving before we give up * *reporting* -- Callbacks which may be used on reporting updates * *nsteps* (`int`) -- If we should report every n-steps, this should be passed * *ema_decay* (`float`) -- If we are doing an exponential moving average, what decay to us4e * *clip* (`int`) -- If we are doing gradient clipping, what value to use * *optim* (`str`) -- The name of the optimizer we are using * *lr* (`float`) -- The learning rate we are using * *mom* (`float`) -- If we are using SGD, what value to use for momentum * *beta1* (`float`) -- Adam-specific hyper-param, defaults to `0.9` * *beta2* (`float`) -- Adam-specific hyper-param, defaults to `0.999` * *epsilon* (`float`) -- Adam-specific hyper-param, defaults to `1e-8 :return: None """ epochs = int(kwargs.get('epochs', 5)) patience = int(kwargs.get('patience', epochs)) model_file = get_model_file('lm', 'tf', kwargs.get('basedir')) do_early_stopping = bool(kwargs.get('do_early_stopping', True)) best_metric = 0 if do_early_stopping: early_stopping_metric = kwargs.get('early_stopping_metric', 'avg_loss') early_stopping_cmp, best_metric = get_metric_cmp( early_stopping_metric, kwargs.get('early_stopping_cmp')) patience = kwargs.get('patience', epochs) print('Doing early stopping on [%s] with patience [%d]' % (early_stopping_metric, patience)) reporting_fns = listify(kwargs.get('reporting', [])) print('reporting', reporting_fns) batchsz = kwargs['batchsz'] test_batchsz = kwargs.get('test_batchsz', batchsz) tgt_key = model_params.get('tgt_key') train_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(ts)) train_dataset = train_dataset.shuffle(buffer_size=SHUF_BUF_SZ) train_dataset = train_dataset.batch(batchsz, drop_remainder=True) train_dataset = train_dataset.prefetch(NUM_PREFETCH) valid_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(vs)) valid_dataset = valid_dataset.batch(batchsz, drop_remainder=True) valid_dataset = valid_dataset.prefetch(NUM_PREFETCH) trainer = LanguageModelTrainerDistributedTf(model_params, **kwargs) train_dataset = trainer.distribute(train_dataset) valid_dataset = trainer.distribute(valid_dataset) last_improved = 0 SET_TRAIN_FLAG(True) for epoch in range(epochs): trainer.train(train_dataset, reporting_fns, steps=len(ts)) test_metrics = trainer.test(valid_dataset, reporting_fns, phase='Valid', steps=len(vs)) if do_early_stopping is False: trainer.checkpoint() trainer.model.save(model_file) elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric): last_improved = epoch best_metric = test_metrics[early_stopping_metric] print('New best %.3f' % best_metric) trainer.checkpoint() trainer.model.save(model_file) elif (epoch - last_improved) > patience: print('Stopping due to persistent failures to improve') break if do_early_stopping is True: print('Best performance on %s: %.3f at epoch %d' % (early_stopping_metric, best_metric, last_improved)) if es is not None: print('Reloading best checkpoint') trainer.recover_last_checkpoint() trainer.strategy = tf.distribute.OneDeviceStrategy('/device:GPU:0') test_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(es)) test_dataset = test_dataset.batch(test_batchsz, drop_remainder=False) test_dataset = test_dataset.prefetch(NUM_PREFETCH) test_dataset = trainer.distribute(test_dataset) trainer.test(test_dataset, reporting_fns, phase='Test', steps=len(es))
def test(self, vs, reporting_fns, phase, steps=0): """Run an epoch of testing over the dataset If we are using a `tf.dataset`-based `fit_func`, we will just cycle the number of steps and let the `dataset` yield new batches. If we are using `feed_dict`s, we convert each batch from the `DataFeed` and pass that into TF as the `feed_dict` :param vs: A validation set :param reporting_fns: Reporting hooks :param phase: The phase of evaluation (`Test`, `Valid`) :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ strategy = self.strategy def _replicated_test_step_no_state(inputs): features, y = inputs per_replica_loss = loss_without_state(self.model, features, y) per_replica_toks = self._num_toks(y) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return per_replica_report_loss, per_replica_toks def _replicated_test_step_with_state(inputs, hidden): features, y = inputs per_replica_loss, new_hidden = loss_with_state( self.model, hidden, features, y) per_replica_toks = self._num_toks(y) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return new_hidden, per_replica_report_loss, per_replica_toks with strategy.scope(): SET_TRAIN_FLAG(False) test_iter = iter(vs) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.time() start = time.time() @tf.function def _distributed_test_no_state(inputs): per_replica_loss, per_replica_toks = strategy.experimental_run_v2( _replicated_test_step_no_state, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None), strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None) @tf.function def _distributed_test_with_state(inputs, hidden): h, per_replica_loss, per_replica_toks = strategy.experimental_run_v2( _replicated_test_step_with_state, args=( inputs, hidden, )) step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) step_toks = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None) return h, step_loss, step_toks epochs = 0 if phase == 'Valid': self.valid_epochs += 1 epochs = self.valid_epochs h = None for i in range(steps): inputs = next(test_iter) if self.model.requires_state: h, per_replica_loss, per_replica_toks = _distributed_test_with_state( inputs, h) else: per_replica_loss, per_replica_toks = _distributed_test_no_state( inputs) epoch_loss.assign_add(per_replica_loss) epoch_div.assign_add(per_replica_toks) metrics = self.calc_metrics(epoch_loss.numpy(), epoch_div.numpy()) self.report(epochs, metrics, start, phase, 'EPOCH', reporting_fns) return metrics
bt_x_v = tf.nn.log_softmax(tf.reshape(logits, [-1, vsz]), axis=-1) one_hots = tf.one_hot(targets, vsz) example_loss = -tf.reduce_sum(one_hots * bt_x_v, axis=-1) loss = tf.reduce_mean(example_loss) return loss, h optimizer = EagerOptimizer(loss, optim="adam", lr=args.lr) for epoch in range(args.epochs): loss_accum = 0. step = 0 start = time.time() h = None SET_TRAIN_FLAG(True) for x, y in train_input_fn(): # Optimize the model loss_value, h = optimizer.update_with_hidden(model, h, x, y) loss_accum += loss_value step += 1 print('training time {}'.format(time.time() - start)) mean_loss = loss_accum / step print('Training Loss {}, Perplexity {}'.format(mean_loss, np.exp(mean_loss))) step = 0 loss_accum = 0 SET_TRAIN_FLAG(False)
def main(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_dir", type=str, required=True, help='Training directory') parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory') parser.add_argument( "--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`") parser.add_argument( "--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`") parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"]) parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=False) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece"], default="bpe") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--ffn_pdrop", type=float, default=0.0, help="Dropout in the dense stack") parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply") parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--causal", type=str2bool, default=False, help="Use CLM (causal) instead of MLM") parser.add_argument("--mlp", type=str2bool, default=False, help="Use Gated MLP") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument( '--rpr_value_on', type=str2bool, default=True, help= "In relative attention, whether add positional correction to values in addition to the " "correction to attention matrix") parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"]) parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False) parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False) parser.add_argument( "--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False) parser.add_argument("--extra_tokens", help="What extra tokens should we use", nargs="+", default=["[CLS]", "[MASK]"]) args = parser.parse_args() SET_TRAIN_FLAG(True) if args.convert_only: args.restart = True if args.basedir is None: args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}' logging.basicConfig(level=logging.INFO) logger.info(f"Writing results to {args.basedir}") if args.tb: logdir = f"{args.basedir}/scalars/{os.getpid()}" file_writer = tf.summary.create_file_writer(logdir + "/metrics") file_writer.set_as_default() logger.info(f"Set up tensorboard logdir {logdir}") strategy = create_distribute_strategy(args.distribute, args.tpu_ep) num_replicas = strategy.num_replicas_in_sync logger.info(f"Using {num_replicas} replicas in this job.") Vec1D = BPEVectorizer1D if args.subword_type == 'bpe' else WordpieceVectorizer1D vectorizer = Vec1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, extra_tokens=args.extra_tokens) vocab = {'x': vectorizer.vocab} preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = preproc_data['vocab'] train_md = args.train_md if args.train_md else os.path.join( args.train_dir, 'md.yml') num_train_samples = get_num_samples(train_md) valid_md = args.valid_md if args.valid_md else os.path.join( args.valid_dir, 'md.yml') num_valid_samples = get_num_samples(valid_md) is_curriculum = True if isinstance(num_train_samples, Mapping) else False def dataset_train_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = None if is_curriculum: for sub in num_train_samples.keys(): train_curr_dir = os.path.join(args.train_dir, str(sub)) batchsz_scale_factor = args.nctx // sub this_batchsz = base_batchsz * batchsz_scale_factor curr_ds = get_dataset(train_curr_dir, args.file_type, args.num_train_workers, causal=args.causal).batch( this_batchsz, drop_remainder=True) if ds is None: ds = curr_ds else: ds = ds.concatenate(curr_ds) else: ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers, causal=args.causal).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) train_loader = strategy.experimental_distribute_datasets_from_function( dataset_train_fn) def dataset_test_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = None if is_curriculum: for sub in num_valid_samples.keys(): valid_curr_dir = os.path.join(args.valid_dir, str(sub)) batchsz_scale_factor = args.nctx // sub this_batchsz = base_batchsz * batchsz_scale_factor curr_ds = get_dataset(valid_curr_dir, args.file_type, args.num_train_workers, causal=args.causal).batch( this_batchsz, drop_remainder=True) if ds is None: ds = curr_ds else: ds = ds.concatenate(curr_ds) else: ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, shuffle=False, causal=args.causal).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) valid_loader = strategy.experimental_distribute_datasets_from_function( dataset_test_fn) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) embeddings = {'x': preproc_data['embeddings']} logger.info("Loaded embeddings") logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) model = create_model(args, embeddings) if isinstance(model, GatedMLPLanguageModel) and is_curriculum: raise Exception( "Variable tensor lengths not currently supported for gMLP") logger.info("Loaded model and loss") if is_curriculum: steps_per_epoch = 0 steps_per_valid_epoch = 0 for k, v in num_train_samples.items(): steps_per_epoch += int(num_train_samples[k] // (args.batch_size * (args.nctx / k))) for k, v in num_valid_samples.items(): steps_per_valid_epoch += int(num_valid_samples[k] // (args.batch_size * (args.nctx / k))) else: steps_per_epoch = num_train_samples // args.batch_size steps_per_valid_epoch = num_valid_samples // args.batch_size update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr) linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay) optimizer = EagerOptimizer(loss_function, optim=args.optim, lr_function=lr_sched, weight_decay=args.weight_decay, clip=args.clip, lr=args.lr) checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=args.basedir, max_to_keep=5) start_epoch = 0 if args.restart: # The global step gets automatically updated here # so we dont have to worry about our LR regimen checkpoint.restore(checkpoint_manager.latest_checkpoint) current_step = optimizer.global_step start_epoch = current_step // steps_per_epoch def _replicated_train_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas) return per_replica_loss @tf.function def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_train_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) def _replicated_test_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas return per_replica_loss @tf.function def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) timer = Timer() with strategy.scope(): for epoch in range(start_epoch, args.epochs): timer.start() SET_TRAIN_FLAG(True) logger.info('Starting epoch %d', epoch + 1) avg_loss = Average('average_train_loss') metrics = {} train_iter = iter(train_loader) for i in range(steps_per_epoch): try: loss = _distributed_train_step(next(train_iter)) avg_loss.update(loss.numpy().item()) tf.summary.scalar("train_loss", data=loss, step=optimizer.global_step) except Exception as e: logger.error( f"Exception at training step {i+1}/{steps_per_epoch}. Skipping" ) pass if args.convert_only: logger.warning( "Convert only flag specified. Stopping after one step" ) steps = optimizer.global_step.numpy() npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_npz(model, npz_checkpoint) return steps = optimizer.global_step.numpy() if (steps + 1) % report_on == 0: logger.info(avg_loss) if (steps + 1) % update_on == 0: elapsed = timer.elapsed(True) logger.info('elapsed time this epoch %d min', elapsed) logger.info('elapsed step time %f steps/min', i / elapsed) checkpoint_manager.save() if args.npz: npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_npz(model, npz_checkpoint) # How much time elapsed in minutes train_token_loss = avg_loss.avg # This is the average training token-level loss across all machines # This is the token-level training perplexity train_token_ppl = math.exp(train_token_loss) metrics['train_elapsed_min'] = timer.elapsed(True) metrics['average_train_loss'] = train_token_loss metrics['train_ppl'] = train_token_ppl metrics['lr'] = float( lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item()) avg_valid_loss = Average('average_valid_loss') timer.start() SET_TRAIN_FLAG(False) valid_iter = iter(valid_loader) for i in range(steps_per_valid_epoch): try: valid_loss = _distributed_test_step(next(valid_iter)) tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.global_step) avg_valid_loss.update(valid_loss.numpy().item()) except Exception as e: logger.error( f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping" ) pass valid_token_loss = avg_valid_loss.avg valid_token_ppl = math.exp(valid_token_loss) metrics['valid_elapsed_min'] = timer.elapsed(True) metrics['average_valid_loss'] = valid_token_loss metrics['average_valid_word_ppl'] = valid_token_ppl logger.info(json.dumps(metrics, indent=4))
def train(self, ts, reporting_fns, steps=0): """Train by looping over the steps For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s in the model (and `dataset` is `True`). For `feed_dict`, we convert the ts samples to `feed_dict`s and hand them in one-by-one :param ts: The training set :param reporting_fns: A list of reporting hooks :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ strategy = self.strategy #num_replicas = strategy.num_replicas_in_sync def _replicated_train_step(inputs): features, y = inputs per_replica_loss = self.optimizer.update(self.model, features, y) per_replica_toks = self._num_toks(features['tgt_len']) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return per_replica_report_loss, per_replica_toks with strategy.scope(): SET_TRAIN_FLAG(True) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.time() start = time.time() @tf.function def _distributed_train_step(inputs): per_replica_loss, per_replica_toks = strategy.experimental_run_v2( _replicated_train_step, args=(inputs, )) total_step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) total_toks = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None) return total_step_loss, total_toks with autograph_options({ "function_optimization": False, "layout_optimizer": False }): train_iter = iter(ts) for i in range(steps): features, y = next(train_iter) step_report_loss, step_toks = _distributed_train_step( (features, y)) epoch_loss.assign_add(step_report_loss) nstep_loss.assign_add(step_report_loss) epoch_div.assign_add(step_toks) nstep_div.assign_add(step_toks) step = self.optimizer.global_step.numpy().item() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy().item(), nstep_div.numpy().item()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.time() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) self.train_epochs += 1 self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH', reporting_fns) return metrics
def test(self, vs, reporting_fns, steps=0, phase='Valid', **kwargs): """Run an epoch of testing over the dataset If we are using a `tf.dataset`-based `fit_func`, we will just cycle the number of steps and let the `dataset` yield new batches. If we are using `feed_dict`s, we convert each batch from the `DataFeed` and pass that into TF as the `feed_dict` :param vs: A validation set :param reporting_fns: Reporting hooks :param phase: The phase of evaluation (`Test`, `Valid`) :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ def _replicated_valid_step(inputs): features, tgt = inputs top_preds = self.model.predict(features, beam=1, make_input=False) per_replica_loss = loss(self.model, features, tgt) per_replica_toks = self._num_toks(features['tgt_len']) per_replica_report_loss = per_replica_loss * tf.cast( per_replica_toks, tf.float32) return per_replica_report_loss, per_replica_toks, top_preds if phase == 'Test': SET_TRAIN_FLAG(False) return self._evaluate(vs, reporting_fns, **kwargs) strategy = self.strategy num_replicas = strategy.num_replicas_in_sync with strategy.scope(): SET_TRAIN_FLAG(False) self.valid_epochs += 1 total_loss = tf.Variable(0.0) total_toks = tf.Variable(0, dtype=tf.int32) preds = [] golds = [] start = time.time() test_iter = iter(vs) for i in range(steps): features, tgt = next(test_iter) inputs = (features, tgt) per_replica_loss, per_replica_toks, _ = strategy.experimental_run_v2( _replicated_valid_step, args=(inputs, )) total_loss.assign_add( strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)) total_toks.assign_add( strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_toks, axis=None)) # Not sure a good way to get top preds merged yet metrics = self.calc_metrics(total_loss.numpy(), total_toks.numpy()) self.report(self.valid_epochs, metrics, start, phase, 'EPOCH', reporting_fns) return metrics
def train(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_dir", type=str, required=True, help='Training directory') parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory') parser.add_argument( "--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`") parser.add_argument( "--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`") parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--gen_d_model", type=int, default=256, help="Model dimension (and embedding dsz)") parser.add_argument("--gen_d_ff", type=int, default=1024, help="FFN dimension") parser.add_argument( "--gen_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--gen_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--gen_num_layers", type=int, default=8, help="Number of layers") parser.add_argument( '--gen_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--gen_loss_scale", type=float, default=50.0, help="Scaling for loss function") parser.add_argument("--gen_dropout", type=float, default=0.1, help="Dropout") parser.add_argument( '--discrim_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--discrim_d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--discrim_d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--discrim_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--discrim_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--discrim_num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--discrim_dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"]) parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--optim", default="adam", type=str, help="Optimizer to use (defaults to adam)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--causal", type=str2bool, default=False, help="Use CLM (causal) instead of MLM") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"]) parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False) parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False) parser.add_argument( "--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False) args = parser.parse_args() SET_TRAIN_FLAG(True) if args.convert_only: args.restart = True args.npz = True if args.basedir is None: args.basedir = f'discrim-{args.dataset_key}-bpe-{os.getpid()}' logging.basicConfig(level=logging.INFO) logger.info(f"Writing results to {args.basedir}") if args.tb: logdir = f"logs/scalars/{os.getpid()}" file_writer = tf.summary.create_file_writer(logdir + "/metrics") file_writer.set_as_default() logger.info(f"Set up tensorboard logdir {logdir}") strategy = create_distribute_strategy(args.distribute, args.tpu_ep) num_replicas = strategy.num_replicas_in_sync logger.info(f"Using {num_replicas} replicas in this job.") vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx) vocab = {'x': vectorizer.vocab} gen_preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.gen_d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = gen_preproc_data['vocab'] discrim_preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.discrim_d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) def dataset_train_fn(input_context): batch_size = input_context.get_per_replica_batch_size(args.batch_size) ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers).batch(batch_size) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) train_loader = strategy.experimental_distribute_datasets_from_function( dataset_train_fn) def dataset_test_fn(input_context): batch_size = input_context.get_per_replica_batch_size(args.batch_size) ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, shuffle=False).batch(batch_size) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) valid_loader = strategy.experimental_distribute_datasets_from_function( dataset_test_fn) train_md = args.train_md if args.train_md else os.path.join( args.train_dir, 'md.yml') num_train_samples = get_num_samples(train_md) valid_md = args.valid_md if args.valid_md else os.path.join( args.valid_dir, 'md.yml') num_valid_samples = get_num_samples(valid_md) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) gen_embeddings = {'x': gen_preproc_data['embeddings']} discrim_embeddings = {'x': discrim_preproc_data['embeddings']} logger.info("Loaded embeddings") logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1: gen_rpr_k = None elif len(args.gen_rpr_k) == 1: gen_rpr_k = args.gen_rpr_k[0] else: gen_rpr_k = args.gen_rpr_k if len(args.discrim_rpr_k) == 0 or args.discrim_rpr_k[0] < 1: discrim_rpr_k = None elif len(args.gen_rpr_k) == 1: discrim_rpr_k = args.discrim_rpr_k[0] else: discrim_rpr_k = args.discrim_rpr_k gen_model = TransformerMaskedLanguageModel.create( gen_embeddings, hsz=args.gen_d_model, d_ff=args.gen_d_ff, tie_weights=True, dropout=args.gen_dropout, gpu=False, num_heads=args.gen_num_heads, layers=args.gen_num_layers, rpr_k=gen_rpr_k, d_k=args.gen_d_k, windowed_ra=args.windowed_ra, src_keys=['x'], tgt_key='x') discrim_model = TransformerDiscriminator(discrim_embeddings, d_model=args.discrim_d_model, d_ff=args.discrim_d_ff, dropout=args.discrim_dropout, num_heads=args.discrim_num_heads, layers=args.discrim_num_layers, rpr_k=discrim_rpr_k, d_k=args.discrim_d_k) logger.info("Loaded model and loss") steps_per_epoch = num_train_samples // args.batch_size steps_per_valid_epoch = num_valid_samples // args.batch_size update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr) linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay) mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1)) if mask_value == -1: logger.error("We could not find a suitable masking token in the vocab") return optimizer, clip = create_keras_optimizer(**vars(args)) discrim_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=discrim_model) discrim_checkpoint_manager = tf.train.CheckpointManager( discrim_checkpoint, directory=os.path.join(args.basedir, 'discrim'), max_to_keep=5) gen_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=discrim_model) gen_checkpoint_manager = tf.train.CheckpointManager(gen_checkpoint, directory=os.path.join( args.basedir, 'gen'), max_to_keep=5) mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1)) if mask_value == -1: logger.error("We could not find a suitable masking token in the vocab") return if args.restart: # The global step gets automatically updated here # so we dont have to worry about our LR regimen gen_checkpoint.restore(gen_checkpoint_manager.latest_checkpoint) discrim_checkpoint.restore( discrim_checkpoint_manager.latest_checkpoint) def _replicated_train_step(inputs): """This runs on a single replica""" noised_x, labels = inputs with tf.GradientTape() as tape: gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( noised_x, labels, gen_model, discrim_model, mask_value) loss_value = (args.gen_loss_scale * gen_loss_step + discrim_loss_step) / num_replicas grads = tape.gradient( loss_value, gen_model.trainable_variables + discrim_model.trainable_variables) grads, _ = tf.clip_by_global_norm(grads, clip) optimizer.apply_gradients( zip( grads, gen_model.trainable_variables + discrim_model.trainable_variables)) return loss_value, gen_loss_step, discrim_loss_step, acc @tf.function def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ loss, gen_loss, discrim_loss, acc = strategy.run( _replicated_train_step, args=(inputs, )) sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, gen_loss, axis=None) sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, discrim_loss, axis=None) sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None) return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc def _replicated_test_step(inputs): """This runs on a single replica""" noised_x, labels = inputs gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( noised_x, labels, gen_model, discrim_model, mask_value) loss_value = (args.gen_loss_scale * gen_loss_step + discrim_loss_step) / num_replicas return loss_value, gen_loss_step, discrim_loss_step, acc @tf.function def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ loss, gen_loss, discrim_loss, acc = strategy.run(_replicated_test_step, args=(inputs, )) sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, gen_loss, axis=None) sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, discrim_loss, axis=None) sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None) return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc # This is the training loop start_epoch = 0 timer = Timer() with strategy.scope(): for epoch in range(start_epoch, args.epochs): SET_TRAIN_FLAG(True) logger.info('Starting epoch %d', epoch + 1) avg_loss = Average('average_train_loss') avg_gen_loss = Average('average_gen_loss') avg_discrim_loss = Average('average_discrim_loss') avg_acc = Average('average_train_acc') metrics = {} timer.start() train_iter = iter(train_loader) for i in range(steps_per_epoch): loss, gen_loss, discrim_loss, acc = _distributed_train_step( next(train_iter)) avg_loss.update(loss.numpy().item()) avg_gen_loss.update(gen_loss.numpy().item()) avg_discrim_loss.update(discrim_loss.numpy().item()) avg_acc.update(acc.numpy().item()) tf.summary.scalar("train_loss", data=loss, step=optimizer.iterations) tf.summary.scalar("train_gen_loss", data=gen_loss, step=optimizer.iterations) tf.summary.scalar("train_discrim_loss", data=discrim_loss, step=optimizer.iterations) tf.summary.scalar("train_acc", data=acc, step=optimizer.iterations) if args.convert_only: logger.warning( "Convert only flag specified. Stopping after one step" ) steps = optimizer.iterations.numpy() npz_checkpoint = os.path.join(args.basedir, f'discrim-step-{steps}.npz') save_tlm_npz(discrim_model, npz_checkpoint) npz_checkpoint = os.path.join(args.basedir, f'gen-step-{steps}.npz') save_tlm_npz(gen_model, npz_checkpoint) return if (i + 1) % report_on == 0: logging.info(avg_loss) logging.info(avg_gen_loss) logging.info(avg_discrim_loss) logging.info(avg_acc) if (i + 1) % update_on == 0: elapsed = timer.elapsed(True) logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) gen_checkpoint_manager.save() discrim_checkpoint_manager.save() if args.npz: steps = optimizer.iterations.numpy() npz_checkpoint = os.path.join( args.basedir, f'discrim-step-{steps}.npz') save_tlm_npz(discrim_model, npz_checkpoint) npz_checkpoint = os.path.join(args.basedir, f'gen-step-{steps}.npz') save_tlm_npz(gen_model, npz_checkpoint) # This is the average training token-level loss across all machines # This is the token-level training perplexity metrics['train_elapsed_min'] = timer.elapsed(True) metrics['average_train_loss'] = avg_loss.avg metrics['average_gen_loss'] = avg_gen_loss.avg metrics['average_discrim_loss'] = avg_discrim_loss.avg metrics['average_train_acc'] = avg_acc.avg metrics['lr'] = float( lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item()) avg_valid_loss = Average('average_valid_loss') avg_valid_gen_loss = Average('average_valid_gen_loss') avg_valid_discrim_loss = Average('average_valid_discrim_loss') avg_valid_acc = Average('average_valid_acc') timer.start() SET_TRAIN_FLAG(False) valid_iter = iter(valid_loader) for i in range(steps_per_valid_epoch): valid_loss, valid_gen_loss, valid_discrim_loss, valid_acc = _distributed_test_step( next(valid_iter)) tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.iterations) tf.summary.scalar('valid_gen_loss', data=valid_gen_loss, step=optimizer.iterations) tf.summary.scalar('valid_discrim_loss', data=valid_discrim_loss, step=optimizer.iterations) tf.summary.scalar('valid_acc', data=valid_acc, step=optimizer.iterations) avg_valid_loss.update(valid_loss.numpy().item()) avg_valid_gen_loss.update(valid_gen_loss.numpy().item()) avg_valid_discrim_loss.update( valid_discrim_loss.numpy().item()) avg_valid_acc.update(valid_acc.numpy().item()) metrics['valid_elapsed_min'] = timer.elapsed(True) metrics['average_valid_loss'] = avg_valid_loss.avg metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg metrics['average_valid_acc'] = avg_valid_acc.avg logger.info(json.dumps(metrics, indent=4))
def train(self, ts, reporting_fns): """Train by looping over the steps For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s in the model (and `dataset` is `True`). For `feed_dict`, we convert the ts samples to `feed_dict`s and hand them in one-by-one :param ts: The training set :param reporting_fns: A list of reporting hooks :param dataset: (`bool`) Are we using `tf.dataset`s :return: Metrics """ SET_TRAIN_FLAG(True) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.perf_counter() start = time.perf_counter() def _train_step_no_state(inputs): """Replicated training step.""" features, y = inputs loss = self.optimizer.update(self.model, features, y) toks = self._num_toks(y) report_loss = loss * tf.cast(toks, tf.float32) return report_loss, toks def _train_step_with_state(inputs, hidden): """Replicated training step.""" features, y = inputs loss, hidden = self.optimizer.update_with_hidden( self.model, hidden, features, y) toks = self._num_toks(y) report_loss = loss * tf.cast(toks, tf.float32) return hidden, report_loss, toks if get_version(tf) >= 2: _train_step_with_state = tf.function(_train_step_with_state) _train_step_no_state = tf.function(_train_step_no_state) h = None for inputs in ts: if self.model.requires_state: h, step_report_loss, step_toks = _train_step_with_state( inputs, h) else: step_report_loss, step_toks = _train_step_no_state(inputs) epoch_loss.assign_add(step_report_loss) nstep_loss.assign_add(step_report_loss) epoch_div.assign_add(step_toks) nstep_div.assign_add(step_toks) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.perf_counter() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) self.train_epochs += 1 self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH', reporting_fns) return metrics