Exemple #1
0
    def test(self, ts, steps=0, **kwargs):
        """Method that evaluates on some data.  There are 2 modes this can run in, `feed_dict` and `dataset`

        In `feed_dict` mode, the model cycles the test data batch-wise and feeds each batch in with a `feed_dict`.
        In `dataset` mode, the data is still passed in to this method, but it is not passed in a `feed_dict` and is
        mostly superfluous since the features are grafted right onto the graph.  However, we do use it for supplying
        the ground truth, ids and text, so it is essential that the caller does not shuffle the data
        :param ts: The test set
        :param conll_output: (`str`) An optional file output
        :param txts: A list of text data associated with the encoded batch
        :param dataset: (`bool`) Is this using `tf.dataset`s
        :return: The metrics
        """
        SET_TRAIN_FLAG(False)

        total_correct = total_sum = 0
        gold_spans = []
        pred_spans = []

        self.cm = ConfusionMatrix(self.idx2classlabel)

        handle = None
        if kwargs.get("conll_output") is not None and kwargs.get(
                'txts') is not None:
            handle = open(kwargs.get("conll_output"), "w")

        try:
            pg = create_progress_bar(steps)
            metrics = {}
            for (features, y), batch in pg(
                    zip_longest(ts, kwargs.get('batches', []), fillvalue={})):
                correct, count, golds, guesses = self.process_batch(
                    features,
                    y,
                    handle=handle,
                    txts=kwargs.get("txts"),
                    ids=batch.get("ids"))
                total_correct += correct
                total_sum += count
                gold_spans.extend(golds)
                pred_spans.extend(guesses)

            total_acc = total_correct / float(total_sum)
            # Only show the fscore if requested
            metrics['tagging_f1'] = span_f1(gold_spans, pred_spans)
            metrics['tagging_acc'] = total_acc
            metrics.update({
                f"classification_{k}": v
                for k, v in self.cm.get_all_metrics().items()
            })
            if self.verbose:
                conll_metrics = per_entity_f1(gold_spans, pred_spans)
                conll_metrics['acc'] = total_acc * 100
                conll_metrics['tokens'] = total_sum
                logger.info(conlleval_output(conll_metrics))
        finally:
            if handle is not None:
                handle.close()

        return metrics
Exemple #2
0
    def train(self, ts, reporting_fns):
        """Train by looping over the steps

        For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s
        in the model (and `dataset` is `True`).  For `feed_dict`, we convert the ts samples
        to `feed_dict`s and hand them in one-by-one

        :param ts: The training set
        :param reporting_fns: A list of reporting hooks
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        SET_TRAIN_FLAG(True)
        epoch_loss = tf.Variable(0.0)
        epoch_div = tf.Variable(0, dtype=tf.int32)
        nstep_loss = tf.Variable(0.0)
        nstep_div = tf.Variable(0, dtype=tf.int32)
        self.nstep_start = time.perf_counter()
        start = time.perf_counter()

        @tf.function
        def _train_step(features, y):
            """Replicated training step."""

            loss = self.optimizer.update(self.model, features, y)
            toks = self._num_toks(features['tgt_len'])
            report_loss = loss * tf.cast(toks, tf.float32)
            return report_loss, toks

        with autograph_options({
                "function_optimization": False,
                "layout_optimizer": False
        }):
            for features, y in ts:
                features['dst'] = y[:, :-1]
                step_report_loss, step_toks = _train_step(features, y)
                epoch_loss.assign_add(step_report_loss)
                nstep_loss.assign_add(step_report_loss)
                epoch_div.assign_add(step_toks)
                nstep_div.assign_add(step_toks)

                step = self.optimizer.global_step.numpy() + 1
                if step % self.nsteps == 0:
                    metrics = self.calc_metrics(nstep_loss.numpy(),
                                                nstep_div.numpy())
                    self.report(step, metrics, self.nstep_start, 'Train',
                                'STEP', reporting_fns, self.nsteps)
                    nstep_loss.assign(0.0)
                    nstep_div.assign(0)
                    self.nstep_start = time.perf_counter()

        epoch_loss = epoch_loss.numpy()
        epoch_div = epoch_div.numpy()
        metrics = self.calc_metrics(epoch_loss, epoch_div)
        self.train_epochs += 1
        self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH',
                    reporting_fns)
        return metrics
Exemple #3
0
    def _train(self, loader, steps=0, **kwargs):
        """Train an epoch of data using either the input loader or using `tf.dataset`

        In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict
        When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps
        to train.  We do use a `feed_dict` for passing the `TRAIN_FLAG` in either case

        :param loader: A data feed
        :param kwargs: See below

        :Keyword Arguments:
         * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True`
         * *reporting_fns* (`list`) A list of reporting hooks to use

        :return: Metrics
        """
        SET_TRAIN_FLAG(True)
        reporting_fns = kwargs.get('reporting_fns', [])
        pg = create_progress_bar(steps)
        epoch_loss = tf.Variable(0.0)
        epoch_div = tf.Variable(0, dtype=tf.int32)
        nstep_loss = tf.Variable(0.0)
        nstep_div = tf.Variable(0, dtype=tf.int32)
        self.nstep_start = time.perf_counter()

        @tf.function
        def _train_step(inputs):
            features, y = inputs
            loss = self.optimizer.update(self.model, features, y)
            batchsz = get_shape_as_list(y)[0]
            report_loss = loss * batchsz
            return report_loss, batchsz

        with autograph_options({
                "function_optimization": False,
                "layout_optimizer": False
        }):
            for inputs in pg(loader):
                step_report_loss, step_batchsz = _train_step(inputs)
                epoch_loss.assign_add(step_report_loss)
                nstep_loss.assign_add(step_report_loss)
                epoch_div.assign_add(step_batchsz)
                nstep_div.assign_add(step_batchsz)

                step = self.optimizer.global_step.numpy() + 1
                if step % self.nsteps == 0:
                    metrics = self.calc_metrics(nstep_loss.numpy(),
                                                nstep_div.numpy())
                    self.report(step, metrics, self.nstep_start, 'Train',
                                'STEP', reporting_fns, self.nsteps)
                    nstep_loss.assign(0.0)
                    nstep_div.assign(0)
                    self.nstep_start = time.perf_counter()

        epoch_loss = epoch_loss.numpy()
        epoch_div = epoch_div.numpy()
        metrics = self.calc_metrics(epoch_loss, epoch_div)
        return metrics
 def get_input(self, training=False):
     SET_TRAIN_FLAG(training)
     dataset = tf.data.Dataset.from_tensor_slices((self.x, self.y))
     dataset = dataset.shuffle(buffer_size=SHUF_BUF_SZ)
     dataset = dataset.batch(50)
     dataset = dataset.map(lambda x, y: ({
         'word': x,
         'lengths': count_nonzero(x, axis=1)
     }, y))
     dataset = dataset.prefetch(NUM_PREFETCH)
     return dataset
Exemple #5
0
def test_windowed_ra():
    num_heads = 4
    d_model = 64
    rpr_k = 1
    batchsize = 2
    nctx = 256
    d_k = d_model // num_heads

    with tf.device("/cpu:0"):
        old = SeqScaledDotProductRelativeAttention(pdrop=0.)
        new = SeqScaledWindowedRelativeAttention(pdrop=0.)

        rpr_key_emb = tf.keras.layers.Embedding(2 * rpr_k + 1, d_k)
        rpr_value_emb = tf.keras.layers.Embedding(2 * rpr_k + 1, d_k)

        Q = tf.random.normal([batchsize, num_heads, nctx, d_k])
        K = tf.random.normal([batchsize, num_heads, nctx, d_k])
        V = tf.random.normal([batchsize, num_heads, nctx, d_k])
        lengths = tf.random.uniform([
            batchsize,
        ], 0, nctx, dtype=tf.int32)
        seq_mask = tf.sequence_mask(lengths, maxlen=nctx, dtype=tf.float32)
        in_mask = tf.expand_dims(tf.expand_dims(seq_mask, 1), 1)
        out_mask = tf.expand_dims(tf.expand_dims(seq_mask, 1), -1)

        # manually create a ra_mask to prevent attention beyond rpr_k
        ones = tf.ones([nctx, nctx])
        ra_mask = tf.linalg.band_part(ones, rpr_k, rpr_k)
        mask = in_mask * tf.expand_dims(tf.expand_dims(ra_mask, 0), 0)
        rpr_key_old, rpr_value_old = make_rpr(rpr_key_emb, rpr_value_emb,
                                              rpr_k, nctx)
        SET_TRAIN_FLAG(False)
        out_old = old((Q, K, V, rpr_key_old, rpr_value_old, mask))
        out_old = masked_fill(out_old, tf.equal(out_mask, 0), 1)
        print(out_old.shape)

        # using the windowed relative attention with the original sequence mask
        rpr_key_new, rpr_value_new = unfold_rpr(rpr_key_emb, rpr_value_emb,
                                                rpr_k)
        out_new = new((Q, K, V, rpr_key_new, rpr_value_new, in_mask))
        out_new = masked_fill(out_new, tf.equal(out_mask, 0), 1)
        print(out_new.shape)
        if get_version(tf) < 2:
            with tf.compat.v1.Session() as sess:
                out_old, out_new = sess.run([out_old, out_new])
        else:
            out_old, out_new = out_old.numpy(), out_new.numpy()

        assert np.allclose(out_old, out_new, atol=1e-6)
Exemple #6
0
    def test(self, vs, reporting_fns, phase='Valid', **kwargs):
        """Run an epoch of testing over the dataset

        If we are using a `tf.dataset`-based `fit_func`, we will just
        cycle the number of steps and let the `dataset` yield new batches.

        If we are using `feed_dict`s, we convert each batch from the `DataFeed`
        and pass that into TF as the `feed_dict`

        :param vs: A validation set
        :param reporting_fns: Reporting hooks
        :param phase: The phase of evaluation (`Test`, `Valid`)
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        SET_TRAIN_FLAG(False)
        if phase == 'Test':
            return self._evaluate(vs, reporting_fns, **kwargs)

        self.valid_epochs += 1

        total_loss = 0
        total_toks = 0
        preds = []
        golds = []

        start = time.perf_counter()
        for features, tgt in vs:
            features['dst'] = tgt[:, :-1]
            top_preds = self.model.predict(features, beam=1,
                                           make_input=False)[0]
            loss_value = self.loss(self.model, features, tgt).numpy()
            toks = tf.cast(self._num_toks(features['tgt_len']),
                           tf.float32).numpy()
            total_loss += loss_value * toks
            total_toks += toks
            preds.extend(
                convert_seq2seq_preds(top_preds[:, 0, :], self.tgt_rlut))
            golds.extend(
                convert_seq2seq_golds(tgt, features['tgt_len'], self.tgt_rlut))

        metrics = self.calc_metrics(total_loss, total_toks)
        metrics['bleu'] = bleu(preds, golds, self.bleu_n_grams)[0]
        self.report(self.valid_epochs, metrics, start, phase, 'EPOCH',
                    reporting_fns)
        return metrics
Exemple #7
0
    def _test(self, loader, steps=0, **kwargs):
        """Test an epoch of data using either the input loader or using `tf.dataset`

        In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict
        When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps
        to train.

        :param loader: A data feed
        :param kwargs: See below

        :Keyword Arguments:
          * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True`
          * *reporting_fns* (`list`) A list of reporting hooks to use
          * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on

        :return: Metrics
        """

        metrics = [LAS(), UAS(), LCM(), UCM()]

        pg = create_progress_bar(steps)

        SET_TRAIN_FLAG(False)
        for features, y in pg(loader):
            heads_gold, labels_gold = y
            greedy_heads_pred, greedy_labels_pred = self.model.decode(features)
            B, T = get_shape_as_list(greedy_labels_pred)[:2]
            labels_gold_trimmed = labels_gold[:, :T].numpy()
            heads_gold_trimmed = heads_gold[:, :T].numpy()

            for i in range(B):
                for m in metrics:
                    if self.punct_eval is False:
                        labels_gold_trimmed[i] = masked_fill(
                            labels_gold_trimmed[i],
                            labels_gold_trimmed[i] == self.model.punct,
                            Offsets.PAD)
                    m.add(greedy_heads_pred[i], heads_gold_trimmed[i],
                          greedy_labels_pred[i], labels_gold_trimmed[i])

        metrics = {m.name: m.score for m in metrics}
        return metrics
Exemple #8
0
    def test(self, vs, reporting_fns, phase):
        """Run an epoch of testing over the dataset

        If we are using a `tf.dataset`-based `fit_func`, we will just
        cycle the number of steps and let the `dataset` yield new batches.

        If we are using `feed_dict`s, we convert each batch from the `DataFeed`
        and pass that into TF as the `feed_dict`

        :param vs: A validation set
        :param reporting_fns: Reporting hooks
        :param phase: The phase of evaluation (`Test`, `Valid`)
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        total_loss = 0.0
        total_toks = 0
        epochs = 0
        if phase == 'Valid':
            self.valid_epochs += 1
            epochs = self.valid_epochs
        SET_TRAIN_FLAG(False)

        start = time.perf_counter()
        h = None
        for features, y in vs:
            if self.model.requires_state:
                loss_value, h = loss_with_state(self.model, h, features, y)
            else:
                loss_value = loss_without_state(self.model, features, y)
            loss_value = loss_value.numpy()
            toks = self._num_toks(y)
            total_loss += loss_value * tf.cast(toks, tf.float32).numpy()
            total_toks += toks.numpy()

        metrics = self.calc_metrics(total_loss, total_toks)
        self.report(epochs, metrics, start, phase, 'EPOCH', reporting_fns)
        return metrics
Exemple #9
0
    def train(self, ts, reporting_fns, steps=0, dataset=True):
        """Train by looping over the steps

        For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s
        in the model (and `dataset` is `True`).  For `feed_dict`, we convert the ts samples
        to `feed_dict`s and hand them in one-by-one

        :param ts: The training set
        :param reporting_fns: A list of reporting hooks
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        strategy = self.strategy

        def _replicated_train_step_no_state(inputs):

            features, y = inputs
            per_replica_loss = self.optimizer.update(self.model, features, y)
            per_replica_toks = self._num_toks(y)
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return per_replica_report_loss, per_replica_toks

        def _replicated_train_step_with_state(inputs, hidden):
            features, y = inputs
            per_replica_loss, new_hidden = self.optimizer.update_with_hidden(
                self.model, hidden, features, y)
            per_replica_toks = self._num_toks(y)
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return new_hidden, per_replica_report_loss, per_replica_toks

        with strategy.scope():
            train_iter = iter(ts)
            SET_TRAIN_FLAG(True)
            epoch_loss = tf.Variable(0.0)
            epoch_div = tf.Variable(0, dtype=tf.int32)
            nstep_loss = tf.Variable(0.0)
            nstep_div = tf.Variable(0, dtype=tf.int32)
            self.nstep_start = time.time()
            start = time.time()

            @tf.function
            def _distributed_train_no_state(inputs):
                per_replica_loss, per_replica_toks = strategy.experimental_run_v2(
                    _replicated_train_step_no_state, args=(inputs, ))
                return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_loss,
                                       axis=None), strategy.reduce(
                                           tf.distribute.ReduceOp.SUM,
                                           per_replica_toks,
                                           axis=None)

            @tf.function
            def _distributed_train_with_state(inputs, hidden):

                h, per_replica_loss, per_replica_toks = strategy.experimental_run_v2(
                    _replicated_train_step_with_state, args=(
                        inputs,
                        hidden,
                    ))
                step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            per_replica_loss,
                                            axis=None)
                step_toks = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            per_replica_toks,
                                            axis=None)
                return h, step_loss, step_toks

            h = None
            for i in range(steps):

                inputs = next(train_iter)
                if self.model.requires_state:
                    h, step_loss, step_toks = _distributed_train_with_state(
                        inputs, h)
                else:
                    step_loss, step_toks = _distributed_train_no_state(inputs)
                epoch_loss.assign_add(step_loss)
                nstep_loss.assign_add(step_loss)
                epoch_div.assign_add(step_toks)
                nstep_div.assign_add(step_toks)
                step = self.optimizer.global_step.numpy() + 1
                if step % self.nsteps == 0:
                    metrics = self.calc_metrics(nstep_loss.numpy(),
                                                nstep_div.numpy())
                    self.report(step, metrics, self.nstep_start, 'Train',
                                'STEP', reporting_fns, self.nsteps)
                    nstep_loss.assign(0.0)
                    nstep_div.assign(0)
                    self.nstep_start = time.time()

            epoch_loss = epoch_loss.numpy()
            epoch_div = epoch_div.numpy()
            metrics = self.calc_metrics(epoch_loss, epoch_div)
            self.train_epochs += 1
            self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH',
                        reporting_fns)
            return metrics
Exemple #10
0
def fit_eager_distributed(model_params, ts, vs, es=None, **kwargs):
    """
    Train an language model using TensorFlow with `tf.dataset`.  This
    is the default behavior for training.

    :param model_params: The model (or parameters to create the model) to train
    :param ts: A training data set
    :param vs: A validation data set
    :param es: A test data set, can be None
    :param kwargs:
        See below

    :Keyword Arguments:
        * *do_early_stopping* (``bool``) --
          Stop after evaluation data is no longer improving.  Defaults to True
        * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on
        * *epochs* (``int``) -- how many epochs.  Default to 20
        * *outfile* -- Model output file, defaults to classifier-model.pyth
        * *patience* --
           How many epochs where evaluation is no longer improving before we give up
        * *reporting* --
           Callbacks which may be used on reporting updates
        * *nsteps* (`int`) -- If we should report every n-steps, this should be passed
        * *ema_decay* (`float`) -- If we are doing an exponential moving average, what decay to us4e
        * *clip* (`int`) -- If we are doing gradient clipping, what value to use
        * *optim* (`str`) -- The name of the optimizer we are using
        * *lr* (`float`) -- The learning rate we are using
        * *mom* (`float`) -- If we are using SGD, what value to use for momentum
        * *beta1* (`float`) -- Adam-specific hyper-param, defaults to `0.9`
        * *beta2* (`float`) -- Adam-specific hyper-param, defaults to `0.999`
        * *epsilon* (`float`) -- Adam-specific hyper-param, defaults to `1e-8

    :return: None
    """

    epochs = int(kwargs.get('epochs', 5))
    patience = int(kwargs.get('patience', epochs))

    model_file = get_model_file('lm', 'tf', kwargs.get('basedir'))

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'avg_loss')
        early_stopping_cmp, best_metric = get_metric_cmp(
            early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        print('Doing early stopping on [%s] with patience [%d]' %
              (early_stopping_metric, patience))

    reporting_fns = listify(kwargs.get('reporting', []))
    print('reporting', reporting_fns)

    batchsz = kwargs['batchsz']
    test_batchsz = kwargs.get('test_batchsz', batchsz)
    tgt_key = model_params.get('tgt_key')

    train_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(ts))
    train_dataset = train_dataset.shuffle(buffer_size=SHUF_BUF_SZ)
    train_dataset = train_dataset.batch(batchsz, drop_remainder=True)
    train_dataset = train_dataset.prefetch(NUM_PREFETCH)

    valid_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(vs))
    valid_dataset = valid_dataset.batch(batchsz, drop_remainder=True)
    valid_dataset = valid_dataset.prefetch(NUM_PREFETCH)

    trainer = LanguageModelTrainerDistributedTf(model_params, **kwargs)
    train_dataset = trainer.distribute(train_dataset)
    valid_dataset = trainer.distribute(valid_dataset)

    last_improved = 0
    SET_TRAIN_FLAG(True)

    for epoch in range(epochs):

        trainer.train(train_dataset, reporting_fns, steps=len(ts))
        test_metrics = trainer.test(valid_dataset,
                                    reporting_fns,
                                    phase='Valid',
                                    steps=len(vs))

        if do_early_stopping is False:
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric],
                                best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            print('New best %.3f' % best_metric)
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif (epoch - last_improved) > patience:
            print('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        print('Best performance on %s: %.3f at epoch %d' %
              (early_stopping_metric, best_metric, last_improved))

    if es is not None:
        print('Reloading best checkpoint')
        trainer.recover_last_checkpoint()
        trainer.strategy = tf.distribute.OneDeviceStrategy('/device:GPU:0')
        test_dataset = tf.data.Dataset.from_tensor_slices(to_tensors(es))
        test_dataset = test_dataset.batch(test_batchsz, drop_remainder=False)
        test_dataset = test_dataset.prefetch(NUM_PREFETCH)
        test_dataset = trainer.distribute(test_dataset)
        trainer.test(test_dataset, reporting_fns, phase='Test', steps=len(es))
Exemple #11
0
    def test(self, vs, reporting_fns, phase, steps=0):
        """Run an epoch of testing over the dataset

        If we are using a `tf.dataset`-based `fit_func`, we will just
        cycle the number of steps and let the `dataset` yield new batches.

        If we are using `feed_dict`s, we convert each batch from the `DataFeed`
        and pass that into TF as the `feed_dict`

        :param vs: A validation set
        :param reporting_fns: Reporting hooks
        :param phase: The phase of evaluation (`Test`, `Valid`)
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        strategy = self.strategy

        def _replicated_test_step_no_state(inputs):
            features, y = inputs
            per_replica_loss = loss_without_state(self.model, features, y)
            per_replica_toks = self._num_toks(y)
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return per_replica_report_loss, per_replica_toks

        def _replicated_test_step_with_state(inputs, hidden):
            features, y = inputs
            per_replica_loss, new_hidden = loss_with_state(
                self.model, hidden, features, y)
            per_replica_toks = self._num_toks(y)
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return new_hidden, per_replica_report_loss, per_replica_toks

        with strategy.scope():
            SET_TRAIN_FLAG(False)
            test_iter = iter(vs)
            epoch_loss = tf.Variable(0.0)
            epoch_div = tf.Variable(0, dtype=tf.int32)
            self.nstep_start = time.time()
            start = time.time()

            @tf.function
            def _distributed_test_no_state(inputs):
                per_replica_loss, per_replica_toks = strategy.experimental_run_v2(
                    _replicated_test_step_no_state, args=(inputs, ))
                return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_loss,
                                       axis=None), strategy.reduce(
                                           tf.distribute.ReduceOp.SUM,
                                           per_replica_toks,
                                           axis=None)

            @tf.function
            def _distributed_test_with_state(inputs, hidden):

                h, per_replica_loss, per_replica_toks = strategy.experimental_run_v2(
                    _replicated_test_step_with_state, args=(
                        inputs,
                        hidden,
                    ))
                step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            per_replica_loss,
                                            axis=None)
                step_toks = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            per_replica_toks,
                                            axis=None)
                return h, step_loss, step_toks

            epochs = 0
            if phase == 'Valid':
                self.valid_epochs += 1
                epochs = self.valid_epochs

            h = None
            for i in range(steps):
                inputs = next(test_iter)
                if self.model.requires_state:
                    h, per_replica_loss, per_replica_toks = _distributed_test_with_state(
                        inputs, h)
                else:
                    per_replica_loss, per_replica_toks = _distributed_test_no_state(
                        inputs)
                epoch_loss.assign_add(per_replica_loss)
                epoch_div.assign_add(per_replica_toks)
            metrics = self.calc_metrics(epoch_loss.numpy(), epoch_div.numpy())
            self.report(epochs, metrics, start, phase, 'EPOCH', reporting_fns)
            return metrics
Exemple #12
0
    bt_x_v = tf.nn.log_softmax(tf.reshape(logits, [-1, vsz]), axis=-1)
    one_hots = tf.one_hot(targets, vsz)
    example_loss = -tf.reduce_sum(one_hots * bt_x_v, axis=-1)
    loss = tf.reduce_mean(example_loss)
    return loss, h


optimizer = EagerOptimizer(loss, optim="adam", lr=args.lr)
for epoch in range(args.epochs):

    loss_accum = 0.
    step = 0
    start = time.time()
    h = None

    SET_TRAIN_FLAG(True)

    for x, y in train_input_fn():
        # Optimize the model
        loss_value, h = optimizer.update_with_hidden(model, h, x, y)
        loss_accum += loss_value
        step += 1
    print('training time {}'.format(time.time() - start))

    mean_loss = loss_accum / step
    print('Training Loss {}, Perplexity {}'.format(mean_loss,
                                                   np.exp(mean_loss)))

    step = 0
    loss_accum = 0
    SET_TRAIN_FLAG(False)
Exemple #13
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=False)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--subword_type",
                        type=str,
                        choices=["bpe", "wordpiece"],
                        default="bpe")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop",
                        type=float,
                        default=0.0,
                        help="Dropout in the dense stack")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--causal",
                        type=str2bool,
                        default=False,
                        help="Use CLM (causal) instead of MLM")
    parser.add_argument("--mlp",
                        type=str2bool,
                        default=False,
                        help="Use Gated MLP")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument(
        '--rpr_value_on',
        type=str2bool,
        default=True,
        help=
        "In relative attention, whether add positional correction to values in addition to the "
        "correction to attention matrix")
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    parser.add_argument("--extra_tokens",
                        help="What extra tokens should we use",
                        nargs="+",
                        default=["[CLS]", "[MASK]"])
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True

    if args.basedir is None:
        args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"{args.basedir}/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    Vec1D = BPEVectorizer1D if args.subword_type == 'bpe' else WordpieceVectorizer1D
    vectorizer = Vec1D(model_file=args.subword_model_file,
                       vocab_file=args.subword_vocab_file,
                       mxlen=args.nctx,
                       extra_tokens=args.extra_tokens)

    vocab = {'x': vectorizer.vocab}
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)

    is_curriculum = True if isinstance(num_train_samples, Mapping) else False

    def dataset_train_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = None
        if is_curriculum:
            for sub in num_train_samples.keys():
                train_curr_dir = os.path.join(args.train_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(train_curr_dir,
                                      args.file_type,
                                      args.num_train_workers,
                                      causal=args.causal).batch(
                                          this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.train_dir,
                             args.file_type,
                             args.num_train_workers,
                             causal=args.causal).batch(base_batchsz)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = None
        if is_curriculum:
            for sub in num_valid_samples.keys():
                valid_curr_dir = os.path.join(args.valid_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(valid_curr_dir,
                                      args.file_type,
                                      args.num_train_workers,
                                      causal=args.causal).batch(
                                          this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.valid_dir,
                             args.file_type,
                             args.num_train_workers,
                             shuffle=False,
                             causal=args.causal).batch(base_batchsz)

        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    model = create_model(args, embeddings)
    if isinstance(model, GatedMLPLanguageModel) and is_curriculum:
        raise Exception(
            "Variable tensor lengths not currently supported for gMLP")
    logger.info("Loaded model and loss")

    if is_curriculum:
        steps_per_epoch = 0
        steps_per_valid_epoch = 0
        for k, v in num_train_samples.items():
            steps_per_epoch += int(num_train_samples[k] // (args.batch_size *
                                                            (args.nctx / k)))
        for k, v in num_valid_samples.items():
            steps_per_valid_epoch += int(num_valid_samples[k] //
                                         (args.batch_size * (args.nctx / k)))

    else:
        steps_per_epoch = num_train_samples // args.batch_size
        steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)
    optimizer = EagerOptimizer(loss_function,
                               optim=args.optim,
                               lr_function=lr_sched,
                               weight_decay=args.weight_decay,
                               clip=args.clip,
                               lr=args.lr)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer,
                                     model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=args.basedir,
                                                    max_to_keep=5)

    start_epoch = 0
    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        current_step = optimizer.global_step
        start_epoch = current_step // steps_per_epoch

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas)
        return per_replica_loss

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_train_step,
                                        args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas
        return per_replica_loss

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            timer.start()
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            metrics = {}
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):

                try:
                    loss = _distributed_train_step(next(train_iter))
                    avg_loss.update(loss.numpy().item())
                    tf.summary.scalar("train_loss",
                                      data=loss,
                                      step=optimizer.global_step)
                except Exception as e:
                    logger.error(
                        f"Exception at training step {i+1}/{steps_per_epoch}. Skipping"
                    )
                    pass
                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.global_step.numpy()
                    npz_checkpoint = os.path.join(
                        args.basedir, f'checkpoint-step-{steps}.npz')
                    save_tlm_npz(model, npz_checkpoint)
                    return

                steps = optimizer.global_step.numpy()
                if (steps + 1) % report_on == 0:
                    logger.info(avg_loss)
                if (steps + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logger.info('elapsed time this epoch %d min', elapsed)
                    logger.info('elapsed step time %f steps/min', i / elapsed)
                    checkpoint_manager.save()
                    if args.npz:

                        npz_checkpoint = os.path.join(
                            args.basedir, f'checkpoint-step-{steps}.npz')
                        save_tlm_npz(model, npz_checkpoint)

            # How much time elapsed in minutes
            train_token_loss = avg_loss.avg
            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            train_token_ppl = math.exp(train_token_loss)
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = train_token_loss
            metrics['train_ppl'] = train_token_ppl
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                try:
                    valid_loss = _distributed_test_step(next(valid_iter))
                    tf.summary.scalar('valid_loss',
                                      data=valid_loss,
                                      step=optimizer.global_step)
                    avg_valid_loss.update(valid_loss.numpy().item())
                except Exception as e:
                    logger.error(
                        f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping"
                    )
                    pass

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)
            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(json.dumps(metrics, indent=4))
Exemple #14
0
    def train(self, ts, reporting_fns, steps=0):
        """Train by looping over the steps

        For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s
        in the model (and `dataset` is `True`).  For `feed_dict`, we convert the ts samples
        to `feed_dict`s and hand them in one-by-one

        :param ts: The training set
        :param reporting_fns: A list of reporting hooks
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        strategy = self.strategy

        #num_replicas = strategy.num_replicas_in_sync
        def _replicated_train_step(inputs):
            features, y = inputs
            per_replica_loss = self.optimizer.update(self.model, features, y)
            per_replica_toks = self._num_toks(features['tgt_len'])
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return per_replica_report_loss, per_replica_toks

        with strategy.scope():
            SET_TRAIN_FLAG(True)
            epoch_loss = tf.Variable(0.0)
            epoch_div = tf.Variable(0, dtype=tf.int32)
            nstep_loss = tf.Variable(0.0)
            nstep_div = tf.Variable(0, dtype=tf.int32)
            self.nstep_start = time.time()
            start = time.time()

            @tf.function
            def _distributed_train_step(inputs):
                per_replica_loss, per_replica_toks = strategy.experimental_run_v2(
                    _replicated_train_step, args=(inputs, ))
                total_step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                                  per_replica_loss,
                                                  axis=None)
                total_toks = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                             per_replica_toks,
                                             axis=None)
                return total_step_loss, total_toks

            with autograph_options({
                    "function_optimization": False,
                    "layout_optimizer": False
            }):
                train_iter = iter(ts)
                for i in range(steps):
                    features, y = next(train_iter)
                    step_report_loss, step_toks = _distributed_train_step(
                        (features, y))
                    epoch_loss.assign_add(step_report_loss)
                    nstep_loss.assign_add(step_report_loss)
                    epoch_div.assign_add(step_toks)
                    nstep_div.assign_add(step_toks)

                    step = self.optimizer.global_step.numpy().item() + 1
                    if step % self.nsteps == 0:
                        metrics = self.calc_metrics(nstep_loss.numpy().item(),
                                                    nstep_div.numpy().item())
                        self.report(step, metrics, self.nstep_start, 'Train',
                                    'STEP', reporting_fns, self.nsteps)
                        nstep_loss.assign(0.0)
                        nstep_div.assign(0)
                        self.nstep_start = time.time()

                epoch_loss = epoch_loss.numpy()
                epoch_div = epoch_div.numpy()
                metrics = self.calc_metrics(epoch_loss, epoch_div)
                self.train_epochs += 1
                self.report(self.train_epochs, metrics, start, 'Train',
                            'EPOCH', reporting_fns)
                return metrics
Exemple #15
0
    def test(self, vs, reporting_fns, steps=0, phase='Valid', **kwargs):
        """Run an epoch of testing over the dataset

        If we are using a `tf.dataset`-based `fit_func`, we will just
        cycle the number of steps and let the `dataset` yield new batches.

        If we are using `feed_dict`s, we convert each batch from the `DataFeed`
        and pass that into TF as the `feed_dict`

        :param vs: A validation set
        :param reporting_fns: Reporting hooks
        :param phase: The phase of evaluation (`Test`, `Valid`)
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """
        def _replicated_valid_step(inputs):
            features, tgt = inputs
            top_preds = self.model.predict(features, beam=1, make_input=False)
            per_replica_loss = loss(self.model, features, tgt)
            per_replica_toks = self._num_toks(features['tgt_len'])
            per_replica_report_loss = per_replica_loss * tf.cast(
                per_replica_toks, tf.float32)
            return per_replica_report_loss, per_replica_toks, top_preds

        if phase == 'Test':
            SET_TRAIN_FLAG(False)
            return self._evaluate(vs, reporting_fns, **kwargs)

        strategy = self.strategy
        num_replicas = strategy.num_replicas_in_sync

        with strategy.scope():

            SET_TRAIN_FLAG(False)
            self.valid_epochs += 1

            total_loss = tf.Variable(0.0)
            total_toks = tf.Variable(0, dtype=tf.int32)
            preds = []
            golds = []

            start = time.time()

            test_iter = iter(vs)

            for i in range(steps):
                features, tgt = next(test_iter)
                inputs = (features, tgt)
                per_replica_loss, per_replica_toks, _ = strategy.experimental_run_v2(
                    _replicated_valid_step, args=(inputs, ))
                total_loss.assign_add(
                    strategy.reduce(tf.distribute.ReduceOp.SUM,
                                    per_replica_loss,
                                    axis=None))
                total_toks.assign_add(
                    strategy.reduce(tf.distribute.ReduceOp.SUM,
                                    per_replica_toks,
                                    axis=None))
                # Not sure a good way to get top preds merged yet

            metrics = self.calc_metrics(total_loss.numpy(), total_toks.numpy())
            self.report(self.valid_epochs, metrics, start, phase, 'EPOCH',
                        reporting_fns)
            return metrics
Exemple #16
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")

    parser.add_argument("--gen_d_model",
                        type=int,
                        default=256,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--gen_d_ff",
                        type=int,
                        default=1024,
                        help="FFN dimension")
    parser.add_argument(
        "--gen_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--gen_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--gen_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        '--gen_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--gen_loss_scale",
                        type=float,
                        default=50.0,
                        help="Scaling for loss function")
    parser.add_argument("--gen_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument(
        '--discrim_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--discrim_d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--discrim_d_ff",
                        type=int,
                        default=2048,
                        help="FFN dimension")
    parser.add_argument(
        "--discrim_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--discrim_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--discrim_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--discrim_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--causal",
                        type=str2bool,
                        default=False,
                        help="Use CLM (causal) instead of MLM")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True
        args.npz = True

    if args.basedir is None:
        args.basedir = f'discrim-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx)
    vocab = {'x': vectorizer.vocab}
    gen_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.gen_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    vocabs = gen_preproc_data['vocab']

    discrim_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.discrim_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    def dataset_train_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.train_dir, args.file_type,
                         args.num_train_workers).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.valid_dir,
                         args.file_type,
                         args.num_train_workers,
                         shuffle=False).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    gen_embeddings = {'x': gen_preproc_data['embeddings']}
    discrim_embeddings = {'x': discrim_preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1:
        gen_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        gen_rpr_k = args.gen_rpr_k[0]
    else:
        gen_rpr_k = args.gen_rpr_k

    if len(args.discrim_rpr_k) == 0 or args.discrim_rpr_k[0] < 1:
        discrim_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        discrim_rpr_k = args.discrim_rpr_k[0]
    else:
        discrim_rpr_k = args.discrim_rpr_k

    gen_model = TransformerMaskedLanguageModel.create(
        gen_embeddings,
        hsz=args.gen_d_model,
        d_ff=args.gen_d_ff,
        tie_weights=True,
        dropout=args.gen_dropout,
        gpu=False,
        num_heads=args.gen_num_heads,
        layers=args.gen_num_layers,
        rpr_k=gen_rpr_k,
        d_k=args.gen_d_k,
        windowed_ra=args.windowed_ra,
        src_keys=['x'],
        tgt_key='x')

    discrim_model = TransformerDiscriminator(discrim_embeddings,
                                             d_model=args.discrim_d_model,
                                             d_ff=args.discrim_d_ff,
                                             dropout=args.discrim_dropout,
                                             num_heads=args.discrim_num_heads,
                                             layers=args.discrim_num_layers,
                                             rpr_k=discrim_rpr_k,
                                             d_k=args.discrim_d_k)

    logger.info("Loaded model and loss")
    steps_per_epoch = num_train_samples // args.batch_size
    steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    optimizer, clip = create_keras_optimizer(**vars(args))

    discrim_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                             model=discrim_model)
    discrim_checkpoint_manager = tf.train.CheckpointManager(
        discrim_checkpoint,
        directory=os.path.join(args.basedir, 'discrim'),
        max_to_keep=5)

    gen_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                         model=discrim_model)
    gen_checkpoint_manager = tf.train.CheckpointManager(gen_checkpoint,
                                                        directory=os.path.join(
                                                            args.basedir,
                                                            'gen'),
                                                        max_to_keep=5)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        gen_checkpoint.restore(gen_checkpoint_manager.latest_checkpoint)
        discrim_checkpoint.restore(
            discrim_checkpoint_manager.latest_checkpoint)

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        with tf.GradientTape() as tape:
            gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                noised_x, labels, gen_model, discrim_model, mask_value)
            loss_value = (args.gen_loss_scale * gen_loss_step +
                          discrim_loss_step) / num_replicas

        grads = tape.gradient(
            loss_value,
            gen_model.trainable_variables + discrim_model.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, clip)
        optimizer.apply_gradients(
            zip(
                grads, gen_model.trainable_variables +
                discrim_model.trainable_variables))

        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(
            _replicated_train_step, args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
            noised_x, labels, gen_model, discrim_model, mask_value)
        loss_value = (args.gen_loss_scale * gen_loss_step +
                      discrim_loss_step) / num_replicas
        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(_replicated_test_step,
                                                         args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    # This is the training loop
    start_epoch = 0
    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            avg_gen_loss = Average('average_gen_loss')
            avg_discrim_loss = Average('average_discrim_loss')
            avg_acc = Average('average_train_acc')

            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):
                loss, gen_loss, discrim_loss, acc = _distributed_train_step(
                    next(train_iter))
                avg_loss.update(loss.numpy().item())
                avg_gen_loss.update(gen_loss.numpy().item())
                avg_discrim_loss.update(discrim_loss.numpy().item())
                avg_acc.update(acc.numpy().item())

                tf.summary.scalar("train_loss",
                                  data=loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_gen_loss",
                                  data=gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_discrim_loss",
                                  data=discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_acc",
                                  data=acc,
                                  step=optimizer.iterations)

                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.iterations.numpy()
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'discrim-step-{steps}.npz')
                    save_tlm_npz(discrim_model, npz_checkpoint)
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'gen-step-{steps}.npz')
                    save_tlm_npz(gen_model, npz_checkpoint)
                    return

                if (i + 1) % report_on == 0:
                    logging.info(avg_loss)
                    logging.info(avg_gen_loss)
                    logging.info(avg_discrim_loss)
                    logging.info(avg_acc)
                if (i + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logging.info('elapsed time this epoch %d min', elapsed)
                    logging.info('elapsed step time %f steps/min', i / elapsed)
                    gen_checkpoint_manager.save()
                    discrim_checkpoint_manager.save()

                    if args.npz:
                        steps = optimizer.iterations.numpy()
                        npz_checkpoint = os.path.join(
                            args.basedir, f'discrim-step-{steps}.npz')
                        save_tlm_npz(discrim_model, npz_checkpoint)
                        npz_checkpoint = os.path.join(args.basedir,
                                                      f'gen-step-{steps}.npz')
                        save_tlm_npz(gen_model, npz_checkpoint)

            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = avg_loss.avg
            metrics['average_gen_loss'] = avg_gen_loss.avg
            metrics['average_discrim_loss'] = avg_discrim_loss.avg
            metrics['average_train_acc'] = avg_acc.avg
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            avg_valid_gen_loss = Average('average_valid_gen_loss')
            avg_valid_discrim_loss = Average('average_valid_discrim_loss')
            avg_valid_acc = Average('average_valid_acc')

            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                valid_loss, valid_gen_loss, valid_discrim_loss, valid_acc = _distributed_test_step(
                    next(valid_iter))
                tf.summary.scalar('valid_loss',
                                  data=valid_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_gen_loss',
                                  data=valid_gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_discrim_loss',
                                  data=valid_discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_acc',
                                  data=valid_acc,
                                  step=optimizer.iterations)
                avg_valid_loss.update(valid_loss.numpy().item())
                avg_valid_gen_loss.update(valid_gen_loss.numpy().item())
                avg_valid_discrim_loss.update(
                    valid_discrim_loss.numpy().item())
                avg_valid_acc.update(valid_acc.numpy().item())

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = avg_valid_loss.avg
            metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg
            metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg
            metrics['average_valid_acc'] = avg_valid_acc.avg
            logger.info(json.dumps(metrics, indent=4))
Exemple #17
0
    def train(self, ts, reporting_fns):
        """Train by looping over the steps

        For a `tf.dataset`-backed `fit_func`, we are using the previously wired `dataset`s
        in the model (and `dataset` is `True`).  For `feed_dict`, we convert the ts samples
        to `feed_dict`s and hand them in one-by-one

        :param ts: The training set
        :param reporting_fns: A list of reporting hooks
        :param dataset: (`bool`) Are we using `tf.dataset`s
        :return: Metrics
        """

        SET_TRAIN_FLAG(True)
        epoch_loss = tf.Variable(0.0)
        epoch_div = tf.Variable(0, dtype=tf.int32)
        nstep_loss = tf.Variable(0.0)
        nstep_div = tf.Variable(0, dtype=tf.int32)
        self.nstep_start = time.perf_counter()
        start = time.perf_counter()

        def _train_step_no_state(inputs):
            """Replicated training step."""

            features, y = inputs
            loss = self.optimizer.update(self.model, features, y)
            toks = self._num_toks(y)
            report_loss = loss * tf.cast(toks, tf.float32)
            return report_loss, toks

        def _train_step_with_state(inputs, hidden):
            """Replicated training step."""

            features, y = inputs
            loss, hidden = self.optimizer.update_with_hidden(
                self.model, hidden, features, y)
            toks = self._num_toks(y)
            report_loss = loss * tf.cast(toks, tf.float32)
            return hidden, report_loss, toks

        if get_version(tf) >= 2:
            _train_step_with_state = tf.function(_train_step_with_state)
            _train_step_no_state = tf.function(_train_step_no_state)

        h = None
        for inputs in ts:
            if self.model.requires_state:
                h, step_report_loss, step_toks = _train_step_with_state(
                    inputs, h)
            else:
                step_report_loss, step_toks = _train_step_no_state(inputs)

            epoch_loss.assign_add(step_report_loss)
            nstep_loss.assign_add(step_report_loss)
            epoch_div.assign_add(step_toks)
            nstep_div.assign_add(step_toks)

            step = self.optimizer.global_step.numpy() + 1
            if step % self.nsteps == 0:
                metrics = self.calc_metrics(nstep_loss.numpy(),
                                            nstep_div.numpy())
                self.report(step, metrics, self.nstep_start, 'Train', 'STEP',
                            reporting_fns, self.nsteps)
                nstep_loss.assign(0.0)
                nstep_div.assign(0)
                self.nstep_start = time.perf_counter()

        epoch_loss = epoch_loss.numpy()
        epoch_div = epoch_div.numpy()
        metrics = self.calc_metrics(epoch_loss, epoch_div)
        self.train_epochs += 1
        self.report(self.train_epochs, metrics, start, 'Train', 'EPOCH',
                    reporting_fns)
        return metrics