def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch))
def main(): args = create_argparser().parse_args() pprint({k:v for k,v in args.__dict__.items()}) dist_util.setup_dist() logger.configure() logger.log("creating model...") pprint(args_to_dict(args, sr_model_and_diffusion_defaults().keys())) model, diffusion = sr_create_model_and_diffusion( **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) ) # skips # load_tolerant(model, args.model_path) model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu") ) model.to(dist_util.dev()) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("loading data...") data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) logger.log("creating samples...") all_images = [] while len(all_images) * args.batch_size < args.num_samples: model_kwargs = next(data) model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} sample = diffusion.p_sample_loop( model, (args.batch_size, 3, args.large_size, args.large_size), clip_denoised=args.clip_denoised, model_kwargs=model_kwargs, ) sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] dist.all_gather(all_samples, sample) # gather not supported with NCCL for sample in all_samples: all_images.append(sample.cpu().numpy()) logger.log(f"created {len(all_images) * args.batch_size} samples") arr = np.concatenate(all_images, axis=0) arr = arr[: args.num_samples] if dist.get_rank() == 0: shape_str = "x".join([str(x) for x in arr.shape]) out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") logger.log(f"saving to {out_path}") np.savez(out_path, arr) dist.barrier() logger.log("sampling complete")
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys())) model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) model.eval() logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=args.class_cond, deterministic=True, ) logger.log("evaluating...") run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised)
def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): all_bpd = [] all_metrics = {"vb": [], "mse": [], "xstart_mse": []} num_complete = 0 while num_complete < num_samples: batch, model_kwargs = next(data) batch = batch.to(dist_util.dev()) model_kwargs = { k: v.to(dist_util.dev()) for k, v in model_kwargs.items() } minibatch_metrics = diffusion.calc_bpd_loop( model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs) for key, term_list in all_metrics.items(): terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() dist.all_reduce(terms) term_list.append(terms.detach().cpu().numpy()) total_bpd = minibatch_metrics["total_bpd"] total_bpd = total_bpd.mean() / dist.get_world_size() dist.all_reduce(total_bpd) all_bpd.append(total_bpd.item()) num_complete += dist.get_world_size() * batch.shape[0] logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") if dist.get_rank() == 0: for name, terms in all_metrics.items(): out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") logger.log(f"saving {name} terms to {out_path}") np.savez(out_path, np.mean(np.stack(terms), axis=0)) dist.barrier() logger.log("evaluation complete")
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model...") model, diffusion = sr_create_model_and_diffusion( **args_to_dict(args, sr_model_and_diffusion_defaults().keys())) model.to(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) logger.log("creating data loader...") data = load_superres_data( args.data_dir, args.batch_size, large_size=args.large_size, small_size=args.small_size, class_cond=args.class_cond, ) logger.log("training...") TrainLoop( model=model, diffusion=diffusion, data=data, batch_size=args.batch_size, microbatch=args.microbatch, lr=args.lr, ema_rate=args.ema_rate, log_interval=args.log_interval, save_interval=args.save_interval, resume_checkpoint=args.resume_checkpoint, use_fp16=args.use_fp16, fp16_scale_growth=args.fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=args.weight_decay, lr_anneal_steps=args.lr_anneal_steps, ).run_loop()
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys()) ) model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu") ) model.to(dist_util.dev()) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("sampling...") all_images = [] all_labels = [] while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} if args.class_cond: classes = th.randint( low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() ) model_kwargs["y"] = classes sample_fn = ( diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop ) sample = sample_fn( model, (args.batch_size, 3, args.image_size, args.image_size), clip_denoised=args.clip_denoised, model_kwargs=model_kwargs, ) sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) if args.class_cond: gathered_labels = [ th.zeros_like(classes) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") arr = np.concatenate(all_images, axis=0) arr = arr[: args.num_samples] if args.class_cond: label_arr = np.concatenate(all_labels, axis=0) label_arr = label_arr[: args.num_samples] if dist.get_rank() == 0: shape_str = "x".join([str(x) for x in arr.shape]) out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") logger.log(f"saving to {out_path}") if args.class_cond: np.savez(out_path, arr, label_arr) else: np.savez(out_path, arr) dist.barrier() logger.log("sampling complete")
def main(**kwargs): args = create_argparser().parse_args() kw = {k: v for k, v in kwargs.items() if k in args} args.__dict__.update(**kw) dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") pprint(args_to_dict(args, model_and_diffusion_defaults().keys())) model, diffusion = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys())) print(f"loading state dict {args.model_path} -> model") model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("loading classifier...") classifier = create_classifier( **args_to_dict(args, classifier_defaults().keys())) print(f"loading state dict {args.classifier_path}") classifier.load_state_dict( dist_util.load_state_dict(args.classifier_path, map_location="cpu")) classifier.to(dist_util.dev()) if args.classifier_use_fp16: classifier.convert_to_fp16() classifier.eval() def cond_fn(x, t, y=None): assert y is not None with th.enable_grad(): # print(f" cond_fn(x,t,y) t {t}") # t, batch_size time sampling 999 -> 0 x_in = x.detach().requires_grad_(True) logits = classifier(x_in, t) # print(f" .. x:{x_in.shape}, t:{t.shape}, logits:{logits.shape}, y: {y.shape}") # x:(batch_size, channels, image_size, image_size), t:(batch_size), (batch_size, numclasses), y: (batch_size) log_probs = F.log_softmax(logits, dim=-1) selected = log_probs[range(len(logits)), y.view(-1)] # print(f" .. selected, softmax(logits)[range(), y] {selected}") # (batch_size) floats cond = th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale # print(f" .. cond: {tuple(cond.shape)}, args.classifier_scale {args.classifier_scale}") # cond: (batch_size, 3, image_size, image_size), args.classifier_scale 0.5 return cond def logt(x): if isinstance(x, th.Tensor): out = f" {tuple(x.shape)}" if x.ndim == 1: out += f"{x}" elif isinstance(x, (int, float)): out = f" {x}" return out def model_fn(x, t, y=None): assert y is not None print(f"timestep {t.tolist()} conditional y {y.tolist()}") #print(f"model_fn, x {logt(x)}, t {logt(t)} y {logt(y)}") return model(x, t, y if args.class_cond else None) logger.log("sampling...") all_images = [] all_labels = [] while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size, ), device=dist_util.dev()) model_kwargs["y"] = classes if args.use_ddim: print("sample_fn = diffusion.ddim_sample_loop: args.use_ddim") else: print("sample_fn = diffusion.p_sample_loop: not args.use_ddim") sample_fn = (diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop) # print(f"sample_fn args.batch_size {args.batch_size}, args.image_size {args.image_size} args.clip_denoised {args.clip_denoised} model_kwargs {model_kwargs}") # model_kwargs['y']: class conditioner e.g [ 53, 37, 609, 498, 679, 38, 242, 705, 253, 822, 721, 762, 64, 42, 337, 483] sample = sample_fn( model_fn, (args.batch_size, 3, args.image_size, args.image_size), clip_denoised=args.clip_denoised, model_kwargs=model_kwargs, cond_fn=cond_fn, device=dist_util.dev(), ) sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() gathered_samples = [ th.zeros_like(sample) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend( [sample.cpu().numpy() for sample in gathered_samples]) gathered_labels = [ th.zeros_like(classes) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") arr = np.concatenate(all_images, axis=0) arr = arr[:args.num_samples] label_arr = np.concatenate(all_labels, axis=0) label_arr = label_arr[:args.num_samples] if dist.get_rank() == 0: shape_str = "x".join([str(x) for x in arr.shape]) out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") logger.log(f"saving to {out_path}") np.savez(out_path, arr, label_arr) dist.barrier() logger.log("sampling complete")
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion( **args_to_dict(args, model_and_diffusion_defaults().keys())) model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) if args.use_fp16: model.convert_to_fp16() model.eval() logger.log("loading classifier...") classifier = create_classifier( **args_to_dict(args, classifier_defaults().keys())) classifier.load_state_dict( dist_util.load_state_dict(args.classifier_path, map_location="cpu")) classifier.to(dist_util.dev()) if args.classifier_use_fp16: classifier.convert_to_fp16() classifier.eval() def cond_fn(x, t, y=None): assert y is not None with th.enable_grad(): x_in = x.detach().requires_grad_(True) logits = classifier(x_in, t) log_probs = F.log_softmax(logits, dim=-1) selected = log_probs[range(len(logits)), y.view(-1)] return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale def model_fn(x, t, y=None): assert y is not None return model(x, t, y if args.class_cond else None) logger.log("sampling...") all_images = [] all_labels = [] while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size, ), device=dist_util.dev()) model_kwargs["y"] = classes sample_fn = (diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop) sample = sample_fn( model_fn, (args.batch_size, 3, args.image_size, args.image_size), clip_denoised=args.clip_denoised, model_kwargs=model_kwargs, cond_fn=cond_fn, device=dist_util.dev(), ) sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() gathered_samples = [ th.zeros_like(sample) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend( [sample.cpu().numpy() for sample in gathered_samples]) gathered_labels = [ th.zeros_like(classes) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") arr = np.concatenate(all_images, axis=0) arr = arr[:args.num_samples] label_arr = np.concatenate(all_labels, axis=0) label_arr = label_arr[:args.num_samples] if dist.get_rank() == 0: shape_str = "x".join([str(x) for x in arr.shape]) out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") logger.log(f"saving to {out_path}") np.savez(out_path, arr, label_arr) dist.barrier() logger.log("sampling complete")
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_classifier_and_diffusion( **args_to_dict(args, classifier_and_diffusion_defaults().keys())) model.to(dist_util.dev()) if args.noised: schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, diffusion) resume_step = 0 if args.resume_checkpoint: resume_step = parse_resume_step_from_filename(args.resume_checkpoint) if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" ) model.load_state_dict( dist_util.load_state_dict(args.resume_checkpoint, map_location=dist_util.dev())) # Needed for creating correct EMAs and fp16 parameters. dist_util.sync_params(model.parameters()) mp_trainer = MixedPrecisionTrainer(model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0) model = DDP( model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, random_crop=True, ) if args.val_data_dir: val_data = load_data( data_dir=args.val_data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, ) else: val_data = None logger.log(f"creating optimizer...") opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) if args.resume_checkpoint: opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt") logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") opt.load_state_dict( dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())) logger.log("training classifier model...") def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch)) for step in range(args.iterations - resume_step): logger.logkv("step", step + resume_step) logger.logkv( "samples", (step + resume_step + 1) * args.batch_size * dist.get_world_size(), ) if args.anneal_lr: set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) forward_backward_log(data) mp_trainer.optimize(opt) if val_data is not None and not step % args.eval_interval: with th.no_grad(): with model.no_sync(): model.eval() forward_backward_log(val_data, prefix="val") model.train() if not step % args.log_interval: logger.dumpkvs() if (step and dist.get_rank() == 0 and not (step + resume_step) % args.save_interval): logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) if dist.get_rank() == 0: logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) dist.barrier()