Esempio n. 1
0
    def train(self, train_dataset, test_data, test_dataset, output_dir):
        tracker = LossTracker(output_dir)
        while self.res_idx < len(self.cfg['resolutions']):
            res = self.cfg['resolutions'][self.res_idx]
            self.set_optimizers_lr(self.cfg['learning_rates'][self.res_idx])
            batch_size = self.cfg['batch_sizes'][self.res_idx]
            batchs_in_phase = self.cfg['phase_lengths'][
                self.res_idx] // batch_size
            dataloader = EndlessDataloader(
                get_dataloader(train_dataset,
                               batch_size,
                               resize=res,
                               device=self.device))
            progress_bar = tqdm(range(batchs_in_phase * 2))
            for i in progress_bar:
                # first half of the batchs are fade in phase where alpha < 1. in the second half alpha =1
                alpha = min(1.0, i / batchs_in_phase)
                batch_real_data = dataloader.next()
                self.perform_train_step(batch_real_data,
                                        tracker,
                                        log=(i % 10 == 0),
                                        calc_scores=(i % 100 == 0),
                                        valid_ds=test_dataset,
                                        final_resolution_idx=self.res_idx,
                                        alpha=alpha)

                self.train_step += 1
                progress_tag = f"gs-{self.train_step}_res-{self.res_idx}={res}x{res}_alpha-{alpha:.2f}"
                progress_bar.set_description(progress_tag)

                if self.train_step % self.cfg['dump_imgs_freq'] == 0:
                    tracker.plot()
                    dump_path = os.path.join(output_dir, 'images',
                                             f"{progress_tag}.jpg")
                    self.save_sample(dump_path,
                                     test_data[0],
                                     test_data[1],
                                     final_resolution_idx=self.res_idx,
                                     alpha=alpha)

                if self.train_step % self.cfg['checkpoint_freq'] == 0:
                    self.save_train_state(
                        os.path.join(output_dir, 'checkpoints',
                                     f"ckpt_{progress_tag}.pt"))
            self.res_idx += 1
        self.save_train_state(
            os.path.join(output_dir, 'checkpoints', f"ckpt_final.pt"))
Esempio n. 2
0
    def train(self, train_dataset, test_data, output_dir):
        train_dataloader = get_dataloader(train_dataset,
                                          self.cfg['batch_size'],
                                          resize=None,
                                          device=self.device)
        tracker = LossTracker(output_dir)
        self.set_optimizers_lr(self.cfg['lr'])
        for epoch in range(self.cfg['epochs']):
            for batch_real_data in tqdm(train_dataloader):
                self.perform_train_step(batch_real_data, tracker)

            tracker.plot()
            dump_path = os.path.join(output_dir, 'images',
                                     f"epoch-{epoch}.jpg")
            self.save_sample(dump_path, test_data[0], test_data[1])

            self.save_train_state(os.path.join(output_dir, "last_ckp.pth"))
Esempio n. 3
0
    def train(self, train_dataset, test_data, output_dir):
        tracker = LossTracker(output_dir)
        global_steps = 0
        for res_idx, res in enumerate(self.cfg['resolutions']):
            self.set_optimizers_lr(self.cfg['learning_rates'][res_idx])
            batchs_in_phase = self.cfg['phase_lengths'][res_idx] // self.cfg['batch_sizes'][res_idx]
            dataloader = EndlessDataloader(get_dataloader(train_dataset, self.cfg['batch_sizes'][res_idx], resize=res, device=self.device))
            progress_bar = tqdm(range(batchs_in_phase * 2))
            for i in progress_bar:
                alpha = min(1.0, i / batchs_in_phase)  # < 1 in the first half and 1 in the second
                progress_bar.set_description(f"gs-{global_steps}_res-{res_idx}={res}x{res}_alpha-{alpha:.3f}")
                batch_real_data = dataloader.next()

                # train discriminator
                self.D_optimizer.zero_grad()
                loss_d = self.get_D_loss(batch_real_data, res_idx, alpha)
                loss_d.backward()
                self.D_optimizer.step()
                tracker.update(dict(loss_d=loss_d))

                if (1+i) % self.cfg['n_critic'] == 0:
                    # train generator
                    self.G_optimizer.zero_grad()
                    loss_g = self.get_G_loss(batch_real_data, res_idx, alpha)
                    loss_g.backward()
                    self.G_optimizer.step()
                    tracker.update(dict(loss_g=loss_g))
                global_steps += 1
                if global_steps % self.cfg['dump_imgs_freq'] == 0:
                    self.save_sample(global_steps, tracker, test_data, output_dir, res_idx, alpha)
            self.save_train_state(os.path.join(output_dir, 'checkpoints', f"ckpt_res-{res_idx}={res}x{res}-end.pt"))
Esempio n. 4
0
def train(folding_id, inliner_classes, ic):
    cfg = get_cfg_defaults()
    cfg.merge_from_file('configs/mnist.yaml')
    cfg.freeze()
    logger = logging.getLogger("logger")

    zsize = cfg.MODEL.LATENT_SIZE
    output_folder = os.path.join('results_' + str(folding_id) + "_" +
                                 "_".join([str(x) for x in inliner_classes]))
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs('models', exist_ok=True)

    train_set, _, _ = make_datasets(cfg, folding_id, inliner_classes)

    logger.info("Train set size: %d" % len(train_set))

    G = Generator(cfg.MODEL.LATENT_SIZE,
                  channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    G.weight_init(mean=0, std=0.02)

    D = Discriminator(channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    D.weight_init(mean=0, std=0.02)

    E = Encoder(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    E.weight_init(mean=0, std=0.02)

    if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH:
        ZD = ZDiscriminator_mergebatch(zsize, cfg.TRAIN.BATCH_SIZE)
    else:
        ZD = ZDiscriminator(zsize, cfg.TRAIN.BATCH_SIZE)
    ZD.weight_init(mean=0, std=0.02)

    lr = cfg.TRAIN.BASE_LEARNING_RATE

    G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    GE_optimizer = optim.Adam(list(E.parameters()) + list(G.parameters()),
                              lr=lr,
                              betas=(0.5, 0.999))
    ZD_optimizer = optim.Adam(ZD.parameters(), lr=lr, betas=(0.5, 0.999))

    BCE_loss = nn.BCELoss()
    sample = torch.randn(64, zsize).view(-1, zsize, 1, 1)

    tracker = LossTracker(output_folder=output_folder)

    for epoch in range(cfg.TRAIN.EPOCH_COUNT):
        G.train()
        D.train()
        E.train()
        ZD.train()

        epoch_start_time = time.time()

        data_loader = make_dataloader(train_set, cfg.TRAIN.BATCH_SIZE,
                                      torch.cuda.current_device())
        train_set.shuffle()

        if (epoch + 1) % 30 == 0:
            G_optimizer.param_groups[0]['lr'] /= 4
            D_optimizer.param_groups[0]['lr'] /= 4
            GE_optimizer.param_groups[0]['lr'] /= 4
            ZD_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        for y, x in data_loader:
            x = x.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS,
                       cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)

            y_real_ = torch.ones(x.shape[0])
            y_fake_ = torch.zeros(x.shape[0])

            y_real_z = torch.ones(
                1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])
            y_fake_z = torch.zeros(
                1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])

            #############################################

            D.zero_grad()

            D_result = D(x).squeeze()
            D_real_loss = BCE_loss(D_result, y_real_)

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z).detach()
            D_result = D(x_fake).squeeze()
            D_fake_loss = BCE_loss(D_result, y_fake_)

            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()

            D_optimizer.step()

            tracker.update(dict(D=D_train_loss))

            #############################################

            G.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z)
            D_result = D(x_fake).squeeze()

            G_train_loss = BCE_loss(D_result, y_real_)

            G_train_loss.backward()
            G_optimizer.step()

            tracker.update(dict(G=G_train_loss))

            #############################################

            ZD.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize)
            z = Variable(z)

            ZD_result = ZD(z).squeeze()
            ZD_real_loss = BCE_loss(ZD_result, y_real_z)

            z = E(x).squeeze().detach()

            ZD_result = ZD(z).squeeze()
            ZD_fake_loss = BCE_loss(ZD_result, y_fake_z)

            ZD_train_loss = ZD_real_loss + ZD_fake_loss
            ZD_train_loss.backward()

            ZD_optimizer.step()

            tracker.update(dict(ZD=ZD_train_loss))

            # #############################################

            E.zero_grad()
            G.zero_grad()

            z = E(x)
            x_d = G(z)

            ZD_result = ZD(z.squeeze()).squeeze()

            E_train_loss = BCE_loss(ZD_result, y_real_z) * 1.0

            Recon_loss = F.binary_cross_entropy(x_d, x.detach()) * 2.0

            (Recon_loss + E_train_loss).backward()

            GE_optimizer.step()

            tracker.update(dict(GE=Recon_loss, E=E_train_loss))

            # #############################################

        comparison = torch.cat([x, x_d])
        save_image(comparison.cpu(),
                   os.path.join(output_folder,
                                'reconstruction_' + str(epoch) + '.png'),
                   nrow=x.shape[0])

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        logger.info(
            '[%d/%d] - ptime: %.2f, %s' %
            ((epoch + 1), cfg.TRAIN.EPOCH_COUNT, per_epoch_ptime, tracker))

        tracker.register_means(epoch)
        tracker.plot()

        with torch.no_grad():
            resultsample = G(sample).cpu()
            save_image(
                resultsample.view(64, cfg.MODEL.INPUT_IMAGE_CHANNELS,
                                  cfg.MODEL.INPUT_IMAGE_SIZE,
                                  cfg.MODEL.INPUT_IMAGE_SIZE),
                os.path.join(output_folder, 'sample_' + str(epoch) + '.png'))

    logger.info("Training finish!... save training results")

    os.makedirs("models", exist_ok=True)

    print("Training finish!... save training results")
    torch.save(G.state_dict(), "models/Gmodel_%d_%d.pkl" % (folding_id, ic))
    torch.save(E.state_dict(), "models/Emodel_%d_%d.pkl" % (folding_id, ic))