Exemplo n.º 1
0
    def compute_mask_loss(self, occ_mask, warped_image, tgt_image):
        """
        Compute losses on the generated occlusion mask.

        Args:
            occ_mask (tensor): Generated occlusion masks.
            warped_image (tensor): Warped image using the flow map.
            tgt_image (tensor): Target image for the warped image.
        Returns:
            (tensor): Loss for the mask.
        """
        loss_mask = dg.to_variable(np.zeros((1, )).astype("float32"))
        if occ_mask is not None:
            dummy0 = L.zeros_like(occ_mask)
            dummy1 = L.ones_like(occ_mask)

            # Compute the confidence map based L1 distance between warped and GT image.
            img_diff = L.reduce_sum(L.abs(warped_image - tgt_image),
                                    1,
                                    keep_dim=True)

            conf = L.clip(1 - img_diff, 0, 1)

            # Force mask value to be small if warped image is similar to GT, and vice versa.
            loss_mask = self.criterionMasked(occ_mask, dummy0, conf)
            loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf)

        return loss_mask
Exemplo n.º 2
0
def greedy_search_infilling(model,
                            q_ids,
                            q_sids,
                            sos_id,
                            eos_id,
                            attn_id,
                            max_encode_len=640,
                            max_decode_len=100,
                            tgt_type_id=3):
    model.eval()
    _, logits, info = model(q_ids, q_sids)
    gen_ids = L.argmax(logits, -1)
    d_batch, d_seqlen = q_ids.shape
    seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)
    has_stopped = np.zeros([d_batch], dtype=np.bool)
    gen_seq_len = np.zeros([d_batch], dtype=np.int64)
    output_ids = []

    past_cache = info['caches']

    cls_ids = L.ones([d_batch], dtype='int64') * sos_id
    attn_ids = L.ones([d_batch], dtype='int64') * attn_id
    ids = L.stack([cls_ids, attn_ids], -1)
    for step in range(max_decode_len):
        bias = gen_bias(q_ids, ids, step)
        pos_ids = D.to_variable(
            np.tile(np.array([[step, step + 1]], dtype=np.int64),
                    [d_batch, 1]))
        pos_ids += seqlen
        _, logits, info = model(ids,
                                L.ones_like(ids) * tgt_type_id,
                                pos_ids=pos_ids,
                                attn_bias=bias,
                                past_cache=past_cache)
        gen_ids = L.argmax(logits, -1)

        past_cached_k, past_cached_v = past_cache
        cached_k, cached_v = info['caches']
        cached_k = [
            L.concat([pk, k[:, :1, :]], 1)
            for pk, k in zip(past_cached_k, cached_k)
        ]  # concat cached
        cached_v = [
            L.concat([pv, v[:, :1, :]], 1)
            for pv, v in zip(past_cached_v, cached_v)
        ]
        past_cache = (cached_k, cached_v)

        gen_ids = gen_ids[:, 1]
        ids = L.stack([gen_ids, attn_ids], 1)

        gen_ids = gen_ids.numpy()
        has_stopped |= (gen_ids == eos_id).astype(np.bool)
        gen_seq_len += (1 - has_stopped.astype(np.int64))
        output_ids.append(gen_ids.tolist())
        if has_stopped.all():
            break
    output_ids = np.array(output_ids).transpose([1, 0])
    return output_ids
Exemplo n.º 3
0
    def compute_mask_losses(self, occ_mask, fake_image, warped_image,
                            tgt_label, tgt_image, fg_mask, ref_fg_mask,
                            body_mask_diff):
        """
        Compute losses on the generated occlusion masks.

        Args:
            occ_mask (tensor or list of tensors): Generated occlusion masks.
            fake_image (tensor): Generated image.
            warped_image (tensor or list of tensors): Warped images using the flow maps.
            tgt_label (tensor): Target label map.
            tgt_image (tensor): Target image for the warped image.
            fg_mask (tensor): Foreground mask for the reference image.
            body_fg_mask (tensor): Difference between warped body part map
                and target body part map. Used for pose dataset only.
        """
        loss_mask = dg.to_variable(np.zeros((1, )).astype("float32"))

        if isinstance(occ_mask, list):
            # Compute occlusion mask losses for both warping reference -> target and previous -> target.
            for i in range(len(occ_mask)):
                loss_mask += self.compute_mask_loss(occ_mask[i],
                                                    warped_image[i], tgt_image)
        else:
            # Compute loss for warping either reference or previous images.
            loss_mask += self.compute_mask_loss(occ_mask, warped_image,
                                                tgt_image)

        if self.warp_ref:
            ref_occ_mask = occ_mask[0]
            dummy0 = L.zeros_like(ref_occ_mask)
            dummy1 = L.ones_like(ref_occ_mask)
            if self.for_pose_dataset:
                # Enforce output to use more warped reference image for face region.
                face_mask = L.unsqueeze(get_face_mask(tgt_label[:, 2]), [1])
                face_mask = L.pool2d(face_mask,
                                     pool_size=15,
                                     pool_type='avg',
                                     pool_stride=1,
                                     pool_padding=7)
                loss_mask += self.criterionMasked(ref_occ_mask, dummy0,
                                                  face_mask)
                loss_mask += self.criterionMasked(fake_image, warped_image[0],
                                                  face_mask)

                # Enforce output to use more hallucinated image for discrepancy
                # regions of body part masks between warped reference and target image.
                loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
                                                  body_mask_diff)

            if self.has_fg:
                # Enforce output to use more hallucinated image for discrepancy regions
                # of foreground masks between reference and target image.
                fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).astype("float32")
                loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
                                                  fg_mask_diff)

        return loss_mask
Exemplo n.º 4
0
def gen_bias(encoder_inputs, decoder_inputs, step):
    decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2]
    attn_bias = L.reshape(L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1])
    decoder_bias = L.cast((L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.),
                          'float32')  #[1, 1, decoderlen, decoderlen]
    encoder_bias = L.unsqueeze(L.cast(L.ones_like(encoder_inputs), 'float32'), [1])  #[bsz, 1, encoderlen]
    encoder_bias = L.expand(encoder_bias, [1, decoder_seqlen, 1])  #[bsz,decoderlen, encoderlen]
    decoder_bias = L.expand(decoder_bias, [decoder_bsz, 1, 1])  #[bsz, decoderlen, decoderlen]
    if step > 0:
        bias = L.concat([encoder_bias, L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias], -1)
    else:
        bias = L.concat([encoder_bias, decoder_bias], -1)
    return bias
Exemplo n.º 5
0
def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh):
    """Implementation of self-attention graph pooling (SAGPool)

    This is an implementation of the paper SELF-ATTENTION GRAPH POOLING
    (https://arxiv.org/pdf/1904.08082.pdf)

    Args:
        gw: Graph wrapper object.

        feature: A tensor with shape (num_nodes, feature_size).

        ratio: The pooling ratio of nodes we want to select.

        graph_id: The graphs that the nodes belong to. 

        dataset: To differentiate FRANKENSTEIN dataset and other datasets.

        name: The name of SAGPool layer.
        
        activation: The activation function.

    Return:
        new_feature: A tensor with shape (num_nodes, feature_size), and the unselected
                     nodes' feature is masked by zero.

        ratio_length: The selected node numbers of each graph.

    """
    if dataset == "FRANKENSTEIN":
        gcn_ = gcn
    else:
        gcn_ = norm_gcn

    score = gcn_(gw=gw,
                 feature=feature,
                 hidden_size=1,
                 activation=None,
                 norm=gw.node_feat["norm"],
                 name=name)
    score = L.squeeze(score, axes=[])
    perm, ratio_length = topk_pool(gw, score, graph_id, ratio)

    mask = L.zeros_like(score)
    mask = L.cast(mask, dtype="float32")
    updates = L.ones_like(perm)
    updates = L.cast(updates, dtype="float32")
    mask = L.scatter(mask, perm, updates)
    new_feature = L.elementwise_mul(feature, mask, axis=0)
    temp_score = activation(score)
    new_feature = L.elementwise_mul(new_feature, temp_score, axis=0)
    return new_feature, ratio_length
Exemplo n.º 6
0
    def forward(self, x, y):
        """Forward network"""
        if self.bias_x:
            x = layers.concat((x, layers.ones_like(x[:, :, :1])), axis=-1)
        if self.bias_y:
            y = layers.concat((y, layers.ones_like(x[:, :, :1])), axis=-1)
        # x.shape=(b, m, i)
        b = x.shape[0]
        # self.weight.shape=(o, i, j)
        o = self.weight.shape[0]
        x = layers.expand(layers.unsqueeze(x, axes=[1]),
                          expand_times=(1, o, 1, 1))
        weight = layers.expand(layers.unsqueeze(self.weight, axes=[0]),
                               expand_times=(b, 1, 1, 1))
        y = layers.expand(layers.unsqueeze(y, axes=[1]),
                          expand_times=(1, o, 1, 1))

        # s.shape=(b, o, m, n), that is, [batch_size, n_out, seq_len, seq_len]
        s = layers.matmul(layers.matmul(x, weight),
                          layers.transpose(y, perm=(0, 1, 3, 2)))
        # remove dim 1 if n_out == 1
        if s.shape[1] == 1:
            s = layers.squeeze(s, axes=[1])
        return s
Exemplo n.º 7
0
 def neighbor_aggregator(self, sent_repr):
     #norm = L.clamp(L.reshape(L.cast(self.graph_wrapper.indegree(), dtype="float32"), [-1, 1]), min=1.)
     norm = L.ones_like(sent_repr)
     def send_func(src, dst , edge):
         return src["h"]
     msg = self.graph_wrapper.send(send_func, nfeat_list=[("h", norm)])
     norm = self.graph_wrapper.recv(msg, "sum")
     norm = L.reduce_mean(norm, -1, keep_dim=True)
     norm = L.clamp(norm, min=1.0)
     
     return gcn(self.graph_wrapper,
                sent_repr,
                self.hidden_size,
                activation="relu",
                name="gcn") / norm
Exemplo n.º 8
0
 def forward(self, *items):
     """Forward network"""
     if self.training and self.p > 0:
         masks = [
             layers.uniform_random(shape=x.shape[:2], min=0, max=1) >=
             self.p for x in items
         ]
         masks = [layers.cast(x, 'float32') for x in masks]
         total = layers.elementwise_add(*masks)
         scale = len(items) / layers.elementwise_max(
             total, layers.ones_like(total))
         masks = [mask * scale for mask in masks]
         items = [
             item * layers.unsqueeze(mask, axes=[-1])
             for item, mask in zip(items, masks)
         ]
     return items
Exemplo n.º 9
0
        def ernie_send(src_feat, dst_feat, edge_feat):
            """doc"""
            cls = L.fill_constant_batch_size_like(src_feat["term_ids"],
                                                  [-1, 1, 1], "int64", 1)
            src_ids = L.concat([cls, src_feat["term_ids"]], 1)
            dst_ids = dst_feat["term_ids"]

            sent_ids = L.concat([L.zeros_like(src_ids),
                                 L.ones_like(dst_ids)], 1)
            term_ids = L.concat([src_ids, dst_ids], 1)

            term_ids.stop_gradient = True
            sent_ids.stop_gradient = True
            ernie = ErnieModel(term_ids,
                               sent_ids,
                               config=self.config.ernie_config)
            feature = ernie.get_pooled_output()
            return feature
Exemplo n.º 10
0
        def ernie_send(src_feat, dst_feat, edge_feat):
            def build_position_ids(term_ids):
                input_mask = L.cast(term_ids > 0, "int64")
                position_ids = L.cumsum(input_mask, axis=1) - 1
                return position_ids

            """doc"""
            # input_ids
            cls = L.fill_constant_batch_size_like(src_feat["term_ids"],
                                                  [-1, 1], "int64",
                                                  self.config.cls_id)
            src_ids = L.concat([cls, src_feat["term_ids"]], 1)
            dst_ids = dst_feat["term_ids"]

            # sent_ids
            sent_ids = L.concat([L.zeros_like(src_ids),
                                 L.ones_like(dst_ids)], 1)
            term_ids = L.concat([src_ids, dst_ids], 1)

            # position_ids
            position_ids = build_position_ids(term_ids)
            ernie_model = ErnieModel(self.config.ernie_config, "")
            feature, _ = ernie_model(term_ids, sent_ids, position_ids)
            return feature
Exemplo n.º 11
0
    def train(self):

        place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace()

        with fluid.dygraph.guard(place):

            self.genA2B.train()
            self.genB2A.train()
            self.disGA.train()
            self.disGB.train()
            self.disLA.train()
            self.disLB.train()

            if self.resume:

                files_list = os.listdir(self.model_path)

                if len(files_list) > 0:

                    files = []
                    print("exist files")

                    for i in files_list:

                        file_ = os.path.splitext(i)[1]
                        files.append(file_)

                    if ".pdparams" in files_list or ".pdopt" in files_list:

                        print("exist model")
                        genA2B_para = fluid.load_dygraph(self.model_path +
                                                         'g_A2B')
                        genB2A_para = fluid.load_dygraph(self.model_path +
                                                         'g_B2A')
                        disGA_para = fluid.load_dygraph(self.model_path +
                                                        'd_GA')
                        disGB_para = fluid.load_dygraph(self.model_path +
                                                        'd_GB')
                        disLA_para = fluid.load_dygraph(self.model_path +
                                                        'd_LA')
                        disLB_para = fluid.load_dygraph(self.model_path +
                                                        'd_LB')
                        G_opt = fluid.load_dygraph(self.model_path + 'G_op')
                        D_opt = fluid.load_dygraph(self.model_path + 'D_op')

                        self.genA2B.load_dict(genA2B_para)
                        self.genB2A.load_dict(genB2A_para)
                        self.disGA.load_dict(disGA_para)
                        self.disGB.load_dict(disGB_para)
                        self.disLA.load_dict(disLA_para)
                        self.disLB.load_dict(disLB_para)
                        self.G_optim.set_dict(G_opt)
                        self.D_optim.set_dict(D_opt)

                        print(" [*] Load SUCCESS")

                    else:

                        print(" No Model!")

                else:

                    print("No Files")

            # training loop
            print('training start !')
            start_iter = 1
            for step in range(start_iter, self.iteration + 1):

                trainA_iter = iter(self.trainA_loader())
                real_A = next(trainA_iter)
                real_A = paddle.fluid.dygraph.to_variable(np.array(real_A))
                real_A = real_A / 255.0

                trainB_iter = iter(self.trainB_loader())
                real_B = next(trainB_iter)
                real_B = paddle.fluid.dygraph.to_variable(np.array(real_B))
                real_B = real_B / 255.0

                # Update D
                self.D_optim.clear_gradients()

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                D_ad_loss_GA = self.MSE_loss(
                    real_GA_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_GA_logit))) + self.MSE_loss(
                            fake_GA_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_GA_logit)))
                D_ad_cam_loss_GA = self.MSE_loss(
                    real_GA_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_GA_cam_logit))) + self.MSE_loss(
                            fake_GA_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_GA_cam_logit)))
                D_ad_loss_LA = self.MSE_loss(
                    real_LA_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_LA_logit))) + self.MSE_loss(
                            fake_LA_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_LA_logit)))
                D_ad_cam_loss_LA = self.MSE_loss(
                    real_LA_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_LA_cam_logit))) + self.MSE_loss(
                            fake_LA_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_LA_cam_logit)))
                D_ad_loss_GB = self.MSE_loss(
                    real_GB_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_GB_logit))) + self.MSE_loss(
                            fake_GB_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_GB_logit)))
                D_ad_cam_loss_GB = self.MSE_loss(
                    real_GB_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_GB_cam_logit))) + self.MSE_loss(
                            fake_GB_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_GB_cam_logit)))
                D_ad_loss_LB = self.MSE_loss(
                    real_LB_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_LB_logit))) + self.MSE_loss(
                            fake_LB_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_LB_logit)))
                D_ad_cam_loss_LB = self.MSE_loss(
                    real_LB_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(real_LB_cam_logit))) + self.MSE_loss(
                            fake_LB_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_LB_cam_logit)))

                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_optim.minimize(Discriminator_loss)

                # Update G
                self.G_optim.clear_gradients()

                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = self.MSE_loss(
                    fake_GA_logit,
                    fluid.dygraph.to_variable(ones_like(fake_GA_logit)))
                G_ad_cam_loss_GA = self.MSE_loss(
                    fake_GA_cam_logit,
                    fluid.dygraph.to_variable(ones_like(fake_GA_cam_logit)))
                G_ad_loss_LA = self.MSE_loss(
                    fake_LA_logit,
                    fluid.dygraph.to_variable(ones_like(fake_LA_logit)))
                G_ad_cam_loss_LA = self.MSE_loss(
                    fake_LA_cam_logit,
                    fluid.dygraph.to_variable(ones_like(fake_LA_cam_logit)))
                G_ad_loss_GB = self.MSE_loss(
                    fake_GB_logit,
                    fluid.dygraph.to_variable(ones_like(fake_GB_logit)))
                G_ad_cam_loss_GB = self.MSE_loss(
                    fake_GB_cam_logit,
                    fluid.dygraph.to_variable(ones_like(fake_GB_cam_logit)))
                G_ad_loss_LB = self.MSE_loss(
                    fake_LB_logit,
                    fluid.dygraph.to_variable(ones_like(fake_LB_logit)))
                G_ad_cam_loss_LB = self.MSE_loss(
                    fake_LB_cam_logit,
                    fluid.dygraph.to_variable(ones_like(fake_LB_cam_logit)))

                G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

                G_cam_loss_A = self.BCE_loss(
                    fake_B2A_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(fake_B2A_cam_logit))) + self.BCE_loss(
                            fake_A2A_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_A2A_cam_logit)))
                G_cam_loss_B = self.BCE_loss(
                    fake_A2B_cam_logit,
                    fluid.dygraph.to_variable(
                        ones_like(fake_A2B_cam_logit))) + self.BCE_loss(
                            fake_B2B_cam_logit,
                            fluid.dygraph.to_variable(
                                zeros_like(fake_B2B_cam_logit)))

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B
                Generator_loss.backward()
                self.G_optim.minimize(Generator_loss)

                # clip parameter of AdaILN and ILN, applied after optimizer step
                clip_rho(self.genA2B, vmin=0, vmax=1)
                clip_rho(self.genB2A, vmin=0, vmax=1)

                if step % 50 == 0:

                    print("[%5d/%5d] d_loss: %.8f, g_loss: %.8f" %
                          (step, self.iteration, Discriminator_loss,
                           Generator_loss))

                if step % self.print_freq == 0:

                    print("print img!")
                    train_sample_num = 5
                    test_sample_num = 5
                    A2B = np.zeros((self.img_size * 7, 0, 3))
                    B2A = np.zeros((self.img_size * 7, 0, 3))

                    self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                    ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()

                    for _ in range(train_sample_num):

                        trainA_iter = iter(self.trainA_loader())
                        real_A = next(trainA_iter)
                        real_A = paddle.fluid.dygraph.to_variable(
                            np.array(real_A))
                        real_A = real_A / 255.0

                        trainB_iter = iter(self.trainB_loader())
                        real_B = next(trainB_iter)
                        real_B = paddle.fluid.dygraph.to_variable(
                            np.array(real_B))
                        real_B = real_B / 255.0

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        A2B = np.concatenate(
                            (A2B,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                  cam(tensor2numpy(fake_A2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                  cam(tensor2numpy(fake_A2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                  cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(
                                      fake_A2B2A[0])))), 0)), 1)

                        B2A = np.concatenate(
                            (B2A,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                  cam(tensor2numpy(fake_B2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                  cam(tensor2numpy(fake_B2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                  cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(
                                      fake_B2A2B[0])))), 0)), 1)

                    for _ in range(test_sample_num):

                        testA_iter = iter(self.testA_loader())
                        real_A = next(testA_iter)
                        real_A = paddle.fluid.dygraph.to_variable(
                            np.array(real_A))
                        real_A = real_A / 255.0

                        testB_iter = iter(self.testB_loader())
                        real_B = next(testB_iter)
                        real_B = paddle.fluid.dygraph.to_variable(
                            np.array(real_B))
                        real_B = real_B / 255.0

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        A2B = np.concatenate(
                            (A2B,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                  cam(tensor2numpy(fake_A2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                  cam(tensor2numpy(fake_A2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                  cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(
                                      fake_A2B2A[0])))), 0)), 1)

                        B2A = np.concatenate(
                            (B2A,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                  cam(tensor2numpy(fake_B2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                  cam(tensor2numpy(fake_B2A_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                  cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                      self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(
                                      fake_B2A2B[0])))), 0)), 1)

                    cv2.imwrite(
                        os.path.join(self.result_dir, 'A2B_%07d.png' % step),
                        A2B * 255.0)
                    cv2.imwrite(
                        os.path.join(self.result_dir, 'B2A_%07d.png' % step),
                        B2A * 255.0)

                if step % self.save_freq == 0:

                    fluid.save_dygraph(self.genA2B.state_dict(),
                                       self.model_path + 'g_A2B')
                    fluid.save_dygraph(self.genB2A.state_dict(),
                                       self.model_path + 'g_B2A')
                    fluid.save_dygraph(self.disGA.state_dict(),
                                       self.model_path + 'd_GA')
                    fluid.save_dygraph(self.disGB.state_dict(),
                                       self.model_path + 'd_GB')
                    fluid.save_dygraph(self.disLA.state_dict(),
                                       self.model_path + 'd_LA')
                    fluid.save_dygraph(self.disLB.state_dict(),
                                       self.model_path + 'd_LB')
                    fluid.save_dygraph(self.G_optim.state_dict(),
                                       self.model_path + 'g_A2B')
                    fluid.save_dygraph(self.G_optim.state_dict(),
                                       self.model_path + 'g_B2A')
                    fluid.save_dygraph(self.D_optim.state_dict(),
                                       self.model_path + 'd_GA')
                    fluid.save_dygraph(self.D_optim.state_dict(),
                                       self.model_path + 'd_GB')
                    fluid.save_dygraph(self.D_optim.state_dict(),
                                       self.model_path + 'd_LA')
                    fluid.save_dygraph(self.D_optim.state_dict(),
                                       self.model_path + 'd_LB')
Exemplo n.º 12
0
import paddle.fluid as fluid
import numpy as np
import paddle.fluid.layers as L
def gen_data():
    return {
        "x": np.random.randint(1, 5, size=[8, 10]).astype('float32'),
        "y": np.random.randint(1, 5, size=[10]).astype('float32'),
    }
x = fluid.layers.data(name="x", shape=[8,10], dtype='float32')
y = fluid.layers.data(name="y", shape=[10], dtype='float32')
mm = L.sqrt(L.reduce_sum(L.elementwise_mul(x,x), dim=0))
kk = L.ones_like(y)
z = fluid.layers.elementwise_div(x, mm, axis=1)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
                    fetch_list=[z.name])
print(z_value) #
Exemplo n.º 13
0
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers

slot = fluid.data('slot', [-1, 1], dtype='int64', lod_level=1)
ones = layers.ones_like(slot)
float_ones = layers.cast(ones, dtype='float32')
value = layers.sequence_pool(float_ones, pool_type='sum')

feed_list = {
    'slot':
    fluid.create_lod_tensor(np.array([[0], [1], [2], [3], [4]], dtype='int64'),
                            [[3, 2]], fluid.CPUPlace())
}
fetch_list = [value]
exe = fluid.Executor(fluid.CPUPlace())
result = exe.run(fluid.default_main_program(),
                 feed=feed_list,
                 fetch_list=fetch_list)
print('sequence length:', result)
Exemplo n.º 14
0
def get_mask(seq, padding_idx=0):
    pix = layers.unsqueeze(layers.ones_like(seq) * padding_idx, axes=2)
    mask = layers.cast(layers.greater_than(layers.unsqueeze(seq, axes=2), pix), 'float32')
    return mask
Exemplo n.º 15
0
def position_id(x, r=0):
    pid = layers.arange(0, x.shape[1], dtype="int32")
    pid = layers.unsqueeze(pid, 0)
    r = layers.cast(layers.ones_like(x), dtype="int32") * r
    return layers.cast(layers.abs(layers.elementwise_sub(pid, r)), dtype='int64')
Exemplo n.º 16
0
def beam_search_infilling(model,
                          q_ids,
                          q_sids,
                          sos_id,
                          eos_id,
                          attn_id,
                          max_encode_len=640,
                          max_decode_len=100,
                          beam_width=5,
                          tgt_type_id=3,
                          length_penalty=1.0):
    model.eval()
    _, __, info = model(q_ids, q_sids)
    d_batch, d_seqlen = q_ids.shape

    state = BeamSearchState(log_probs=L.zeros([d_batch, beam_width],
                                              'float32'),
                            lengths=L.zeros([d_batch, beam_width], 'int64'),
                            finished=L.zeros([d_batch, beam_width], 'int64'))
    outputs = []

    def reorder_(t, parent_id):
        """reorder cache according to parent beam id"""
        gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape(
            parent_id, [-1])
        t = L.gather(t, gather_idx)
        return t

    def tile_(t, times):
        _shapes = list(t.shape[1:])
        ret = L.reshape(
            L.expand(L.unsqueeze(t, [1]), [
                1,
                times,
            ] + [
                1,
            ] * len(_shapes)), [
                -1,
            ] + _shapes)
        return ret

    cached_k, cached_v = info['caches']
    cached_k = [tile_(k, beam_width) for k in cached_k]
    cached_v = [tile_(v, beam_width) for v in cached_v]
    past_cache = (cached_k, cached_v)

    q_ids = tile_(q_ids, beam_width)
    seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)

    cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id
    attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id  # SOS
    ids = L.stack([cls_ids, attn_ids], -1)
    for step in range(max_decode_len):
        bias = gen_bias(q_ids, ids, step)
        pos_ids = D.to_variable(
            np.tile(np.array([[step, step + 1]], dtype=np.int64),
                    [d_batch * beam_width, 1]))
        pos_ids += seqlen
        _, logits, info = model(ids,
                                L.ones_like(ids) * tgt_type_id,
                                pos_ids=pos_ids,
                                attn_bias=bias,
                                past_cache=past_cache)

        output, state = beam_search_step(state,
                                         logits[:, 1],
                                         eos_id=eos_id,
                                         beam_width=beam_width,
                                         is_first_step=(step == 0),
                                         length_penalty=length_penalty)
        outputs.append(output)

        past_cached_k, past_cached_v = past_cache
        cached_k, cached_v = info['caches']
        cached_k = [
            reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids)
            for pk, k in zip(past_cached_k, cached_k)
        ]  # concat cached
        cached_v = [
            reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids)
            for pv, v in zip(past_cached_v, cached_v)
        ]
        past_cache = (cached_k, cached_v)

        pred_ids_flatten = L.reshape(output.predicted_ids,
                                     [d_batch * beam_width])
        ids = L.stack([pred_ids_flatten, attn_ids], 1)

        if state.finished.numpy().all():
            break

    final_ids = L.stack([o.predicted_ids for o in outputs], 0)
    final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0)
    final_ids = L.gather_tree(final_ids, final_parent_ids)[:, :,
                                                           0]  # pick best beam
    final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0])
    return final_ids
Exemplo n.º 17
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        # TODO 恢复训练还没研究过
        # if self.resume:
        #     # glob 返回符合xxxx.pt的文件路径
        #     model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
        #     if not len(model_list) == 0:
        #         model_list.sort()
        #         start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
        #         self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
        #         print(" [*] Load SUCCESS")
        #         if self.decay_flag and start_iter > (self.iteration // 2):
        #             self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
        #             self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            # TODO decay
            # if self.decay_flag and step > (self.iteration // 2):
            #     self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
            #     self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))

            try:
                real_A, _ = next(trainA_iter)
            except:
                trainA_iter = self.trainA_loader()
                real_A, _ = next(trainA_iter)[0]

            try:
                real_B, _ = next(trainB_iter)
            except:
                trainB_iter = self.trainB_loader()
                real_B, _ = next(trainB_iter)[0]

            # real_A, real_B = real_A, real_B

            # Update D
            self.D_optim.clear_gradients()

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit,
                layers.ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, layers.zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                layers.ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, layers.zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit,
                layers.ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, layers.zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                layers.ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, layers.zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit,
                layers.ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, layers.zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                layers.ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, layers.zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit,
                layers.ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, layers.zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                layers.ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, layers.zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)

            # Update G
            self.G_optim.clear_gradients()

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         layers.ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit, layers.ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         layers.ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit, layers.ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         layers.ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit, layers.ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         layers.ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit, layers.ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                layers.ones_like(fake_B2A_cam_logit),
            ) + self.BCE_loss(fake_A2A_cam_logit,
                              layers.zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                layers.ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, layers.zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)

            # clip parameter of AdaILN and ILN, applied after optimizer step
            self.genA2B.apply(self.Rho_clipper)
            self.genB2A.apply(self.Rho_clipper)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    try:
                        real_A, _ = next(trainA_iter)
                    except:
                        trainA_iter = iter(self.trainA_loader)
                        real_A, _ = next(trainA_iter)

                    try:
                        real_B, _ = next(trainB_iter)
                    except:
                        trainB_iter = iter(self.trainB_loader)
                        real_B, _ = next(trainB_iter)
                    real_A, real_B = real_A, real_B

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = testA_iter.next()
                    except:
                        testA_iter = iter(self.testA_loader)
                        real_A, _ = testA_iter.next()

                    try:
                        real_B, _ = testB_iter.next()
                    except:
                        testB_iter = iter(self.testB_loader)
                        real_B, _ = testB_iter.next()
                    real_A, real_B = real_A, real_B

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                params = {}
                params['genA2B'] = self.genA2B.state_dict()
                params['genB2A'] = self.genB2A.state_dict()
                params['disGA'] = self.disGA.state_dict()
                params['disGB'] = self.disGB.state_dict()
                params['disLA'] = self.disLA.state_dict()
                params['disLB'] = self.disLB.state_dict()
                fluid.save_dygraph(
                    params,
                    os.path.join(self.result_dir,
                                 self.dataset + '_params_latest'))
Exemplo n.º 18
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = glob(
                os.path.join(self.result_dir, self.dataset, 'model',
                             '*.pdparams'))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          start_iter)
                print(" [*] Load SUCCESS", start_iter)
                if self.decay_flag and start_iter > (self.iteration // 2):
                    self.G_optim.set_lr(self.G_optim.current_step_lr() -
                                        (self.lr / (self.iteration // 2)) *
                                        (start_iter - self.iteration // 2))
                    self.D_optim.set_lr(self.D_optim.current_step_lr() -
                                        (self.lr / (self.iteration // 2)) *
                                        (start_iter - self.iteration // 2))

        # training loop
        print('training start !')
        start_time = time.time()
        for step in tqdm(range(start_iter, self.iteration + 1)):
            if self.decay_flag and step > (self.iteration // 2):
                self.G_optim.set_lr(self.G_optim.current_step_lr() -
                                    (self.lr / (self.iteration // 2)))
                self.D_optim.set_lr(self.D_optim.current_step_lr() -
                                    (self.lr / (self.iteration // 2)))
            d_lr = self.D_optim.current_step_lr()
            g_lr = self.G_optim.current_step_lr()

            try:
                real_A, _ = next(trainA_iter)
            except:
                trainA_iter = iter(self.trainA_loader)
                real_A, _ = next(trainA_iter)

            try:
                real_B, _ = next(trainB_iter)
            except:
                trainB_iter = iter(self.trainB_loader)
                real_B, _ = next(trainB_iter)

            # Update D
            if 1:
                self.D_optim.clear_gradients()

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                # to 1
                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                # to 0
                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                # GA
                D_ad_loss_GA_1 = self.MSE_loss(real_GA_logit,
                                               L.ones_like(real_GA_logit))
                D_ad_loss_GA_0 = self.MSE_loss(fake_GA_logit,
                                               L.zeros_like(fake_GA_logit))
                D_ad_loss_GA = D_ad_loss_GA_1 + D_ad_loss_GA_0

                D_ad_cam_loss_GA_1 = self.MSE_loss(
                    real_GA_cam_logit, L.ones_like(real_GA_cam_logit))
                D_ad_cam_loss_GA_0 = self.MSE_loss(
                    fake_GA_cam_logit, L.zeros_like(fake_GA_cam_logit))
                D_ad_cam_loss_GA = D_ad_cam_loss_GA_1 + D_ad_cam_loss_GA_0

                # LA
                D_ad_loss_LA = self.MSE_loss(
                    real_LA_logit, L.ones_like(real_LA_logit)) + self.MSE_loss(
                        fake_LA_logit, L.zeros_like(fake_LA_logit))
                D_ad_cam_loss_LA = self.MSE_loss(
                    real_LA_cam_logit,
                    L.ones_like(real_LA_cam_logit)) + self.MSE_loss(
                        fake_LA_cam_logit, L.zeros_like(fake_LA_cam_logit))

                # GB
                D_ad_loss_GB = self.MSE_loss(
                    real_GB_logit, L.ones_like(real_GB_logit)) + self.MSE_loss(
                        fake_GB_logit, L.zeros_like(fake_GB_logit))
                D_ad_cam_loss_GB = self.MSE_loss(
                    real_GB_cam_logit,
                    L.ones_like(real_GB_cam_logit)) + self.MSE_loss(
                        fake_GB_cam_logit, L.zeros_like(fake_GB_cam_logit))

                # LB
                D_ad_loss_LB = self.MSE_loss(
                    real_LB_logit, L.ones_like(real_LB_logit)) + self.MSE_loss(
                        fake_LB_logit, L.zeros_like(fake_LB_logit))
                D_ad_cam_loss_LB = self.MSE_loss(
                    real_LB_cam_logit,
                    L.ones_like(real_LB_cam_logit)) + self.MSE_loss(
                        fake_LB_cam_logit, L.zeros_like(fake_LB_cam_logit))

                # GA and LA
                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                # GB and LB
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_optim.minimize(Discriminator_loss)
            else:
                Discriminator_loss = 0

            # Update G
            if 1:
                self.G_optim.clear_gradients()

                # run twice for the gradient computation
                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

                # cycle
                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                # NOTICE!
                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                # to 1, generate
                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                             L.ones_like(fake_GA_logit))
                G_ad_cam_loss_GA = self.MSE_loss(
                    fake_GA_cam_logit, L.ones_like(fake_GA_cam_logit))
                G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                             L.ones_like(fake_LA_logit))
                G_ad_cam_loss_LA = self.MSE_loss(
                    fake_LA_cam_logit, L.ones_like(fake_LA_cam_logit))
                G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                             L.ones_like(fake_GB_logit))
                G_ad_cam_loss_GB = self.MSE_loss(
                    fake_GB_cam_logit, L.ones_like(fake_GB_cam_logit))
                G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                             L.ones_like(fake_LB_logit))
                G_ad_cam_loss_LB = self.MSE_loss(
                    fake_LB_cam_logit, L.ones_like(fake_LB_cam_logit))

                G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

                G_cam_loss_A = self.BCE_loss(
                    fake_B2A_cam_logit,
                    L.ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                        fake_A2A_cam_logit, L.zeros_like(fake_A2A_cam_logit))
                G_cam_loss_B = self.BCE_loss(
                    fake_A2B_cam_logit,
                    L.ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                        fake_B2B_cam_logit, L.zeros_like(fake_B2B_cam_logit))

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B

                Generator_loss.backward()
                self.G_optim.minimize(Generator_loss)
            else:
                Generator_loss = 0

            print(
                "[%5d/%5d] time: %4.4f d_lr: %.8f g_lr: %.8f d_loss: %.8f, g_loss: %.8f"
                % (step, self.iteration, time.time() - start_time, d_lr, g_lr,
                   Discriminator_loss, Generator_loss))
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    try:
                        real_A, _ = next(trainA_iter)
                    except:
                        trainA_iter = iter(self.trainA_loader)
                        real_A, _ = next(trainA_iter)

                    try:
                        real_B, _ = next(trainB_iter)
                    except:
                        trainB_iter = iter(self.trainB_loader)
                        real_B, _ = next(trainB_iter)

                    #real_A, real_B = to_variable(real_A), to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             ((tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             ((tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = next(testA_iter)
                    except:
                        testA_iter = iter(self.testA_loader)
                        real_A, _ = next(testA_iter)

                    try:
                        real_B, _ = next(testB_iter)
                    except:
                        testB_iter = iter(self.testB_loader)
                        real_B, _ = next(testB_iter)

                    #real_A, real_B = to_variable(real_A), to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             ((tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             ((tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              (tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)
Exemplo n.º 19
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = os.listdir(
                os.path.join(self.result_dir, self.dataset, 'model'))
            if not len(model_list) == 0:
                model_list.sort()
                iter = int(model_list[-1])
                print("[*]load %d" % (iter))
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          iter)
                print("[*] Load SUCCESS")

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            real_A = next(self.trainA_loader)
            real_B = next(self.trainB_loader)
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            real_B = to_variable(real_B)
            # Update D

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.D_optim.clear_gradients()

            # Update G

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit,
                                             ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit,
                                             ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit,
                                             ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit,
                                             ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                    fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.G_optim.clear_gradients()

            self.Rho_clipper(self.genA2B)
            self.Rho_clipper(self.genB2A)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))

            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    real_A = next(self.trainA_loader)
                    real_B = next(self.trainB_loader)
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    real_A = next(self.testA_loader())
                    real_B = next(self.testB_loader())
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()
            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genA2B"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genB2A"))
                fluid.save_dygraph(
                    self.disGA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGA"))
                fluid.save_dygraph(
                    self.disGB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGB"))
                fluid.save_dygraph(
                    self.disLA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLA"))
                fluid.save_dygraph(
                    self.disLB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLB"))
                fluid.save_dygraph(
                    self.D_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.G_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
Exemplo n.º 20
0
    def train(self):
        d_loss_writer = LogWriter(logdir="./log/UGATIT/train")
        g_loss_writer = LogWriter(logdir="./log/UGATIT/train")
        self.start_iter = 1

        if self.resume:
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      self.start_iter_arg, True)

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(self.start_iter, self.iteration + 1):
            self.genA2B.train(), self.genB2A.train(), self.disGA.train(
            ), self.disGB.train(), self.disLA.train(), self.disLB.train()

            try:
                real_A, _ = next(trainA_iter)[0]
            except:
                trainA_iter = self.trainA_loader()
                real_A, _ = next(trainA_iter)[0]

            try:
                real_B, _ = next(trainB_iter)[0]
            except:
                trainB_iter = self.trainB_loader()
                real_B, _ = next(trainB_iter)[0]

            # Update D
            self.D_optim.clear_gradients()

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit,
                layers.ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, layers.zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                layers.ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, layers.zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit,
                layers.ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, layers.zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                layers.ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, layers.zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit,
                layers.ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, layers.zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                layers.ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, layers.zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit,
                layers.ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, layers.zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                layers.ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, layers.zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)

            # Update G
            self.G_optim.clear_gradients()

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         layers.ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit, layers.ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         layers.ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit, layers.ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         layers.ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit, layers.ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         layers.ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit, layers.ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCELoss(
                fake_B2A_cam_logit,
                layers.ones_like(fake_B2A_cam_logit)) + self.BCELoss(
                    fake_A2A_cam_logit, layers.zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCELoss(
                fake_A2B_cam_logit,
                layers.ones_like(fake_A2B_cam_logit)) + self.BCELoss(
                    fake_B2B_cam_logit, layers.zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)

            # clip parameter of AdaILN and ILN, applied after optimizer step
            clip_rho(self.genA2B)
            clip_rho(self.genB2A)

            d_loss_writer.add_scalar(tag="d_loss",
                                     step=step,
                                     value=Discriminator_loss)
            g_loss_writer.add_scalar(tag="g_loss",
                                     step=step,
                                     value=Generator_loss)
            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    try:
                        real_A, _ = next(testA_iter)[0]
                    except:
                        testA_iter = self.testA_loader()
                        real_A, _ = next(testA_iter)[0]

                    try:
                        real_B, _ = next(testB_iter)[0]
                    except:
                        testB_iter = self.testB_loader()
                        real_B, _ = next(testB_iter)[0]

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = next(testA_iter)[0]
                    except:
                        testA_iter = self.testA_loader()
                        real_A, _ = next(testA_iter)[0]

                    try:
                        real_B, _ = next(testB_iter)[0]
                    except:
                        testB_iter = self.testB_loader()
                        real_B, _ = next(testB_iter)[0]

                    real_A, real_B = real_A, real_B

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()

            if step in [8000, 9000, 10000]:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)