def maybe_save(self, epoch, valid_loss=None):
        """
        Main entry point for model saver
        It wraps the `_save` method with checks and apply `keep_checkpoint`
        related logic
        """
        if self.keep_checkpoint == 0:
            return

        if epoch % self.save_checkpoint_epochs != 0:
            return

        chkpt, chkpt_name = self._save(epoch)

        if valid_loss is not None and valid_loss < self.minimum_valid_loss:
            self.minimum_valid_loss = valid_loss
            if self.best_checkpoint is not None:
                self._rm_checkpoint(self.best_checkpoint)
            # self.best_checkpoint = chkpt_name + ".best_ppl_" + str(self.optim.learning_rate) + "." + str(valid_loss)
            self.best_checkpoint = "%s.best_lr_%5.4f_ppl_%5.2f" % (
                chkpt_name, self.optim.learning_rate, valid_loss)
            logger.info("Saving checkpoint %s with lowest ppl" %
                        (self.best_checkpoint))
            copyfile(chkpt_name, self.best_checkpoint)

        if self.keep_checkpoint > 0:
            if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
                todel = self.checkpoint_queue.popleft()
                self._rm_checkpoint(todel)
            self.checkpoint_queue.append(chkpt_name)
def report_bleu(tgt_path, out_path):
    output_file = codecs.open(out_path, 'r', 'utf-8')
    res = subprocess.check_output("perl tools/multi-bleu.perl %s" % (tgt_path),
                                  stdin=output_file,
                                  shell=True).decode("utf-8")

    msg = res.strip()
    logger.info(msg)
Ejemplo n.º 3
0
def _load_fields(dataset, opt, checkpoint):
    if checkpoint is not None:
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = load_fields_from_vocab(checkpoint['vocab'])
    else:
        fields = load_fields_from_vocab(torch.load(opt.data + '.vocab.pt'))
    fields = dict([(k, f) for (k, f) in fields.items()
                   if k in dataset.examples[0].__dict__])

    logger.info(' * vocabulary size == %d' % len(fields['conversation'].vocab))
    return fields
Ejemplo n.º 4
0
def build_model(model_opt, fields, gpu, checkpoint):
    """ Build the Model """
    logger.info('Building model...')
    if model_opt.model == "HRED":
        model = build_HRED(model_opt, fields)
    elif model_opt.model == "TDCM":
        model = build_TDCM(model_opt, fields)
    elif model_opt.model == "TDACM":
        model = build_TDACM(model_opt, fields)
    else:
        raise NotImplementedError

    device = torch.device("cuda" if gpu else "cpu")

    if model_opt.share_embeddings:
        model.decoder.embedding.weight = model.encoder.embedding.weight

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
    else:
        if model_opt.param_init != 0.0:
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
        if model_opt.param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        # topic orthogonal
        if model_opt.model == "TDACM" or model_opt.model == "TDCM":
            nn.init.orthogonal_(model.encoder.topic_keys)
            nn.init.orthogonal_(model.encoder.topic_values)

        vocab = fields["conversation"].vocab
        embedding_pt_path = model_opt.data + ".embedding.pt"
        if not os.path.exists(embedding_pt_path):
            from gensim.models import KeyedVectors
            w2v = KeyedVectors.load_word2vec_format(
                "data/GoogleNews-vectors-negative300.bin", binary=True)
            for token in vocab.itos:
                if token in w2v:
                    model.encoder.embedding.weight.data[
                        vocab.stoi[token]] = torch.from_numpy(w2v[token])
            torch.save(model.encoder.embedding.weight.data, embedding_pt_path)
        else:
            model.encoder.embedding.weight.data.copy_(
                torch.load(embedding_pt_path))

    model.to(device)
    logger.info(model)
    return model
    def _save(self, epoch):
        model_state_dict = self.model.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'vocab': mtdg.data.save_fields_to_vocab(self.fields),
            'opt': self.opt,
            'epoch': epoch,
            'optim': self.optim
        }

        logger.info("Saving checkpoint %s_step_%d.pt" %
                    (self.base_path, epoch))
        checkpoint_path = '%s_step_%d.pt' % (self.base_path, epoch)
        torch.save(checkpoint, checkpoint_path)
        return checkpoint, checkpoint_path
Ejemplo n.º 6
0
def training_opt_postprocessing(opt):
    if torch.cuda.is_available() and not opt.gpuid:
        logger.info("WARNING: You have a CUDA device, should run with -gpuid")

    if opt.gpuid:
        torch.cuda.set_device(opt.gpuid[0])
        if opt.seed > 0:
            # this one is needed for torchtext random call (shuffled iterator)
            # in multi gpu it ensures datasets are read in the same order
            random.seed(opt.seed)
            # These ensure same initialization in multi gpu mode
            torch.manual_seed(opt.seed)
            torch.cuda.manual_seed(opt.seed)

    return opt
    def output(self, epoch, cur_iter, num_iters, learning_rate, start):
        """Write out statistics to stdout.

        Args:
           step (int): current step
           n_batch (int): total batches
           start (int): start time of step.
        """
        t = self.elapsed_time()
        logger.info(
            ("Epoch %2d, [%4d/%4d]; acc: %6.2f; ppl: %6.2f; xent: %6.2f; " +
             "lr: %7.5f; %3.0f / %3.0f tok/s; %6.0f sec") %
            (epoch, cur_iter, num_iters, self.accuracy(), self.ppl(),
             self.xent(), learning_rate, self.n_src_words /
             (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start))
        sys.stdout.flush()
def report_distinct(tgt_path):
    distinct_1 = set()
    distinct_2 = set()

    lines = codecs.open(tgt_path, "r", "utf-8").readlines()

    total_words = 0
    for line in lines:
        # words = [word for word in line.split() if word not in string.punctuation]
        words = [word for word in line.split()]
        distinct_1.update(words)
        distinct_2.update(
            [words[i] + words[i + 1] for i in range(len(words) - 2)])
        total_words += len(words)
    # print("distinct-1 = ", len(distinct_1) * 1.0 / total_words, "distinct number = ", len(distinct_1))
    # print("distinct-2 = ", len(distinct_2) * 1.0 / total_words, "distinct number = ", len(distinct_2))
    msg = ("Distinct-1 = %.4f, Distinct-2 = %.4f" %
           (len(distinct_1) * 1.0 / total_words,
            len(distinct_2) * 1.0 / total_words))
    logger.info(msg)
Ejemplo n.º 9
0
def build_iterator(corpus_type, fields, opt, is_train=True, is_topic=False):
    if is_topic:
        pt_file = opt.data + '.' + corpus_type + '.topic'
        dataset = torch.load(pt_file)
        _fields = mtdg.topic_dataset.get_fields()
        _fields["text"].vocab = _fields["target"].vocab = fields[
            "conversation"].vocab
        dataset.fields = _fields
    else:
        pt_file = opt.data + '.' + corpus_type + '.pt'
        dataset = torch.load(pt_file)
        dataset.fields = fields

    logger.info('Loading %s dataset from %s, number of examples: %d' %
                (corpus_type, pt_file, len(dataset)))
    device = torch.device("cuda" if opt.gpuid else "cpu")
    # return torchtext.data.Iterator(dataset=dataset, batch_size=opt.batch_size, device=device,
    #                                train=is_train, sort=not is_train)
    return OrderedIterator(dataset=dataset,
                           batch_size=opt.batch_size,
                           device=device,
                           train=is_train,
                           sort=False,
                           repeat=False)
def report_embedding(tgt_path, out_path, embedding_path):
    w2v = KeyedVectors.load_word2vec_format(embedding_path, binary=True)

    r = average(tgt_path, out_path, w2v)
    msg = ("Embedding Average Score: %f +/- %f ( %f )" % (r[0], r[1], r[2]))
    logger.info(msg)

    r = greedy_match(tgt_path, out_path, w2v)
    msg = ("Greedy Matching Score: %f +/- %f ( %f )" % (r[0], r[1], r[2]))
    logger.info(msg)

    r = extrema_score(tgt_path, out_path, w2v)
    msg = ("Extrema Score: %f +/- %f ( %f )" % (r[0], r[1], r[2]))
    logger.info(msg)
def report_ground_truth(opt):
    assert opt.ckpt is not None and opt.data is not None

    # 1. Model & Fields
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    fields, model, model_opt = load_test_model(opt, dummy_opt.__dict__)

    # 2. Dataset & Iterator
    if ".pt" in opt.data:
        dataset = torch.load(opt.data)
        dataset.fields = fields
    else:
        conversations = mtdg.data.read_dailydialog_file(opt.data)
        dataset = mtdg.data.Dataset(conversations, fields)
    device = "cuda" if use_gpu(opt) else "cpu"
    data_iter = mtdg.data.OrderedIterator(dataset=dataset,
                                          batch_size=opt.batch_size,
                                          device=device,
                                          train=False,
                                          sort=False,
                                          repeat=False,
                                          shuffle=False)
    # 3. Trainer
    trainer = build_trainer(model_opt,
                            model,
                            fields,
                            optim=None,
                            device=device)

    # 4. Run on test data
    test_stats = trainer.valid(data_iter)
    if opt.report_ppl:
        msg = ("Perplexity: %g" % test_stats.ppl())
        logger.info(msg)
    if opt.report_xent:
        msg = ("Xent: %g" % test_stats.xent())
        logger.info(msg)
    if opt.report_accuracy:
        msg = ("Accuracy: %g" % test_stats.accuracy())
        logger.info(msg)
    def train_topic(self,
                    train_iter,
                    valid_iter,
                    epoch,
                    criterion=None,
                    optimizer=None,
                    test_iter=None):
        """
        Pretrain the topics.
        """
        logger.info(
            '---------------------------------------------------------------------------------'
        )
        logger.info('Start training topics...')

        # epoch = 0
        # for epoch in range(train_epochs):

        train_loss = 0
        # training
        for i, batch in enumerate(train_iter):
            # 1. Input
            input, length = batch.text
            target = batch.target
            input_sentences = input.t().contiguous()
            batch_size = len(batch)

            # 2. zero_grad
            self.model.zero_grad()

            # 3. Model
            topic_aware_representation, encoder_hidden = self.model.encoder(
                input_sentences)
            # scores = self.model.decoder.softmax(self.model.decoder.out(topic_aware_representation))
            scores = self.model.predictor(topic_aware_representation)

            # 4. Loss
            loss = criterion(scores, target, batch_size)
            train_loss += loss.item()

            # 5. optimize
            optimizer.step()

            # 6. Logging
            if i % 2000 == -1 % 2000:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, i * batch_size, len(train_iter.dataset),
                    100. * i / len(train_iter), loss.item()))
        print('Train loss of topic: %g' %
              (train_loss / len(train_iter.dataset)))

        # validation
        valid_loss = None
        if valid_iter is not None:
            self.model.eval()
            valid_loss = 0
            with torch.no_grad():
                for idx, batch in enumerate(valid_iter):
                    input, length = batch.text
                    target = batch.target
                    input_sentences = input.t().contiguous()
                    batch_size = len(batch)

                    topic_aware_representation, encoder_hidden = self.model.encoder(
                        input_sentences, length)
                    scores = self.model.predictor(topic_aware_representation)
                    # scores = self.model.decoder.softmax(self.model.decoder.out(topic_aware_representation))

                    loss = criterion(scores, target, batch_size, train=False)
                    valid_loss += loss.item()

            self.model.train()
            valid_loss = valid_loss / len(valid_iter.dataset)
            print('Validation loss of topic: %g' % (valid_loss))
        rouge_results = r.convert_and_evaluate()
        results_dict = r.output_to_dict(rouge_results)
        return results_dict
    finally:
        pass
        if os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir)


def rouge_results_to_str(results_dict):
    return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format(
        results_dict["rouge_1_f_score"] * 100,
        results_dict["rouge_2_f_score"] * 100,
        results_dict["rouge_3_f_score"] * 100,
        results_dict["rouge_l_f_score"] * 100,
        results_dict["rouge_su*_f_score"] * 100)


if __name__ == "__main__":
    init_logger('test_rouge.log')
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', type=str, default="candidate.txt",
                        help='candidate file')
    parser.add_argument('-r', type=str, default="reference.txt",
                        help='reference file')
    args = parser.parse_args()
    if args.c.upper() == "STDIN":
        args.c = sys.stdin
    results_dict = test_rouge(args.c, args.r)
    logger.info(rouge_results_to_str(results_dict))
def report_rouge(tgt_path, out_path):
    tgt_file = codecs.open(tgt_path, "r", "utf-8")
    out_file = codecs.open(out_path, "r", "utf-8")
    results_dict = test_rouge(out_file, tgt_file)
    msg = rouge_results_to_str(results_dict)
    logger.info(msg)
    def train_topic_v2(self,
                       train_iter,
                       valid_iter,
                       train_epochs,
                       valid_epochs,
                       criterion=None,
                       optimizer=None,
                       test_iter=None):
        """
        Pretrain the topics.
        """

        logger.info('Start training topics...')

        epoch = 0
        for epoch in range(train_epochs):

            train_loss = 0
            # training
            for i, batch in enumerate(train_iter):

                # 1. Input
                input_sentences, target_sentences = batch.conversation
                input_length, target_length = batch.length
                input_turns = batch.turn
                max_turn_length = input_turns.data.max()

                # self.optim.optimizer.zero_grad()
                self.model.zero_grad()

                # scores = self.model(input_sentences, input_length, input_turns, target_sentences)

                # 3. Model
                topic_hidden, encoder_hidden = self.model.encoder(
                    input_sentences)
                start = torch.cumsum(
                    torch.cat(
                        (input_turns.data.new(1).zero_(), input_turns[:-1])),
                    0)

                scores = self.model.decoder.softmax(
                    self.model.decoder.out(topic_aware_representation))

                # 4. Loss
                loss = criterion(scores, target, batch_size)
                train_loss += loss.item()

                # 5. optimize
                optimizer.step()

                # 6. Logging
                if i % 300 == -1 % 300:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                          format(epoch, i * batch_size,
                                 len(train_iter.dataset),
                                 100. * i / len(train_iter), loss.item()))
            print('Train loss of topic: %g' %
                  (train_loss / len(train_iter.dataset)))

            # validation
            self.model.eval()
            valid_loss = 0
            with torch.no_grad():
                for idx, batch in enumerate(valid_iter):
                    input, length = batch.text
                    target = batch.target
                    input_sentences = input.t().contiguous()
                    batch_size = len(batch)

                    topic_aware_representation, encoder_hidden = self.model.encoder(
                        input_sentences, length)
                    scores = self.model.predictor(encoder_hidden.squeeze(0))

                    loss = criterion(scores, target, batch_size, train=False)
                    valid_loss += loss.item()

            self.model.train()
            valid_loss = valid_loss / len(valid_iter.dataset)
            print('Validation loss of topic: %g' % (valid_loss))

            # drop checkpoints
            self._maybe_save(epoch, valid_loss)
Ejemplo n.º 16
0
def main(opt):
    opt = training_opt_postprocessing(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        start_epoch = checkpoint["epoch"] + 1
    else:
        checkpoint = None
        model_opt = opt
        start_epoch = 0

    # Load fields generated from preprocess phase.
    train_dataset = torch.load(opt.data + '.valid.pt')
    fields = _load_fields(train_dataset, opt, checkpoint)

    train_iter = build_iterator("train", fields, opt)
    # train_iter = build_iterator("valid", fields, opt)
    valid_iter = build_iterator("valid", fields, opt, is_train=False)
    # test_iter = build_iterator("test", fields, opt, is_train=False)

    # Build model.
    model = build_model(model_opt, fields, use_gpu(opt), checkpoint)

    # Build optimizer.
    # optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model, opt, fields, optim)

    # Build trainer
    device = "cuda" if use_gpu(opt) else "cpu"
    # if opt.vis_logger:
    #     vis_logger = visdom.Visdom(
    #         server='http://202.117.54.73',
    #         endpoint='events',
    #         port=8097,
    #         ipv6=True,
    #         # proxy=None,
    #         env='multi_turn_dialog')
    trainer = build_trainer(
        # opt, model, fields, optim, device=device, model_saver=model_saver, vis_logger=vis_logger)
        opt,
        model,
        fields,
        optim,
        device=device,
        model_saver=model_saver,
        vis_logger=None)

    # 6. Do training.
    if opt.model == "TDCM" or opt.model == "TDACM":
        # 6.1 Data
        train_topic_iter = build_iterator("train", fields, opt, is_topic=True)
        valid_topic_iter = build_iterator("valid",
                                          fields,
                                          opt,
                                          is_train=False,
                                          is_topic=True)
        # test_topic_iter = build_iterator("test", fields, opt, is_train=False, is_topic=True)

        # optimizer = build_optim(model, opt, checkpoint, is_topic=True)
        topic_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
        topic_criterion = mtdg.TopicLossCompute(
            fields["conversation"].vocab).to(device)
        # trainer.train_topic(train_topic_iter, valid_topic_iter, 10, 1, topic_criterion, topic_optimizer, test_iter=None)
        # # checkpoint = torch.load(model_saver.best_checkpoint)
        # checkpoint = torch.load(model_saver.checkpoint_queue[-1])
        # model.load_state_dict(checkpoint["model"])
        # model_saver._rm_checkpoint(model_saver.best_checkpoint)
        # model_saver.reset()
        trainer.train(train_iter,
                      valid_iter, [start_epoch, opt.epochs],
                      1,
                      train_topic_iter=train_topic_iter,
                      valid_topic_iter=valid_topic_iter,
                      topic_criterion=topic_criterion,
                      topic_optimizer=topic_optimizer)
    else:
        trainer.train(train_iter, valid_iter, [start_epoch, opt.epochs], 1)
    def train(self,
              train_iter,
              valid_iter,
              train_epochs,
              valid_epochs,
              report_func=None,
              train_topic_iter=None,
              valid_topic_iter=None,
              topic_criterion=None,
              topic_optimizer=None):
        """
        The main training loops.
        """
        logger.info('Start training...')
        total_stats = mtdg.utils.Statistics()
        report_stats = mtdg.utils.Statistics()

        epoch = train_epochs[0]
        # for epoch in range(train_epochs):
        while epoch <= train_epochs[1]:

            # Pre-train with topics.
            # if train_topic_iter is not None:
            # self.train_topic(train_topic_iter, valid_topic_iter, epoch, topic_criterion, topic_optimizer)

            # training
            logger.info('Train conversations...')
            for i, batch in enumerate(train_iter):
                input_sentences, target_sentences = batch.conversation
                input_length, target_length = batch.length
                input_turns = batch.turn
                max_turn_length = input_turns.data.max()

                report_stats.n_src_words += input_length.sum().item()

                # self.optim.optimizer.zero_grad()
                self.model.zero_grad()

                scores = self.model(input_sentences, input_length, input_turns,
                                    target_sentences)

                batch_stats = self.loss_function.sharded_compute_loss(
                    scores, target_sentences, self.shard_size)
                # batch_stats = self.loss_function._masked_cross_entropy(scores, target_sentences, target_length)

                self.optim.step()
                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

                if i % 500 == -1 % 500:
                    report_stats.output(epoch, i + 1, len(train_iter),
                                        self.optim.learning_rate,
                                        total_stats.start_time)
                    # Visdom Logger
                    if self.vis_logger is not None:
                        self.vis_logger.line(
                            X=torch.Tensor([i + epoch * len(train_iter)]),
                            Y=torch.Tensor([report_stats.ppl()]),
                            win="Training_perplexity",
                            opts={"title": "Training Perplexity"},
                            update=None if epoch == 0 and i == 9 else "append")
                        self.vis_logger.line(
                            X=torch.Tensor([i + epoch * len(train_iter)]),
                            Y=torch.Tensor([report_stats.xent()]),
                            win="Training_xent",
                            opts={"title": "Training Xent"},
                            update=None if epoch == 0 and i == 9 else "append")
                        self.vis_logger.line(
                            X=torch.Tensor([i + epoch * len(train_iter)]),
                            Y=torch.Tensor([report_stats.accuracy()]),
                            win="Training_accuracy",
                            opts={"title": "Training Accuracy"},
                            update=None if epoch == 0 and i == 9 else "append")
                    # self.visdom_logger([i + epoch * len(train_iter)], [report_stats.ppl()],
                    #                    win="")
                    report_stats = mtdg.utils.Statistics()

            print('Train xent: %g' % total_stats.xent())
            print('Train perplexity: %g' % total_stats.ppl())
            print('Train accuracy: %g' % total_stats.accuracy())
            # self.vis_logger.line(X=[epoch],
            #                      Y=[total_stats.ppl()],
            #                      win="Train_perplexity", opts={"title": "Train Perplexity"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[total_stats.xent()],
            #                      win="Train_xent", opts={"title": "Train Xent"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[total_stats.accuracy()],
            #                      win="Train_accuracy", opts={"title": "Train Accuracy"},
            #                      update=None if epoch == 0 and i == 0 else "append")

            # validation
            # if epoch % valid_epochs:
            valid_stats = self.valid(valid_iter)
            print('Validation xent: %g' % valid_stats.xent())
            print('Validation perplexity: %g' % valid_stats.ppl())
            print('Validation accuracy: %g' % valid_stats.accuracy())
            # self.vis_logger.line(X=[epoch],
            #                      Y=[valid_stats.ppl()],
            #                      win="Validation_perplexity", opts={"title": "Validation Perplexity"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[valid_stats.xent()],
            #                      win="Validation_xent", opts={"title": "Validation Xent"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[valid_stats.accuracy()],
            #                      win="Validation_accuracy", opts={"title": "Validation Accuracy"},
            #                      update=None if epoch == 0 and i == 0 else "append")

            # test
            # test_stats = self.test(test_iter)
            # print('Test xent: %g' % test_stats.xent())
            # print('Test perplexity: %g' % test_stats.ppl())
            # print('Test accuracy: %g' % test_stats.accuracy())
            # self.vis_logger.line(X=[epoch],
            #                      Y=[test_stats.ppl()],
            #                      win="Test_perplexity", opts={"title": "Test Perplexity"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[test_stats.xent()],
            #                      win="Test_xent", opts={"title": "Test Xent"},
            #                      update=None if epoch == 0 and i == 0 else "append")
            # self.vis_logger.line(X=[epoch],
            #                      Y=[test_stats.accuracy()],
            #                      win="Test_accuracy", opts={"title": "Test Accuracy"},
            #                      update=None if epoch == 0 and i == 0 else "append")

            if self.vis_logger is not None:
                self.vis_logger.line(
                    X=torch.Tensor([[epoch, epoch]]),
                    Y=torch.Tensor([[total_stats.ppl(),
                                     valid_stats.ppl()]]),
                    win="ppl",
                    opts={"title": "Perplexity"},
                    update=None if epoch == 0 and i == 0 else "append")
                self.vis_logger.line(
                    X=torch.Tensor([[epoch, epoch]]),
                    Y=torch.Tensor([[total_stats.xent(),
                                     valid_stats.xent()]]),
                    win="xent",
                    opts={"title": "Xent"},
                    update=None if epoch == 0 and i == 0 else "append")
                self.vis_logger.line(
                    X=torch.Tensor([[epoch, epoch]]),
                    Y=torch.Tensor(
                        [[total_stats.accuracy(),
                          valid_stats.accuracy()]]),
                    win="accuracy",
                    opts={"title": "Accuracy"},
                    update=None if epoch == 0 and i == 0 else "append")

            total_stats = mtdg.utils.Statistics()
            # drop checkpoints
            self._maybe_save(epoch, valid_stats.ppl())
            epoch += 1

        return total_stats