Beispiel #1
0
def train_cleitc(dataloader, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae5000', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [
            autoencoder.parameters(),
            transmitter.parameters()
        ]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 1 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = cleit_train_step(ae=autoencoder,
                                                         reference_encoder=reference_encoder,
                                                         transmitter=transmitter,
                                                         batch=batch,
                                                         device=kwargs['device'],
                                                         optimizer=cleit_optimizer,
                                                         history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(), os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #2
0
def train_ae(dataloader, **kwargs):
    """
    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        ae_optimizer = torch.optim.AdamW(autoencoder.parameters(),
                                         lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = ae_train_step(
                    ae=autoencoder,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=ae_optimizer,
                    history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'ae.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    return autoencoder.encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #3
0
def train_cleita(dataloader, seed, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(
        torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    ae_train_history = defaultdict(list)
    ae_val_history = defaultdict(list)
    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [autoencoder.parameters(), transmitter.parameters()]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params),
                                            lr=kwargs['lr'])
        classifier_optimizer = torch.optim.RMSprop(
            confounding_classifier.parameters(), lr=kwargs['lr'])
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'confounder wgan training epoch {epoch}')
            for step, batch in enumerate(dataloader):
                critic_train_history = critic_train_step(
                    critic=confounding_classifier,
                    ae=autoencoder,
                    reference_encoder=reference_encoder,
                    transmitter=transmitter,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=classifier_optimizer,
                    history=critic_train_history,
                    # clip=0.1,
                    gp=10.0)
                if (step + 1) % 5 == 0:
                    gen_train_history = gan_gen_train_step(
                        critic=confounding_classifier,
                        ae=autoencoder,
                        transmitter=transmitter,
                        batch=batch,
                        device=kwargs['device'],
                        optimizer=cleit_optimizer,
                        alpha=1.0,
                        history=gen_train_history)

        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'],
                                 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_train_history, ae_val_history, critic_train_history,
                     gen_train_history)
Beispiel #4
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

testset = torchvision.datasets.ImageFolder(val_data_folder,
                                           transform=test_transform)
test_loader = torch.utils.data.DataLoader(testset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=20)

model = AE(K=K).to(device)
model = nn.DataParallel(model, device_ids=[0])
model.load_state_dict(
    torch.load(saved_model_name, map_location={'cuda:1': 'cuda:0'}))

if not os.path.exists(save_folder_name):
    os.makedirs(save_folder_name)

with tqdm(total=len(test_loader), desc="Batches") as pbar:
    for i, (data) in enumerate(test_loader):
        model.eval()
        img, labels = data
        encoded, out, hashed = model(img)
        torch.save(out, save_folder_name + "/out/out_{}.pt".format(i))
        torch.save(labels, save_folder_name + "/lab/lab_{}.pt".format(i))
        torch.save(hashed, save_folder_name + "/hash/hash_{}.pt".format(i))
        pbar.update(1)
Beispiel #5
0
    send_slack_notif("Started with train stage-{}".format(trn_stage_no + 1))

    tensorboard_folder = '../runs/ae_hash/ae_stage{}'.format(trn_stage_no + 1)
    writer = SummaryWriter(tensorboard_folder)

    trainset = torchvision.datasets.ImageFolder(train_data_folder, transform=train_transform)
    valset = torchvision.datasets.ImageFolder(val_data_folder, transform=val_transform)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=20)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=20)

    model = AE(K=K).to(device)
    model = nn.DataParallel(model, device_ids=[0, 1])

    if train_stage["use_weight"]:
        model.load_state_dict(torch.load(saved_model_name + '_stage1.pt'), strict=False)

    # Adding layer parameters for different (10x faster than pretrained) learning rate
    fast_learning_layers = ['hashed_layer.{}'.format(ii) for ii in [0, 2, 4, 6]]
    fast_learning_layers = ['module.' + s + sb for s in fast_learning_layers for sb in ['.weight', '.bias']]

    params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] in fast_learning_layers, model.named_parameters()))))
    base_params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] not in fast_learning_layers, model.named_parameters()))))
    assert len(params) == len(fast_learning_layers)

    # Initializing losses and adding loss params to optimizer with higher lr
    criterion1 = comp_loss()
    criterion2 = cauchy_loss(K=K, q_lambda=q_lambda)
    params.extend(list(criterion1.parameters()))
    params.extend(list(criterion2.parameters()))