def test(self):
        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas = get_sigmas(self.config)

        dataset, test_dataset = get_dataset(self.args, self.config)
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.config.test.batch_size,
            shuffle=True,
            num_workers=self.config.data.num_workers,
            drop_last=True,
        )

        verbose = False
        for ckpt in tqdm.tqdm(
            range(self.config.test.begin_ckpt, self.config.test.end_ckpt + 1, 5000),
            desc="processing ckpt:",
        ):
            states = torch.load(
                os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"),
                map_location=self.config.device,
            )

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            step = 0
            mean_loss = 0.0
            mean_grad_norm = 0.0
            average_grad_scale = 0.0
            for x, y in test_dataloader:
                step += 1

                x = x.to(self.config.device)
                x = data_transform(self.config, x)

                test_loss = anneal_sliced_score_estimation_vr(
                    score, x, sigmas, None, self.config.training.anneal_power
                )
                if verbose:
                    logging.info(
                        "step: {}, test_loss: {}".format(step, test_loss.item())
                    )

                mean_loss += test_loss.item()

            mean_loss /= step
            mean_grad_norm /= step
            average_grad_scale /= step

            logging.info("ckpt: {}, average test loss: {}".format(ckpt, mean_loss))
Example #2
0
    def sample(self):
        model = Model(self.config)

        if not self.args.use_pretrained:
            if getattr(self.config.sampling, "ckpt_id", None) is None:
                states = paddle.load(
                    os.path.join(self.args.log_path, "ckpt.pdl"))
            else:
                states = paddle.load(
                    os.path.join(self.args.log_path,
                                 f"ckpt_{self.config.sampling.ckpt_id}.pdl"))
            model = model
            model = paddle.DataParallel(model)
            model.set_state_dict({
                k.split("$model_")[-1]: v
                for k, v in states.items() if "$model_" in k
            })

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(model)
                ema_helper.set_state_dict({
                    k.split("$ema_")[-1]: v
                    for k, v in states.items() if "$ema_" in k
                })
                ema_helper.ema(model)
            else:
                ema_helper = None
        else:
            # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
            if self.config.data.dataset == "CIFAR10":
                name = "cifar10"
            elif self.config.data.dataset == "LSUN":
                name = f"lsun_{self.config.data.category}"
            else:
                raise ValueError
            ckpt = get_ckpt_path(f"ema_{name}")
            print("Loading checkpoint {}".format(ckpt))
            model.set_state_dict(paddle.load(ckpt))
            model = paddle.DataParallel(model)

        model.eval()

        if self.args.fid:
            self.sample_fid(model)
        elif self.args.interpolation:
            self.sample_interpolation(model)
        elif self.args.sequence:
            self.sample_sequence(model)
        else:
            raise NotImplementedError("Sample procedeure not defined")
Example #3
0
    def _load_states(self, score):
        if self.config.sampling.ckpt_id is None:
            path = os.path.join(self.args.log_path, 'checkpoint.pth')
        else:
            path = os.path.join(self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth')
        states = torch.load(path, map_location=self.args.device)

        # score.load_state_dict(states[0], strict=True)

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)
            ema_helper.load_state_dict(states[-1])
            ema_helper.ema(score)
        else:
            score.load_state_dict(states[0])

        del states

        return score
Example #4
0
    def sample(self):
        D = DensityRatioEstNet(self.config.model.ngf_d, self.config.data.image_size, self.config.data.channels).to(self.device)
        D.load_state_dict(torch.load(os.path.join(self.args.log_path, f"ckpt_DRE_{self.args.sigma_sq}_{self.args.tau}.pth"))['D'])
        S = Model(self.config)

        if getattr(self.config.sampling, "ckpt_id", None) is None:
            states = torch.load(
                os.path.join(self.args.log_path, "ckpt.pth"),
                map_location=self.config.device,
            )
        else:
            states = torch.load(
                os.path.join(
                    self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth"
                ),
                map_location=self.config.device,
            )
        S = S.to(self.device)
        S = torch.nn.DataParallel(S)
        S.load_state_dict(states[0], strict=True)

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(S)
            ema_helper.load_state_dict(states[-1])
            ema_helper.ema(S)
        else:
            ema_helper = None

        S.eval()

        if self.args.fid:
            self.sample_fid(S, D)
        elif self.args.interpolation:
            self.sample_interpolation(S)
        elif self.args.inpainting:
            self.sample_inpainting(S)
        elif self.args.sbp:
            self.sample_sbp(S, D)
        else:
            raise NotImplementedError("Sample procedeure not defined")
Example #5
0
    def train_s(self):
        dataset, test_dataset = get_dataset(self.args, self.config)
        train_loader = data.DataLoader(
            dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=self.config.data.num_workers,
        )
        S = Model(self.config)

        S = S.to(self.device)
        S = torch.nn.DataParallel(S)

        optimizer = get_optimizer(self.config, S.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(S)
        else:
            ema_helper = None

        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
            S.load_state_dict(states[0])

            states[1]["param_groups"][0]["eps"] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        S.train()

        for epoch in range(start_epoch, self.config.training.n_epochs):
            for i, (x, y) in enumerate(train_loader):
                n = x.size(0)
                step += 1

                x = x.to(self.device)
                x = x - 0.5
                e = torch.randn_like(x)

                # antithetic sampling
                t = torch.randint(
                    low=0, high=self.num_timesteps, size=(n // 2 + 1,)
                ).to(self.device)
                t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
                loss = noise_estimation_loss(S, x, t.long(), e, self.sigma)

                if not step % 100:
                    logging.info(f"step: {step}, loss: {loss.item()}")

                optimizer.zero_grad()
                loss.backward()

                try:
                    torch.nn.utils.clip_grad_norm_(
                        S.parameters(), self.config.optim.grad_clip
                    )
                except Exception:
                    pass
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(S)

                if step % self.config.training.snapshot_freq == 0 or step == 1:
                    states = [
                        S.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(
                        states,
                        os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
                    )
                    torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

                data_start = time.time()
    def train(self):
        obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32)
        input_channels = obs[0]
        train_loader, test_loader = dataset.get_dataset(self.config)
        model = PixelCNN(self.config)
        model = model.to(self.config.device)
        model = torch.nn.DataParallel(model)
        sample_model = partial(model, sample=True)

        rescaling_inv = lambda x: .5 * x + .5
        rescaling = lambda x: (x - .5) * 2.

        if 'MNIST' in self.config.dataset:
            loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake)
            clamp = False
            sample_op = lambda x: sample_from_discretized_mix_logistic_1d(
                x, sample_model, self.config.nr_logistic_mix, clamp=clamp)

        elif 'CIFAR10' in self.config.dataset:
            loss_op = lambda real, fake: mix_logistic_loss(real, fake)
            clamp = False
            sample_op = lambda x: sample_from_discretized_mix_logistic(
                x, sample_model, self.config.nr_logistic_mix, clamp=clamp)

        elif 'celeba' in self.config.dataset:
            loss_op = lambda real, fake: mix_logistic_loss(real, fake)
            clamp = False
            sample_op = lambda x: sample_from_discretized_mix_logistic(
                x, sample_model, self.config.nr_logistic_mix, clamp=clamp)
        else:
            raise Exception(
                '{} dataset not in {mnist, cifar10, celeba}'.format(
                    self.config.dataset))

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None

        optimizer = optim.Adam(model.parameters(), lr=self.config.lr)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=1,
                                        gamma=self.config.lr_decay)

        ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts')
        if not os.path.exists(ckpt_path):
            os.makedirs(ckpt_path)

        if self.args.resume_training:
            state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'),
                                    map_location=self.config.device)
            model.load_state_dict(state_dict[0])
            optimizer.load_state_dict(state_dict[1])
            scheduler.load_state_dict(state_dict[2])
            if len(state_dict) > 3:
                epoch = state_dict[3]
                if self.config.model.ema:
                    ema_helper.load_state_dict(states[4])
            print('model parameters loaded')

        tb_path = os.path.join(self.args.log, 'tensorboard')
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        os.makedirs(tb_path)
        tb_logger = SummaryWriter(log_dir=tb_path)

        def debug_sample(model, noisy_image):
            model.eval()
            with torch.no_grad():
                data = torch.cat([noisy_image, noisy_image], dim=2)
                for i in range(obs[1], obs[1] * 2, 1):
                    for j in range(obs[2]):
                        data_v = data
                        out_sample = sample_op(data_v)
                        data[:, :, i, j] = out_sample.data[:, :, i, j]
                return data

        print('starting training', flush=True)
        writes = 0
        for epoch in range(self.config.max_epochs):
            train_loss = 0.
            model.train()
            for batch_idx, (input, _) in enumerate(train_loader):
                input = input.cuda(non_blocking=True)

                # input: [-1, 1]
                ## add noise to the entire image
                noisy_input = input + torch.randn_like(
                    input) * self.config.noise
                clean_input = input + torch.randn_like(
                    input) * self.config.clean_noise  # add very small noise
                input = torch.cat([noisy_input, clean_input], dim=2)
                output = model(input)[:, :, input.shape[-1]:, :]
                loss = loss_op(clean_input, output)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(model)

                train_loss += loss.item()
                if (batch_idx + 1) % self.config.print_every == 0:
                    deno = self.config.print_every * self.config.batch_size * np.prod(
                        obs) * np.log(2.)
                    train_loss = train_loss / deno
                    print('epoch: {}, batch: {}, loss : {:.4f}'.format(
                        epoch, batch_idx, train_loss),
                          flush=True)
                    tb_logger.add_scalar('loss',
                                         train_loss,
                                         global_step=writes)
                    train_loss = 0.
                    writes += 1
                # decrease learning rate
                scheduler.step()

            if self.config.model.ema:
                test_model = ema_helper.ema_copy(model)
            else:
                test_model = model

            test_model.eval()
            test_loss = 0.
            with torch.no_grad():
                for batch_idx, (input_var, _) in enumerate(test_loader):
                    input_var = input_var.cuda(non_blocking=True)

                    noisy_input_var = input_var + torch.randn_like(
                        input_var) * self.config.noise
                    clean_input_var = input_var + torch.randn_like(
                        input_var) * self.config.clean_noise  #* 0.02

                    input_var = torch.cat([noisy_input_var, clean_input_var],
                                          dim=2)
                    output = test_model(input_var)[:, :,
                                                   input_var.shape[-1]:, :]
                    loss = loss_op(clean_input_var, output)
                    test_loss += loss.item()
                    del loss, output

                deno = batch_idx * self.config.batch_size * np.prod(
                    obs) * np.log(2.)
                test_loss = test_loss / deno
                print('epoch: %s, test loss : %s' % (epoch, test_loss),
                      flush=True)
                tb_logger.add_scalar('test_loss',
                                     test_loss,
                                     global_step=writes)

            if (epoch + 1) % self.config.save_interval == 0:
                state_dict = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    scheduler.state_dict(),
                    epoch,
                ]
                if self.config.model.ema:
                    state_dict.append(ema_helper.state_dict())

                if (epoch + 1) % (self.config.save_interval * 2) == 0:
                    torch.save(
                        state_dict,
                        os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth'))
                torch.save(state_dict, os.path.join(ckpt_path,
                                                    'checkpoint.pth'))

            if epoch % 10 == 0:
                print('sampling...', flush=True)
                sample_t = debug_sample(test_model, noisy_input_var[:25])
                sample_t = torch.cat([clean_input_var[:25], sample_t],
                                     dim=2)  #add original sample
                if self.config.with_logit is True:
                    sample_t = sigmoid_transform(sample_t)
                else:
                    sample_t = rescaling_inv(sample_t)

                if not os.path.exists(os.path.join(self.args.log, 'images')):
                    os.makedirs(os.path.join(self.args.log, 'images'))
                utils.save_image(sample_t,
                                 os.path.join(self.args.log, 'images',
                                              f'sample_epoch_{epoch}.png'),
                                 nrow=5,
                                 padding=0)
            if self.config.model.ema:
                del test_model
Example #7
0
    def fast_fid(self):
        ### Test the fids of ensembled checkpoints.
        ### Shouldn't be used for models with ema
        if self.config.fast_fid.ensemble:
            if self.config.model.ema:
                raise RuntimeError("Cannot apply ensembling to models with EMA.")
            self.fast_ensemble_fid()
            return

        from evaluation.fid_score import get_fid, get_fid_stats_path
        import pickle
        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        fids = {}
        for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000),
                              desc="processing ckpt"):
            states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{ckpt}.pth'),
                                map_location=self.config.device)

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size
            output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt))
            os.makedirs(output_path, exist_ok=True)
            for i in range(num_iters):
                init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels,
                                          self.config.data.image_size, self.config.data.image_size,
                                          device=self.config.device)
                init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(init_samples, score, sigmas,
                                                       self.config.fast_fid.n_steps_each,
                                                       self.config.fast_fid.step_lr,
                                                       verbose=self.config.fast_fid.verbose,
                                                       denoise=self.config.sampling.denoise)

                final_samples = all_samples[-1]
                for id, sample in enumerate(final_samples):
                    sample = sample.view(self.config.data.channels,
                                         self.config.data.image_size,
                                         self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(id)))

            stat_path = get_fid_stats_path(self.args, self.config, download=True)
            fid = get_fid(stat_path, output_path)
            fids[ckpt] = fid
            print("ckpt: {}, fid: {}".format(ckpt, fid))

        with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle:
            pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
Example #8
0
    def train(self):
        dataset, test_dataset = get_dataset(self.args, self.config)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=self.config.data.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_logger = self.config.tb_logger

        score = get_model(self.config)

        score = torch.nn.DataParallel(score)
        optimizer = get_optimizer(self.config, score.parameters())

        start_epoch = 0
        step = 0

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            ### Make sure we can resume with different eps
            states[1]['param_groups'][0]['eps'] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        sigmas = get_sigmas(self.config)

        if self.config.training.log_all_sigmas:
            ### Commented out training time logging to save time.
            test_loss_per_sigma = [None for _ in range(len(sigmas))]

            def hook(loss, labels):
                # for i in range(len(sigmas)):
                #     if torch.any(labels == i):
                #         test_loss_per_sigma[i] = torch.mean(loss[labels == i])
                pass

            def tb_hook():
                # for i in range(len(sigmas)):
                #     if test_loss_per_sigma[i] is not None:
                #         tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                #                              global_step=step)
                pass

            def test_hook(loss, labels):
                for i in range(len(sigmas)):
                    if torch.any(labels == i):
                        test_loss_per_sigma[i] = torch.mean(loss[labels == i])

            def test_tb_hook():
                for i in range(len(sigmas)):
                    if test_loss_per_sigma[i] is not None:
                        tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                                             global_step=step)

        else:
            hook = test_hook = None

            def tb_hook():
                pass

            def test_tb_hook():
                pass

        for epoch in range(start_epoch, self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                score.train()
                step += 1

                X = X.to(self.config.device)
                X = data_transform(self.config, X)

                loss = anneal_dsm_score_estimation(score, X, sigmas, None,
                                                   self.config.training.anneal_power,
                                                   hook)
                tb_logger.add_scalar('loss', loss, global_step=step)
                tb_hook()

                logging.info("step: {}, loss: {}".format(step, loss.item()))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(score)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    if self.config.model.ema:
                        test_score = ema_helper.ema_copy(score)
                    else:
                        test_score = score

                    test_score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    test_X = data_transform(self.config, test_X)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(test_score, test_X, sigmas, None,
                                                                    self.config.training.anneal_power,
                                                                    hook=test_hook)
                        tb_logger.add_scalar('test_loss', test_dsm_loss, global_step=step)
                        test_tb_hook()
                        logging.info("step: {}, test_loss: {}".format(step, test_dsm_loss.item()))

                        del test_score

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint.pth'))

                    if self.config.training.snapshot_sampling:
                        if self.config.model.ema:
                            test_score = ema_helper.ema_copy(score)
                        else:
                            test_score = score

                        test_score.eval()

                        ## Different part from NeurIPS 2019.
                        ## Random state will be affected because of sampling during training time.
                        init_samples = torch.rand(36, self.config.data.channels,
                                                  self.config.data.image_size, self.config.data.image_size,
                                                  device=self.config.device)
                        init_samples = data_transform(self.config, init_samples)

                        all_samples = anneal_Langevin_dynamics(init_samples, test_score, sigmas.cpu().numpy(),
                                                               self.config.sampling.n_steps_each,
                                                               self.config.sampling.step_lr,
                                                               final_only=True, verbose=True,
                                                               denoise=self.config.sampling.denoise)

                        sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                      self.config.data.image_size,
                                                      self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, 6)
                        save_image(image_grid,
                                   os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(step)))
                        torch.save(sample, os.path.join(self.args.log_sample_path, 'samples_{}.pth'.format(step)))

                        del test_score
                        del all_samples
Example #9
0
    def sample(self):
        if self.config.sampling.ckpt_id is None:
            states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'), map_location=self.config.device)
        else:
            states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth'),
                                map_location=self.config.device)

        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        score.load_state_dict(states[0], strict=True)

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)
            ema_helper.load_state_dict(states[-1])
            ema_helper.ema(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        dataset, _ = get_dataset(self.args, self.config)
        dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True,
                                num_workers=4)

        score.eval()

        if not self.config.sampling.fid:
            if self.config.sampling.inpainting:
                data_iter = iter(dataloader)
                refer_images, _ = next(data_iter)
                refer_images = refer_images.to(self.config.device)
                width = int(np.sqrt(self.config.sampling.batch_size))
                init_samples = torch.rand(width, width, self.config.data.channels,
                                          self.config.data.image_size,
                                          self.config.data.image_size,
                                          device=self.config.device)
                init_samples = data_transform(self.config, init_samples)
                all_samples = anneal_Langevin_dynamics_inpainting(init_samples, refer_images[:width, ...], score,
                                                                  sigmas,
                                                                  self.config.data.image_size,
                                                                  self.config.sampling.n_steps_each,
                                                                  self.config.sampling.step_lr)

                torch.save(refer_images[:width, ...], os.path.join(self.args.image_folder, 'refer_image.pth'))
                refer_images = refer_images[:width, None, ...].expand(-1, width, -1, -1, -1).reshape(-1,
                                                                                                     *refer_images.shape[
                                                                                                      1:])
                save_image(refer_images, os.path.join(self.args.image_folder, 'refer_image.png'), nrow=width)

                if not self.config.sampling.final_only:
                    for i, sample in enumerate(tqdm.tqdm(all_samples)):
                        sample = sample.view(self.config.sampling.batch_size, self.config.data.channels,
                                             self.config.data.image_size,
                                             self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                        save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i)))
                        torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(i)))
                else:
                    sample = all_samples[-1].view(self.config.sampling.batch_size, self.config.data.channels,
                                                  self.config.data.image_size,
                                                  self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                    save_image(image_grid, os.path.join(self.args.image_folder,
                                                        'image_grid_{}.png'.format(self.config.sampling.ckpt_id)))
                    torch.save(sample, os.path.join(self.args.image_folder,
                                                    'completion_{}.pth'.format(self.config.sampling.ckpt_id)))

            elif self.config.sampling.interpolation:
                if self.config.sampling.data_init:
                    data_iter = iter(dataloader)
                    samples, _ = next(data_iter)
                    samples = samples.to(self.config.device)
                    samples = data_transform(self.config, samples)
                    init_samples = samples + sigmas_th[0] * torch.randn_like(samples)

                else:
                    init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels,
                                              self.config.data.image_size, self.config.data.image_size,
                                              device=self.config.device)
                    init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics_interpolation(init_samples, score, sigmas,
                                                                     self.config.sampling.n_interpolations,
                                                                     self.config.sampling.n_steps_each,
                                                                     self.config.sampling.step_lr, verbose=True,
                                                                     final_only=self.config.sampling.final_only)

                if not self.config.sampling.final_only:
                    for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples),
                                               desc="saving image samples"):
                        sample = sample.view(sample.shape[0], self.config.data.channels,
                                             self.config.data.image_size,
                                             self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, nrow=self.config.sampling.n_interpolations)
                        save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i)))
                        torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i)))
                else:
                    sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                  self.config.data.image_size,
                                                  self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    image_grid = make_grid(sample, self.config.sampling.n_interpolations)
                    save_image(image_grid, os.path.join(self.args.image_folder,
                                                        'image_grid_{}.png'.format(self.config.sampling.ckpt_id)))
                    torch.save(sample, os.path.join(self.args.image_folder,
                                                    'samples_{}.pth'.format(self.config.sampling.ckpt_id)))

            else:
                if self.config.sampling.data_init:
                    data_iter = iter(dataloader)
                    samples, _ = next(data_iter)
                    samples = samples.to(self.config.device)
                    samples = data_transform(self.config, samples)
                    init_samples = samples + sigmas_th[0] * torch.randn_like(samples)

                else:
                    init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels,
                                              self.config.data.image_size, self.config.data.image_size,
                                              device=self.config.device)
                    init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(init_samples, score, sigmas,
                                                       self.config.sampling.n_steps_each,
                                                       self.config.sampling.step_lr, verbose=True,
                                                       final_only=self.config.sampling.final_only,
                                                       denoise=self.config.sampling.denoise)

                if not self.config.sampling.final_only:
                    for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples),
                                               desc="saving image samples"):
                        sample = sample.view(sample.shape[0], self.config.data.channels,
                                             self.config.data.image_size,
                                             self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                        save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i)))
                        torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i)))
                else:
                    sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                  self.config.data.image_size,
                                                  self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                    save_image(image_grid, os.path.join(self.args.image_folder,
                                                        'image_grid_{}.png'.format(self.config.sampling.ckpt_id)))
                    torch.save(sample, os.path.join(self.args.image_folder,
                                                    'samples_{}.pth'.format(self.config.sampling.ckpt_id)))

        else:
            total_n_samples = self.config.sampling.num_samples4fid
            n_rounds = total_n_samples // self.config.sampling.batch_size
            if self.config.sampling.data_init:
                dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True,
                                        num_workers=4)
                data_iter = iter(dataloader)

            img_id = 0
            for _ in tqdm.tqdm(range(n_rounds), desc='Generating image samples for FID/inception score evaluation'):
                if self.config.sampling.data_init:
                    try:
                        samples, _ = next(data_iter)
                    except StopIteration:
                        data_iter = iter(dataloader)
                        samples, _ = next(data_iter)
                    samples = samples.to(self.config.device)
                    samples = data_transform(self.config, samples)
                    samples = samples + sigmas_th[0] * torch.randn_like(samples)
                else:
                    samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels,
                                         self.config.data.image_size,
                                         self.config.data.image_size, device=self.config.device)
                    samples = data_transform(self.config, samples)

                all_samples = anneal_Langevin_dynamics(samples, score, sigmas,
                                                       self.config.sampling.n_steps_each,
                                                       self.config.sampling.step_lr, verbose=False,
                                                       denoise=self.config.sampling.denoise)

                samples = all_samples[-1]
                for img in samples:
                    img = inverse_data_transform(self.config, img)

                    save_image(img, os.path.join(self.args.image_folder, 'image_{}.png'.format(img_id)))
                    img_id += 1
Example #10
0
    def train(self):
        args, config = self.args, self.config
        vdl_logger = self.config.vdl_logger
        dataset, test_dataset = get_dataset(args, config)
        train_loader = data.DataLoader(
            dataset,
            batch_size=config.training.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
            use_shared_memory=False,
        )
        model = Model(config)

        model = model
        model = paddle.DataParallel(model)

        optimizer = get_optimizer(self.config, model.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None

        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = paddle.load(os.path.join(self.args.log_path, "ckpt.pdl"))
            model.set_state_dict({
                k.split("$model_")[-1]: v
                for k, v in states.items() if "$model_" in k
            })

            optimizer.set_state_dict({
                k.split("$optimizer_")[-1]: v
                for k, v in states.items() if "$optimizer_" in k
            })
            optimizer._epsilon = self.config.optim.eps
            start_epoch = states["$epoch"]
            step = states["$step"]
            if self.config.model.ema:
                ema_helper.set_state_dict({
                    k.split("$ema_")[-1]: v
                    for k, v in states.items() if "$ema_" in k
                })

        for epoch in range(start_epoch, self.config.training.n_epochs):
            data_start = time.time()
            data_time = 0
            for i, (x, y) in enumerate(train_loader):
                n = x.shape[0]
                data_time += time.time() - data_start
                model.train()
                step += 1

                x = data_transform(self.config, x)
                e = paddle.randn(x.shape)
                b = self.betas

                # antithetic sampling
                t = paddle.randint(low=0,
                                   high=self.num_timesteps,
                                   shape=(n // 2 + 1, ))
                t = paddle.concat([t, self.num_timesteps - t - 1], 0)[:n]
                loss = loss_registry[config.model.type](model, x, t, e, b)

                vdl_logger.add_scalar("loss", loss, step=step)

                logging.info(
                    f"step: {step}, loss: {loss.numpy()[0]}, data time: {data_time / (i+1)}"
                )

                optimizer.clear_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(model)

                if step % self.config.training.snapshot_freq == 0 or step == 1:
                    states = dict(
                        **{
                            "$model_" + k: v
                            for k, v in model.state_dict().items()
                        },
                        **{
                            "$optimizer_" + k: v
                            for k, v in optimizer.state_dict().items()
                        },
                        **{"$epoch": epoch},
                        **{"$step": step},
                    )
                    if self.config.model.ema:
                        states.update({
                            "$ema_" + k: v
                            for k, v in ema_helper.state_dict().items()
                        })

                    paddle.save(
                        states,
                        os.path.join(self.args.log_path,
                                     "ckpt_{}.pdl".format(step)),
                    )
                    paddle.save(states,
                                os.path.join(self.args.log_path, "ckpt.pdl"))

                data_start = time.time()
    def calculate_fid(self):
        import fid, pickle
        import tensorflow as tf

        stats_path = "fid_stats_cifar10_train.npz"  # training set statistics
        inception_path = fid.check_or_download_inception(
            "./tmp/"
        )  # download inception network

        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        fids = {}
        for ckpt in tqdm.tqdm(
            range(
                self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000
            ),
            desc="processing ckpt",
        ):
            states = torch.load(
                os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"),
                map_location=self.config.device,
            )

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            num_iters = (
                self.config.fast_fid.num_samples // self.config.fast_fid.batch_size
            )
            output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt))
            os.makedirs(output_path, exist_ok=True)
            for i in range(num_iters):
                init_samples = torch.rand(
                    self.config.fast_fid.batch_size,
                    self.config.data.channels,
                    self.config.data.image_size,
                    self.config.data.image_size,
                    device=self.config.device,
                )
                init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(
                    init_samples,
                    score,
                    sigmas,
                    self.config.fast_fid.n_steps_each,
                    self.config.fast_fid.step_lr,
                    verbose=self.config.fast_fid.verbose,
                )

                final_samples = all_samples[-1]
                for id, sample in enumerate(final_samples):
                    sample = sample.view(
                        self.config.data.channels,
                        self.config.data.image_size,
                        self.config.data.image_size,
                    )

                    sample = inverse_data_transform(self.config, sample)

                    save_image(
                        sample, os.path.join(output_path, "sample_{}.png".format(id))
                    )

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f["mu"][:], f["sigma"][:]
            f.close()

            fid.create_inception_graph(
                inception_path
            )  # load the graph into the current TF graph
            final_samples = (
                (final_samples - final_samples.min())
                / (final_samples.max() - final_samples.min()).data.cpu().numpy()
                * 255
            )
            final_samples = np.transpose(final_samples, [0, 2, 3, 1])
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    final_samples, sess, batch_size=100
                )

            fid_value = fid.calculate_frechet_distance(
                mu_gen, sigma_gen, mu_real, sigma_real
            )
            print("FID: %s" % fid_value)

        with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle:
            pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)