Beispiel #1
0
def make_model(src_vocab,
               tgt_vocab,
               N=6,
               d_model=512,
               d_ff=2048,
               h=8,
               dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    # The dimension of multi-head model is the same as the embedding. Is it must?
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
Beispiel #2
0
def train_cleitcs(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    encoder = autoencoder.encoder
    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder).to(kwargs['device'])

    train_history = defaultdict(list)
    # ae_eval_train_history = defaultdict(list)
    val_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)
    cleit_params = [
        target_regressor.parameters(),
        transmitter.parameters()
    ]
    model_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'Coral training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = cleit_train_step(model=target_regressor,
                                             transmitter=transmitter,
                                             reference_encoder=reference_encoder,
                                             s_batch=s_batch,
                                             t_batch=t_batch,
                                             device=kwargs['device'],
                                             optimizer=model_optimizer,
                                             alpha=kwargs['alpha'],
                                             history=train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=s_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=t_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                               dataloader=test_dataloader,
                                                                               device=kwargs['device'],
                                                                               history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag:
            torch.save(target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (
        train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history,
        target_regression_eval_val_history, target_regression_eval_test_history)
import torch.nn.functional as F
import torch.optim as optim
from DataSet import *
import os
from refinement import RefineNet

device = 0
ed_epoch = 100
refine_epoch = 100
final_epoch = 100
batch_size = 16

RF = RefineNet().double().cuda(device)
ED = EncoderDecoder().double().cuda(device)

opt_ED = optim.SGD(ED.parameters(), lr=1e-5, momentum=0.9)
opt_RF = optim.SGD(RF.parameters(), lr=5e-2, momentum=0.9)

a_path = '/home/zhuyuanjin/data/Human_Matting/alpha'
img_path = '/home/zhuyuanjin/data/Human_Matting/image_matting'

ed_pretrained = '/home/zhuyuanjin/data/Human_Matting/models/ed_pretrained_Matting'
rf_pretrained = '/home/zhuyuanjin/data/Human_Matting/models/rf_pretrained'

final_param = '/home/zhuyuanjin/data/Human_Matting/models/final_param'

dataset = MattingDataSet(
    a_path=a_path,
    img_path=img_path,
)
dataloader = DataLoader(dataset,
Beispiel #4
0
import torch.optim as optim
from DataSet import *
import os
from refinement import RefineNet


device = 0
ed_epoch = 2
refine_epoch = 10
final_epoch = 100
batch_size = 16

RF = RefineNet().double().cuda(device)
ED = EncoderDecoder().double().cuda(device)

opt_ED = optim.Adam(ED.parameters(), lr=1e-4)
opt_RF = optim.Adam(RF.parameters(), lr=1e-4)

fg_path = '/data1/zhuyuanji01/Adobe_Deep_Image_Matting_Dataset/dim431/fg/'
bg_path = '/data1/zhuyuanji01/cocostuff/dataset/images/train2017/'
a_path = '/data1/zhuyuanji01/Adobe_Deep_Image_Matting_Dataset/dim431/alpha/'
merged_path = '/data1/zhuyuanji01/Adobe_Deep_Image_Matting_Dataset/dim431/merged/'

dataset = MattingDataSet(fg_path=fg_path, bg_path=bg_path, a_path=a_path, merged_path=merged_path)
dataloader = DataLoader(dataset, num_workers=10 , batch_size=batch_size, shuffle=True)

if __name__ == '__main__':

    print('Beginning to PreTrain the Encoder Decoder')

    for epoch in range(ed_epoch):
Beispiel #5
0
def fine_tune_encoder(train_dataloader, val_dataloader, seed, test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    target_regression_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr'])

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 10 == 0:
            print(f'MLP fine-tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(model=target_regressor,
                                                                    batch=batch,
                                                                    device=kwargs['device'],
                                                                    optimizer=target_regression_optimizer,
                                                                    history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                dataloader=train_dataloader,
                                                                                device=kwargs['device'],
                                                                                history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                   dataloader=test_dataloader,
                                                                                   device=kwargs['device'],
                                                                                   history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag or epoch == 0:
            torch.save(target_regressor.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt'))
            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'ft_encoder_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    if test_dataloader is not None:
        evaluate_target_regression_epoch(regressor=target_regressor,
                                         dataloader=test_dataloader,
                                         device=kwargs['device'],
                                         history=None,
                                         seed=seed,
                                         output_folder=kwargs['model_save_folder'])


    return target_regressor, (target_regression_train_history, target_regression_eval_train_history,
                              target_regression_eval_val_history, target_regression_eval_test_history)
Beispiel #6
0
def train_dann(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    classifier = MLP(input_dim=kwargs['latent_dim'],
                     output_dim=1,
                     hidden_dims=kwargs['classifier_hidden_dims'],
                     dop=kwargs['dop'],
                     out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder,
                                           decoder=classifier).to(
                                               kwargs['device'])

    train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    confounded_loss = nn.BCEWithLogitsLoss()
    dann_params = [
        target_regressor.parameters(),
        confounder_classifier.decoder.parameters()
    ]
    dann_optimizer = torch.optim.AdamW(chain(*dann_params), lr=kwargs['lr'])

    # start alternative training
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'DANN training epoch {epoch}')
        # start autoencoder training epoch
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = dann_train_step(classifier=confounder_classifier,
                                            model=target_regressor,
                                            s_batch=s_batch,
                                            t_batch=t_batch,
                                            loss_fn=confounded_loss,
                                            alpha=kwargs['alpha'],
                                            device=kwargs['device'],
                                            optimizer=dann_optimizer,
                                            history=train_history,
                                            scheduler=None)

        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'dann_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'dann_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
Beispiel #7
0
def train_adda(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    model_optimizer = torch.optim.AdamW(target_regressor.parameters(),
                                        lr=kwargs['lr'])
    classifier_optimizer = torch.optim.RMSprop(
        confounding_classifier.parameters(), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'ADDA training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            critic_train_history = critic_train_step(
                critic=confounding_classifier,
                model=target_regressor,
                s_batch=s_batch,
                t_batch=t_batch,
                device=kwargs['device'],
                optimizer=classifier_optimizer,
                history=critic_train_history,
                # clip=0.1,
                gp=10.0)
            if (step + 1) % 5 == 0:
                gen_train_history = gan_gen_train_step(
                    critic=confounding_classifier,
                    model=target_regressor,
                    s_batch=s_batch,
                    t_batch=t_batch,
                    device=kwargs['device'],
                    optimizer=model_optimizer,
                    alpha=1.0,
                    history=gen_train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'adda_regressor_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'adda_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (critic_train_history, gen_train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)