def test_shape_on_random_data(self):
        set_seed(42)

        bs = 3
        src_len = 5
        tgt_len = 7

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=17,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        # decoder accepts vocabulary of schema vocab + pointer embeddings
        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=23,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        # logits are projected into schema vocab and combined with pointer scores
        max_pointer = src_len + 3
        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_pointer)

        x_enc = torch.randint(0, encoder_config.vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, decoder_config.vocab_size, size=(bs, tgt_len))

        out = model(input_ids=x_enc, decoder_input_ids=x_dec)

        # different encoders return different number of outputs
        # e.g. BERT returns two, but DistillBERT only one
        self.assertGreaterEqual(len(out), 4)

        schema_vocab = decoder_config.vocab_size - max_pointer

        combined_logits = out[0]
        expected_shape = (bs, tgt_len, schema_vocab + src_len)
        self.assertEqual(combined_logits.shape, expected_shape)

        decoder_hidden = out[1]
        expected_shape = (bs, tgt_len, decoder_config.hidden_size)
        self.assertEqual(decoder_hidden.shape, expected_shape)

        combined_logits = out[2]
        expected_shape = (bs, decoder_config.hidden_size)
        self.assertEqual(combined_logits.shape, expected_shape)

        encoder_hidden = out[3]
        expected_shape = (bs, src_len, encoder_config.hidden_size)
        self.assertEqual(encoder_hidden.shape, expected_shape)
    def setUp(self):
        utils.set_seed(3)
        src_tokenizer = transformers.AutoTokenizer.from_pretrained(
            "bert-base-cased")
        vocab = {
            "[",
            "]",
            "IN:",
            "SL:",
            "GET_DIRECTIONS",
            "DESTINATION",
            "DATE_TIME_DEPARTURE",
            "GET_ESTIMATED_ARRIVAL",
        }
        self.schema_tokenizer = TopSchemaTokenizer(vocab, src_tokenizer)

        self.model = EncoderDecoderWPointerModel.from_parameters(
            layers=2,
            hidden=32,
            heads=2,
            src_vocab_size=src_tokenizer.vocab_size,
            tgt_vocab_size=self.schema_tokenizer.vocab_size,
            max_src_len=17,
            dropout=0.1,
        )

        source_texts = [
            "Directions to Lowell",
            "Get directions to Mountain View",
        ]
        target_texts = [
            "[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]",
            "[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View]]",
        ]

        pairs = [
            self.schema_tokenizer.encode_pair(t, s)
            for t, s in zip(target_texts, source_texts)
        ]

        self.dataset = PointerDataset.from_pair_items(pairs)
        self.dataset.torchify()

        collator = Seq2SeqDataCollator(
            pad_id=self.schema_tokenizer.pad_token_id)
        dataloader = torch.utils.data.DataLoader(
            self.dataset, batch_size=2, collate_fn=collator.collate_batch)

        self.test_batch = next(iter(dataloader))

        self.module = PointerModule(
            model=self.model,
            schema_tokenizer=self.schema_tokenizer,
            train_dataset=self.dataset,
            valid_dataset=self.dataset,
            lr=1e-3,
        )
Ejemplo n.º 3
0
    def test_smoothing(self):
        # only checks that it does not fail
        set_seed(98)

        v_size = 43
        eps = 0.1

        preds = torch.randn(size=(7, 19, v_size)).view(-1, v_size)
        labels = torch.randint(43, size=(7, 19)).view(-1)
        mask = torch.ones_like(labels)

        ce2 = LabelSmoothedCrossEntropy(eps=eps)(preds, labels, mask)
Ejemplo n.º 4
0
    def test_no_smoothing(self):
        set_seed(98)

        v_size = 43

        preds = torch.randn(size=(7, 19, v_size)).view(-1, v_size)
        labels = torch.randint(43, size=(7, 19)).view(-1)
        mask = torch.ones_like(labels)

        ce1 = F.cross_entropy(preds, labels)
        ce2 = LabelSmoothedCrossEntropy(eps=0)(preds, labels, mask)

        self.assertTrue(torch.allclose(ce1, ce2))
    def test_shape_on_real_data_batched(self):
        set_seed(42)
        src_vocab_size = 17
        tgt_vocab_size = 23
        max_position = 7

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=src_vocab_size,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=tgt_vocab_size + max_position,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_position)

        # similar to real data
        src_seq = torch.LongTensor([[1, 6, 12, 15, 2, 0, 0],
                                    [1, 6, 12, 15, 5, 3, 2]])
        tgt_seq = torch.LongTensor([
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7, 0, 0],
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 13, 14, 7, 7],
        ])
        mask = torch.FloatTensor([[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1,
                                                          0]])

        combined_logits = model(input_ids=src_seq,
                                decoder_input_ids=tgt_seq,
                                pointer_mask=mask)[0]

        expected_shape = (2, tgt_seq.shape[1],
                          tgt_vocab_size + src_seq.shape[1])
        self.assertEqual(combined_logits.shape, expected_shape)
    def test_getitem(self):
        utils.set_seed(29)

        d_mask = 2
        src_tensors = [
            torch.tensor([2, 5, 4, 4, 2]),
            torch.tensor([1, 8, 2, 8, 5, 4, 2, 2, 5, 7]),
            torch.tensor([4, 6, 4, 2, 1, 2]),
        ]
        tgt_tensors = [
            torch.tensor([6, 7, 8, 7, 2, 2, 4, 8, 5]),
            torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1]),
            torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8]),
        ]
        tgt_masks = [(tgt_tensors[i] == d_mask).type(torch.FloatTensor)
                     for i in range(3)]

        expected_decoder_input_ids = [
            torch.tensor([6, 7, 8, 7, 2, 2, 4, 8]),
            torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2]),
            torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7]),
        ]
        expected_decoder_pointer_mask = [
            torch.tensor([0, 0, 0, 0, 1, 1, 0, 0]),
            torch.tensor([0, 1, 1, 0, 0, 0, 0, 0, 1, 1]),
            torch.tensor([0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0]),
        ]

        dataset = nsp.data.PointerDataset(src_tensors,
                                          tgt_tensors,
                                          target_pointer_masks=tgt_masks)

        item: InputDataClass = dataset[0]

        self.assertIsInstance(item, InputDataClass)
        self.assertIsInstance(item.input_ids, torch.LongTensor)
        self.assertIsInstance(item.decoder_input_ids, torch.LongTensor)
        self.assertIsInstance(item.labels, torch.LongTensor)

        self.assertTrue(torch.all(item.input_ids == src_tensors[0]))
        self.assertTrue(
            torch.all(item.decoder_input_ids == expected_decoder_input_ids[0]))
        self.assertTrue(
            torch.all(
                item.decoder_pointer_mask == expected_decoder_pointer_mask[0]))
    def test_shape_on_real_data(self):
        set_seed(42)
        src_vocab_size = 17
        tgt_vocab_size = 23
        max_position = 5

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=src_vocab_size,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=tgt_vocab_size + max_position,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_position)

        # similar to real data
        # e.g. '[CLS] Directions to Lowell [SEP]'
        src_seq = torch.LongTensor([[1, 6, 12, 15, 2]])
        # e.g. '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]'
        tgt_seq = torch.LongTensor([[8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7]])
        mask = torch.FloatTensor([[0, 1, 1, 1, 0]])

        combined_logits = model(input_ids=src_seq,
                                decoder_input_ids=tgt_seq,
                                pointer_mask=mask)[0]

        expected_shape = (1, tgt_seq.shape[1],
                          tgt_vocab_size + src_seq.shape[1])
        self.assertEqual(combined_logits.shape, expected_shape)
    def test_getitem_inference(self):
        utils.set_seed(29)

        src_tensors = [
            torch.randint(0,
                          10,
                          size=(random.randint(5, 13), ),
                          dtype=torch.int64) for _ in range(10)
        ]

        dataset = nsp.data.PointerDataset(src_tensors)

        item = dataset[0]

        self.assertIsInstance(item, InputDataClass)
        self.assertIsInstance(item.input_ids, torch.LongTensor)

        self.assertIsNone(item.decoder_input_ids)
        self.assertIsNone(item.labels)
Ejemplo n.º 9
0
def main(args):
    utils.set_seed(args.seed)

    if os.path.exists(args.output_dir):
        raise ValueError(f"output_dir {args.output_dir} already exists")

    wandb_logger = pl.loggers.WandbLogger(project=args.wandb_project,
                                          tags=args.tags)
    wandb_logger.log_hyperparams(args)

    logger.info(
        f"Starting finetuning with args: \n{pprint.pformat(vars(args))}")

    logger.info("Loading tokenizers")
    schema_tokenizer = load_tokenizer(args.model_dir, args.data_dir)

    logger.info("Loading data")
    train_dataset, eval_dataset = load_data(
        path=path_join(args.data_dir, "data.pkl"),
        new_data_amount=args.new_data_amount,
        old_data_amount=args.old_data_amount,
        old_data_sampling_method=args.old_data_sampling_method,
        wandb_logger=wandb_logger,
    )
    train_args = cli_utils.load_saved_args(
        path_join(args.model_dir, "args.toml"))

    # NOTE: do not log metrics as hyperparameters
    wandb_logger.log_hyperparams(
        {"pretrain_" + k: v
         for k, v in train_args.items() if k != "metrics"})

    wandb_logger.log_hyperparams({"num_total_data": len(train_dataset)})

    logger.info("Loading model")
    model = load_model(
        model_dir=args.model_dir,
        dropout=args.dropout,
        move_norm=args.move_norm,
        move_norm_p=args.move_norm_p,
        label_smoothing=args.label_smoothing,
        weight_consolidation=args.weight_consolidation,
    )

    logger.info("Preparing for training")

    lightning_module = load_lightning_module(
        checkpoint_path=train_args["pl_checkpoint_path"],
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        schema_tokenizer=schema_tokenizer,
        args=args,
        wandb_logger=wandb_logger,
    )

    # override some of the parameters saved in the Trainer
    checkpoint_path = modify_checkpoint_for_retraining(
        train_args["pl_checkpoint_path"],
        args.output_dir,
        args.lr,
        args.weight_decay,
        args.no_opt_state,
        lightning_module,
    )
    trainer = load_trainer(checkpoint_path, args, wandb_logger)

    # get evaluation metrics of the initial model
    pretrain_metrics = train_args["metrics"]
    _first_step_metrics = {
        "epoch": -1,
        "global_step": -1,
        **pretrain_metrics["means"],
        **pretrain_metrics["stdevs"],
    }
    wandb_logger.log_metrics(_first_step_metrics, step=-1)

    wandb_logger.watch(lightning_module,
                       log="all",
                       log_freq=lightning_module.log_every)

    # --- FIT

    # call .test to load optimizer state
    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        trainer.test(lightning_module,
                     lightning_module.val_dataloader(subset_size=0.01))

    cli_utils.check_config(lightning_module, trainer, args, strict=True)
    if model.config.move_norm is not None:
        assert torch.allclose(model.get_move_norm(),
                              torch.zeros(1, device=model.device))

    trainer.fit(lightning_module)

    cli_utils.check_config(lightning_module, trainer, args, strict=True)

    with open(path_join(args.output_dir, "args.toml"), "w") as f:
        args_dict = {"version": nsp.SAVE_FORMAT_VERSION, **vars(args)}
        toml.dump(args_dict, f)

    logger.info("Training finished!")

    best_model_checkpoint = trainer.checkpoint_callback.last_checkpoint_path
    if args.average_checkpoints:
        average_checkpoints(model,
                            args,
                            save_to=os.path.dirname(best_model_checkpoint))

    final_metrics, description = cli_utils.evaluate_model(
        best_model_checkpoint,
        schema_tokenizer,
        eval_dataset,
        prefix="eval",
        max_len=train_args.get("max_tgt_len", 68),  # 68 is max_tgt_len for TOP
    )

    logger.info(description)
    wandb_logger.log_metrics({
        **final_metrics["means"],
        **final_metrics["stdevs"]
    })

    # Compute RI and RD
    class_weights = eval_dataset.get_class_frequencies(schema_tokenizer)
    class_weights = {
        f"cls/eval_{cls}_tree_path_f1": p
        for cls, p in class_weights.items()
    }

    finetuning_metrics = cli_utils.evaluate_finetuning_procedure(
        pretrain_metrics, final_metrics, class_weights)
    wandb_logger.log_metrics(finetuning_metrics)

    # Compute RI and RD with very small outliers stuff
    finetuning_metrics0 = cli_utils.evaluate_finetuning_procedure(
        pretrain_metrics, final_metrics, class_weights, sigma=0.1)
    finetuning_metrics0 = {
        k + "_0.1": v
        for k, v in finetuning_metrics0.items()
    }
    wandb_logger.log_metrics(finetuning_metrics0)

    wandb_logger.close()

    if args.clean_output:
        shutil.rmtree(args.output_dir)
Ejemplo n.º 10
0
def main(args):
    utils.set_seed(args.seed)

    if os.path.exists(args.output_dir):
        raise ValueError(f"output_dir {args.output_dir} already exists")

    # File structure:
    # that's text\tthat 's text\t[IN:UNSUPPORTED that 's text]
    train_path = path_join(path_join(args.data, "train.tsv"))
    train_data = pd.read_table(train_path, names=["text", "tokens", "schema"])
    full_train_data_size = len(
        train_data)  # used to check the train/finetune split
    finetune_data, finetune_path = None, None

    schema_vocab = reduce(set.union,
                          map(utils.get_vocab_top_schema, train_data.schema))

    if args.split_amount is not None:
        # finetune part is not used by train script, but used by retrain script
        logger.info("Splitting the training dataset")
        train_data, finetune_data = train_finetune_split(
            train_data, schema_vocab, args.split_amount, args.split_class)

        os.makedirs(args.output_dir)

        finetune_path = path_join(args.output_dir, "finetune.tsv")
        logger.info(f"Saving the finetune_data to {finetune_path}")
        finetune_data.to_csv(finetune_path,
                             sep="\t",
                             index=False,
                             header=False)

        train_path = path_join(args.output_dir, "train.tsv")
        logger.info(f"Saving the modified training set to {train_path}")
        train_data.to_csv(train_path, sep="\t", index=False, header=False)

    logger.info("Getting schema vocabulary")

    if args.split_amount is not None:
        finetune_schema_vocab = reduce(
            set.union, map(utils.get_vocab_top_schema, finetune_data.schema))
        vocab_delta = finetune_schema_vocab - schema_vocab
        if len(vocab_delta) > 0:
            logger.warning(
                f"Finetuning subset contains vocabulary elements not from the training subset"
            )
            logger.warning(f"New elements: {', '.join(vocab_delta)}")

    logger.info(f"Schema vocabulary size: {len(schema_vocab)}")

    logger.info("Building tokenizers")
    text_tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.text_tokenizer, use_fast=True)
    schema_tokenizer = nsp.TopSchemaTokenizer(schema_vocab, text_tokenizer)

    logger.info("Tokenizing train dataset")
    train_dataset = nsp.data.make_dataset(train_path, schema_tokenizer)

    logger.info("Tokenizing validation and test datasets")
    valid_dataset = nsp.data.make_dataset(path_join(args.data, "eval.tsv"),
                                          schema_tokenizer)
    test_dataset = nsp.data.make_dataset(path_join(args.data, "test.tsv"),
                                         schema_tokenizer)

    finetune_dataset = None
    if args.split_amount is not None:
        logger.info("Tokenizing finetune set")
        finetune_dataset = nsp.data.make_dataset(finetune_path,
                                                 schema_tokenizer)

        logger.info(f"Original train set size: {full_train_data_size}")
        logger.info(f"Reduced  train set size: {len(train_dataset)}")
        logger.info(f"Finetune       set size: {len(finetune_dataset)}")

        train_finetune_data_size = len(train_dataset) + len(finetune_dataset)
        if train_finetune_data_size != full_train_data_size:
            raise RuntimeError(
                f"{train_finetune_data_size} != {full_train_data_size}")

    logger.info(f"Saving config, data and tokenizer to {args.output_dir}")
    os.makedirs(args.output_dir, exist_ok=True)

    with open(path_join(args.output_dir, "args.toml"), "w") as f:
        args_dict = {"version": nsp.SAVE_FORMAT_VERSION, **vars(args)}
        toml.dump(args_dict, f)

    # text tokenizer is saved along with schema_tokenizer
    model_type = None
    if not os.path.exists(args.text_tokenizer):
        model_type = utils.get_model_type(args.text_tokenizer)

    schema_tokenizer.save(path_join(args.output_dir, "tokenizer"),
                          encoder_model_type=model_type)

    data_state = {
        "train_dataset": train_dataset,
        "valid_dataset": valid_dataset,
        "test_dataset": test_dataset,
        "finetune_dataset": finetune_dataset,
        "version": nsp.SAVE_FORMAT_VERSION,
    }

    torch.save(data_state, path_join(args.output_dir, "data.pkl"))
    def test_collate_batch_shapes(self):
        utils.set_seed(29)

        bs = 3
        e_pad = 0
        d_pad = 1
        d_mask = 2
        src_tensors = [
            torch.tensor([2, 5, 4, 4, 2]),
            torch.tensor([1, 8, 2, 8, 5, 4, 2, 2, 5, 7]),
            torch.tensor([4, 6, 4, 2, 1, 2]),
        ]
        tgt_tensors = [
            torch.tensor([6, 7, 8, 7, 2, 2, 4, 8, 5]),
            torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1]),
            torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8]),
        ]
        tgt_masks = [(tgt_tensors[i] == d_mask).type(torch.FloatTensor)
                     for i in range(3)]

        expected_input_ids = torch.tensor([
            [2, 5, 4, 4, 2, 0, 0, 0, 0, 0],
            [1, 8, 2, 8, 5, 4, 2, 2, 5, 7],
            [4, 6, 4, 2, 1, 2, 0, 0, 0, 0],
        ])
        expected_decoder_input_ids = torch.tensor([
            [6, 7, 8, 7, 2, 2, 4, 8, 5, 1, 1, 1, 1, 1],
            [5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1, 1, 1, 1],
            [8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8],
        ])
        expected_decoder_pointer_mask = torch.tensor([
            [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
            [0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        ])

        assert d_mask in tgt_tensors[0]

        examples = [
            InputDataClass(
                input_ids=src_tensors[i],
                decoder_input_ids=tgt_tensors[i],
                decoder_pointer_mask=tgt_masks[i],
                labels=tgt_tensors[i],
            ) for i in range(bs)
        ]

        collator = nsp.data.Seq2SeqDataCollator(e_pad, d_pad)

        batch = collator.collate_batch(examples)

        self.assertEqual(batch["input_ids"].shape, (bs, 10))
        self.assertIsInstance(batch["input_ids"], torch.LongTensor)
        self.assertEqual(batch["input_ids"][0, -1], e_pad)
        self.assertTrue(torch.all(batch["input_ids"] == expected_input_ids))

        self.assertEqual(batch["decoder_input_ids"].shape, (bs, 14))
        self.assertIsInstance(batch["decoder_input_ids"], torch.LongTensor)
        self.assertEqual(batch["decoder_input_ids"][0, -1], d_pad)
        self.assertTrue(
            torch.all(
                batch["decoder_input_ids"] == expected_decoder_input_ids))

        self.assertEqual(batch["labels"].shape, (bs, 14))
        self.assertIsInstance(batch["labels"], torch.LongTensor)

        self.assertEqual(batch["decoder_pointer_mask"].shape, (bs, 14))
        self.assertIsInstance(batch["decoder_pointer_mask"], torch.FloatTensor)
        _mask = batch["decoder_pointer_mask"]
        self.assertTrue(((_mask == 0) | (_mask == 1)).all())
        self.assertTrue(
            torch.all(batch["decoder_pointer_mask"] ==
                      expected_decoder_pointer_mask))
def main(args):
    utils.set_seed(args.seed)

    wandb_logger = pl.loggers.WandbLogger(project=args.wandb_project,
                                          tags=args.tags)
    wandb_logger.log_hyperparams(args)

    logger.info(f"Starting training with args: \n{pprint.pformat(vars(args))}")

    if os.path.exists(args.output_dir):
        raise ValueError(f"output_dir {args.output_dir} already exists")

    logger.info("Loading tokenizers")
    schema_tokenizer = nsp.TopSchemaTokenizer.load(
        path_join(args.data_dir, "tokenizer"))

    logger.info("Loading data")
    datasets = torch.load(path_join(args.data_dir, "data.pkl"))
    train_dataset: nsp.PointerDataset = datasets["train_dataset"]
    eval_dataset: nsp.PointerDataset = datasets["valid_dataset"]

    wandb_logger.log_hyperparams({"num_data": len(train_dataset)})

    max_src_len, max_tgt_len = train_dataset.get_max_len()

    try:
        preprocess_args = cli_utils.load_saved_args(
            path_join(args.data_dir, "args.toml"))
        wandb.config.update(
            {"preprocess_" + k: v
             for k, v in preprocess_args.items()})

    except FileNotFoundError:
        preprocess_args = None

    logger.info("Creating a model")
    model = make_model(schema_tokenizer, max_src_len, args, preprocess_args)

    logger.info("Preparing for training")

    lightning_module = make_lightning_module(model, schema_tokenizer,
                                             train_dataset, eval_dataset,
                                             max_tgt_len, args, wandb_logger)
    trainer = make_trainer(args, wandb_logger)

    # --- FIT
    cli_utils.check_config(lightning_module, trainer, args)

    trainer.fit(lightning_module)

    if args.track_grad_square:
        lightning_module.model.register_weight_consolidation_buffer()

    logger.info("Training finished!")

    # top_k == 1 --> the last checkpoint is the best model
    assert trainer.checkpoint_callback.save_top_k == 1
    logger.info(f"Loading and evaluating the best model")

    final_metrics, description = cli_utils.evaluate_model(
        trainer.checkpoint_callback.last_checkpoint_path,
        schema_tokenizer,
        eval_dataset,
        prefix="eval",
        max_len=max_tgt_len,
    )

    with open(path_join(args.output_dir, "args.toml"), "w") as f:
        args_dict = {
            "version": nsp.SAVE_FORMAT_VERSION,
            "pl_checkpoint_path":
            trainer.checkpoint_callback.last_checkpoint_path,
            "metrics": final_metrics,
            "max_src_len": max_src_len,
            "max_tgt_len": max_tgt_len,
            **vars(args),
        }
        toml.dump(args_dict, f)

    logger.info(description)
    wandb_logger.log_metrics({
        **final_metrics["means"],
        **final_metrics["stdevs"]
    })
    wandb_logger.close()