def run(): # load dataset src_data_loader = get_data_loader(params.src_dataset) tgt_data_loader = get_data_loader(params.tgt_dataset) # load models src_encoder = init_model(net=LeNetEncoder(), restore=params.src_encoder_restore) tgt_encoder = init_model(net=LeNetEncoder(), restore=params.tgt_encoder_restore) critic = init_model(Discriminator(input_dims=params.d_input_dims, hidden_dims=params.d_hidden_dims, output_dims=params.d_output_dims), restore=params.d_model_restore) # Adapt target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") im, _ = next(iter(tgt_data_loader)) summary(tgt_encoder, input_size=im[0].size()) print(">>> Critic <<<") print(critic) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) # Train target if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader)
def office(): init_random_seed(params.manual_seed) # load dataset src_data_loader = get_data_loader(params.src_dataset) src_data_loader_eval = get_data_loader(params.src_dataset, train=False) tgt_data_loader = get_data_loader(params.tgt_dataset) tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) # load models src_encoder = init_model(net=LeNetEncoder(), restore=params.src_encoder_restore) src_classifier = init_model(net=LeNetClassifier(), restore=params.src_classifier_restore) tgt_encoder = init_model(net=LeNetEncoder(), restore=params.tgt_encoder_restore) critic = init_model(Discriminator(input_dims=params.d_input_dims, hidden_dims=params.d_hidden_dims, output_dims=params.d_output_dims), restore=params.d_model_restore) if not (src_encoder.restored and src_classifier.restored and params.src_model_trained): src_encoder, src_classifier = train_src( src_encoder, src_classifier, src_data_loader) # eval source model # print("=== Evaluating classifier for source domain ===") # eval_src(src_encoder, src_classifier, src_data_loader_eval) # train target encoder by GAN # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader) # eval target encoder on test set of target dataset print(">>> domain adaption <<<") acc = eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval) return acc
def experiments(exp): #print(exp, case, affine, num_epochs) # init random seed #params.d_learning_rate = lr_d #params.c_learning_rate = lr_c init_random_seed(params.manual_seed) # load dataset src_dataset, tgt_dataset = exp.split('_') src_data_loader = get_data_loader(src_dataset) src_data_loader_eval = get_data_loader(src_dataset, train=False) tgt_data_loader = get_data_loader(tgt_dataset) tgt_data_loader_eval = get_data_loader(tgt_dataset, train=False) # load models src_encoder = init_model(net=LeNetEncoder(), restore=params.src_encoder_restore, exp=exp) src_classifier = init_model(net=LeNetClassifier(), restore=params.src_classifier_restore, exp=exp) tgt_encoder = init_model(net=LeNetEncoder(), restore=params.tgt_encoder_restore, exp=exp) critic = init_model(Discriminator(input_dims=params.d_input_dims, hidden_dims=params.d_hidden_dims, output_dims=params.d_output_dims), exp=exp, restore=params.d_model_restore) # train source model print("=== Training classifier for source domain ===") print(">>> Source Encoder <<<") print(src_encoder) print(">>> Source Classifier <<<") print(src_classifier) if not (src_encoder.restored and src_classifier.restored and params.src_model_trained): src_encoder, src_classifier = train_src(exp, src_encoder, src_classifier, src_data_loader, src_data_loader_eval) # eval source model print("=== Evaluating classifier for source domain ===") evaluation(src_encoder, src_classifier, src_data_loader_eval) # train target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Critic <<<") print(critic) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(exp, src_encoder, tgt_encoder, critic, src_classifier, src_data_loader, tgt_data_loader, tgt_data_loader_eval) # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") evaluation(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") evaluation(tgt_encoder, src_classifier, tgt_data_loader_eval)
and params.src_model_trained): src_encoder, src_classifier = train_src(src_encoder, src_classifier, src_data_loader) # eval source model # print("=== Evaluating classifier for source domain ===") # eval_src(src_encoder, src_classifier, src_data_loader_eval) # train target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Critic <<<") print(critic) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader) # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") # print(">>> source only <<<") # eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
if not (src_encoder.restored and src_classifier.restored and params.src_model_trained): src_encoder, src_classifier = train_src(src_encoder, src_classifier, src_data_loader, params) # eval source model print("=== Evaluating classifier for source domain ===") eval(src_encoder, src_classifier, src_data_loader) print("=== Evaluating classifier for target domain ===") eval(src_encoder, src_classifier, tgt_data_loader) # train target encoder by GAN print("=== Training encoder for target domain ===") # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, src_classifier, tgt_encoder, critic, src_data_loader, tgt_data_loader, params) # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") eval(src_encoder, src_classifier, tgt_data_loader) print(">>> domain adaption <<<") eval(tgt_encoder, src_classifier, tgt_data_loader)
print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Critic <<<") print(critic) print(">>> Generator <<<") print(generator) print(">>> Discriminator <<<") print(discriminator) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and cfg.tgt_model_trained): tgt_encoder, tgt_classifier = train_tgt( src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader, src_classifier, tgt_classifier, tgt_data_loader_eval, generator, discriminator, Saver, logger) # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") eval_func(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") tgt_classifier = init_model(net=SythnetClassifier(nf=cfg.d_input_dims, ncls=cfg.ncls), restore=cfg.tgt_classifier_restore) eval_func(tgt_encoder, tgt_classifier, tgt_data_loader_eval)
# Train target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Discriminator <<<") print(discriminator) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: print( "[main.py] INFO | No trained target encoder found, initialising target encoder with trained source encoder weights.." ) tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and discriminator.restored and params.tgt_model_trained): print( "[main.py] INFO | No trained target encoder found, beginning adverserial training.." ) tgt_encoder = train_tgt(src_encoder, tgt_encoder, discriminator, src_data_loader, tgt_data_loader, src_classifier, tgt_data_loader_eval) # src_data_loader, tgt_data_loader_eval,src_classifier,tgt_data_loader_eval) # Eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") _ = eval_src(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") _ = eval_src(tgt_encoder, src_classifier, tgt_data_loader_eval)
src_encoder, src_classifier, src_data_loader, dataset_name="EMOTION") # eval source model print("=== Evaluating classifier for source domain ===") eval_src(src_encoder, src_classifier, src_data_loader_eval) # train target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Critic <<<") print(critic) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader, dataset_name='CONFLICT') # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
# train target encoder by GAN print("=== Training encoder for target domain ===") print(">>> Target Encoder <<<") print(tgt_encoder) print(">>> Critic <<<") print(critic) # init weights of target encoder with those of source encoder if not tgt_encoder.restored: tgt_encoder.load_state_dict(src_encoder.state_dict()) if not (tgt_encoder.restored and critic.restored and params.tgt_model_trained): tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader, dataset_name=params.tgt_dataset) tgt_encoder, tgt_classifier = train_tgt_classifier(tgt_encoder, tgt_classifier, tgt_data_loader) # eval target encoder on test set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval) print(">>> domain adaption <<<") eval_tgt(tgt_encoder, tgt_classifier, tgt_data_loader_eval) print(">>> enhanced domain adaptation<<<") eval_tgt_with_probe(tgt_encoder, critic, src_classifier, tgt_classifier,
def main(): args = get_arguments() # init random seed init_random_seed(manual_seed) src_data_loader, src_data_loader_eval, tgt_data_loader, tgt_data_loader_eval = get_dataset(args) # argument setting print("=== Argument Setting ===") print("src: " + args.src) print("tgt: " + args.tgt) print("patience: " + str(args.patience)) print("num_epochs_pre: " + str(args.num_epochs_pre)) print("eval_step_pre: " + str(args.eval_step_pre)) print("save_step_pre: " + str(args.save_step_pre)) print("num_epochs: " + str(args.num_epochs)) print("src encoder lr: " + str(args.lr)) print("tgt encoder lr: " + str(args.t_lr)) print("critic lr: " + str(args.c_lr)) print("batch_size: " + str(args.batch_size)) # load models src_encoder_restore = "snapshots/src-encoder-adda-{}.pt".format(args.src) src_classifier_restore = "snapshots/src-classifier-adda-{}.pt".format(args.src) tgt_encoder_restore = "snapshots/tgt-encoder-adda-{}.pt".format(args.src) d_model_restore = "snapshots/critic-adda-{}.pt".format(args.src) src_encoder = init_model(BERTEncoder(), restore=src_encoder_restore) src_classifier = init_model(BERTClassifier(), restore=src_classifier_restore) tgt_encoder = init_model(BERTEncoder(), restore=tgt_encoder_restore) critic = init_model(Discriminator(), restore=d_model_restore) # no, fine-tune BERT # if not args.enc_train: # for param in src_encoder.parameters(): # param.requires_grad = False if torch.cuda.device_count() > 1: print('Let\'s use {} GPUs!'.format(torch.cuda.device_count())) src_encoder = nn.DataParallel(src_encoder) src_classifier = nn.DataParallel(src_classifier) tgt_encoder = nn.DataParallel(tgt_encoder) critic = nn.DataParallel(critic) # train source model print("=== Training classifier for source domain ===") src_encoder, src_classifier = train_src( args, src_encoder, src_classifier, src_data_loader, src_data_loader_eval) # eval source model print("=== Evaluating classifier for source domain ===") eval_src(src_encoder, src_classifier, src_data_loader_eval) # train target encoder by GAN print("=== Training encoder for target domain ===") if not (tgt_encoder.module.restored and critic.module.restored and tgt_model_trained): tgt_encoder = train_tgt(args, src_encoder, tgt_encoder, critic, src_data_loader, tgt_data_loader) # eval target encoder on test set of target dataset print("Evaluate tgt test data on src encoder: {}".format(args.tgt)) eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval) print("Evaluate tgt test data on tgt encoder: {}".format(args.tgt)) eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
tgt_encoder.load_state_dict(src_encoder.state_dict()) # freeze target encoder params for params in tgt_encoder.parameters(): params.requires_grad = False if torch.cuda.device_count() > 1: for params in tgt_encoder.module.encoder.embeddings.parameters(): params.requires_grad = True else: for params in tgt_encoder.encoder.embeddings.parameters(): params.requires_grad = True # train target encoder by GAN print("=== Training encoder for target domain ===") tgt_encoder = train_tgt(args, src_encoder, tgt_encoder, critic, src_classifier, src_data_loader, tgt_data_loader) # eval target encoder on lambda0.1 set of target dataset print("=== Evaluating classifier for encoded target domain ===") print(">>> source only <<<") src_tgt = eval_tgt(src_encoder, src_classifier, tgt_data_loader) print(">>> domain adaption <<<") tgt_tgt = eval_tgt(tgt_encoder, src_classifier, tgt_data_loader) worksheet.write(fold_index + 1, 0, fold_index + 1) worksheet.write(fold_index + 1, 1, src_src) worksheet.write(fold_index + 1, 2, src_tgt) worksheet.write(fold_index + 1, 3, tgt_tgt) src_src_stack.append(src_src) src_tgt_stack.append(src_tgt)