コード例 #1
0
ファイル: train_cleitcs.py プロジェクト: he-org/CLEIT
def train_cleitcs(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    encoder = autoencoder.encoder
    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder).to(kwargs['device'])

    train_history = defaultdict(list)
    # ae_eval_train_history = defaultdict(list)
    val_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)
    cleit_params = [
        target_regressor.parameters(),
        transmitter.parameters()
    ]
    model_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'Coral training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = cleit_train_step(model=target_regressor,
                                             transmitter=transmitter,
                                             reference_encoder=reference_encoder,
                                             s_batch=s_batch,
                                             t_batch=t_batch,
                                             device=kwargs['device'],
                                             optimizer=model_optimizer,
                                             alpha=kwargs['alpha'],
                                             history=train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=s_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=t_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                               dataloader=test_dataloader,
                                                                               device=kwargs['device'],
                                                                               history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag:
            torch.save(target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (
        train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history,
        target_regression_eval_val_history, target_regression_eval_test_history)
コード例 #2
0
ファイル: fine_tuning.py プロジェクト: XieResearchGroup/CLEIT
def fine_tune_encoder(encoder,
                      train_dataloader,
                      val_dataloader,
                      seed,
                      task_save_folder,
                      test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False,
                      **kwargs):
    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(
                                          kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    encoder_module_indices = [
        i for i in range(len(list(encoder.modules())))
        if str(list(encoder.modules())[i]).startswith('Linear')
    ]

    reset_count = 1
    lr = kwargs['lr']

    target_regression_params = [target_regressor.decoder.parameters()]
    target_regression_optimizer = torch.optim.AdamW(
        chain(*target_regression_params), lr=lr)

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 50 == 0:
            print(f'Fine tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(
                model=target_regressor,
                batch=batch,
                device=kwargs['device'],
                optimizer=target_regression_optimizer,
                history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=train_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(
                regressor=target_regressor,
                dataloader=test_dataloader,
                device=kwargs['device'],
                history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=10,
            reset_count=reset_count)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(task_save_folder, f'target_regressor_{seed}.pt'))

            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(task_save_folder, f'ft_encoder_{seed}.pt'))

        if stop_flag:
            try:
                ind = encoder_module_indices.pop()
                print(f'Unfreezing {epoch}')
                target_regressor.load_state_dict(
                    torch.load(
                        os.path.join(task_save_folder,
                                     f'target_regressor_{seed}.pt')))

                target_regression_params.append(
                    list(target_regressor.encoder.modules())[ind].parameters())
                lr = lr * kwargs['decay_coefficient']
                target_regression_optimizer = torch.optim.AdamW(
                    chain(*target_regression_params), lr=lr)
                reset_count += 1
            except IndexError:
                break

    target_regressor.load_state_dict(
        torch.load(
            os.path.join(task_save_folder, f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (target_regression_train_history,
                              target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
コード例 #3
0
def train_dann(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    classifier = MLP(input_dim=kwargs['latent_dim'],
                     output_dim=1,
                     hidden_dims=kwargs['classifier_hidden_dims'],
                     dop=kwargs['dop'],
                     out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder,
                                           decoder=classifier).to(
                                               kwargs['device'])

    train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    confounded_loss = nn.BCEWithLogitsLoss()
    dann_params = [
        target_regressor.parameters(),
        confounder_classifier.decoder.parameters()
    ]
    dann_optimizer = torch.optim.AdamW(chain(*dann_params), lr=kwargs['lr'])

    # start alternative training
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'DANN training epoch {epoch}')
        # start autoencoder training epoch
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = dann_train_step(classifier=confounder_classifier,
                                            model=target_regressor,
                                            s_batch=s_batch,
                                            t_batch=t_batch,
                                            loss_fn=confounded_loss,
                                            alpha=kwargs['alpha'],
                                            device=kwargs['device'],
                                            optimizer=dann_optimizer,
                                            history=train_history,
                                            scheduler=None)

        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'dann_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'dann_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
コード例 #4
0
ファイル: mlp_main.py プロジェクト: he-org/CLEIT
def fine_tune_encoder(train_dataloader, val_dataloader, seed, test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    target_regression_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr'])

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 10 == 0:
            print(f'MLP fine-tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(model=target_regressor,
                                                                    batch=batch,
                                                                    device=kwargs['device'],
                                                                    optimizer=target_regression_optimizer,
                                                                    history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                dataloader=train_dataloader,
                                                                                device=kwargs['device'],
                                                                                history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                   dataloader=test_dataloader,
                                                                                   device=kwargs['device'],
                                                                                   history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag or epoch == 0:
            torch.save(target_regressor.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt'))
            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'ft_encoder_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    if test_dataloader is not None:
        evaluate_target_regression_epoch(regressor=target_regressor,
                                         dataloader=test_dataloader,
                                         device=kwargs['device'],
                                         history=None,
                                         seed=seed,
                                         output_folder=kwargs['model_save_folder'])


    return target_regressor, (target_regression_train_history, target_regression_eval_train_history,
                              target_regression_eval_val_history, target_regression_eval_test_history)
コード例 #5
0
def train_adda(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    model_optimizer = torch.optim.AdamW(target_regressor.parameters(),
                                        lr=kwargs['lr'])
    classifier_optimizer = torch.optim.RMSprop(
        confounding_classifier.parameters(), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'ADDA training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            critic_train_history = critic_train_step(
                critic=confounding_classifier,
                model=target_regressor,
                s_batch=s_batch,
                t_batch=t_batch,
                device=kwargs['device'],
                optimizer=classifier_optimizer,
                history=critic_train_history,
                # clip=0.1,
                gp=10.0)
            if (step + 1) % 5 == 0:
                gen_train_history = gan_gen_train_step(
                    critic=confounding_classifier,
                    model=target_regressor,
                    s_batch=s_batch,
                    t_batch=t_batch,
                    device=kwargs['device'],
                    optimizer=model_optimizer,
                    alpha=1.0,
                    history=gen_train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'adda_regressor_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'adda_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (critic_train_history, gen_train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)