예제 #1
0
    def sample(self):
        model = Model(self.config)

        if not self.args.use_pretrained:
            if getattr(self.config.sampling, "ckpt_id", None) is None:
                states = paddle.load(
                    os.path.join(self.args.log_path, "ckpt.pdl"))
            else:
                states = paddle.load(
                    os.path.join(self.args.log_path,
                                 f"ckpt_{self.config.sampling.ckpt_id}.pdl"))
            model = model
            model = paddle.DataParallel(model)
            model.set_state_dict({
                k.split("$model_")[-1]: v
                for k, v in states.items() if "$model_" in k
            })

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(model)
                ema_helper.set_state_dict({
                    k.split("$ema_")[-1]: v
                    for k, v in states.items() if "$ema_" in k
                })
                ema_helper.ema(model)
            else:
                ema_helper = None
        else:
            # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
            if self.config.data.dataset == "CIFAR10":
                name = "cifar10"
            elif self.config.data.dataset == "LSUN":
                name = f"lsun_{self.config.data.category}"
            else:
                raise ValueError
            ckpt = get_ckpt_path(f"ema_{name}")
            print("Loading checkpoint {}".format(ckpt))
            model.set_state_dict(paddle.load(ckpt))
            model = paddle.DataParallel(model)

        model.eval()

        if self.args.fid:
            self.sample_fid(model)
        elif self.args.interpolation:
            self.sample_interpolation(model)
        elif self.args.sequence:
            self.sample_sequence(model)
        else:
            raise NotImplementedError("Sample procedeure not defined")
예제 #2
0
    def train(self):
        args, config = self.args, self.config
        vdl_logger = self.config.vdl_logger
        dataset, test_dataset = get_dataset(args, config)
        train_loader = data.DataLoader(
            dataset,
            batch_size=config.training.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
            use_shared_memory=False,
        )
        model = Model(config)

        model = model
        model = paddle.DataParallel(model)

        optimizer = get_optimizer(self.config, model.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None

        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = paddle.load(os.path.join(self.args.log_path, "ckpt.pdl"))
            model.set_state_dict({
                k.split("$model_")[-1]: v
                for k, v in states.items() if "$model_" in k
            })

            optimizer.set_state_dict({
                k.split("$optimizer_")[-1]: v
                for k, v in states.items() if "$optimizer_" in k
            })
            optimizer._epsilon = self.config.optim.eps
            start_epoch = states["$epoch"]
            step = states["$step"]
            if self.config.model.ema:
                ema_helper.set_state_dict({
                    k.split("$ema_")[-1]: v
                    for k, v in states.items() if "$ema_" in k
                })

        for epoch in range(start_epoch, self.config.training.n_epochs):
            data_start = time.time()
            data_time = 0
            for i, (x, y) in enumerate(train_loader):
                n = x.shape[0]
                data_time += time.time() - data_start
                model.train()
                step += 1

                x = data_transform(self.config, x)
                e = paddle.randn(x.shape)
                b = self.betas

                # antithetic sampling
                t = paddle.randint(low=0,
                                   high=self.num_timesteps,
                                   shape=(n // 2 + 1, ))
                t = paddle.concat([t, self.num_timesteps - t - 1], 0)[:n]
                loss = loss_registry[config.model.type](model, x, t, e, b)

                vdl_logger.add_scalar("loss", loss, step=step)

                logging.info(
                    f"step: {step}, loss: {loss.numpy()[0]}, data time: {data_time / (i+1)}"
                )

                optimizer.clear_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(model)

                if step % self.config.training.snapshot_freq == 0 or step == 1:
                    states = dict(
                        **{
                            "$model_" + k: v
                            for k, v in model.state_dict().items()
                        },
                        **{
                            "$optimizer_" + k: v
                            for k, v in optimizer.state_dict().items()
                        },
                        **{"$epoch": epoch},
                        **{"$step": step},
                    )
                    if self.config.model.ema:
                        states.update({
                            "$ema_" + k: v
                            for k, v in ema_helper.state_dict().items()
                        })

                    paddle.save(
                        states,
                        os.path.join(self.args.log_path,
                                     "ckpt_{}.pdl".format(step)),
                    )
                    paddle.save(states,
                                os.path.join(self.args.log_path, "ckpt.pdl"))

                data_start = time.time()