Beispiel #1
0
def init_optimizer(args, model, opt_type="noam"):
    dim_input = args.dim_input
    warmup = args.warmup
    lr = args.lr

    if opt_type == "noam":
        opt = NoamOpt(dim_input,
                      args.k_lr,
                      warmup,
                      torch.optim.Adam(model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-9),
                      min_lr=args.min_lr)
    elif opt_type == "sgd":
        opt = AnnealingOpt(
            lr, args.lr_anneal,
            torch.optim.SGD(model.parameters(),
                            lr=lr,
                            momentum=args.momentum,
                            nesterov=True))
    else:
        opt = None
        print("Optimizer is not defined")

    return opt
    train_iter = MyIterator(train,
                            batch_size=BATCH_SIZE,
                            device=device,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True)
    valid_iter = MyIterator(val,
                            batch_size=BATCH_SIZE,
                            device=device,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=False)

    model_opt = NoamOpt(
        model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))

    for epoch in range(N_EPOCH):
        model.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter), model,
                  LossCompute(model.generator, criterion, model_opt))

        model.eval()
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model,
                         LossCompute(model.generator, criterion, None))

    torch.save(model.state_dict(), SAVE_PATH)
def train(args):

    set_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')

    batch_size = args.batch_size
    max_length = args.max_length
    mtl = args.mtl

    learning_rate = 0.0005
    if not args.learning_rate:
        learning_rate = args.learning_rate

    if len(args.train_source) != len(args.train_target):
        print("Error.Number of inputs in train are not the same")
        return

    if len(args.dev_source) != len(args.dev_target):
        print("Error: Number of inputs in dev are not the same")
        return

    if not args.tie_embeddings:
        print("Building Encoder vocabulary")
        source_vocabs = build_vocab(args.train_source, args.src_vocab, save_dir=args.save_dir)
        print("Building Decoder vocabulary")
        target_vocabs = build_vocab(args.train_target, args.tgt_vocab, mtl=mtl, name ="tgt", save_dir=args.save_dir)
    else:
        print("Building Share vocabulary")
        source_vocabs = build_vocab(args.train_source + args.train_target, args.src_vocab, name="tied", save_dir=args.save_dir)
        if mtl:
            target_vocabs = [source_vocabs[0] for _ in range(len(args.train_target))]
        else:
            target_vocabs = source_vocabs
    print("Number of source vocabularies:", len(source_vocabs))
    print("Number of target vocabularies:", len(target_vocabs))

    save_params(args, args.save_dir + "args.json")

    # source_vocabs, target_vocabs = build_vocab(args.train_source, args.train_target, mtl=mtl)

    print("Building training set and dataloaders")
    train_loaders = build_dataset(args.train_source, args.train_target, batch_size, \
            source_vocabs=source_vocabs, target_vocabs=target_vocabs, shuffle=True, mtl=mtl, max_length=max_length)
    for train_loader in train_loaders:
        print(f'Train - {len(train_loader):d} batches with size: {batch_size:d}')

    print("Building dev set and dataloaders")
    dev_loaders = build_dataset(args.dev_source, args.dev_target, batch_size, \
            source_vocabs=source_vocabs, target_vocabs=target_vocabs, mtl=mtl, max_length=max_length)
    for dev_loader in dev_loaders:
        print(f'Dev - {len(dev_loader):d} batches with size: {batch_size:d}')

    if args.model is not None:
        print("Loading the encoder from an external model...")
        multitask_model = load_model(args, source_vocabs, target_vocabs, device, max_length)
    else:
        print("Building model")
        multitask_model = build_model(args, source_vocabs, target_vocabs, device, max_length)

    print(f'The Transformer has {count_parameters(multitask_model):,} trainable parameters')
    print(f'The Encoder has {count_parameters(multitask_model.encoder):,} trainable parameters')
    for index, decoder in enumerate(multitask_model.decoders):
        print(f'The Decoder {index+1} has {count_parameters(decoder):,} trainable parameters')


    # Defining CrossEntropyLoss as default
    #criterion = nn.CrossEntropyLoss(ignore_index = constants.PAD_IDX)
    criterions = [LabelSmoothing(size=target_vocab.len(), padding_idx=constants.PAD_IDX, smoothing=0.1) \
                                        for target_vocab in target_vocabs]

    # Default optimizer
    optimizer = torch.optim.Adam(multitask_model.parameters(), lr = learning_rate, betas=(0.9, 0.98), eps=1e-09)
    model_opts = [NoamOpt(args.hidden_size, args.warmup_steps, optimizer) for _ in target_vocabs]

    task_id = 0
    print_loss_total = 0  # Reset every print_every

    n_tasks = len(train_loaders)
    best_valid_loss = [float(0) for _ in range(n_tasks)]

    if not args.translate:
        print("Start training...")
        patience = 30
        if not args.patience:
            patience = args.patience

        if n_tasks > 1:
            print("Patience wont be taking into account in Multitask learning")

        for _iter in range(1, args.steps + 1):

            train_loss = train_step(multitask_model, train_loaders[task_id], \
                       LossCompute(criterions[task_id], model_opts[task_id]), device, task_id = task_id)

            print_loss_total += train_loss

            if _iter % args.print_every == 0:
                print_loss_avg = print_loss_total / args.print_every
                print_loss_total = 0
                print(f'Task: {task_id:d} | Step: {_iter:d} | Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')


            if _iter % args.eval_steps == 0:
                print("Evaluating...")
                accuracies = run_evaluation(multitask_model, source_vocabs[0], target_vocabs, device, args.beam_size, args.eval, args.eval_ref, max_length)
                accuracy = round(accuracies[task_id], 3)
                valid_loss = evaluate(multitask_model, dev_loaders[task_id], LossCompute(criterions[task_id], None), \
                                device, task_id=task_id)
                print(f'Task: {task_id:d} | Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} | Acc. {accuracy:.3f}')
                if accuracy > best_valid_loss[task_id]:
                    print(f'The accuracy increased from {best_valid_loss[task_id]:.3f} to {accuracy:.3f} in the task {task_id}... saving checkpoint')
                    patience = 30
                    best_valid_loss[task_id] = accuracy
                    torch.save(multitask_model.state_dict(), args.save_dir + 'model.pt')
                    print("Saved model.pt")
                else:
                    if n_tasks == 1:
                        if patience == 0:
                            break
                        else:
                            patience -= 1

                if n_tasks > 1:
                    print("Changing to the next task ...")
                    task_id = (0 if task_id == n_tasks - 1 else task_id + 1)

    try:
        multitask_model.load_state_dict(torch.load(args.save_dir + 'model.pt'))
    except:
        print(f'There is no model in the following path {args.save_dir}')
        return

    print("Evaluating and testing")
    run_translate(multitask_model, source_vocabs[0], target_vocabs, args.save_dir, device, args.beam_size, args.eval, max_length=max_length)
    run_translate(multitask_model, source_vocabs[0], target_vocabs, args.save_dir, device, args.beam_size, args.test, max_length=max_length)