def compute_mask_loss(self, occ_mask, warped_image, tgt_image): """ Compute losses on the generated occlusion mask. Args: occ_mask (tensor): Generated occlusion masks. warped_image (tensor): Warped image using the flow map. tgt_image (tensor): Target image for the warped image. Returns: (tensor): Loss for the mask. """ loss_mask = dg.to_variable(np.zeros((1, )).astype("float32")) if occ_mask is not None: dummy0 = L.zeros_like(occ_mask) dummy1 = L.ones_like(occ_mask) # Compute the confidence map based L1 distance between warped and GT image. img_diff = L.reduce_sum(L.abs(warped_image - tgt_image), 1, keep_dim=True) conf = L.clip(1 - img_diff, 0, 1) # Force mask value to be small if warped image is similar to GT, and vice versa. loss_mask = self.criterionMasked(occ_mask, dummy0, conf) loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf) return loss_mask
def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, tgt_type_id=3): model.eval() _, logits, info = model(q_ids, q_sids) gen_ids = L.argmax(logits, -1) d_batch, d_seqlen = q_ids.shape seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) has_stopped = np.zeros([d_batch], dtype=np.bool) gen_seq_len = np.zeros([d_batch], dtype=np.int64) output_ids = [] past_cache = info['caches'] cls_ids = L.ones([d_batch], dtype='int64') * sos_id attn_ids = L.ones([d_batch], dtype='int64') * attn_id ids = L.stack([cls_ids, attn_ids], -1) for step in range(max_decode_len): bias = gen_bias(q_ids, ids, step) pos_ids = D.to_variable( np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1])) pos_ids += seqlen _, logits, info = model(ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache) gen_ids = L.argmax(logits, -1) past_cached_k, past_cached_v = past_cache cached_k, cached_v = info['caches'] cached_k = [ L.concat([pk, k[:, :1, :]], 1) for pk, k in zip(past_cached_k, cached_k) ] # concat cached cached_v = [ L.concat([pv, v[:, :1, :]], 1) for pv, v in zip(past_cached_v, cached_v) ] past_cache = (cached_k, cached_v) gen_ids = gen_ids[:, 1] ids = L.stack([gen_ids, attn_ids], 1) gen_ids = gen_ids.numpy() has_stopped |= (gen_ids == eos_id).astype(np.bool) gen_seq_len += (1 - has_stopped.astype(np.int64)) output_ids.append(gen_ids.tolist()) if has_stopped.all(): break output_ids = np.array(output_ids).transpose([1, 0]) return output_ids
def compute_mask_losses(self, occ_mask, fake_image, warped_image, tgt_label, tgt_image, fg_mask, ref_fg_mask, body_mask_diff): """ Compute losses on the generated occlusion masks. Args: occ_mask (tensor or list of tensors): Generated occlusion masks. fake_image (tensor): Generated image. warped_image (tensor or list of tensors): Warped images using the flow maps. tgt_label (tensor): Target label map. tgt_image (tensor): Target image for the warped image. fg_mask (tensor): Foreground mask for the reference image. body_fg_mask (tensor): Difference between warped body part map and target body part map. Used for pose dataset only. """ loss_mask = dg.to_variable(np.zeros((1, )).astype("float32")) if isinstance(occ_mask, list): # Compute occlusion mask losses for both warping reference -> target and previous -> target. for i in range(len(occ_mask)): loss_mask += self.compute_mask_loss(occ_mask[i], warped_image[i], tgt_image) else: # Compute loss for warping either reference or previous images. loss_mask += self.compute_mask_loss(occ_mask, warped_image, tgt_image) if self.warp_ref: ref_occ_mask = occ_mask[0] dummy0 = L.zeros_like(ref_occ_mask) dummy1 = L.ones_like(ref_occ_mask) if self.for_pose_dataset: # Enforce output to use more warped reference image for face region. face_mask = L.unsqueeze(get_face_mask(tgt_label[:, 2]), [1]) face_mask = L.pool2d(face_mask, pool_size=15, pool_type='avg', pool_stride=1, pool_padding=7) loss_mask += self.criterionMasked(ref_occ_mask, dummy0, face_mask) loss_mask += self.criterionMasked(fake_image, warped_image[0], face_mask) # Enforce output to use more hallucinated image for discrepancy # regions of body part masks between warped reference and target image. loss_mask += self.criterionMasked(ref_occ_mask, dummy1, body_mask_diff) if self.has_fg: # Enforce output to use more hallucinated image for discrepancy regions # of foreground masks between reference and target image. fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).astype("float32") loss_mask += self.criterionMasked(ref_occ_mask, dummy1, fg_mask_diff) return loss_mask
def gen_bias(encoder_inputs, decoder_inputs, step): decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] attn_bias = L.reshape(L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) decoder_bias = L.cast((L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), 'float32') #[1, 1, decoderlen, decoderlen] encoder_bias = L.unsqueeze(L.cast(L.ones_like(encoder_inputs), 'float32'), [1]) #[bsz, 1, encoderlen] encoder_bias = L.expand(encoder_bias, [1, decoder_seqlen, 1]) #[bsz,decoderlen, encoderlen] decoder_bias = L.expand(decoder_bias, [decoder_bsz, 1, 1]) #[bsz, decoderlen, decoderlen] if step > 0: bias = L.concat([encoder_bias, L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias], -1) else: bias = L.concat([encoder_bias, decoder_bias], -1) return bias
def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh): """Implementation of self-attention graph pooling (SAGPool) This is an implementation of the paper SELF-ATTENTION GRAPH POOLING (https://arxiv.org/pdf/1904.08082.pdf) Args: gw: Graph wrapper object. feature: A tensor with shape (num_nodes, feature_size). ratio: The pooling ratio of nodes we want to select. graph_id: The graphs that the nodes belong to. dataset: To differentiate FRANKENSTEIN dataset and other datasets. name: The name of SAGPool layer. activation: The activation function. Return: new_feature: A tensor with shape (num_nodes, feature_size), and the unselected nodes' feature is masked by zero. ratio_length: The selected node numbers of each graph. """ if dataset == "FRANKENSTEIN": gcn_ = gcn else: gcn_ = norm_gcn score = gcn_(gw=gw, feature=feature, hidden_size=1, activation=None, norm=gw.node_feat["norm"], name=name) score = L.squeeze(score, axes=[]) perm, ratio_length = topk_pool(gw, score, graph_id, ratio) mask = L.zeros_like(score) mask = L.cast(mask, dtype="float32") updates = L.ones_like(perm) updates = L.cast(updates, dtype="float32") mask = L.scatter(mask, perm, updates) new_feature = L.elementwise_mul(feature, mask, axis=0) temp_score = activation(score) new_feature = L.elementwise_mul(new_feature, temp_score, axis=0) return new_feature, ratio_length
def forward(self, x, y): """Forward network""" if self.bias_x: x = layers.concat((x, layers.ones_like(x[:, :, :1])), axis=-1) if self.bias_y: y = layers.concat((y, layers.ones_like(x[:, :, :1])), axis=-1) # x.shape=(b, m, i) b = x.shape[0] # self.weight.shape=(o, i, j) o = self.weight.shape[0] x = layers.expand(layers.unsqueeze(x, axes=[1]), expand_times=(1, o, 1, 1)) weight = layers.expand(layers.unsqueeze(self.weight, axes=[0]), expand_times=(b, 1, 1, 1)) y = layers.expand(layers.unsqueeze(y, axes=[1]), expand_times=(1, o, 1, 1)) # s.shape=(b, o, m, n), that is, [batch_size, n_out, seq_len, seq_len] s = layers.matmul(layers.matmul(x, weight), layers.transpose(y, perm=(0, 1, 3, 2))) # remove dim 1 if n_out == 1 if s.shape[1] == 1: s = layers.squeeze(s, axes=[1]) return s
def neighbor_aggregator(self, sent_repr): #norm = L.clamp(L.reshape(L.cast(self.graph_wrapper.indegree(), dtype="float32"), [-1, 1]), min=1.) norm = L.ones_like(sent_repr) def send_func(src, dst , edge): return src["h"] msg = self.graph_wrapper.send(send_func, nfeat_list=[("h", norm)]) norm = self.graph_wrapper.recv(msg, "sum") norm = L.reduce_mean(norm, -1, keep_dim=True) norm = L.clamp(norm, min=1.0) return gcn(self.graph_wrapper, sent_repr, self.hidden_size, activation="relu", name="gcn") / norm
def forward(self, *items): """Forward network""" if self.training and self.p > 0: masks = [ layers.uniform_random(shape=x.shape[:2], min=0, max=1) >= self.p for x in items ] masks = [layers.cast(x, 'float32') for x in masks] total = layers.elementwise_add(*masks) scale = len(items) / layers.elementwise_max( total, layers.ones_like(total)) masks = [mask * scale for mask in masks] items = [ item * layers.unsqueeze(mask, axes=[-1]) for item, mask in zip(items, masks) ] return items
def ernie_send(src_feat, dst_feat, edge_feat): """doc""" cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1) src_ids = L.concat([cls, src_feat["term_ids"]], 1) dst_ids = dst_feat["term_ids"] sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) term_ids = L.concat([src_ids, dst_ids], 1) term_ids.stop_gradient = True sent_ids.stop_gradient = True ernie = ErnieModel(term_ids, sent_ids, config=self.config.ernie_config) feature = ernie.get_pooled_output() return feature
def ernie_send(src_feat, dst_feat, edge_feat): def build_position_ids(term_ids): input_mask = L.cast(term_ids > 0, "int64") position_ids = L.cumsum(input_mask, axis=1) - 1 return position_ids """doc""" # input_ids cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1], "int64", self.config.cls_id) src_ids = L.concat([cls, src_feat["term_ids"]], 1) dst_ids = dst_feat["term_ids"] # sent_ids sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) term_ids = L.concat([src_ids, dst_ids], 1) # position_ids position_ids = build_position_ids(term_ids) ernie_model = ErnieModel(self.config.ernie_config, "") feature, _ = ernie_model(term_ids, sent_ids, position_ids) return feature
def train(self): place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() with fluid.dygraph.guard(place): self.genA2B.train() self.genB2A.train() self.disGA.train() self.disGB.train() self.disLA.train() self.disLB.train() if self.resume: files_list = os.listdir(self.model_path) if len(files_list) > 0: files = [] print("exist files") for i in files_list: file_ = os.path.splitext(i)[1] files.append(file_) if ".pdparams" in files_list or ".pdopt" in files_list: print("exist model") genA2B_para = fluid.load_dygraph(self.model_path + 'g_A2B') genB2A_para = fluid.load_dygraph(self.model_path + 'g_B2A') disGA_para = fluid.load_dygraph(self.model_path + 'd_GA') disGB_para = fluid.load_dygraph(self.model_path + 'd_GB') disLA_para = fluid.load_dygraph(self.model_path + 'd_LA') disLB_para = fluid.load_dygraph(self.model_path + 'd_LB') G_opt = fluid.load_dygraph(self.model_path + 'G_op') D_opt = fluid.load_dygraph(self.model_path + 'D_op') self.genA2B.load_dict(genA2B_para) self.genB2A.load_dict(genB2A_para) self.disGA.load_dict(disGA_para) self.disGB.load_dict(disGB_para) self.disLA.load_dict(disLA_para) self.disLB.load_dict(disLB_para) self.G_optim.set_dict(G_opt) self.D_optim.set_dict(D_opt) print(" [*] Load SUCCESS") else: print(" No Model!") else: print("No Files") # training loop print('training start !') start_iter = 1 for step in range(start_iter, self.iteration + 1): trainA_iter = iter(self.trainA_loader()) real_A = next(trainA_iter) real_A = paddle.fluid.dygraph.to_variable(np.array(real_A)) real_A = real_A / 255.0 trainB_iter = iter(self.trainB_loader()) real_B = next(trainB_iter) real_B = paddle.fluid.dygraph.to_variable(np.array(real_B)) real_B = real_B / 255.0 # Update D self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, fluid.dygraph.to_variable( ones_like(real_GA_logit))) + self.MSE_loss( fake_GA_logit, fluid.dygraph.to_variable( zeros_like(fake_GA_logit))) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, fluid.dygraph.to_variable( ones_like(real_GA_cam_logit))) + self.MSE_loss( fake_GA_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_GA_cam_logit))) D_ad_loss_LA = self.MSE_loss( real_LA_logit, fluid.dygraph.to_variable( ones_like(real_LA_logit))) + self.MSE_loss( fake_LA_logit, fluid.dygraph.to_variable( zeros_like(fake_LA_logit))) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, fluid.dygraph.to_variable( ones_like(real_LA_cam_logit))) + self.MSE_loss( fake_LA_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_LA_cam_logit))) D_ad_loss_GB = self.MSE_loss( real_GB_logit, fluid.dygraph.to_variable( ones_like(real_GB_logit))) + self.MSE_loss( fake_GB_logit, fluid.dygraph.to_variable( zeros_like(fake_GB_logit))) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, fluid.dygraph.to_variable( ones_like(real_GB_cam_logit))) + self.MSE_loss( fake_GB_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_GB_cam_logit))) D_ad_loss_LB = self.MSE_loss( real_LB_logit, fluid.dygraph.to_variable( ones_like(real_LB_logit))) + self.MSE_loss( fake_LB_logit, fluid.dygraph.to_variable( zeros_like(fake_LB_logit))) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, fluid.dygraph.to_variable( ones_like(real_LB_cam_logit))) + self.MSE_loss( fake_LB_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_LB_cam_logit))) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) # Update G self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss( fake_GA_logit, fluid.dygraph.to_variable(ones_like(fake_GA_logit))) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, fluid.dygraph.to_variable(ones_like(fake_GA_cam_logit))) G_ad_loss_LA = self.MSE_loss( fake_LA_logit, fluid.dygraph.to_variable(ones_like(fake_LA_logit))) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, fluid.dygraph.to_variable(ones_like(fake_LA_cam_logit))) G_ad_loss_GB = self.MSE_loss( fake_GB_logit, fluid.dygraph.to_variable(ones_like(fake_GB_logit))) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, fluid.dygraph.to_variable(ones_like(fake_GB_cam_logit))) G_ad_loss_LB = self.MSE_loss( fake_LB_logit, fluid.dygraph.to_variable(ones_like(fake_LB_logit))) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, fluid.dygraph.to_variable(ones_like(fake_LB_cam_logit))) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, fluid.dygraph.to_variable( ones_like(fake_B2A_cam_logit))) + self.BCE_loss( fake_A2A_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_A2A_cam_logit))) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, fluid.dygraph.to_variable( ones_like(fake_A2B_cam_logit))) + self.BCE_loss( fake_B2B_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_B2B_cam_logit))) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step clip_rho(self.genA2B, vmin=0, vmax=1) clip_rho(self.genB2A, vmin=0, vmax=1) if step % 50 == 0: print("[%5d/%5d] d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: print("print img!") train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): trainA_iter = iter(self.trainA_loader()) real_A = next(trainA_iter) real_A = paddle.fluid.dygraph.to_variable( np.array(real_A)) real_A = real_A / 255.0 trainB_iter = iter(self.trainB_loader()) real_B = next(trainB_iter) real_B = paddle.fluid.dygraph.to_variable( np.array(real_B)) real_B = real_B / 255.0 fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A( fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B( fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): testA_iter = iter(self.testA_loader()) real_A = next(testA_iter) real_A = paddle.fluid.dygraph.to_variable( np.array(real_A)) real_A = real_A / 255.0 testB_iter = iter(self.testB_loader()) real_B = next(testB_iter) real_B = paddle.fluid.dygraph.to_variable( np.array(real_B)) real_B = real_B / 255.0 fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A( fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B( fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, 'B2A_%07d.png' % step), B2A * 255.0) if step % self.save_freq == 0: fluid.save_dygraph(self.genA2B.state_dict(), self.model_path + 'g_A2B') fluid.save_dygraph(self.genB2A.state_dict(), self.model_path + 'g_B2A') fluid.save_dygraph(self.disGA.state_dict(), self.model_path + 'd_GA') fluid.save_dygraph(self.disGB.state_dict(), self.model_path + 'd_GB') fluid.save_dygraph(self.disLA.state_dict(), self.model_path + 'd_LA') fluid.save_dygraph(self.disLB.state_dict(), self.model_path + 'd_LB') fluid.save_dygraph(self.G_optim.state_dict(), self.model_path + 'g_A2B') fluid.save_dygraph(self.G_optim.state_dict(), self.model_path + 'g_B2A') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_GA') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_GB') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_LA') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_LB')
import paddle.fluid as fluid import numpy as np import paddle.fluid.layers as L def gen_data(): return { "x": np.random.randint(1, 5, size=[8, 10]).astype('float32'), "y": np.random.randint(1, 5, size=[10]).astype('float32'), } x = fluid.layers.data(name="x", shape=[8,10], dtype='float32') y = fluid.layers.data(name="y", shape=[10], dtype='float32') mm = L.sqrt(L.reduce_sum(L.elementwise_mul(x,x), dim=0)) kk = L.ones_like(y) z = fluid.layers.elementwise_div(x, mm, axis=1) # z = x / y place = fluid.CPUPlace() exe = fluid.Executor(place) z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) print(z_value) #
from __future__ import print_function import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers slot = fluid.data('slot', [-1, 1], dtype='int64', lod_level=1) ones = layers.ones_like(slot) float_ones = layers.cast(ones, dtype='float32') value = layers.sequence_pool(float_ones, pool_type='sum') feed_list = { 'slot': fluid.create_lod_tensor(np.array([[0], [1], [2], [3], [4]], dtype='int64'), [[3, 2]], fluid.CPUPlace()) } fetch_list = [value] exe = fluid.Executor(fluid.CPUPlace()) result = exe.run(fluid.default_main_program(), feed=feed_list, fetch_list=fetch_list) print('sequence length:', result)
def get_mask(seq, padding_idx=0): pix = layers.unsqueeze(layers.ones_like(seq) * padding_idx, axes=2) mask = layers.cast(layers.greater_than(layers.unsqueeze(seq, axes=2), pix), 'float32') return mask
def position_id(x, r=0): pid = layers.arange(0, x.shape[1], dtype="int32") pid = layers.unsqueeze(pid, 0) r = layers.cast(layers.ones_like(x), dtype="int32") * r return layers.cast(layers.abs(layers.elementwise_sub(pid, r)), dtype='int64')
def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, beam_width=5, tgt_type_id=3, length_penalty=1.0): model.eval() _, __, info = model(q_ids, q_sids) d_batch, d_seqlen = q_ids.shape state = BeamSearchState(log_probs=L.zeros([d_batch, beam_width], 'float32'), lengths=L.zeros([d_batch, beam_width], 'int64'), finished=L.zeros([d_batch, beam_width], 'int64')) outputs = [] def reorder_(t, parent_id): """reorder cache according to parent beam id""" gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape( parent_id, [-1]) t = L.gather(t, gather_idx) return t def tile_(t, times): _shapes = list(t.shape[1:]) ret = L.reshape( L.expand(L.unsqueeze(t, [1]), [ 1, times, ] + [ 1, ] * len(_shapes)), [ -1, ] + _shapes) return ret cached_k, cached_v = info['caches'] cached_k = [tile_(k, beam_width) for k in cached_k] cached_v = [tile_(v, beam_width) for v in cached_v] past_cache = (cached_k, cached_v) q_ids = tile_(q_ids, beam_width) seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS ids = L.stack([cls_ids, attn_ids], -1) for step in range(max_decode_len): bias = gen_bias(q_ids, ids, step) pos_ids = D.to_variable( np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch * beam_width, 1])) pos_ids += seqlen _, logits, info = model(ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache) output, state = beam_search_step(state, logits[:, 1], eos_id=eos_id, beam_width=beam_width, is_first_step=(step == 0), length_penalty=length_penalty) outputs.append(output) past_cached_k, past_cached_v = past_cache cached_k, cached_v = info['caches'] cached_k = [ reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) for pk, k in zip(past_cached_k, cached_k) ] # concat cached cached_v = [ reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) for pv, v in zip(past_cached_v, cached_v) ] past_cache = (cached_k, cached_v) pred_ids_flatten = L.reshape(output.predicted_ids, [d_batch * beam_width]) ids = L.stack([pred_ids_flatten, attn_ids], 1) if state.finished.numpy().all(): break final_ids = L.stack([o.predicted_ids for o in outputs], 0) final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0) final_ids = L.gather_tree(final_ids, final_parent_ids)[:, :, 0] # pick best beam final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0]) return final_ids
def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 # TODO 恢复训练还没研究过 # if self.resume: # # glob 返回符合xxxx.pt的文件路径 # model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) # if not len(model_list) == 0: # model_list.sort() # start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) # self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) # print(" [*] Load SUCCESS") # if self.decay_flag and start_iter > (self.iteration // 2): # self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) # self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): # TODO decay # if self.decay_flag and step > (self.iteration // 2): # self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) # self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) try: real_A, _ = next(trainA_iter) except: trainA_iter = self.trainA_loader() real_A, _ = next(trainA_iter)[0] try: real_B, _ = next(trainB_iter) except: trainB_iter = self.trainB_loader() real_B, _ = next(trainB_iter)[0] # real_A, real_B = real_A, real_B # Update D self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, layers.ones_like(real_GA_logit)) + self.MSE_loss( fake_GA_logit, layers.zeros_like(fake_GA_logit)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, layers.ones_like(real_GA_cam_logit)) + self.MSE_loss( fake_GA_cam_logit, layers.zeros_like(fake_GA_cam_logit)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, layers.ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, layers.zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, layers.ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, layers.zeros_like(fake_LA_cam_logit)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, layers.ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, layers.zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, layers.ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, layers.zeros_like(fake_GB_cam_logit)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, layers.ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, layers.zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, layers.ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, layers.zeros_like(fake_LB_cam_logit)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) # Update G self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, layers.ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, layers.ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, layers.ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, layers.ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, layers.ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, layers.ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, layers.ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, layers.ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, layers.ones_like(fake_B2A_cam_logit), ) + self.BCE_loss(fake_A2A_cam_logit, layers.zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, layers.ones_like(fake_A2B_cam_logit)) + self.BCE_loss( fake_B2B_cam_logit, layers.zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step self.genA2B.apply(self.Rho_clipper) self.genB2A.apply(self.Rho_clipper) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = next(trainA_iter) except: trainA_iter = iter(self.trainA_loader) real_A, _ = next(trainA_iter) try: real_B, _ = next(trainB_iter) except: trainB_iter = iter(self.trainB_loader) real_B, _ = next(trainB_iter) real_A, real_B = real_A, real_B fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = testA_iter.next() except: testA_iter = iter(self.testA_loader) real_A, _ = testA_iter.next() try: real_B, _ = testB_iter.next() except: testB_iter = iter(self.testB_loader) real_B, _ = testB_iter.next() real_A, real_B = real_A, real_B fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: params = {} params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() fluid.save_dygraph( params, os.path.join(self.result_dir, self.dataset + '_params_latest'))
def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: model_list = glob( os.path.join(self.result_dir, self.dataset, 'model', '*.pdparams')) if not len(model_list) == 0: model_list.sort() start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) print(" [*] Load SUCCESS", start_iter) if self.decay_flag and start_iter > (self.iteration // 2): self.G_optim.set_lr(self.G_optim.current_step_lr() - (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)) self.D_optim.set_lr(self.D_optim.current_step_lr() - (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)) # training loop print('training start !') start_time = time.time() for step in tqdm(range(start_iter, self.iteration + 1)): if self.decay_flag and step > (self.iteration // 2): self.G_optim.set_lr(self.G_optim.current_step_lr() - (self.lr / (self.iteration // 2))) self.D_optim.set_lr(self.D_optim.current_step_lr() - (self.lr / (self.iteration // 2))) d_lr = self.D_optim.current_step_lr() g_lr = self.G_optim.current_step_lr() try: real_A, _ = next(trainA_iter) except: trainA_iter = iter(self.trainA_loader) real_A, _ = next(trainA_iter) try: real_B, _ = next(trainB_iter) except: trainB_iter = iter(self.trainB_loader) real_B, _ = next(trainB_iter) # Update D if 1: self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) # to 1 real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) # to 0 fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) # GA D_ad_loss_GA_1 = self.MSE_loss(real_GA_logit, L.ones_like(real_GA_logit)) D_ad_loss_GA_0 = self.MSE_loss(fake_GA_logit, L.zeros_like(fake_GA_logit)) D_ad_loss_GA = D_ad_loss_GA_1 + D_ad_loss_GA_0 D_ad_cam_loss_GA_1 = self.MSE_loss( real_GA_cam_logit, L.ones_like(real_GA_cam_logit)) D_ad_cam_loss_GA_0 = self.MSE_loss( fake_GA_cam_logit, L.zeros_like(fake_GA_cam_logit)) D_ad_cam_loss_GA = D_ad_cam_loss_GA_1 + D_ad_cam_loss_GA_0 # LA D_ad_loss_LA = self.MSE_loss( real_LA_logit, L.ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, L.zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, L.ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, L.zeros_like(fake_LA_cam_logit)) # GB D_ad_loss_GB = self.MSE_loss( real_GB_logit, L.ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, L.zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, L.ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, L.zeros_like(fake_GB_cam_logit)) # LB D_ad_loss_LB = self.MSE_loss( real_LB_logit, L.ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, L.zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, L.ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, L.zeros_like(fake_LB_cam_logit)) # GA and LA D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) # GB and LB D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) else: Discriminator_loss = 0 # Update G if 1: self.G_optim.clear_gradients() # run twice for the gradient computation fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) # cycle fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) # NOTICE! fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) # to 1, generate fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, L.ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, L.ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, L.ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, L.ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, L.ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, L.ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, L.ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, L.ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, L.ones_like(fake_B2A_cam_logit)) + self.BCE_loss( fake_A2A_cam_logit, L.zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, L.ones_like(fake_A2B_cam_logit)) + self.BCE_loss( fake_B2B_cam_logit, L.zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) else: Generator_loss = 0 print( "[%5d/%5d] time: %4.4f d_lr: %.8f g_lr: %.8f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, d_lr, g_lr, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = next(trainA_iter) except: trainA_iter = iter(self.trainA_loader) real_A, _ = next(trainA_iter) try: real_B, _ = next(trainB_iter) except: trainB_iter = iter(self.trainB_loader) real_B, _ = next(trainB_iter) #real_A, real_B = to_variable(real_A), to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( ((tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( ((tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = next(testA_iter) except: testA_iter = iter(self.testA_loader) real_A, _ = next(testA_iter) try: real_B, _ = next(testB_iter) except: testB_iter = iter(self.testB_loader) real_B, _ = next(testB_iter) #real_A, real_B = to_variable(real_A), to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( ((tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( ((tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), (tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)
def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: model_list = os.listdir( os.path.join(self.result_dir, self.dataset, 'model')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1]) print("[*]load %d" % (iter)) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print("[*] Load SUCCESS") # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): real_A = next(self.trainA_loader) real_B = next(self.trainB_loader) real_A = np.array([real_A[0].reshape(3, 256, 256)]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256)]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) # Update D fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss( fake_GA_logit, zeros_like(fake_GA_logit)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, ones_like(real_GA_cam_logit)) + self.MSE_loss( fake_GA_cam_logit, zeros_like(fake_GA_cam_logit)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, zeros_like(fake_LA_cam_logit)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, zeros_like(fake_GB_cam_logit)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, zeros_like(fake_LB_cam_logit)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) self.genB2A.clear_gradients() self.genA2B.clear_gradients() self.disGA.clear_gradients() self.disLA.clear_gradients() self.disGB.clear_gradients() self.disLB.clear_gradients() self.D_optim.clear_gradients() # Update G fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, ones_like(fake_B2A_cam_logit)) + self.BCE_loss( fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, ones_like(fake_A2B_cam_logit)) + self.BCE_loss( fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) self.genB2A.clear_gradients() self.genA2B.clear_gradients() self.disGA.clear_gradients() self.disLA.clear_gradients() self.disGB.clear_gradients() self.disLB.clear_gradients() self.G_optim.clear_gradients() self.Rho_clipper(self.genA2B) self.Rho_clipper(self.genB2A) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): real_A = next(self.trainA_loader) real_B = next(self.trainB_loader) real_A = np.array([real_A[0].reshape(3, 256, 256) ]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256) ]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): real_A = next(self.testA_loader()) real_B = next(self.testB_loader()) real_A = np.array([real_A[0].reshape(3, 256, 256) ]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256) ]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: fluid.save_dygraph( self.genA2B.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/genA2B")) fluid.save_dygraph( self.genB2A.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/genB2A")) fluid.save_dygraph( self.disGA.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disGA")) fluid.save_dygraph( self.disGB.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disGB")) fluid.save_dygraph( self.disLA.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disLA")) fluid.save_dygraph( self.disLB.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disLB")) fluid.save_dygraph( self.D_optim.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/D_optim")) fluid.save_dygraph( self.G_optim.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/G_optim")) fluid.save_dygraph( self.genA2B.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/D_optim")) fluid.save_dygraph( self.genB2A.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/G_optim"))
def train(self): d_loss_writer = LogWriter(logdir="./log/UGATIT/train") g_loss_writer = LogWriter(logdir="./log/UGATIT/train") self.start_iter = 1 if self.resume: self.load(os.path.join(self.result_dir, self.dataset, 'model'), self.start_iter_arg, True) # training loop print('training start !') start_time = time.time() for step in range(self.start_iter, self.iteration + 1): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() try: real_A, _ = next(trainA_iter)[0] except: trainA_iter = self.trainA_loader() real_A, _ = next(trainA_iter)[0] try: real_B, _ = next(trainB_iter)[0] except: trainB_iter = self.trainB_loader() real_B, _ = next(trainB_iter)[0] # Update D self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, layers.ones_like(real_GA_logit)) + self.MSE_loss( fake_GA_logit, layers.zeros_like(fake_GA_logit)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, layers.ones_like(real_GA_cam_logit)) + self.MSE_loss( fake_GA_cam_logit, layers.zeros_like(fake_GA_cam_logit)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, layers.ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, layers.zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, layers.ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, layers.zeros_like(fake_LA_cam_logit)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, layers.ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, layers.zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, layers.ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, layers.zeros_like(fake_GB_cam_logit)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, layers.ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, layers.zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, layers.ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, layers.zeros_like(fake_LB_cam_logit)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) # Update G self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, layers.ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, layers.ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, layers.ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, layers.ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, layers.ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, layers.ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, layers.ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, layers.ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCELoss( fake_B2A_cam_logit, layers.ones_like(fake_B2A_cam_logit)) + self.BCELoss( fake_A2A_cam_logit, layers.zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCELoss( fake_A2B_cam_logit, layers.ones_like(fake_A2B_cam_logit)) + self.BCELoss( fake_B2B_cam_logit, layers.zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step clip_rho(self.genA2B) clip_rho(self.genB2A) d_loss_writer.add_scalar(tag="d_loss", step=step, value=Discriminator_loss) g_loss_writer.add_scalar(tag="g_loss", step=step, value=Generator_loss) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = next(testA_iter)[0] except: testA_iter = self.testA_loader() real_A, _ = next(testA_iter)[0] try: real_B, _ = next(testB_iter)[0] except: testB_iter = self.testB_loader() real_B, _ = next(testB_iter)[0] fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = next(testA_iter)[0] except: testA_iter = self.testA_loader() real_A, _ = next(testA_iter)[0] try: real_B, _ = next(testB_iter)[0] except: testB_iter = self.testB_loader() real_B, _ = next(testB_iter)[0] real_A, real_B = real_A, real_B fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step in [8000, 9000, 10000]: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)