def bp(self, fake_logit: Variable, prob): bs = fake_logit.size()[0] self.opt.zero_grad() reward = torch.tanh(fake_logit.detach()) # loss = -(torch.mean(torch.log(prob) * reward)).backward() torch.log(prob).backward(-reward / bs) self.opt.step()
def forward(self, x: List): """ Forward pass for all architecture :param x: Has different meaning with different mode of training :return: """ if self.mode == 1: ''' Variable length training. This mode runs for one more than the length of program for producing stop symbol. Note that there is no padding as is done in traditional RNN for variable length programs. This is done mainly because of computational efficiency of forward pass, that is, each batch contains only programs of same length and losses from all batches of different time-lengths are combined to compute gradient and update in the network. This ensures that every update of the network has equal contribution coming from programs of different lengths. Training is done using the script train_synthetic.py ''' data, input_op, program_len = x assert data.size()[0] == program_len + 1, "Incorrect stack size!!" batch_size = data.size()[1] h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) x_f = x_f.view(1, batch_size, self.in_sz) if not self.tf: outputs = [] for timestep in range(0, program_len + 1): # X_f is always input to the RNN at every time step # along with previous predicted label input_op_rnn = self.relu( self.dense_input_op(input_op[:, timestep, :])) input_op_rnn = input_op_rnn.view(1, batch_size, self.input_op_sz) #input_op_rnn = torch.zeros((1, batch_size, self.input_op_sz)).cuda() input = torch.cat((self.drop(x_f), input_op_rnn), 2) h, _ = self.rnn(input, h) hd = self.relu(self.dense_fc_1(self.drop(h[0]))) output = self.dense_output(self.drop(hd)) outputs.append(output) return torch.stack(outputs) else: # remove stop token for input to decoder input_op_rnn = self.relu( self.dense_input_op(input_op))[:, :-1, :].permute(1, 0, 2) # input_op_rnn = torch.zeros((program_len+1, batch_size, self.input_op_sz)).cuda() x_f = x_f.repeat(program_len + 1, 1, 1) input = torch.cat((self.drop(x_f), input_op_rnn), 2) output, h = self.rnn(input, h) output = self.relu(self.dense_fc_1(self.drop(output))) output = self.dense_output(self.drop(output)) return output elif self.mode == 2: '''Train variable length RL''' # program length in this case is the maximum time step that RNN runs data, input_op, program_len = x batch_size = data.size()[1] h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) x_f = x_f.view(1, batch_size, self.in_sz) outputs = [] samples = [] temp_input_op = input_op[:, 0, :] for timestep in range(0, program_len): # X_f is the input to the RNN at every time step along with previous # predicted label input_op_rnn = self.relu(self.dense_input_op(temp_input_op)) input_op_rnn = input_op_rnn.view(1, batch_size, self.input_op_sz) input = torch.cat((x_f, input_op_rnn), 2) h, _ = self.rnn(input, h) hd = self.relu(self.dense_fc_1(self.drop(h[0]))) dense_output = self.dense_output(self.drop(hd)) output = self.logsoftmax(dense_output) # output for loss, these are log-probabs outputs.append(output) output_probs = self.softmax(dense_output) # Get samples from output probabs based on epsilon greedy way # Epsilon will be reduced to 0 gradually following some schedule if np.random.rand() < self.epsilon: # This is during training sample = torch.multinomial(output_probs, 1) else: # This is during testing sample = torch.max(output_probs, 1)[1].view(batch_size, 1) # Stopping the gradient to flow backward from samples sample = sample.detach() samples.append(sample) # Create next input to the RNN from the sampled instructions arr = Variable( torch.zeros(batch_size, self.num_draws + 1).scatter_( 1, sample.data.cpu(), 1.0)).cuda() arr = arr.detach() temp_input_op = arr return [outputs, samples] else: assert False, "Incorrect mode!!"
def train_simple_trans(): opt = TrainOptions().parse() data_root = 'data/processed' train_params = {'lr': 0.01, 'epoch_milestones': (100, 500)} # dataset = DynTexNNFTrainDataset(data_root, 'flame') dataset = DynTexFigureTrainDataset(data_root, 'flame') dataloader = DataLoader(dataset=dataset, batch_size=opt.batchsize, num_workers=opt.num_workers, shuffle=True) nnf_conf = 3 syner = Synthesiser() nnfer = NNFPredictor(out_channel=nnf_conf) if torch.cuda.is_available(): syner = syner.cuda() nnfer = nnfer.cuda() optimizer_nnfer = Adam(nnfer.parameters(), lr=train_params['lr']) table = Table() for epoch in range(opt.epoch): pbar = tqdm(total=len(dataloader), desc='epoch#{}'.format(epoch)) pbar.set_postfix({'loss': 'N/A'}) loss_tot = 0.0 gamma = epoch / opt.epoch for i, (source_t, target_t, source_t1, target_t1) in enumerate(dataloader): if torch.cuda.is_available(): source_t = Variable(source_t, requires_grad=True).cuda() target_t = Variable(target_t, requires_grad=True).cuda() source_t1 = Variable(source_t1, requires_grad=True).cuda() target_t1 = Variable(target_t1, requires_grad=True).cuda() nnf = nnfer(source_t, target_t) if nnf_conf == 3: nnf = nnf[:, : 2, :, :] * nnf[:, 2:, :, :] # mask via the confidence # --- synthesis --- target_predict = syner(source_t, nnf) target_t1_predict = syner(source_t1, nnf) loss_t = tnf.mse_loss(target_predict, target_t) loss_t1 = tnf.mse_loss(target_t1_predict, target_t1) loss = loss_t + loss_t1 optimizer_nnfer.zero_grad() loss.backward() optimizer_nnfer.step() loss_tot += float(loss_t1) # --- vis --- name = os.path.join(data_root, '../result/', str(epoch), '{}.png'.format(str(i))) index = str(epoch) + '({})'.format(i) if not os.path.exists('/'.join(name.split('/')[:-1])): os.makedirs('/'.join(name.split('/')[:-1])) cv2.imwrite( name.replace('.png', '_s.png'), (source_t.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_s1.png'), (source_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_t.png'), (target_t.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_p.png'), (target_predict.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_t1.png'), (target_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite(name.replace('.png', '_p1.png'), (target_t1_predict.detach().cpu().numpy()[0].transpose( 1, 2, 0) * 255).astype('int')) # vis in table table.add( index, os.path.abspath(name.replace('.png', '_s.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_s1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_t.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_t1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_p.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_p1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) pbar.set_postfix({'loss': str(loss_tot / (i + 1))}) pbar.update(1) table.build_html('data/') pbar.close()
def train_complex_trans(): opt = TrainOptions().parse() data_root = 'data/processed' train_params = {'lr': 0.001, 'epoch_milestones': (100, 500)} dataset = DynTexFigureTransTrainDataset(data_root, 'flame') dataloader = DataLoader(dataset=dataset, batch_size=opt.batchsize, num_workers=opt.num_workers, shuffle=True) nnf_conf = 3 syner = Synthesiser() nnfer = NNFPredictor(out_channel=nnf_conf) flownet = NNFPredictor(out_channel=nnf_conf) if torch.cuda.is_available(): syner = syner.cuda() nnfer = nnfer.cuda() flownet = flownet.cuda() optimizer_nnfer = Adam(nnfer.parameters(), lr=train_params['lr']) optimizer_flow = Adam(flownet.parameters(), lr=train_params['lr'] * 0.1) scheduler_nnfer = lr_scheduler.MultiStepLR( optimizer_nnfer, gamma=0.1, last_epoch=-1, milestones=train_params['epoch_milestones']) scheduler_flow = lr_scheduler.MultiStepLR( optimizer_flow, gamma=0.1, last_epoch=-1, milestones=train_params['epoch_milestones']) table = Table() writer = SummaryWriter(log_dir=opt.log_dir) for epoch in range(opt.epoch): scheduler_flow.step() scheduler_nnfer.step() pbar = tqdm(total=len(dataloader), desc='epoch#{}'.format(epoch)) pbar.set_postfix({'loss': 'N/A'}) loss_tot = 0.0 for i, (source_t, target_t, source_t1, target_t1) in enumerate(dataloader): if torch.cuda.is_available(): source_t = Variable(source_t, requires_grad=True).cuda() target_t = Variable(target_t, requires_grad=True).cuda() source_t1 = Variable(source_t1, requires_grad=True).cuda() target_t1 = Variable(target_t1, requires_grad=True).cuda() nnf = nnfer(source_t, target_t) flow = flownet(source_t, source_t1) # mask... if nnf_conf == 3: nnf = nnf[:, : 2, :, :] * nnf[:, 2:, :, :] # mask via the confidence flow = flow[:, :2, :, :] * flow[:, 2:, :, :] # --- synthesis --- source_t1_predict = syner(source_t, flow) # flow penalty target_flow = syner(flow, nnf) # predict flow target_t1_predict = syner(target_t, target_flow) # target_t1_predict = syner(source_t1, nnf) loss_t1_f = tnf.mse_loss(source_t1, source_t1_predict) # flow penalty loss_t1 = tnf.mse_loss(target_t1_predict, target_t1) # total penalty loss = loss_t1_f + loss_t1 * 2 optimizer_flow.zero_grad() optimizer_nnfer.zero_grad() loss.backward() optimizer_nnfer.step() optimizer_flow.step() loss_tot += float(loss) # --- vis --- if epoch % 10 == 0 and i % 2 == 0: name = os.path.join(data_root, '../result/', str(epoch), '{}.png'.format(str(i))) index = str(epoch) + '({})'.format(i) if not os.path.exists('/'.join(name.split('/')[:-1])): os.makedirs('/'.join(name.split('/')[:-1])) cv2.imwrite( name.replace('.png', '_s.png'), (source_t.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_s1.png'), (source_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_t.png'), (target_t.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_p.png'), (source_t1_predict.detach().cpu().numpy()[0].transpose( 1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_t1.png'), (target_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255).astype('int')) cv2.imwrite( name.replace('.png', '_p1.png'), (target_t1_predict.detach().cpu().numpy()[0].transpose( 1, 2, 0) * 255).astype('int')) # vis in table table.add( index, os.path.abspath(name.replace('.png', '_s.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_s1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_t.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_t1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_p.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) table.add( index, os.path.abspath(name.replace('.png', '_p1.png')).replace( '/mnt/cephfs_hl/lab_ad_idea/maoyiming', '')) pbar.set_postfix({'loss': str(loss_tot / (i + 1))}) writer.add_scalar('scalars/{}/loss_train'.format(opt.time), float(loss), i + int(epoch * len(dataloader))) writer.add_scalar('scalars/{}/lr'.format(opt.time), float(scheduler_nnfer.get_lr()[0]), i + int(epoch * len(dataloader))) pbar.update(1) table.build_html('data/') pbar.close() writer.close()
for epoch in range(opt.n_epochs): print('Epoch {}'.format(epoch)) for i, (batch_A, batch_B) in enumerate(dataloader_train): real = Variable(Tensor(batch_A.size(0), 1).fill_(1), requires_grad=False) fake = Variable(Tensor(batch_A.size(0), 1).fill_(0), requires_grad=False) imgs_real_A = Variable(batch_A.type(Tensor)) imgs_real_B = Variable(batch_B.type(Tensor)) # == Discriminator update == # optimizer_D.zero_grad() imgs_fake = Variable(generator(imgs_real_A.detach())) d_loss = gan_loss(discriminator(imgs_real_A, imgs_real_B), real) + gan_loss( discriminator(imgs_real_A, imgs_fake), fake) d_loss.backward() optimizer_D.step() # == Generator update == # imgs_fake = generator(imgs_real_A) optimizer_G.zero_grad() g_loss = gan_loss( discriminator(imgs_real_A, imgs_fake),
for epoch in range(opt.n_epochs): print('Epoch {}'.format(epoch)) for i, (batch_lr, batch_hr) in enumerate(dataloader_train): real = Variable(Tensor(batch_lr.size(0), 1).fill_(1), requires_grad=False) fake = Variable(Tensor(batch_lr.size(0), 1).fill_(0), requires_grad=False) imgs_real_lr = Variable(batch_lr.type(Tensor)) imgs_real_hr = Variable(batch_hr.type(Tensor)) # == Discriminator update == # optimizer_D.zero_grad() imgs_fake_hr = Variable(generator(imgs_real_lr.detach())) d_loss = gan_loss(discriminator(imgs_real_hr), real) + gan_loss( discriminator(imgs_fake_hr), fake) d_loss.backward() optimizer_D.step() # == Generator update == # imgs_fake_hr = generator(imgs_real_lr) optimizer_G.zero_grad() g_loss = (1 / 12.75) * content_loss( vgg(imgs_fake_hr), vgg(imgs_real_hr.detach())) + 1e-3 * gan_loss( discriminator(imgs_fake_hr), real)