def main():
    CUDA = False
    if torch.cuda.is_available():
        CUDA = True
        print('cuda available')
        torch.backends.cudnn.benchmark = True
    config = config_process(parser.parse_args())
    print(config)

    with open('pkl/task_1_train.pkl', 'rb') as f:
        task_1_train = pkl.load(f)
    with open('pkl/task_1_test.pkl', 'rb') as g:
        task_1_test = pkl.load(g)

    ###### task 0:seen training data and unseen test data
    examples, labels, class_map = image_load(config['class_file'],
                                             config['image_label'])
    ###### task 0: seen test data
    examples_0, labels_0, class_map_0 = image_load(config['class_file'],
                                                   config['test_seen_classes'])

    datasets = split_byclass(config, examples, labels,
                             np.loadtxt(config['attributes_file']), class_map)
    datasets_0 = split_byclass(config, examples_0, labels_0,
                               np.loadtxt(config['attributes_file']),
                               class_map)
    print('load the task 0 train: {} the task 1 as test: {}'.format(
        len(datasets[0][0]), len(datasets[0][1])))
    print('load task 0 test data {}'.format(len(datasets_0[0][0])))

    test_attr = F.normalize(datasets[0][4])

    best_cfg = config
    best_cfg['n_classes'] = datasets[0][3].size(0)
    best_cfg['n_train_lbl'] = datasets[0][3].size(0)
    best_cfg['n_test_lbl'] = datasets[0][4].size(0)

    task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True)
    # task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False)

    base_model = models.__dict__[config['arch']](pretrained=False)
    if config['arch'].startswith('resnet'):
        FE_model = nn.Sequential(*list(base_model.children())[:-1])
    else:
        print('untested')
        raise NotImplementedError

    ###### if finetune == False,
    print('load pretrained FE_model')
    FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['softmax_method'], config['arch'],
        config['task_id'], config['finetune'], 'checkpoint.pth')

    FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE'])
    for name, para in FE_model.named_parameters():
        para.requires_grad = False

    vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'],
              latent_size=config['latent_size'],
              decoder_layer_sizes=config['decoder_layer_sizes'],
              num_labels=config['num_labels'])

    vae_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], 'vae', config['arch'], 0, config['finetune'],
        'ckpt.pth')
    vae.load_state_dict(torch.load(vae_path))

    print(vae)
    if CUDA:
        FE_model = FE_model.cuda()
        vae = vae.cuda()
    FE_model.eval()
    optimizer = torch.optim.Adam(vae.parameters(), lr=config['lr'])

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                config['step'],
                                                gamma=0.1,
                                                last_epoch=-1)
    criterion = loss_fn
    print('have got real trainval feats and labels')

    for epoch in range(config['epoch']):
        print('\n epoch: %d' % epoch)
        print('...TRAIN...')
        print_learning_rate(optimizer)
        ### train_attr--->test_attr, task_0_train_set---> task_1_train_set
        train(epoch, FE_model, vae, task_1_train_set, optimizer, criterion,
              test_attr, CUDA)
        scheduler.step()

    vae_ckpt_name = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['method'], config['softmax_method'],
        config['arch'], config['task_id'], config['finetune'], 'ckpt.pth')
    torch.save(vae.state_dict(), vae_ckpt_name)
Пример #2
0
model = nn.DataParallel(model)
model = model.to(device)

criterion = vae_loss_function
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

#
# Training
#

model.train()
for epoch in range(epochs):
    loss, bce, kld = 0, 0, 0

    max_batches = len(unlabeled_trainloader)
    for idx, (images, camera_index) in enumerate(unlabeled_trainloader):
        images = images.to(device)
        recon_images, mu, logvar = model(images)

        loss, bce, kld = criterion(recon_images, images, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 1000 == 0:
            print('[', epoch, '|', idx, '/', max_batches, ']', 'loss:',
                  loss.item(), 'bce:', bce.item(), 'kld:', kld.item())

    torch.save(model.state_dict(), 'vae-epoch-latest.torch')
Пример #3
0
def main():
    parser = _build_parser()
    args = parser.parse_args()

    logging.basicConfig(format="%(levelname)s: %(message)s",
                        level=logging.DEBUG)

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = None
    save_path = args.save
    if args.model == 'vae':
        model = VAE(args.dim)
        logging.info('training VAE with dims: {}'.format(args.dim))
    elif args.model == 'ae':
        model = AE(args.dim)
        logging.info('training AE with dims: {}'.format(args.dim))
    elif args.model == 'hm':
        model = HM(args.color)
    elif args.model == 'gmvae':
        model = GMVAE()
    else:
        logging.critical('model unimplemented: %s' % args.model)
        return

    if not save_path.exists():
        save_path.mkdir(parents=True)

    model = model.float()
    model.to(device)
    optimizer = optim.Adam(model.parameters())

    train_ds, test_ds = build_datasets(args.path)

    losses = []
    for e in range(args.epochs):
        logging.info('epoch: %d of %d' % (e + 1, args.epochs))

        loader = DataLoader(train_ds,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=torch.cuda.is_available())
        total_batches = len(train_ds) // args.batch_size

        log_every = total_batches // 50 + 1
        save_every = 1  # hardcoded for now
        for i, x in enumerate(loader):
            x = x.to(device)
            optimizer.zero_grad()
            output = model(x)
            total_loss = model.loss_function(output)
            if type(total_loss) is dict:  # TODO: generalize loss handling
                total_loss = total_loss['loss']

            total_loss.backward()
            optimizer.step()

            if i % log_every == 0:
                model.eval()
                loss = _eval(model, test_ds, device)
                model.train()

                logging.info('[batch %d/%d] ' % (i + 1, total_batches) +
                             model.print_loss(loss))
                # TODO: generalize printing
                # print_params = (i+1, total_batches, loss['loss'], loss['mse'], loss['kld'])
                # logging.info('[batch %d/%d] loss: %f, mse: %f, kld: %f' % print_params)
                # print_params = (i+1, total_batches, loss)
                # logging.info('[batch %d/%d] loss: %f' % print_params)
                losses.append({'iter': i, 'epoch': e, 'loss': loss})

        if e % save_every == 0:
            torch.save(
                {
                    'epoch': e + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, save_path / ('epoch_%d.pt' % (e + 1)))

    model.eval()
    loss = _eval(model, test_ds, device)
    model.train()

    logging.info('final loss: %s' % loss)
    losses.append({'iter': 0, 'epoch': e + 1, 'loss': loss})

    with open(save_path / 'loss.pk', 'wb') as pkf:
        pickle.dump(losses, pkf)

    torch.save(
        {
            'epoch': args.epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, save_path / 'final.pt')
    print('done!')
Пример #4
0
    print('feature data path: ' + path)
    X = io.mmread(path).A.astype('float32')
    args.n, args.d = X.shape

    # X = normalize(X, norm='l2', axis=0)

    vae = VAE(args)
    if args.load:
        path = '{}/model/vae_{}_{}'.format(params['dataset_dir'],
                                           params['dataset'], layer)
        vae.load_state_dict(torch.load(path))

    vae.to(args.device)

    # psm = PSM(torch.from_numpy(R.astype('float32')).to(args.device), args).to(args.device)

    loader = DataLoader(np.arange(args.n), batch_size=1, shuffle=True)
    optimizer = optim.Adam(vae.parameters(), lr=args.lr)

    evaluator = Evaluator({'recall', 'dcg_cut'})
    # variational()
    # maximum()
    # evaluate()
    train()

    if args.save:
        vae.cpu()
        path = '{}/model/vae_{}_{}'.format(params['dataset_dir'],
                                           params['dataset'], layer)
        torch.save(vae.state_dict(), path)
Пример #5
0
            test_loss += loss
            test_reconstr += reconstr
            test_kld += kld
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], recon_batch[:n]])
                writer.add_image('Test/Reconstructions', make_grid(comparison),
                                 epoch)

    test_loss /= len(test_loader)
    test_reconstr /= len(test_loader)
    test_kld /= len(test_loader)

    # After finishing epoch 0, we have seen 1 * len(train_loader) batches
    x_value = len(train_loader) * (epoch + 1)
    writer.add_scalar('Test/Loss_Combined', test_loss, x_value)
    writer.add_scalar('Test/Reconstr', test_reconstr, x_value)
    writer.add_scalar('Test/KLD', test_kld, x_value)


if __name__ == "__main__":
    run_name = f'{args.run_name}_bs={args.batch_size}_beta={args.beta}'
    writer = SummaryWriter(f'./vae_tensorboard_logs/{run_name}')
    checkpoint_dir = f'./vae_checkpoints/{run_name}'
    os.makedirs(checkpoint_dir, exist_ok=True)
    for epoch in trange(args.epochs, leave=True, desc='Epoch'):
        train(epoch, writer)
        test(epoch, writer)
        scheduler.step()
        torch.save(model.state_dict(), f'{checkpoint_dir}/epoch_{epoch}.pt')