Exemple #1
0
    def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch,
                           step):

        self.train()
        self.main_optimizer.zero_grad()
        """ Encoding the Input """
        sketch_encoded_dist = self.Sketch_Encoder(sketch_vector, length_sketch)
        sketch_encoded_z_vector = sketch_encoded_dist.rsample()

        rgb_encoded_dist = self.Image_Encoder(rgb_image)
        rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample()
        """ Ditribution Matching Loss """
        prior_distribution = torch.distributions.Normal(
            torch.zeros_like(sketch_encoded_dist.mean),
            torch.ones_like(sketch_encoded_dist.stddev))
        kl_cost_1 = torch.distributions.kl_divergence(
            sketch_encoded_dist, prior_distribution).sum()
        kl_cost_2 = torch.distributions.kl_divergence(
            rgb_encoded_dist, prior_distribution).sum()

        ##############################################################
        """ Cross Modal the Decoding """
        ##############################################################
        """ a) Photo to Sketch """
        start_token = torch.stack(
            [torch.tensor([0, 0, 1, 0, 0])] *
            rgb_image.shape[0]).unsqueeze(0).float().to(device)
        batch_init = torch.cat([start_token, sketch_vector], 0)
        z_stack = torch.stack([rgb_encoded_dist_z_vector] *
                              (self.hp.max_seq_len + 1))
        inputs = torch.cat([batch_init, z_stack], 2)

        photo2sketch_output, _ = self.Sketch_Decoder(
            inputs, rgb_encoded_dist_z_vector, length_sketch + 1)

        end_token = torch.stack(
            [torch.tensor([0, 0, 0, 0, 1])] *
            rgb_image.shape[0]).unsqueeze(0).to(device).float()
        batch = torch.cat([sketch_vector, end_token], 0)
        x_target = batch.permute(1, 0,
                                 2)  # batch-> Seq_Len, Batch, Feature_dim

        sup_p2s_loss = sketch_reconstruction_loss(
            photo2sketch_output, x_target)  #TODO: Photo to Sketch Loss
        """ b)  Sketch to Photo """
        cross_recons_photo = self.Image_Decoder(sketch_encoded_z_vector)
        # sup_s2p_loss = F.mse_loss(rgb_image, cross_recons_photo, reduction='sum')/rgb_image.shape[0] #TODO: Sketch 2 Photo Loss
        sup_s2p_loss = F.mse_loss(rgb_image, cross_recons_photo)

        ##############################################################
        """ Self Modal the Decoding """
        ##############################################################
        """ a) Photo to photo """
        self_recons_photo = self.Image_Decoder(rgb_encoded_dist_z_vector)
        # short_p2p_loss = F.mse_loss(rgb_image, self_recons_photo, reduction='sum')/rgb_image.shape[0]
        short_p2p_loss = F.mse_loss(rgb_image, self_recons_photo)
        """ a) Sketch to Sketch """
        start_token = torch.stack(
            [torch.Tensor([0, 0, 1, 0, 0])] *
            rgb_image.shape[0]).unsqueeze(0).to(device).float()
        batch_init = torch.cat([start_token, sketch_vector], 0)
        z_stack = torch.stack([sketch_encoded_z_vector] *
                              (self.hp.max_seq_len + 1))
        inputs = torch.cat([batch_init, z_stack], 2)

        sketch2sketch_output, _ = self.Sketch_Decoder(inputs,
                                                      sketch_encoded_z_vector,
                                                      length_sketch + 1)

        end_token = torch.stack(
            [torch.Tensor([0, 0, 0, 0, 1])] *
            rgb_image.shape[0]).unsqueeze(0).to(device).float()
        batch = torch.cat([sketch_vector, end_token], 0)
        x_target = batch.permute(1, 0,
                                 2)  # batch-> Seq_Len, Batch, Feature_dim

        short_s2s_loss = sketch_reconstruction_loss(
            sketch2sketch_output, x_target)  # TODO: Photo to Sketch Loss

        loss = sup_p2s_loss + sup_s2p_loss + short_p2p_loss + short_s2s_loss + 0.01 * (
            kl_cost_1 + kl_cost_2)

        loss.backward()
        nn.utils.clip_grad_norm(self.train_params, self.hp.grad_clip)
        self.main_optimizer.step()

        if step % 1000 == 0:
            """ Draw Photo to Sketch """
            start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device)
            start_token = torch.stack([start_token] *
                                      rgb_encoded_dist_z_vector.shape[0],
                                      dim=1)
            state = start_token
            hidden_cell = None

            batch_gen_strokes = []
            for i_seq in range(self.hp.max_seq_len):
                input = torch.cat(
                    [state, rgb_encoded_dist_z_vector.unsqueeze(0)], 2)
                state, hidden_cell = self.Sketch_Decoder(
                    input,
                    rgb_encoded_dist_z_vector,
                    hidden_cell=hidden_cell,
                    isTrain=False,
                    get_deterministic=True)
                batch_gen_strokes.append(state.squeeze(0))
            photo2sketch_gen = torch.stack(batch_gen_strokes, dim=1)
            """ Draw Sketch to Sketch """
            start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device)
            start_token = torch.stack([start_token] *
                                      sketch_encoded_z_vector.shape[0],
                                      dim=1)
            state = start_token
            hidden_cell = None

            batch_gen_strokes = []
            for i_seq in range(self.hp.max_seq_len):
                input = torch.cat(
                    [state, sketch_encoded_z_vector.unsqueeze(0)], 2)
                state, hidden_cell = self.Sketch_Decoder(
                    input,
                    sketch_encoded_z_vector,
                    hidden_cell=hidden_cell,
                    isTrain=False,
                    get_deterministic=True)
                batch_gen_strokes.append(state.squeeze(0))
            sketch2sketch_gen = torch.stack(batch_gen_strokes, dim=1)

            sketch_vector_gt = sketch_vector.permute(1, 0, 2)

            sketch_vector_gt_draw = batch_rasterize_relative(
                sketch_vector_gt).to(device)
            photo2sketch_gen_draw = batch_rasterize_relative(
                photo2sketch_gen).to(device)
            sketch2sketch_gen_draw = batch_rasterize_relative(
                sketch2sketch_gen).to(device)

            batch_redraw = []
            for a, b, c, d, e, f in zip(sketch_vector_gt_draw, rgb_image,
                                        photo2sketch_gen_draw,
                                        sketch2sketch_gen_draw,
                                        self_recons_photo, cross_recons_photo):
                batch_redraw.append(
                    torch.cat((1. - a, b, 1. - c, 1. - d, e, f), dim=-1))

            torchvision.utils.save_image(
                torch.stack(batch_redraw),
                './Redraw_Photo2Sketch/redraw_{}.jpg'.format(step),
                nrow=6)

        return sup_p2s_loss, sup_s2p_loss, short_p2p_loss, short_s2s_loss, kl_cost_1, kl_cost_2, loss
Exemple #2
0
    def pretrain_SketchBranch(self, iteration=100000):

        dataloader = get_sketchOnly_dataloader(self.hp)
        self.hp.max_seq_len = self.hp.sketch_rnn_max_seq_len
        self.Sketch_Encoder.train()
        self.Sketch_Decoder.train()
        self.train_sketch_params = list(
            self.Sketch_Encoder.parameters()) + list(
                self.Sketch_Decoder.parameters())
        self.sketch_optimizer = optim.Adam(self.train_sketch_params,
                                           self.hp.learning_rate)
        self.visalizer = Visualizer()

        for step in range(iteration):

            batch, lengths = dataloader.train_batch()

            self.sketch_optimizer.zero_grad()

            curr_learning_rate = (
                (self.hp.learning_rate - self.hp.min_learning_rate) *
                (self.hp.decay_rate)**step + self.hp.min_learning_rate)
            curr_kl_weight = (self.hp.kl_weight -
                              (self.hp.kl_weight - self.hp.kl_weight_start) *
                              (self.hp.kl_decay_rate)**step)

            post_dist = self.Sketch_Encoder(batch, lengths)

            z_vector = post_dist.rsample()
            start_token = torch.stack(
                [torch.Tensor([0, 0, 1, 0, 0])] *
                self.hp.batch_size_sketch_rnn).unsqueeze(0).to(device)
            batch_init = torch.cat([start_token, batch], 0)
            z_stack = torch.stack([z_vector] *
                                  (self.hp.sketch_rnn_max_seq_len + 1))
            inputs = torch.cat([batch_init, z_stack], 2)

            output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1)

            end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] *
                                    batch.shape[1]).unsqueeze(0).to(device)
            batch = torch.cat([batch, end_token], 0)
            x_target = batch.permute(1, 0,
                                     2)  # batch-> Seq_Len, Batch, Feature_dim

            #################### Loss Calculation ########################################
            ##############################################################################
            recons_loss = sketch_reconstruction_loss(output, x_target)

            prior_distribution = torch.distributions.Normal(
                torch.zeros_like(post_dist.mean),
                torch.ones_like(post_dist.stddev))
            kl_cost = torch.max(
                torch.distributions.kl_divergence(post_dist,
                                                  prior_distribution).mean(),
                torch.tensor(self.hp.kl_tolerance).to(device))
            loss = recons_loss + curr_kl_weight * kl_cost

            #################### Update Gradient ########################################
            #############################################################################
            set_learninRate(self.sketch_optimizer, curr_learning_rate)
            loss.backward()
            nn.utils.clip_grad_norm(self.train_sketch_params,
                                    self.hp.grad_clip)
            self.sketch_optimizer.step()

            if (step + 1) % 5 == 0:
                print('Step:{} ** KL_Loss:{} '
                      '** Recons_Loss:{} ** Total_loss:{}'.format(
                          step, kl_cost.item(), recons_loss.item(),
                          loss.item()))

                data = {}
                data['Reconstrcution_Loss'] = recons_loss
                data['KL_Loss'] = kl_cost
                data['Total Loss'] = loss
                self.visalizer.plot_scalars(data, step)

            if (step + 1) % self.hp.eval_freq_iter == 0:

                batch_input, batch_gen_strokes = self.sketch_generation_deterministic(
                    dataloader)
                # batch_input, batch_gen_strokes = self.sketch_generation_sample(dataloader)

                batch_redraw = batch_rasterize_relative(batch_gen_strokes)

                if batch_input is not None:
                    batch_input_redraw = batch_rasterize_relative(batch_input)
                    batch = []
                    for a, b in zip(batch_input_redraw, batch_redraw):
                        batch.append(torch.cat((a, 1. - b), dim=-1))
                    batch = torch.stack(batch).float()
                else:
                    batch = batch_redraw.float()

                torchvision.utils.save_image(
                    batch,
                    './pretrain_sketch_Viz/deterministic/batch_rceonstruction_'
                    + str(step) + '_.jpg',
                    nrow=round(math.sqrt(len(batch))))

                torch.save(self.Sketch_Encoder.state_dict(),
                           './pretrain_models/Sketch_Encoder.pth')
                torch.save(self.Sketch_Decoder.state_dict(),
                           './pretrain_models/Sketch_Decoder.pth')

                self.Sketch_Encoder.train()
                self.Sketch_Decoder.train()
Exemple #3
0
    def pretrain_SketchBranch_ShoeV2(self, iteration=10000):

        self.hp.batchsize = 100
        dataloader_Train, dataloader_Test = get_dataloader(self.hp)

        self.Sketch_Encoder.train()
        self.Sketch_Decoder.train()

        self.train_sketch_params = list(
            self.Sketch_Encoder.parameters()) + list(
                self.Sketch_Decoder.parameters())
        self.sketch_optimizer = optim.Adam(self.train_sketch_params,
                                           self.hp.learning_rate)

        self.visalizer = Visualizer()

        step = 0

        for i_epoch in range(2000):

            for batch_data in dataloader_Train:

                batch = batch_data['relative_fivePoint'].to(device).permute(
                    1, 0, 2).float()  # Seq_Len, Batch, Feature
                lengths = batch_data['sketch_length'].to(
                    device) - 1  # TODO: Relative coord has one less
                step += 1
                # batch, lengths = dataloader.train_batch()

                self.sketch_optimizer.zero_grad()

                curr_learning_rate = (
                    (self.hp.learning_rate - self.hp.min_learning_rate) *
                    (self.hp.decay_rate)**step + self.hp.min_learning_rate)
                curr_kl_weight = (
                    self.hp.kl_weight -
                    (self.hp.kl_weight - self.hp.kl_weight_start) *
                    (self.hp.kl_decay_rate)**step)

                post_dist = self.Sketch_Encoder(batch, lengths)

                z_vector = post_dist.rsample()
                start_token = torch.stack(
                    [torch.Tensor([0, 0, 1, 0, 0])] *
                    batch.shape[1]).unsqueeze(0).to(device)
                batch_init = torch.cat([start_token, batch], 0)
                z_stack = torch.stack([z_vector] * (self.hp.max_seq_len + 1))
                inputs = torch.cat([batch_init, z_stack], 2)

                output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1)

                end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] *
                                        batch.shape[1]).unsqueeze(0).to(device)
                batch = torch.cat([batch, end_token], 0)
                x_target = batch.permute(
                    1, 0, 2)  # batch-> Seq_Len, Batch, Feature_dim

                #################### Loss Calculation ########################################
                ##############################################################################
                recons_loss = sketch_reconstruction_loss(output, x_target)

                prior_distribution = torch.distributions.Normal(
                    torch.zeros_like(post_dist.mean),
                    torch.ones_like(post_dist.stddev))
                kl_cost = torch.max(
                    torch.distributions.kl_divergence(
                        post_dist, prior_distribution).mean(),
                    torch.tensor(self.hp.kl_tolerance).to(device))
                loss = recons_loss + curr_kl_weight * kl_cost

                #################### Update Gradient ########################################
                #############################################################################
                set_learninRate(self.sketch_optimizer, curr_learning_rate)
                loss.backward()
                nn.utils.clip_grad_norm(self.train_sketch_params,
                                        self.hp.grad_clip)
                self.sketch_optimizer.step()

                if (step + 1) % 5 == 0:
                    print('Step:{} ** KL_Loss:{} '
                          '** Recons_Loss:{} ** Total_loss:{}'.format(
                              step, kl_cost.item(), recons_loss.item(),
                              loss.item()))
                    data = {}
                    data['Reconstrcution_Loss'] = recons_loss
                    data['KL_Loss'] = kl_cost
                    data['Total Loss'] = loss
                    self.visalizer.plot_scalars(data, step)

                if (step - 1) % 1000 == 0:
                    """ Draw Sketch to Sketch """
                    start_token = torch.Tensor([0, 0, 1, 0,
                                                0]).view(-1, 5).to(device)
                    start_token = torch.stack([start_token] *
                                              z_vector.shape[0],
                                              dim=1)
                    state = start_token
                    hidden_cell = None

                    batch_gen_strokes = []
                    for i_seq in range(self.hp.average_seq_len):
                        input = torch.cat([state, z_vector.unsqueeze(0)], 2)
                        state, hidden_cell = self.Sketch_Decoder(
                            input,
                            z_vector,
                            hidden_cell=hidden_cell,
                            isTrain=False,
                            get_deterministic=True)
                        batch_gen_strokes.append(state.squeeze(0))

                    sketch2sketch_gen = torch.stack(batch_gen_strokes, dim=1)
                    sketch_vector_gt = batch.permute(1, 0, 2)

                    sketch_vector_gt_draw = batch_rasterize_relative(
                        sketch_vector_gt).to(device)
                    sketch2sketch_gen_draw = batch_rasterize_relative(
                        sketch2sketch_gen).to(device)

                    batch_redraw = []
                    for a, b in zip(sketch_vector_gt_draw,
                                    sketch2sketch_gen_draw):
                        batch_redraw.append(torch.cat((a, 1. - b), dim=-1))

                    torchvision.utils.save_image(
                        torch.stack(batch_redraw),
                        './pretrain_sketch_Viz/ShoeV2/redraw_{}.jpg'.format(
                            step),
                        nrow=8)

                    torch.save(self.Sketch_Encoder.state_dict(),
                               './pretrain_models/ShoeV2/Sketch_Encoder.pth')
                    torch.save(self.Sketch_Decoder.state_dict(),
                               './pretrain_models/ShoeV2/Sketch_Decoder.pth')

                    self.Sketch_Encoder.train()
                    self.Sketch_Decoder.train()
    def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch, step, sketch_name):

        self.train()
        self.optimizer.zero_grad()

        curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) *
                              (self.hp.decay_rate) ** step + self.hp.min_learning_rate)
        curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) *
                          (self.hp.kl_decay_rate) ** step)


        """ Encoding the Input """
        backbone_feature, rgb_encoded_dist = self.Image_Encoder(rgb_image)
        rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample()

        """ Ditribution Matching Loss """
        prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean),
                                                        torch.ones_like(rgb_encoded_dist.stddev))
        
        kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device))
        
        ##############################################################
        ##############################################################
        """ Cross Modal the Decoding """
        ##############################################################
        ##############################################################
        
        photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1)
        
        end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float()
        batch = torch.cat([sketch_vector, end_token], 0)
        x_target = batch.permute(1, 0, 2)  # batch-> Seq_Len, Batch, Feature_dim
        
        sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target)  #TODO: Photo to Sketch Loss
        
        loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb
        
        set_learninRate(self.optimizer, curr_learning_rate)
        loss.backward()
        nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip)
        self.optimizer.step()
        
        print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss,
                                                                               kl_cost_rgb, loss))


        if step%5 == 0:
        
            data = {}
            data['Reconstrcution_Loss'] = sup_p2s_loss
            data['KL_Loss'] = kl_cost_rgb
            data['Total Loss'] = loss
        
            self.visualizer.plot_scalars(data, step)


        if step%1 == 0:

            folder_name = os.path.join('./CVPR_SSL/' + '_'.join(sketch_name.split('/')[-1].split('_')[:-1]))
            if not os.path.exists(folder_name):
                os.makedirs(folder_name)

            sketch_vector_gt = sketch_vector.permute(1, 0, 2)

            save_sketch(sketch_vector_gt[0], sketch_name)


            with torch.no_grad():
                photo2sketch_gen, attention_plot  = \
                    self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch+1, isTrain=False)

            sketch_vector_gt = sketch_vector.permute(1, 0, 2)


            for num, len in enumerate(length_sketch):
                photo2sketch_gen[num, len:, 4 ] = 1.0
                photo2sketch_gen[num, len:, 2:4] = 0.0

            save_sketch_gen(photo2sketch_gen[0], sketch_name)

            sketch_vector_gt_draw = batch_rasterize_relative(sketch_vector_gt)
            photo2sketch_gen_draw = batch_rasterize_relative(photo2sketch_gen)

            batch_redraw = []
            plot_attention = showAttention(attention_plot, rgb_image, sketch_vector_gt_draw, photo2sketch_gen_draw, sketch_name)
            # max_image = 5
            # for a, b, c, d in zip(sketch_vector_gt_draw[:max_image], rgb_image.cpu()[:max_image],
            #                       photo2sketch_gen_draw[:max_image], plot_attention[:max_image]):
            #     batch_redraw.append(torch.cat((1. - a, b, 1. - c,  d), dim=-1))
            #
            # torchvision.utils.save_image(torch.stack(batch_redraw), './Redraw_Photo2Sketch_'
            #                              + self.hp.setup + '/redraw_{}.jpg'.format(step),
            #                              nrow=1, normalize=False)

            # data = {'attention_1': [], 'attention_2':[]}
            # for x in attention_plot:
            #     data['attention_1'].append(x[0])
            #     data['attention_2'].append(x[2])
            #
            # data['attention_1'] = torch.stack(data['attention_1'])
            # data['attention_2'] = torch.stack(data['attention_2'])
            #
            # self.visualizer.vis_image(data, step)



        # return sup_p2s_loss, kl_cost_rgb, loss

        return 0, 0, 0