示例#1
0
def main(args):

    # use cuda if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # create directory for saving models if it doesn't already exist
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    SRC = torch.load(os.path.join(args.nmt_data_path, "src_vocab.pt"))
    TRG = torch.load(os.path.join(args.data_path, "trg_vocab.pt"))

    # gather parameters from the vocabulary
    input_dim = len(SRC.vocab)
    output_dim = len(TRG.vocab)
    pad_idx = SRC.vocab.stoi[SRC.pad_token]

    # create lazydataset and data loader
    train_path = os.path.join(args.data_path, "train.tsv")
    training_set = LazyDataset(train_path, SRC, TRG, "tagging")

    train_batch_sampler = BucketBatchSampler(train_path, args.batch_size)
    # number of batches comes from the sampler, not the iterator
    num_batches = train_batch_sampler.num_batches

    # build dictionary of parameters for the Dataloader
    train_loader_params = {
        # since bucket sampler returns batch, batch_size is 1
        "batch_size": 1,
        # sort_batch reverse sorts for pack_pad_seq
        "collate_fn": sort_batch,
        "batch_sampler": train_batch_sampler,
        "num_workers": args.num_workers,
        "shuffle": args.shuffle,
        "pin_memory": True,
        "drop_last": False,
    }

    train_iterator = torch.utils.data.DataLoader(training_set,
                                                 **train_loader_params)

    # load pretrained-model
    prev_state_dict = torch.load(args.pretrained_model,
                                 map_location=torch.device("cpu"))
    enc_dropout = prev_state_dict["dropout"]
    prev_state_dict = prev_state_dict["model_state_dict"]

    # gather parameters except dec_hid_dim since tagger gets this from args
    prev_param_dict = get_prev_params(prev_state_dict)

    new_state_dict = make_encoder_dict(prev_state_dict)

    if args.repr_layer == "embedding":
        new_dict = {}
        # add embedding layer
        new_dict["enc_embedding.weight"] = new_state_dict[
            "enc_embedding.weight"]
        # replace state dict with new dict
        new_state_dict = new_dict
    elif args.repr_layer == "rnn1":
        new_dict = {}
        # add embedding layer
        new_dict["enc_embedding.weight"] = new_state_dict[
            "enc_embedding.weight"]
        # add first layer weights and bias
        for k, v in new_state_dict.items():
            if "l0" in k:
                new_dict[k] = v
        # replace state dict with new dict
        new_state_dict = new_dict

    model = Tagger(
        new_state_dict=new_state_dict,
        input_dim=input_dim,
        emb_dim=prev_param_dict["emb_dim"],
        enc_hid_dim=prev_param_dict["enc_hid_dim"],
        dec_hid_dim=args.hid_dim,
        output_dim=output_dim,
        enc_layers=prev_param_dict["enc_layers"],
        dec_layers=args.n_layers,
        enc_dropout=enc_dropout,
        dec_dropout=args.dropout,
        bidirectional=prev_param_dict["bidirectional"],
        pad_idx=pad_idx,
        repr_layer=args.repr_layer,
    ).to(device)

    # optionally randomly initialize weights
    if args.random_init:
        model.apply(random_init_weights)

    print(model)
    print(f"The model has {count_parameters(model):,} trainable parameters")

    optimizer = make_muliti_optim(model.named_parameters(), args.learning_rate)

    if args.unfreeze_encoder == False:
        for param in model.encoder.parameters():
            param.requires_grad = False

    SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
    TRG_PAD_IDX = len(TRG.vocab) + 1
    criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

    best_valid_loss = float("inf")

    # training
    batch_history = []
    epoch_history = []
    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train_loss, batch_loss = train_model(
            model=model,
            iterator=train_iterator,
            task="tagging",
            optimizer=optimizer,
            criterion=criterion,
            clip=args.clip,
            device=device,
            epoch=epoch,
            start_time=start_time,
            save_path=args.save_path,
            pad_indices=(SRC_PAD_IDX, TRG_PAD_IDX),
            dropout=(enc_dropout, args.dropout),
            checkpoint=args.checkpoint,
            repr_layer=args.repr_layer,
            num_batches=num_batches,
        )
        batch_history += batch_loss
        epoch_history.append(train_loss)
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        model_filename = os.path.join(args.save_path,
                                      f"model_epoch_{epoch}.pt")
        adam, sparse_adam = optimizer.return_optimizers()
        if not args.only_best:
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "adam_state_dict": adam.state_dict(),
                    "sparse_adam_state_dict": sparse_adam.state_dict(),
                    "loss": valid_loss,
                    "dropout": (enc_dropout, args.dropout),
                    "repr_layer": args.repr_layer,
                },
                model_filename,
            )

        # optionally validate
        if not args.skip_validate:
            valid_path = os.path.join(args.data_path, "valid.tsv")
            valid_set = LazyDataset(valid_path, SRC, TRG, "tagging")
            valid_batch_sampler = BucketBatchSampler(valid_path,
                                                     args.batch_size)
            # number of batches comes from the sampler, not the iterator
            valid_num_batches = valid_batch_sampler.num_batches
            valid_loader_params = {
                # since bucket sampler returns batch, batch_size is 1
                "batch_size": 1,
                # sort_batch reverse sorts for pack_pad_seq
                "collate_fn": sort_batch,
                "batch_sampler": valid_batch_sampler,
                "num_workers": args.num_workers,
                "shuffle": args.shuffle,
                "pin_memory": True,
                "drop_last": False,
            }

            valid_iterator = torch.utils.data.DataLoader(
                valid_set, **valid_loader_params)

            valid_loss = evaluate_model(
                model,
                valid_iterator,
                num_batches=valid_num_batches,
                optimizer=optimizer,
                criterion=criterion,
                task="tagging",
                device=device,
                pad_indices=(SRC_PAD_IDX, TRG_PAD_IDX),
            )

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss

                best_filename = os.path.join(args.save_path, f"best_model.pt")
                torch.save(
                    {
                        "epoch": epoch,
                        "model_state_dict": model.state_dict(),
                        "adam_state_dict": adam.state_dict(),
                        "sparse_adam_state_dict": sparse_adam.state_dict(),
                        "loss": valid_loss,
                        "dropout": (enc_dropout, args.dropout),
                        "repr_layer": args.repr_layer,
                    },
                    best_filename,
                )

            print(f"Epoch: {epoch:02} | Time: {epoch_mins}m {epoch_secs}s")
            print(
                f"\t Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}"
            )
            print(
                f"\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}"
            )

        else:
            print(f"Epoch: {epoch:02} | Time: {epoch_mins}m {epoch_secs}s")
            print(
                f"\t Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}"
            )

    if args.loss_plot:
        make_loss_plot(batch_history, args.save_path, args.epochs)