def main():
    CUDA = False
    if torch.cuda.is_available():
        CUDA = True
        print('cuda available')
        torch.backends.cudnn.benchmark = True
    config = config_process(parser.parse_args())
    print(config)

    with open('pkl/task_1_train.pkl', 'rb') as f:
        task_1_train = pkl.load(f)
    with open('pkl/task_1_test.pkl', 'rb') as g:
        task_1_test = pkl.load(g)

    ###### task 0:seen training data and unseen test data
    examples, labels, class_map = image_load(config['class_file'],
                                             config['image_label'])
    ###### task 0: seen test data
    examples_0, labels_0, class_map_0 = image_load(config['class_file'],
                                                   config['test_seen_classes'])

    datasets = split_byclass(config, examples, labels,
                             np.loadtxt(config['attributes_file']), class_map)
    datasets_0 = split_byclass(config, examples_0, labels_0,
                               np.loadtxt(config['attributes_file']),
                               class_map)
    print('load the task 0 train: {} the task 1 as test: {}'.format(
        len(datasets[0][0]), len(datasets[0][1])))
    print('load task 0 test data {}'.format(len(datasets_0[0][0])))

    test_attr = F.normalize(datasets[0][4])

    best_cfg = config
    best_cfg['n_classes'] = datasets[0][3].size(0)
    best_cfg['n_train_lbl'] = datasets[0][3].size(0)
    best_cfg['n_test_lbl'] = datasets[0][4].size(0)

    task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True)
    # task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False)

    base_model = models.__dict__[config['arch']](pretrained=False)
    if config['arch'].startswith('resnet'):
        FE_model = nn.Sequential(*list(base_model.children())[:-1])
    else:
        print('untested')
        raise NotImplementedError

    ###### if finetune == False,
    print('load pretrained FE_model')
    FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['softmax_method'], config['arch'],
        config['task_id'], config['finetune'], 'checkpoint.pth')

    FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE'])
    for name, para in FE_model.named_parameters():
        para.requires_grad = False

    vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'],
              latent_size=config['latent_size'],
              decoder_layer_sizes=config['decoder_layer_sizes'],
              num_labels=config['num_labels'])

    vae_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], 'vae', config['arch'], 0, config['finetune'],
        'ckpt.pth')
    vae.load_state_dict(torch.load(vae_path))

    print(vae)
    if CUDA:
        FE_model = FE_model.cuda()
        vae = vae.cuda()
    FE_model.eval()
    optimizer = torch.optim.Adam(vae.parameters(), lr=config['lr'])

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                config['step'],
                                                gamma=0.1,
                                                last_epoch=-1)
    criterion = loss_fn
    print('have got real trainval feats and labels')

    for epoch in range(config['epoch']):
        print('\n epoch: %d' % epoch)
        print('...TRAIN...')
        print_learning_rate(optimizer)
        ### train_attr--->test_attr, task_0_train_set---> task_1_train_set
        train(epoch, FE_model, vae, task_1_train_set, optimizer, criterion,
              test_attr, CUDA)
        scheduler.step()

    vae_ckpt_name = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['method'], config['softmax_method'],
        config['arch'], config['task_id'], config['finetune'], 'ckpt.pth')
    torch.save(vae.state_dict(), vae_ckpt_name)
Пример #2
0
if args.cuda:
    torch.cuda.manual_seed(args.seed)

train_loader = DataLoader(get_dataset(dsize=args.size),
                          batch_size=args.batch_size,
                          shuffle=True)
test_loader = DataLoader(get_dataset(root="../data/test", dsize=args.size),
                         batch_size=args.batch_size,
                         shuffle=True)

model = VAE(args)
model.apply(weights_init())
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

if args.rec_loss == "ssim":
    rec_loss = SSIMLoss(method=args.ssim_method)
elif args.rec_loss == "bce":
    rec_loss = nn.BCELoss(size_average=args.size_average)
elif args.rec_loss == "l1":
    rec_loss = nn.L1Loss(size_average=args.size_average)
else:
    rec_loss = None
kl_loss = KLLoss(size_average=args.size_average)


def loss_function(x, x_rec, mu, logvar):
    return rec_loss(x, x_rec) + kl_loss(mu, logvar)
Пример #3
0
#
# Model
#

device = torch.device('cuda')

epochs = 50
hidden_size = 1024
latent_size = 512

model = VAE(h_dim=hidden_size, z_dim=latent_size)
model = nn.DataParallel(model)
model = model.to(device)

criterion = vae_loss_function
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

#
# Training
#

model.train()
for epoch in range(epochs):
    loss, bce, kld = 0, 0, 0

    max_batches = len(unlabeled_trainloader)
    for idx, (images, camera_index) in enumerate(unlabeled_trainloader):
        images = images.to(device)
        recon_images, mu, logvar = model(images)

        loss, bce, kld = criterion(recon_images, images, mu, logvar)
Пример #4
0
    print('feature data path: ' + path)
    X = io.mmread(path).A.astype('float32')
    args.n, args.d = X.shape

    # X = normalize(X, norm='l2', axis=0)

    vae = VAE(args)
    if args.load:
        path = '{}/model/vae_{}_{}'.format(params['dataset_dir'],
                                           params['dataset'], layer)
        vae.load_state_dict(torch.load(path))

    vae.to(args.device)

    # psm = PSM(torch.from_numpy(R.astype('float32')).to(args.device), args).to(args.device)

    loader = DataLoader(np.arange(args.n), batch_size=1, shuffle=True)
    optimizer = optim.Adam(vae.parameters(), lr=args.lr)

    evaluator = Evaluator({'recall', 'dcg_cut'})
    # variational()
    # maximum()
    # evaluate()
    train()

    if args.save:
        vae.cpu()
        path = '{}/model/vae_{}_{}'.format(params['dataset_dir'],
                                           params['dataset'], layer)
        torch.save(vae.state_dict(), path)
Пример #5
0
def main():
    parser = _build_parser()
    args = parser.parse_args()

    logging.basicConfig(format="%(levelname)s: %(message)s",
                        level=logging.DEBUG)

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = None
    save_path = args.save
    if args.model == 'vae':
        model = VAE(args.dim)
        logging.info('training VAE with dims: {}'.format(args.dim))
    elif args.model == 'ae':
        model = AE(args.dim)
        logging.info('training AE with dims: {}'.format(args.dim))
    elif args.model == 'hm':
        model = HM(args.color)
    elif args.model == 'gmvae':
        model = GMVAE()
    else:
        logging.critical('model unimplemented: %s' % args.model)
        return

    if not save_path.exists():
        save_path.mkdir(parents=True)

    model = model.float()
    model.to(device)
    optimizer = optim.Adam(model.parameters())

    train_ds, test_ds = build_datasets(args.path)

    losses = []
    for e in range(args.epochs):
        logging.info('epoch: %d of %d' % (e + 1, args.epochs))

        loader = DataLoader(train_ds,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=torch.cuda.is_available())
        total_batches = len(train_ds) // args.batch_size

        log_every = total_batches // 50 + 1
        save_every = 1  # hardcoded for now
        for i, x in enumerate(loader):
            x = x.to(device)
            optimizer.zero_grad()
            output = model(x)
            total_loss = model.loss_function(output)
            if type(total_loss) is dict:  # TODO: generalize loss handling
                total_loss = total_loss['loss']

            total_loss.backward()
            optimizer.step()

            if i % log_every == 0:
                model.eval()
                loss = _eval(model, test_ds, device)
                model.train()

                logging.info('[batch %d/%d] ' % (i + 1, total_batches) +
                             model.print_loss(loss))
                # TODO: generalize printing
                # print_params = (i+1, total_batches, loss['loss'], loss['mse'], loss['kld'])
                # logging.info('[batch %d/%d] loss: %f, mse: %f, kld: %f' % print_params)
                # print_params = (i+1, total_batches, loss)
                # logging.info('[batch %d/%d] loss: %f' % print_params)
                losses.append({'iter': i, 'epoch': e, 'loss': loss})

        if e % save_every == 0:
            torch.save(
                {
                    'epoch': e + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, save_path / ('epoch_%d.pt' % (e + 1)))

    model.eval()
    loss = _eval(model, test_ds, device)
    model.train()

    logging.info('final loss: %s' % loss)
    losses.append({'iter': 0, 'epoch': e + 1, 'loss': loss})

    with open(save_path / 'loss.pk', 'wb') as pkf:
        pickle.dump(losses, pkf)

    torch.save(
        {
            'epoch': args.epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, save_path / 'final.pt')
    print('done!')
Пример #6
0
train_loader, img_size = get_dataloader('./data/celeba_preprocessed',
                                        args.batch_size,
                                        n_train=-1,
                                        train=True)
test_loader, _ = get_dataloader('./data/celeba_preprocessed',
                                args.batch_size,
                                n_train=-1,
                                train=False)

save_img_every_n = 20

latent_dim = 512
lr = 2e-4
model = VAE(latent_dim).to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    reconstr = F.mse_loss(recon_x, x,
                          reduction='none').sum(3).sum(2).sum(1).mean(0)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = torch.mean(-0.5 *
                     torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1),
                     dim=0)
Пример #7
0
                        help='load pretrained model (default: False)')

    args = parser.parse_args()

    batch_loader = BatchLoader()
    parameters = Parameters(batch_loader.vocab_size)

    vae = VAE(parameters.vocab_size, parameters.embed_size,
              parameters.latent_size, parameters.decoder_rnn_size,
              parameters.decoder_rnn_num_layers)
    if args.use_trained:
        vae.load_state_dict(t.load('trained_VAE'))
    if args.use_cuda:
        vae = vae.cuda()

    optimizer = Adam(vae.parameters(), args.learning_rate)

    for iteration in range(args.num_iterations):
        '''Train step'''
        input, decoder_input, target = batch_loader.next_batch(
            args.batch_size, 'train', args.use_cuda)
        target = target.view(-1)

        logits, aux_logits, kld = vae(args.dropout, input, decoder_input)

        logits = logits.view(-1, batch_loader.vocab_size)
        cross_entropy = F.cross_entropy(logits, target, size_average=False)

        aux_logits = aux_logits.view(-1, batch_loader.vocab_size)
        aux_cross_entropy = F.cross_entropy(aux_logits,
                                            target,