def forward(self, g, n_feat, e_feat): """Predict molecule labels Parameters ---------- g : DGLGraph Input DGLGraph for molecule(s) n_feat : tensor of dtype float32 and shape (B1, D1) Node features. B1 for number of nodes and D1 for the node feature size. e_feat : tensor of dtype float32 and shape (B2, D2) Edge features. B2 for number of edges and D2 for the edge feature size. Returns ------- res : Predicted labels """ out = F.relu(self.lin0(n_feat)) # (B1, H1) h = out.unsqueeze(0) # (1, B1, H1) c = torch.zeros_like(h) for i in range(self.num_step_message_passing): m = F.relu(self.conv(g, out, e_feat)) # (B1, H1) if self.lstm_as_gate: out, (h, c) = self.lstm(m.unsqueeze(0), (h, c)) else: out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) return out
def forward(self, x, heatmap=None): """ x: (batch, c, x_dim, y_dim) """ coords = self.coords.repeat(x.size(0), 1, 1, 1) if self.with_boundary and heatmap is not None: boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) zero_tensor = torch.zeros_like(self.x_coords) xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to( zero_tensor.device) yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to( zero_tensor.device) coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1) x_and_coords = torch.cat([x, coords], dim=1) return x_and_coords
def train(self): if not self.is_parallel: writer = LogWriter(logdir=self.result_dir + "/log/") self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: print(self.result_dir, self.dataset, os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) model_list = glob( os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) print("resuming, model_list", model_list) if not len(model_list) == 0: model_list.sort() start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) print("resuming, start_iter", start_iter) self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) print(" [*] Load SUCCESS") if self.decay_flag and start_iter > (self.iteration // 2): self.G_optim._learning_rate -= (self.lr / (self.iteration // 2)) * ( start_iter - self.iteration // 2) self.D_optim._learning_rate -= (self.lr / (self.iteration // 2)) * ( start_iter - self.iteration // 2) # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): if self.decay_flag and step > (self.iteration // 2): self.G_optim._learning_rate -= (self.lr / (self.iteration // 2)) self.D_optim._learning_rate -= (self.lr / (self.iteration // 2)) try: real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A = real_A[0] real_B = real_B[0] ##some handling needed using paddle dataloader # Update D if hasattr(self.D_optim, "_optimizer"): # support meta optimizer self.D_optim._optimizer.clear_gradients() else: self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, torch.ones_like(real_GA_logit).to( self.device)) + self.MSE_loss( fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to( self.device)) + self.MSE_loss( fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, torch.ones_like(real_LA_logit).to( self.device)) + self.MSE_loss( fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to( self.device)) + self.MSE_loss( fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, torch.ones_like(real_GB_logit).to( self.device)) + self.MSE_loss( fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to( self.device)) + self.MSE_loss( fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, torch.ones_like(real_LB_logit).to( self.device)) + self.MSE_loss( fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to( self.device)) + self.MSE_loss( fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) / self.n_gpu D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) / self.n_gpu Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() if self.is_parallel: self.disGA.apply_collective_grads() self.disGB.apply_collective_grads() self.disLA.apply_collective_grads() self.disLB.apply_collective_grads() self.genA2B.apply_collective_grads() self.genB2A.apply_collective_grads() self.D_optim.minimize(Discriminator_loss) # Update G if hasattr(self.G_optim, "_optimizer"): # support meta optimizer self.G_optim._optimizer.clear_gradients() else: self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss( fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) G_ad_loss_LA = self.MSE_loss( fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) G_ad_loss_GB = self.MSE_loss( fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) G_ad_loss_LB = self.MSE_loss( fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to( self.device)) + self.BCE_loss( fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to( self.device)) + self.BCE_loss( fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device)) G_loss_A = (self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A) / self.n_gpu G_loss_B = (self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B) / self.n_gpu Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() if self.is_parallel: self.disGA.apply_collective_grads() self.disGB.apply_collective_grads() self.disLA.apply_collective_grads() self.disLB.apply_collective_grads() self.genA2B.apply_collective_grads() self.genB2A.apply_collective_grads() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step self.Rho_clipper(self.genA2B) self.Rho_clipper(self.genB2A) if not self.is_parallel: writer.add_scalar(tag="G/G_loss_A", step=step, value=G_loss_A.numpy()) writer.add_scalar(tag="G/G_loss_B", step=step, value=G_loss_B.numpy()) writer.add_scalar(tag="D/D_loss_A", step=step, value=D_loss_A.numpy()) writer.add_scalar(tag="D/D_loss_B", step=step, value=D_loss_B.numpy()) writer.add_scalar(tag="D/Discriminator_loss", step=step, value=Discriminator_loss.numpy()) writer.add_scalar(tag="D/Generator_loss", step=step, value=Generator_loss.numpy()) if step % 10 == 9: writer.add_image("fake_A2B", (porch.Tensor(fake_A2B[0] * 255)).clamp_( 0, 255).numpy().transpose( [1, 2, 0]).astype(np.uint8), step) writer.add_image("fake_B2A", (porch.Tensor(fake_B2A[0] * 255)).clamp_( 0, 255).numpy().transpose( [1, 2, 0]).astype(np.uint8), step) writer.add_image("fake_A2B2A", (porch.Tensor(fake_A2B[0] * 255)).clamp_( 0, 255).numpy().transpose( [1, 2, 0]).astype(np.uint8), step) writer.add_image("fake_B2A2B", (porch.Tensor(fake_B2A[0] * 255)).clamp_( 0, 255).numpy().transpose( [1, 2, 0]).astype(np.uint8), step) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A, real_B = real_A[0], real_B[0] fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = testA_iter.next() except: testA_iter = iter(self.testA_loader) real_A, _ = testA_iter.next() try: real_B, _ = testB_iter.next() except: testB_iter = iter(self.testB_loader) real_B, _ = testB_iter.next() real_A, real_B = real_A[0], real_B[0] fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) if (not self.is_parallel ) or fluid.dygraph.parallel.Env().local_rank == 0: cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: if (not self.is_parallel ) or fluid.dygraph.parallel.Env().local_rank == 0: self.save( os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: params = {} params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() if (not self.is_parallel ) or fluid.dygraph.parallel.Env().local_rank == 0: torch.save( params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
def truncate(x, thres=0.1): """Remove small values in heatmaps.""" return porch.where(x < thres, porch.zeros_like(x), x)