Exemple #1
0
 def build_model(self):
     """Builds a generator and a discriminator."""
     self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
     self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
     #self.g12_gc = G12(self.config, conv_dim=self.g_conv_dim)
     #self.g21_gc = G21(self.config, conv_dim=self.g_conv_dim)
     self.d1 = D1(conv_dim=self.d_conv_dim)
     self.d2 = D2(conv_dim=self.d_conv_dim)
     self.d1_gc = D1(conv_dim=self.d_conv_dim)
     self.d2_gc = D2(conv_dim=self.d_conv_dim)
     
     g_params = list(self.g12.parameters()) + list(self.g21.parameters())
     d_params = list(self.d1.parameters()) + list(self.d2.parameters())
     d_gc_params = list(self.d1_gc.parameters()) + list(self.d2_gc.parameters())
     
     self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
     self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])
     self.d_gc_optimizer = optim.Adam(d_gc_params, self.lr, [self.beta1, self.beta2])
     
     if torch.cuda.is_available():
         self.g12.cuda()
         self.g21.cuda()
         #self.g12_gc.cuda()
         #self.g21_gc.cuda()
         self.d1.cuda()
         self.d2.cuda()
         self.d1_gc.cuda()
         self.d2_gc.cuda()
Exemple #2
0
    def test(self, svhn_test_loader, mnist_test_loader):
        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        svhn_test_iter = iter(svhn_test_loader)
        mnist_test_iter = iter(mnist_test_loader)
        index = 1
        fixed_svhn = self.to_var(svhn_test_iter.next()[index])
        fixed_mnist = self.to_var(mnist_test_iter.next()[index])
        g12_path = os.path.join(self.model_path, 'g12-%d.pkl' %(40000))
        g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(40000))
        self.g12 = G12(1, conv_dim=64) 
        self.g21 = G21(1, conv_dim=64) 
        self.g12.load_state_dict(torch.load(g12_path))
        self.g21.load_state_dict(torch.load(g21_path))

        self.g12.cuda()
        self.g21.cuda()
                
        fake_svhn = self.g12(fixed_mnist)
        fake_mnist = self.g21(fixed_svhn)
                
        mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
        svhn , fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)
                
                
        merged = self.merge_images(svhn, fake_mnist)
        path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(index))
        scipy.misc.imsave(path, merged)
        print ('saved %s' %path)
Exemple #3
0
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g12 = G12(conv_dim=self.g_conv_dim)
        self.g21 = G21(conv_dim=self.g_conv_dim)
        self.g32 = G32(conv_dim=self.g_conv_dim)
        self.g23 = G23(conv_dim=self.g_conv_dim)
        self.d3 = D3(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)

        g_params = list(self.g12.parameters()) + list(
            self.g21.parameters()) + list(self.g32.parameters()) + list(
                self.g23.parameters())
        d_params = list(self.d3.parameters()) + list(
            self.d2.parameters()) + list(self.d1.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.g32.cuda()
            self.g23.cuda()
            self.d3.cuda()
            self.d2.cuda()
            self.d1.cuda()
Exemple #4
0
    def build_model(self):
        """ Builds a generator and a discriminator. """
        self.g12 = G12(conv_dim=self.g_conv_dim)
        self.g21 = G21(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        """ Concatenate the lists and send as one """
        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()
            self.log_file.write("Cuda is available!\n")
            self.log_file.flush()
Exemple #5
0
def train(config):

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    #load data
    Real_Dataset = License_Real(config.real_path, 'real_train')
    Real_Dataloader = DataLoader(Real_Dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=config.num_workers)

    Virtual_Dataset = License_Virtual(config.virtual_path,
                                      'angle0_without_night')
    Virtual_Dataloader = DataLoader(Virtual_Dataset,
                                    batch_size=config.batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=config.num_workers)

    #model and optim
    Gv2r = G12(conv_dim=config.g_conv_dim)
    Gr2v = G21(conv_dim=config.g_conv_dim)
    Dv = D1(conv_dim=config.d_conv_dim)
    Dr = D2(conv_dim=config.d_conv_dim)
    Gv2r.train()
    Gr2v.train()
    Dv.train()
    Dr.train()

    seg_model = BiSeNet(66, 8, 'resnet101')
    seg_model = torch.nn.DataParallel(seg_model).cuda()
    seg_model.module.load_state_dict(torch.load(config.seg_model_path), True)
    seg_model.eval()
    print('seg model loaded!')

    g_params = list(Gv2r.parameters()) + list(Gr2v.parameters())
    d_params = list(Dv.parameters()) + list(Dr.parameters())

    g_optimizer = torch.optim.Adam(g_params, config.lr,
                                   [config.beta1, config.beta2])
    d_optimizer = torch.optim.Adam(d_params, config.lr,
                                   [config.beta1, config.beta2])

    if torch.cuda.is_available():
        Gv2r.cuda()
        Gr2v.cuda()
        Dv.cuda()
        Dr.cuda()

    #sample
    Real_iter = iter(Real_Dataloader)
    Virtual_iter = iter(Virtual_Dataloader)

    Real_sample_batch = Real_iter.next()
    Virtual_sample_batch, _, _ = Virtual_iter.next()

    Real_sample_batch = Real_sample_batch.cuda()
    Virtual_sample_batch = Virtual_sample_batch.cuda()
    # print(Real_sample_batch.shape)

    #train & criterion
    step = 0
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    tb_logger = SummaryWriter('./logs/{}'.format(exp_name))

    for each_epoch in range(config.train_epochs):
        for itr, (r_batch_data, (v_batch_data, v_batch_seg,
                                 v_batch_pos)) in enumerate(
                                     zip(Real_Dataloader, Virtual_Dataloader)):
            #============ train D ============#
            # train with real images
            r_batch_data = r_batch_data.cuda()
            # r_batch_seg = r_batch_seg.cuda()
            # r_batch_seg = torch.squeeze(r_batch_seg, 1)
            v_batch_data = v_batch_data.cuda()
            v_batch_seg = v_batch_seg.cuda()
            v_batch_seg = torch.squeeze(v_batch_seg, 1)
            # print(r_batch_data.shape)

            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            out = Dr(r_batch_data)
            dr_loss = torch.mean((out - 1)**2)

            out = Dv(v_batch_data)
            dv_loss = torch.mean((out - 1)**2)

            d_real_loss = dr_loss + dv_loss
            if step % config.D_up_step == 0:
                d_real_loss.backward()
                d_optimizer.step()

            # train with fake images
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_v = Gr2v(r_batch_data)
            out = Dv(fake_v)

            dv_loss = torch.mean(out**2)

            fake_r = Gv2r(v_batch_data)
            out = Dr(fake_r)

            dr_loss = torch.mean(out**2)

            d_fake_loss = dv_loss + dr_loss
            if step % config.D_up_step == 0:
                d_fake_loss.backward()
                d_optimizer.step()

            #============ train G ============#
            # train r-v-r cycle
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_v = Gr2v(r_batch_data)
            out = Dv(fake_v)
            reconst_r = Gv2r(fake_v)
            # seg_pred, _ = seg_model(F.interpolate(reconst_r, size=(50, 160)))
            # r_seg_loss = criterion(seg_pred, r_batch_seg.type(torch.long))

            # print(reconst_r.shape)

            g_loss = torch.mean((out - 1)**2) + torch.mean(
                (r_batch_data - reconst_r)**2)  # + 1.0*r_seg_loss
            g_loss.backward()
            g_optimizer.step()

            # train v-r-v cycle
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_r = Gv2r(v_batch_data)
            out = Dr(fake_r)
            reconst_v = Gr2v(fake_r)
            # seg_pred, _ = seg_model(F.interpolate(reconst_v, size=(50, 160), mode='bilinear'))
            v_seg_loss = 0  #criterion(seg_pred, v_batch_seg.type(torch.long))

            g_loss = torch.mean((out - 1)**2) + torch.mean(
                (v_batch_data - reconst_v)**2)  # + 0.01*v_seg_loss
            g_loss.backward()
            g_optimizer.step()

            #print the log
            tb_logger.add_scalar('d_real_loss', d_real_loss, step)
            tb_logger.add_scalar('d_fake_loss', d_fake_loss, step)
            tb_logger.add_scalar('g_loss', g_loss, step)
            tb_logger.add_scalar('dv_loss', dv_loss, step)
            tb_logger.add_scalar('dr_loss', dr_loss, step)
            # tb_logger.add_scalar('r_seg_loss', r_seg_loss, step)
            tb_logger.add_scalar('v_seg_loss', v_seg_loss, step)

            print(
                'step:{}, d_real_loss:{}, d_fake_loss:{}, g_loss:{}, dv_loss:{}, dr_loss:{}, v_seg_loss:{}'
                .format(step, d_real_loss, d_fake_loss, g_loss, dv_loss,
                        dr_loss, v_seg_loss))

            #save the sampled image
            if (step + 1) % config.sample_step == 0:
                fake_v = Gv2r(Virtual_sample_batch)
                fake_r = Gr2v(Real_sample_batch)

                fake_r_np = fake_r.cpu().detach().numpy() * 255
                fake_v_np = fake_v.cpu().detach().numpy() * 255

                real_r_np = Real_sample_batch.cpu().detach().numpy() * 255
                real_v_np = Virtual_sample_batch.cpu().detach().numpy() * 255
                # print(real_r_np.shape, real_v_np.shape)

                r_merged_image = merge_images(config, real_r_np, fake_r_np)
                v_merged_image = merge_images(config, real_v_np, fake_v_np)
                r_sample = r_merged_image.copy()
                v_sample = v_merged_image.copy()
                r_merged_image = transform2tensor(r_merged_image)
                v_merged_image = transform2tensor(v_merged_image)
                x1 = vutils.make_grid(r_merged_image,
                                      normalize=True,
                                      scale_each=True)
                x2 = vutils.make_grid(v_merged_image,
                                      normalize=True,
                                      scale_each=True)

                tb_logger.add_image('r_Imgs', x1, step + 1)
                tb_logger.add_image('v_Imgs', x2, step + 1)
                # print(r_merged_image.shape, v_merged_image.shape)
                # save sample

            if (step + 1) % config.save_step == 0:
                Gv2r_path = os.path.join(
                    config.model_path,
                    'Gv2r-{}-{}.pkl'.format(exp_name, step + 1))
                Gr2v_path = os.path.join(
                    config.model_path,
                    'Gr2v-{}-{}.pkl'.format(exp_name, step + 1))
                Dr_path = os.path.join(
                    config.model_path,
                    'Dr-{}-{}.pkl'.format(exp_name, step + 1))
                Dv_path = os.path.join(
                    config.model_path,
                    'Dv-{}-{}.pkl'.format(exp_name, step + 1))
                torch.save(Gv2r.state_dict(), Gv2r_path)
                torch.save(Gr2v.state_dict(), Gr2v_path)
                torch.save(Dr.state_dict(), Dr_path)
                torch.save(Dv.state_dict(), Dv_path)

                cv2.imwrite(
                    os.path.join(
                        config.sample_path,
                        'r_sample_{}.png'.format(str(step + 1).zfill(5))),
                    cv2.cvtColor(r_sample, cv2.COLOR_BGR2RGB))
                cv2.imwrite(
                    os.path.join(
                        config.sample_path,
                        'v_sample_{}.png'.format(str(step + 1).zfill(5))),
                    cv2.cvtColor(v_sample, cv2.COLOR_BGR2RGB))

        #     r_np_data = r_batch_data.cpu().detach().numpy()
        #     v_np_data = v_batch_data.cpu().detach().numpy()
        #     merged_image = merge_images(config, r_np_data, v_np_data)
        #     cv2.imwrite(os.path.join(config.sample_path, 'sample_{}.png'.format(str(step).zfill(5))), merged_image)
            step += 1