示例#1
0
 def build_model(self):
     """Builds a generator and a discriminator."""
     self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
     self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
     #self.g12_gc = G12(self.config, conv_dim=self.g_conv_dim)
     #self.g21_gc = G21(self.config, conv_dim=self.g_conv_dim)
     self.d1 = D1(conv_dim=self.d_conv_dim)
     self.d2 = D2(conv_dim=self.d_conv_dim)
     self.d1_gc = D1(conv_dim=self.d_conv_dim)
     self.d2_gc = D2(conv_dim=self.d_conv_dim)
     
     g_params = list(self.g12.parameters()) + list(self.g21.parameters())
     d_params = list(self.d1.parameters()) + list(self.d2.parameters())
     d_gc_params = list(self.d1_gc.parameters()) + list(self.d2_gc.parameters())
     
     self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
     self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])
     self.d_gc_optimizer = optim.Adam(d_gc_params, self.lr, [self.beta1, self.beta2])
     
     if torch.cuda.is_available():
         self.g12.cuda()
         self.g21.cuda()
         #self.g12_gc.cuda()
         #self.g21_gc.cuda()
         self.d1.cuda()
         self.d2.cuda()
         self.d1_gc.cuda()
         self.d2_gc.cuda()
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2.load_state_dict(torch.load(self.d2_load_path))
        self.d_optimizer = optim.Adam(list(self.d2.parameters()), self.lr_d,
                                      [self.beta1, self.beta2])

        self.net = skip(num_input_channels=1,
                        num_output_channels=1,
                        num_channels_down=[64, 128],
                        num_channels_up=[64, 128],
                        num_channels_skip=[0, 0],
                        upsample_mode='bilinear',
                        need_sigmoid=True,
                        need_bias=True,
                        pad='reflection',
                        act_fun='LeakyReLU')

        self.net.load_state_dict(torch.load(self.net_load_path))
        self.net_optimizer = optim.Adam(list(self.net.parameters()),
                                        lr=self.lr_net)
        self.unshared_optimizer = optim.Adam(
            list(self.net.unshared_parameters()), self.lr_net,
            [self.beta1, self.beta2])

        self.mse = torch.nn.MSELoss()

        # s = sum([np.prod(list(p.size())) for p in self.net.parameters()])
        # print('Number of params in the main network: %d' % s)

        if torch.cuda.is_available():
            self.d2.cuda()
            self.net.cuda()
            self.mse.cuda()
示例#3
0
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g12 = G12(conv_dim=self.g_conv_dim)
        self.g21 = G21(conv_dim=self.g_conv_dim)
        self.g32 = G32(conv_dim=self.g_conv_dim)
        self.g23 = G23(conv_dim=self.g_conv_dim)
        self.d3 = D3(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)

        g_params = list(self.g12.parameters()) + list(
            self.g21.parameters()) + list(self.g32.parameters()) + list(
                self.g23.parameters())
        d_params = list(self.d3.parameters()) + list(
            self.d2.parameters()) + list(self.d1.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.g32.cuda()
            self.g23.cuda()
            self.d3.cuda()
            self.d2.cuda()
            self.d1.cuda()
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g11 = skip(
            num_input_channels=1, num_output_channels=1,
            num_channels_down=[64, 128],
            num_channels_up=[64, 128],
            num_channels_skip=[0, 0],
            upsample_mode='bilinear',
            need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU')

        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        g_params = list(self.g11.parameters())
        d_params = list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

        if self.config.continue_training:
            self.d2.load_state_dict(torch.load(self.d2_load_path))
            self.g11.load_state_dict(torch.load(self.g11_load_path))

        if torch.cuda.is_available():
            self.g11.cuda()
            self.d2.cuda()
示例#5
0
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g12 = G12(conv_dim=self.g_conv_dim)
        init_weights(self.g12, init_type='normal')
        self.g21 = G21(conv_dim=self.g_conv_dim)
        init_weights(self.g21, init_type='normal')
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        init_weights(self.d1, init_type='normal')
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        init_weights(self.d2, init_type='normal')
        self.dreid = DSiamese(class_count=self.num_classes_market)

        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())
        dr_params = list(self.dreid.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr,
                                      [self.beta1, self.beta2])
        self.dr_optimizer = optim.Adam(dr_params, self.lr,
                                       [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()
            self.dreid.cuda()
示例#6
0
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g = G(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim)
        self.d2 = D2(conv_dim=self.d_conv_dim)

        g_params = list(self.g.parameters())
        d1_params = list(self.d1.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())
        self.gc_optimizer = optim.Adam(g_params, 0.001, [0.5, 0.999])
        self.g_optimizer = optim.Adam(g_params, self.lr, [0.5, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [0.5, self.beta2])

        if torch.cuda.is_available():
            self.g.cuda()
            self.d1.cuda()
            self.d2.cuda()
    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g11 = G11(conv_dim=self.g_conv_dim)
        self.g22 = G22(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        g_params = list(self.g11.parameters()) + list(self.g22.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g11.cuda()
            self.g22.cuda()
            self.d1.cuda()
            self.d2.cuda()
示例#8
0
    def build_model(self):
        """ Builds a generator and a discriminator. """
        self.g12 = G12(conv_dim=self.g_conv_dim)
        self.g21 = G21(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=self.use_labels)
        """ Concatenate the lists and send as one """
        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()
            self.log_file.write("Cuda is available!\n")
            self.log_file.flush()
示例#9
0
    def build_model(self):
        """Builds a generator and a discriminator."""
        # self.g11 = G11(conv_dim=self.g_conv_dim)
        self.g22 = skip(
            num_input_channels=3, num_output_channels=3,
            num_channels_down=[8, 16, 32],
            num_channels_up=[8, 16, 32],
            num_channels_skip=[0, 0, 0],
            upsample_mode='bilinear',
            need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU')
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        g_params = list(self.g22.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

        if torch.cuda.is_available():
            # self.g11.cuda()
            self.g22.cuda()
            self.d1.cuda()
            self.d2.cuda()
示例#10
0
def train(config):

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    #load data
    Real_Dataset = License_Real(config.real_path, 'real_train')
    Real_Dataloader = DataLoader(Real_Dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=config.num_workers)

    Virtual_Dataset = License_Virtual(config.virtual_path,
                                      'angle0_without_night')
    Virtual_Dataloader = DataLoader(Virtual_Dataset,
                                    batch_size=config.batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=config.num_workers)

    #model and optim
    Gv2r = G12(conv_dim=config.g_conv_dim)
    Gr2v = G21(conv_dim=config.g_conv_dim)
    Dv = D1(conv_dim=config.d_conv_dim)
    Dr = D2(conv_dim=config.d_conv_dim)
    Gv2r.train()
    Gr2v.train()
    Dv.train()
    Dr.train()

    seg_model = BiSeNet(66, 8, 'resnet101')
    seg_model = torch.nn.DataParallel(seg_model).cuda()
    seg_model.module.load_state_dict(torch.load(config.seg_model_path), True)
    seg_model.eval()
    print('seg model loaded!')

    g_params = list(Gv2r.parameters()) + list(Gr2v.parameters())
    d_params = list(Dv.parameters()) + list(Dr.parameters())

    g_optimizer = torch.optim.Adam(g_params, config.lr,
                                   [config.beta1, config.beta2])
    d_optimizer = torch.optim.Adam(d_params, config.lr,
                                   [config.beta1, config.beta2])

    if torch.cuda.is_available():
        Gv2r.cuda()
        Gr2v.cuda()
        Dv.cuda()
        Dr.cuda()

    #sample
    Real_iter = iter(Real_Dataloader)
    Virtual_iter = iter(Virtual_Dataloader)

    Real_sample_batch = Real_iter.next()
    Virtual_sample_batch, _, _ = Virtual_iter.next()

    Real_sample_batch = Real_sample_batch.cuda()
    Virtual_sample_batch = Virtual_sample_batch.cuda()
    # print(Real_sample_batch.shape)

    #train & criterion
    step = 0
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    tb_logger = SummaryWriter('./logs/{}'.format(exp_name))

    for each_epoch in range(config.train_epochs):
        for itr, (r_batch_data, (v_batch_data, v_batch_seg,
                                 v_batch_pos)) in enumerate(
                                     zip(Real_Dataloader, Virtual_Dataloader)):
            #============ train D ============#
            # train with real images
            r_batch_data = r_batch_data.cuda()
            # r_batch_seg = r_batch_seg.cuda()
            # r_batch_seg = torch.squeeze(r_batch_seg, 1)
            v_batch_data = v_batch_data.cuda()
            v_batch_seg = v_batch_seg.cuda()
            v_batch_seg = torch.squeeze(v_batch_seg, 1)
            # print(r_batch_data.shape)

            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            out = Dr(r_batch_data)
            dr_loss = torch.mean((out - 1)**2)

            out = Dv(v_batch_data)
            dv_loss = torch.mean((out - 1)**2)

            d_real_loss = dr_loss + dv_loss
            if step % config.D_up_step == 0:
                d_real_loss.backward()
                d_optimizer.step()

            # train with fake images
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_v = Gr2v(r_batch_data)
            out = Dv(fake_v)

            dv_loss = torch.mean(out**2)

            fake_r = Gv2r(v_batch_data)
            out = Dr(fake_r)

            dr_loss = torch.mean(out**2)

            d_fake_loss = dv_loss + dr_loss
            if step % config.D_up_step == 0:
                d_fake_loss.backward()
                d_optimizer.step()

            #============ train G ============#
            # train r-v-r cycle
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_v = Gr2v(r_batch_data)
            out = Dv(fake_v)
            reconst_r = Gv2r(fake_v)
            # seg_pred, _ = seg_model(F.interpolate(reconst_r, size=(50, 160)))
            # r_seg_loss = criterion(seg_pred, r_batch_seg.type(torch.long))

            # print(reconst_r.shape)

            g_loss = torch.mean((out - 1)**2) + torch.mean(
                (r_batch_data - reconst_r)**2)  # + 1.0*r_seg_loss
            g_loss.backward()
            g_optimizer.step()

            # train v-r-v cycle
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            fake_r = Gv2r(v_batch_data)
            out = Dr(fake_r)
            reconst_v = Gr2v(fake_r)
            # seg_pred, _ = seg_model(F.interpolate(reconst_v, size=(50, 160), mode='bilinear'))
            v_seg_loss = 0  #criterion(seg_pred, v_batch_seg.type(torch.long))

            g_loss = torch.mean((out - 1)**2) + torch.mean(
                (v_batch_data - reconst_v)**2)  # + 0.01*v_seg_loss
            g_loss.backward()
            g_optimizer.step()

            #print the log
            tb_logger.add_scalar('d_real_loss', d_real_loss, step)
            tb_logger.add_scalar('d_fake_loss', d_fake_loss, step)
            tb_logger.add_scalar('g_loss', g_loss, step)
            tb_logger.add_scalar('dv_loss', dv_loss, step)
            tb_logger.add_scalar('dr_loss', dr_loss, step)
            # tb_logger.add_scalar('r_seg_loss', r_seg_loss, step)
            tb_logger.add_scalar('v_seg_loss', v_seg_loss, step)

            print(
                'step:{}, d_real_loss:{}, d_fake_loss:{}, g_loss:{}, dv_loss:{}, dr_loss:{}, v_seg_loss:{}'
                .format(step, d_real_loss, d_fake_loss, g_loss, dv_loss,
                        dr_loss, v_seg_loss))

            #save the sampled image
            if (step + 1) % config.sample_step == 0:
                fake_v = Gv2r(Virtual_sample_batch)
                fake_r = Gr2v(Real_sample_batch)

                fake_r_np = fake_r.cpu().detach().numpy() * 255
                fake_v_np = fake_v.cpu().detach().numpy() * 255

                real_r_np = Real_sample_batch.cpu().detach().numpy() * 255
                real_v_np = Virtual_sample_batch.cpu().detach().numpy() * 255
                # print(real_r_np.shape, real_v_np.shape)

                r_merged_image = merge_images(config, real_r_np, fake_r_np)
                v_merged_image = merge_images(config, real_v_np, fake_v_np)
                r_sample = r_merged_image.copy()
                v_sample = v_merged_image.copy()
                r_merged_image = transform2tensor(r_merged_image)
                v_merged_image = transform2tensor(v_merged_image)
                x1 = vutils.make_grid(r_merged_image,
                                      normalize=True,
                                      scale_each=True)
                x2 = vutils.make_grid(v_merged_image,
                                      normalize=True,
                                      scale_each=True)

                tb_logger.add_image('r_Imgs', x1, step + 1)
                tb_logger.add_image('v_Imgs', x2, step + 1)
                # print(r_merged_image.shape, v_merged_image.shape)
                # save sample

            if (step + 1) % config.save_step == 0:
                Gv2r_path = os.path.join(
                    config.model_path,
                    'Gv2r-{}-{}.pkl'.format(exp_name, step + 1))
                Gr2v_path = os.path.join(
                    config.model_path,
                    'Gr2v-{}-{}.pkl'.format(exp_name, step + 1))
                Dr_path = os.path.join(
                    config.model_path,
                    'Dr-{}-{}.pkl'.format(exp_name, step + 1))
                Dv_path = os.path.join(
                    config.model_path,
                    'Dv-{}-{}.pkl'.format(exp_name, step + 1))
                torch.save(Gv2r.state_dict(), Gv2r_path)
                torch.save(Gr2v.state_dict(), Gr2v_path)
                torch.save(Dr.state_dict(), Dr_path)
                torch.save(Dv.state_dict(), Dv_path)

                cv2.imwrite(
                    os.path.join(
                        config.sample_path,
                        'r_sample_{}.png'.format(str(step + 1).zfill(5))),
                    cv2.cvtColor(r_sample, cv2.COLOR_BGR2RGB))
                cv2.imwrite(
                    os.path.join(
                        config.sample_path,
                        'v_sample_{}.png'.format(str(step + 1).zfill(5))),
                    cv2.cvtColor(v_sample, cv2.COLOR_BGR2RGB))

        #     r_np_data = r_batch_data.cpu().detach().numpy()
        #     v_np_data = v_batch_data.cpu().detach().numpy()
        #     merged_image = merge_images(config, r_np_data, v_np_data)
        #     cv2.imwrite(os.path.join(config.sample_path, 'sample_{}.png'.format(str(step).zfill(5))), merged_image)
            step += 1