def sample(self): model = Model(self.config) if not self.args.use_pretrained: if getattr(self.config.sampling, "ckpt_id", None) is None: states = paddle.load( os.path.join(self.args.log_path, "ckpt.pdl")) else: states = paddle.load( os.path.join(self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pdl")) model = model model = paddle.DataParallel(model) model.set_state_dict({ k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k }) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(model) ema_helper.set_state_dict({ k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k }) ema_helper.ema(model) else: ema_helper = None else: # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion if self.config.data.dataset == "CIFAR10": name = "cifar10" elif self.config.data.dataset == "LSUN": name = f"lsun_{self.config.data.category}" else: raise ValueError ckpt = get_ckpt_path(f"ema_{name}") print("Loading checkpoint {}".format(ckpt)) model.set_state_dict(paddle.load(ckpt)) model = paddle.DataParallel(model) model.eval() if self.args.fid: self.sample_fid(model) elif self.args.interpolation: self.sample_interpolation(model) elif self.args.sequence: self.sample_sequence(model) else: raise NotImplementedError("Sample procedeure not defined")
def train(self): args, config = self.args, self.config vdl_logger = self.config.vdl_logger dataset, test_dataset = get_dataset(args, config) train_loader = data.DataLoader( dataset, batch_size=config.training.batch_size, shuffle=True, num_workers=config.data.num_workers, use_shared_memory=False, ) model = Model(config) model = model model = paddle.DataParallel(model) optimizer = get_optimizer(self.config, model.parameters()) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(model) else: ema_helper = None start_epoch, step = 0, 0 if self.args.resume_training: states = paddle.load(os.path.join(self.args.log_path, "ckpt.pdl")) model.set_state_dict({ k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k }) optimizer.set_state_dict({ k.split("$optimizer_")[-1]: v for k, v in states.items() if "$optimizer_" in k }) optimizer._epsilon = self.config.optim.eps start_epoch = states["$epoch"] step = states["$step"] if self.config.model.ema: ema_helper.set_state_dict({ k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k }) for epoch in range(start_epoch, self.config.training.n_epochs): data_start = time.time() data_time = 0 for i, (x, y) in enumerate(train_loader): n = x.shape[0] data_time += time.time() - data_start model.train() step += 1 x = data_transform(self.config, x) e = paddle.randn(x.shape) b = self.betas # antithetic sampling t = paddle.randint(low=0, high=self.num_timesteps, shape=(n // 2 + 1, )) t = paddle.concat([t, self.num_timesteps - t - 1], 0)[:n] loss = loss_registry[config.model.type](model, x, t, e, b) vdl_logger.add_scalar("loss", loss, step=step) logging.info( f"step: {step}, loss: {loss.numpy()[0]}, data time: {data_time / (i+1)}" ) optimizer.clear_grad() loss.backward() optimizer.step() if self.config.model.ema: ema_helper.update(model) if step % self.config.training.snapshot_freq == 0 or step == 1: states = dict( **{ "$model_" + k: v for k, v in model.state_dict().items() }, **{ "$optimizer_" + k: v for k, v in optimizer.state_dict().items() }, **{"$epoch": epoch}, **{"$step": step}, ) if self.config.model.ema: states.update({ "$ema_" + k: v for k, v in ema_helper.state_dict().items() }) paddle.save( states, os.path.join(self.args.log_path, "ckpt_{}.pdl".format(step)), ) paddle.save(states, os.path.join(self.args.log_path, "ckpt.pdl")) data_start = time.time()