Exemple #1
0
def run():

    # load dataset
    src_data_loader = get_data_loader(params.src_dataset)
    tgt_data_loader = get_data_loader(params.tgt_dataset)

    # load models
    src_encoder = init_model(net=LeNetEncoder(),
                             restore=params.src_encoder_restore)
    tgt_encoder = init_model(net=LeNetEncoder(),
                             restore=params.tgt_encoder_restore)
    critic = init_model(Discriminator(input_dims=params.d_input_dims,
                                      hidden_dims=params.d_hidden_dims,
                                      output_dims=params.d_output_dims),
                        restore=params.d_model_restore)

    # Adapt target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    im, _ = next(iter(tgt_data_loader))
    summary(tgt_encoder, input_size=im[0].size())
    print(">>> Critic <<<")
    print(critic)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    # Train target
    if not (tgt_encoder.restored and critic.restored
            and params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
                                src_data_loader, tgt_data_loader)
Exemple #2
0
def office():
    init_random_seed(params.manual_seed)


    # load dataset
    src_data_loader = get_data_loader(params.src_dataset)
    src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
    tgt_data_loader = get_data_loader(params.tgt_dataset)
    tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)

    # load models
    src_encoder = init_model(net=LeNetEncoder(),
                             restore=params.src_encoder_restore)
    src_classifier = init_model(net=LeNetClassifier(),
                                restore=params.src_classifier_restore)
    tgt_encoder = init_model(net=LeNetEncoder(),
                             restore=params.tgt_encoder_restore)
    critic = init_model(Discriminator(input_dims=params.d_input_dims,
                                      hidden_dims=params.d_hidden_dims,
                                      output_dims=params.d_output_dims),
                        restore=params.d_model_restore)


    if not (src_encoder.restored and src_classifier.restored and
            params.src_model_trained):
        src_encoder, src_classifier = train_src(
            src_encoder, src_classifier, src_data_loader)

    # eval source model
    # print("=== Evaluating classifier for source domain ===")
    # eval_src(src_encoder, src_classifier, src_data_loader_eval)

    # train target encoder by GAN

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored and
            params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
                                src_data_loader, tgt_data_loader)

    # eval target encoder on test set of target dataset
    print(">>> domain adaption <<<")
    acc = eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
    return acc
Exemple #3
0
def experiments(exp):

    #print(exp, case, affine, num_epochs)

    # init random seed
    #params.d_learning_rate = lr_d
    #params.c_learning_rate = lr_c
    init_random_seed(params.manual_seed)

    # load dataset
    src_dataset, tgt_dataset = exp.split('_')
    src_data_loader = get_data_loader(src_dataset)
    src_data_loader_eval = get_data_loader(src_dataset, train=False)

    tgt_data_loader = get_data_loader(tgt_dataset)
    tgt_data_loader_eval = get_data_loader(tgt_dataset, train=False)

    # load models
    src_encoder = init_model(net=LeNetEncoder(),
                             restore=params.src_encoder_restore,
                             exp=exp)
    src_classifier = init_model(net=LeNetClassifier(),
                                restore=params.src_classifier_restore,
                                exp=exp)
    tgt_encoder = init_model(net=LeNetEncoder(),
                             restore=params.tgt_encoder_restore,
                             exp=exp)
    critic = init_model(Discriminator(input_dims=params.d_input_dims,
                                      hidden_dims=params.d_hidden_dims,
                                      output_dims=params.d_output_dims),
                        exp=exp,
                        restore=params.d_model_restore)

    # train source model
    print("=== Training classifier for source domain ===")
    print(">>> Source Encoder <<<")
    print(src_encoder)
    print(">>> Source Classifier <<<")
    print(src_classifier)

    if not (src_encoder.restored and src_classifier.restored
            and params.src_model_trained):
        src_encoder, src_classifier = train_src(exp, src_encoder,
                                                src_classifier,
                                                src_data_loader,
                                                src_data_loader_eval)

    # eval source model
    print("=== Evaluating classifier for source domain ===")
    evaluation(src_encoder, src_classifier, src_data_loader_eval)

    # train target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Critic <<<")
    print(critic)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored
            and params.tgt_model_trained):
        tgt_encoder = train_tgt(exp, src_encoder, tgt_encoder, critic,
                                src_classifier, src_data_loader,
                                tgt_data_loader, tgt_data_loader_eval)

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    evaluation(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    evaluation(tgt_encoder, src_classifier, tgt_data_loader_eval)
Exemple #4
0
            and params.src_model_trained):
        src_encoder, src_classifier = train_src(src_encoder, src_classifier,
                                                src_data_loader)

    # eval source model
    # print("=== Evaluating classifier for source domain ===")
    # eval_src(src_encoder, src_classifier, src_data_loader_eval)

    # train target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Critic <<<")
    print(critic)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored
            and params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
                                src_data_loader, tgt_data_loader)

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    # print(">>> source only <<<")
    # eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
Exemple #5
0
    if not (src_encoder.restored and src_classifier.restored
            and params.src_model_trained):
        src_encoder, src_classifier = train_src(src_encoder, src_classifier,
                                                src_data_loader, params)

    # eval source model
    print("=== Evaluating classifier for source domain ===")
    eval(src_encoder, src_classifier, src_data_loader)
    print("=== Evaluating classifier for target domain ===")
    eval(src_encoder, src_classifier, tgt_data_loader)

    # train target encoder by GAN
    print("=== Training encoder for target domain ===")

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored
            and params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder, src_classifier, tgt_encoder,
                                critic, src_data_loader, tgt_data_loader,
                                params)

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    eval(src_encoder, src_classifier, tgt_data_loader)
    print(">>> domain adaption <<<")
    eval(tgt_encoder, src_classifier, tgt_data_loader)
Exemple #6
0
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Critic <<<")
    print(critic)
    print(">>> Generator <<<")
    print(generator)
    print(">>> Discriminator <<<")
    print(discriminator)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored
            and cfg.tgt_model_trained):
        tgt_encoder, tgt_classifier = train_tgt(
            src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader,
            src_classifier, tgt_classifier, tgt_data_loader_eval, generator,
            discriminator, Saver, logger)

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    eval_func(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    tgt_classifier = init_model(net=SythnetClassifier(nf=cfg.d_input_dims,
                                                      ncls=cfg.ncls),
                                restore=cfg.tgt_classifier_restore)
    eval_func(tgt_encoder, tgt_classifier, tgt_data_loader_eval)
    # Train target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Discriminator <<<")
    print(discriminator)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        print(
            "[main.py] INFO | No trained target encoder found, initialising target encoder with trained source encoder weights.."
        )
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and discriminator.restored
            and params.tgt_model_trained):
        print(
            "[main.py] INFO | No trained target encoder found, beginning adverserial training.."
        )
        tgt_encoder = train_tgt(src_encoder, tgt_encoder, discriminator,
                                src_data_loader, tgt_data_loader,
                                src_classifier, tgt_data_loader_eval)
        # src_data_loader, tgt_data_loader_eval,src_classifier,tgt_data_loader_eval)

    # Eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    _ = eval_src(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    _ = eval_src(tgt_encoder, src_classifier, tgt_data_loader_eval)
            src_encoder,
            src_classifier,
            src_data_loader, dataset_name="EMOTION")

    # eval source model
    print("=== Evaluating classifier for source domain ===")
    eval_src(src_encoder, src_classifier, src_data_loader_eval)

    # train target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Critic <<<")
    print(critic)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored and
            params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
                                src_data_loader, tgt_data_loader, dataset_name='CONFLICT')

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
Exemple #9
0
    # train target encoder by GAN
    print("=== Training encoder for target domain ===")
    print(">>> Target Encoder <<<")
    print(tgt_encoder)
    print(">>> Critic <<<")
    print(critic)

    # init weights of target encoder with those of source encoder
    if not tgt_encoder.restored:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.restored and critic.restored
            and params.tgt_model_trained):
        tgt_encoder = train_tgt(src_encoder,
                                tgt_encoder,
                                critic,
                                src_data_loader,
                                tgt_data_loader,
                                dataset_name=params.tgt_dataset)

    tgt_encoder, tgt_classifier = train_tgt_classifier(tgt_encoder,
                                                       tgt_classifier,
                                                       tgt_data_loader)

    # eval target encoder on test set of target dataset
    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> source only <<<")
    eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
    print(">>> domain adaption <<<")
    eval_tgt(tgt_encoder, tgt_classifier, tgt_data_loader_eval)
    print(">>> enhanced domain adaptation<<<")
    eval_tgt_with_probe(tgt_encoder, critic, src_classifier, tgt_classifier,
Exemple #10
0
def main():
    args = get_arguments()

    # init random seed
    init_random_seed(manual_seed)

    src_data_loader, src_data_loader_eval, tgt_data_loader, tgt_data_loader_eval = get_dataset(args)

    # argument setting
    print("=== Argument Setting ===")
    print("src: " + args.src)
    print("tgt: " + args.tgt)
    print("patience: " + str(args.patience))
    print("num_epochs_pre: " + str(args.num_epochs_pre))
    print("eval_step_pre: " + str(args.eval_step_pre))
    print("save_step_pre: " + str(args.save_step_pre))
    print("num_epochs: " + str(args.num_epochs))
    print("src encoder lr: " + str(args.lr))
    print("tgt encoder lr: " + str(args.t_lr))
    print("critic lr: " + str(args.c_lr))
    print("batch_size: " + str(args.batch_size))

    # load models
    src_encoder_restore = "snapshots/src-encoder-adda-{}.pt".format(args.src)
    src_classifier_restore = "snapshots/src-classifier-adda-{}.pt".format(args.src)
    tgt_encoder_restore = "snapshots/tgt-encoder-adda-{}.pt".format(args.src)
    d_model_restore = "snapshots/critic-adda-{}.pt".format(args.src)
    src_encoder = init_model(BERTEncoder(),
                             restore=src_encoder_restore)
    src_classifier = init_model(BERTClassifier(),
                                restore=src_classifier_restore)
    tgt_encoder = init_model(BERTEncoder(),
                             restore=tgt_encoder_restore)
    critic = init_model(Discriminator(),
                        restore=d_model_restore)

    # no, fine-tune BERT
    # if not args.enc_train:
    #     for param in src_encoder.parameters():
    #         param.requires_grad = False

    if torch.cuda.device_count() > 1:
        print('Let\'s use {} GPUs!'.format(torch.cuda.device_count()))
        src_encoder = nn.DataParallel(src_encoder)
        src_classifier = nn.DataParallel(src_classifier)
        tgt_encoder = nn.DataParallel(tgt_encoder)
        critic = nn.DataParallel(critic)

    # train source model
    print("=== Training classifier for source domain ===")
    src_encoder, src_classifier = train_src(
        args, src_encoder, src_classifier, src_data_loader, src_data_loader_eval)

    # eval source model
    print("=== Evaluating classifier for source domain ===")
    eval_src(src_encoder, src_classifier, src_data_loader_eval)

    # train target encoder by GAN
    print("=== Training encoder for target domain ===")
    if not (tgt_encoder.module.restored and critic.module.restored and
            tgt_model_trained):
        tgt_encoder = train_tgt(args, src_encoder, tgt_encoder, critic,
                                src_data_loader, tgt_data_loader)

    # eval target encoder on test set of target dataset
    print("Evaluate tgt test data on src encoder: {}".format(args.tgt))
    eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
    print("Evaluate tgt test data on tgt encoder: {}".format(args.tgt))
    eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
Exemple #11
0
        tgt_encoder.load_state_dict(src_encoder.state_dict())

        # freeze target encoder params
        for params in tgt_encoder.parameters():
            params.requires_grad = False
        if torch.cuda.device_count() > 1:
            for params in tgt_encoder.module.encoder.embeddings.parameters():
                params.requires_grad = True
        else:
            for params in tgt_encoder.encoder.embeddings.parameters():
                params.requires_grad = True

        # train target encoder by GAN
        print("=== Training encoder for target domain ===")
        tgt_encoder = train_tgt(args, src_encoder, tgt_encoder, critic,
                                src_classifier, src_data_loader,
                                tgt_data_loader)

        # eval target encoder on lambda0.1 set of target dataset
        print("=== Evaluating classifier for encoded target domain ===")
        print(">>> source only <<<")
        src_tgt = eval_tgt(src_encoder, src_classifier, tgt_data_loader)
        print(">>> domain adaption <<<")
        tgt_tgt = eval_tgt(tgt_encoder, src_classifier, tgt_data_loader)

        worksheet.write(fold_index + 1, 0, fold_index + 1)
        worksheet.write(fold_index + 1, 1, src_src)
        worksheet.write(fold_index + 1, 2, src_tgt)
        worksheet.write(fold_index + 1, 3, tgt_tgt)
        src_src_stack.append(src_src)
        src_tgt_stack.append(src_tgt)