def pred(config, mode='cifar10'): if mode == 'cifar10': obs = (3, 32, 32) sample_batch_size = 25 model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters, input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda() if config.load_params: load_part_of_model(model, config.load_params) print('model parameters loaded') sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix) rescaling_inv = lambda x: .5 * x + .5 def sample(model): model.train(False) data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2]) data = data.cuda() for i in range(obs[1]): for j in range(obs[2]): with torch.no_grad(): data_v = data out = model(data_v, sample=True) out_sample = sample_op(out) data[:, :, i, j] = out_sample.data[:, :, i, j] return data print('sampling...') sample_t = sample(model) sample_t = rescaling_inv(sample_t) save_image(sample_t, 'images/sample.png', nrow=5, padding=0)
def main(): net = R3Net_prior(motion=args['motion'], se_layer=args['se_layer'], attention=args['attention'], pre_attention=args['pre_attention'], isTriplet=args['isTriplet'], basic_model=args['basic_model'], sta=args['sta'], naive_fuse=args['naive']).cuda().train() # fix_parameters(net.named_parameters()) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum']) if len(args['snapshot']) > 0: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] if len(args['pretrain']) > 0: print('pretrain model from ' + args['pretrain']) net = load_part_of_model(net, args['pretrain'], device_id=device_id) check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(log_path, 'w').write(str(args) + '\n\n') train(net, optimizer)
def train(config, mode='cifar10'): model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(config.lr, config.nr_resnet, config.nr_filters) try: os.makedirs('models') os.makedirs('images') # print('mkdir:', config.outfile) except OSError: pass seed = np.random.randint(0, 10000) print("Random Seed: ", seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) cudnn.benchmark = True trainset, train_loader, testset, test_loader, classes = load_data(mode=mode, batch_size=config.batch_size) if mode == 'cifar10' or mode == 'faces': obs = (3, 32, 32) loss_op = lambda real, fake: discretized_mix_logistic_loss(real, fake, config.nr_logistic_mix) sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix) elif mode == 'mnist': obs = (1, 28, 28) loss_op = lambda real, fake: discretized_mix_logistic_loss_1d(real, fake, config.nr_logistic_mix) sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, config.nr_logistic_mix) sample_batch_size = 25 rescaling_inv = lambda x: .5 * x + .5 model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters, input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config.lr_decay) if config.load_params: load_part_of_model(model, config.load_params) print('model parameters loaded') def sample(model): model.train(False) data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2]) data = data.cuda() with tqdm(total=obs[1] * obs[2]) as pbar: for i in range(obs[1]): for j in range(obs[2]): with torch.no_grad(): data_v = data out = model(data_v, sample=True) out_sample = sample_op(out) data[:, :, i, j] = out_sample.data[:, :, i, j] pbar.update(1) return data print('starting training') for epoch in range(config.max_epochs): model.train() torch.cuda.synchronize() train_loss = 0. time_ = time.time() with tqdm(total=len(train_loader)) as pbar: for batch_idx, (data, label) in enumerate(train_loader): data = data.requires_grad_(True).cuda() output = model(data) loss = loss_op(data, output) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() pbar.update(1) deno = batch_idx * config.batch_size * np.prod(obs) print('train loss : %s' % (train_loss / deno), end='\t') # decrease learning rate scheduler.step() model.eval() test_loss = 0. with tqdm(total=len(test_loader)) as pbar: for batch_idx, (data, _) in enumerate(test_loader): data = data.requires_grad_(False).cuda() output = model(data) loss = loss_op(data, output) test_loss += loss.item() del loss, output pbar.update(1) deno = batch_idx * config.batch_size * np.prod(obs) print('test loss : {:.4f}, time : {:.4f}'.format((test_loss / deno), (time.time() - time_))) torch.cuda.synchronize() if (epoch + 1) % config.save_interval == 0: torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch)) print('sampling...') sample_t = sample(model) sample_t = rescaling_inv(sample_t) save_image(sample_t, 'images/{}_{}.png'.format(model_name, epoch), nrow=5, padding=0)
return x_out if __name__ == '__main__': # img = torch.zeros(8, 3, 32, 32).float().uniform_(-1, 1).cuda() # # img = torch.zeros(8, 3, 32, 32).float().cuda() # model = PixelCNN(nr_resnet=3, nr_filters=100, input_channels=img.size(1)).cuda() # out = model(img) # # loss = discretized_mix_logistic_loss(img, out) # print('loss : %s' % loss.item()) img = torch.zeros(1, 3, 32, 32).float().cuda() model = PixelCNN(nr_resnet=5, nr_filters=160, input_channels=img.size(1), nr_logistic_mix=10).cuda() load_part_of_model(model, 'models/pcnn_lr_0.00020_nr-resnet5_nr-filters160_58.pth') sample_op = lambda x: sample_from_discretized_mix_logistic(x, 10) from tqdm import tqdm with tqdm(total=32*32) as pbar: for i in range(32): for j in range(32): with torch.no_grad(): data_v = img out = model(data_v, sample=True) out_sample = sample_op(out) img[:, :, i, j] = out_sample.data[:, :, i, j] pbar.update(1) from torchvision.utils import save_image save_image(img.data.cpu(), '1.jpg', nrow=1)