def eval_epoch(model, data_loader, epoch, config): metrics = { "accuracy": Mean(), "entropy": Mean(), } with torch.no_grad(): model.eval() for images_x, targets_x in tqdm(data_loader, desc="epoch {}/{}, eval".format( epoch, config.epochs)): images_x, targets_x = images_x.to(DEVICE), targets_x.to(DEVICE) probs_x = model(images_x) metrics["entropy"].update(entropy(probs_x).data.cpu().numpy()) metrics["accuracy"].update( (probs_x.argmax(-1) == targets_x).float().data.cpu().numpy()) writer = SummaryWriter(os.path.join(config.experiment_path, "eval")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images_x", torchvision.utils.make_grid(images_x, nrow=compute_nrow(images_x), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def eval_epoch(model, data_loader, epoch, config): metrics = { "loss": Mean(), "accuracy": Mean(), } with torch.no_grad(): model.eval() for images, targets in tqdm( data_loader, desc="epoch {}/{}, eval".format(epoch, config.epochs) ): images, targets = images.to(DEVICE), targets.to(DEVICE) logits = model(images) loss = F.cross_entropy(input=logits, target=targets, reduction="none") metrics["loss"].update(loss.data.cpu().numpy()) metrics["accuracy"].update((logits.argmax(-1) == targets).float().data.cpu().numpy()) writer = SummaryWriter(os.path.join(config.experiment_path, "eval")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid(images, nrow=compute_nrow(images), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "images": Last(), "loss": Mean(), "lr": Last(), } # loop over batches ################################################################################################ model.train() for images, meta, targets in tqdm( data_loader, desc="fold {}, epoch {}/{}, train".format(config.fold, epoch, config.train.epochs), ): images, meta, targets = ( images.to(DEVICE), {k: meta[k].to(DEVICE) for k in meta}, targets.to(DEVICE), ) # images, targets = mix_up(images, targets, alpha=1.) logits = model(images, meta) loss = compute_loss(input=logits, target=targets, config=config) metrics["images"].update(images.data.cpu()) metrics["loss"].update(loss.data.cpu()) metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() # compute metrics ################################################################################################## with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} writer = SummaryWriter(os.path.join(config.experiment_path, "train")) writer.add_image( "images", torchvision.utils.make_grid( metrics["images"], nrow=compute_nrow(metrics["images"]), normalize=True ), global_step=epoch, ) writer.add_scalar("loss", metrics["loss"], global_step=epoch) writer.add_scalar("lr", metrics["lr"], global_step=epoch) writer.flush() writer.close()
def eval_epoch(model, data_loader, epoch, config): metrics = { "teacher/accuracy": Mean(), "teacher/entropy": Mean(), "student/accuracy": Mean(), "student/entropy": Mean(), } with torch.no_grad(): model.eval() for x_image, x_target in tqdm(data_loader, desc="epoch {}/{}, eval".format( epoch, config.epochs)): x_image, x_target = x_image.to(DEVICE), x_target.to(DEVICE) probs_teacher = model.teacher(x_image) probs_student = model.student(x_image) metrics["teacher/entropy"].update( entropy(probs_teacher).data.cpu().numpy()) metrics["student/entropy"].update( entropy(probs_student).data.cpu().numpy()) metrics["teacher/accuracy"].update( (probs_teacher.argmax(-1) == x_target ).float().data.cpu().numpy()) metrics["student/accuracy"].update( (probs_student.argmax(-1) == x_target ).float().data.cpu().numpy()) writer = SummaryWriter(os.path.join(config.experiment_path, "eval")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "x_image", torchvision.utils.make_grid(denormalize(x_image), nrow=compute_nrow(x_image), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss": Mean(), "lr": Last(), } model.train() for images, targets in tqdm(data_loader, desc="epoch {}/{}, train".format( epoch, config.epochs)): images, targets = images.to(DEVICE), targets.to(DEVICE) logits = model(images) loss = F.cross_entropy(input=logits, target=targets, reduction="none") metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_lr())) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid(denormalize(images), nrow=compute_nrow(images), normalize=True), global_step=epoch, ) writer.add_histogram("params", flatten_weights(model.parameters()), global_step=epoch) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss/x": Mean(), "loss/u": Mean(), "weight/u": Last(), "lr": Last(), } model.train() for (images_x, targets_x), (images_u_0, images_u_1) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)): # prepare data ################################################################################################# images_x, targets_x, images_u_0, images_u_1 = ( images_x.to(DEVICE), targets_x.to(DEVICE), images_u_0.to(DEVICE), images_u_1.to(DEVICE), ) targets_x = one_hot(targets_x, NUM_CLASSES) # mix-match #################################################################################################### with torch.no_grad(): (images_x, targets_x), (images_u, targets_u) = mix_match(x=(images_x, targets_x), u=(images_u_0, images_u_1), model=model, config=config) probs_x, probs_u = model(torch.cat([images_x, images_u])).split( [images_x.size(0), images_u.size(0)]) # x ############################################################################################################ loss_x = compute_loss_x(input=probs_x, target=targets_x) metrics["loss/x"].update(loss_x.data.cpu().numpy()) # u ############################################################################################################ loss_u = compute_loss_u(input=probs_u, target=targets_u) metrics["loss/u"].update(loss_u.data.cpu().numpy()) # opt step ##################################################################################################### metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) weight_u = config.train.mix_match.weight_u * min( (epoch - 1) / config.epochs_warmup, 1.0) metrics["weight/u"].update(weight_u) optimizer.zero_grad() (loss_x.mean() + weight_u * loss_u.mean()).backward() optimizer.step() scheduler.step() if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images_x", torchvision.utils.make_grid(images_x, nrow=compute_nrow(images_x), normalize=True), global_step=epoch, ) writer.add_image( "images_u", torchvision.utils.make_grid(images_u, nrow=compute_nrow(images_u), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def eval_epoch(model, data_loader, epoch, config): metrics = { "loss": Mean(), } with torch.no_grad(): model.eval() for (text, text_mask), (audio, audio_mask) in tqdm( data_loader, desc="epoch {}/{}, eval".format(epoch, config.train.epochs)): text, audio, text_mask, audio_mask = [ x.to(DEVICE) for x in [text, audio, text_mask, audio_mask] ] output, pre_output, target, target_mask, weight = model( text, text_mask, audio, audio_mask) loss = masked_mse(output, target, target_mask) + masked_mse( pre_output, target, target_mask) metrics["loss"].update(loss.data.cpu()) writer = SummaryWriter(os.path.join(config.experiment_path, "eval")) with torch.no_grad(): gl_true = griffin_lim(target, model.spectra) gl_pred = griffin_lim(output, model.spectra) output, pre_output, target, weight = [ x.unsqueeze(1) for x in [output, pre_output, target, weight] ] nrow = compute_nrow(target) for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "target", torchvision.utils.make_grid(target, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "output", torchvision.utils.make_grid(output, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "pre_output", torchvision.utils.make_grid(pre_output, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "weight", torchvision.utils.make_grid(weight, nrow=nrow, normalize=True), global_step=epoch, ) for i in tqdm(range(min(text.size(0), 4)), desc="writing audio"): writer.add_audio("audio/{}".format(i), audio[i], sample_rate=config.sample_rate, global_step=epoch) writer.add_audio( "griffin-lim-true/{}".format(i), gl_true[i], sample_rate=config.sample_rate, global_step=epoch, ) writer.add_audio( "griffin-lim-pred/{}".format(i), gl_pred[i], sample_rate=config.sample_rate, global_step=epoch, ) writer.flush() writer.close()
def eval_epoch(model, data_loader, epoch, config): writer = SummaryWriter(os.path.join(config.experiment_path, "eval")) metrics = { "loss": Mean(), "iou": IoU(), } with torch.no_grad(): model.eval() for images, targets in tqdm( data_loader, desc="epoch {}/{}, eval".format(epoch, config.epochs) ): images, targets = images.to(DEVICE), targets.to(DEVICE) targets = image_one_hot(targets, NUM_CLASSES) logits = model(images) loss = compute_loss(input=logits, target=targets) metrics["loss"].update(loss.data.cpu().numpy()) metrics["iou"].update(input=logits.argmax(1), target=targets.argmax(1)) with torch.no_grad(): mask_true = F.interpolate(draw_masks(targets.argmax(1, keepdim=True)), scale_factor=1) mask_pred = F.interpolate(draw_masks(logits.argmax(1, keepdim=True)), scale_factor=1) for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid( denormalize(images), nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.add_image( "mask_true", torchvision.utils.make_grid(mask_true, nrow=compute_nrow(mask_true), normalize=True), global_step=epoch, ) writer.add_image( "mask_pred", torchvision.utils.make_grid(mask_pred, nrow=compute_nrow(mask_pred), normalize=True), global_step=epoch, ) writer.add_image( "images_true", torchvision.utils.make_grid( denormalize(images) + mask_true, nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.add_image( "images_pred", torchvision.utils.make_grid( denormalize(images) + mask_pred, nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "x_loss": Mean(), "u_loss": Mean(), "u_loss_mask": Mean(), "lr": Last(), } model.train() for (x_w_images, x_targets), (u_w_images, u_s_images) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs) ): x_w_images, x_targets, u_w_images, u_s_images = ( x_w_images.to(DEVICE), x_targets.to(DEVICE), u_w_images.to(DEVICE), u_s_images.to(DEVICE), ) x_w_logits, u_w_logits, u_s_logits = model( torch.cat([x_w_images, u_w_images, u_s_images], 0) ).split([x_w_images.size(0), u_w_images.size(0), u_s_images.size(0)]) # x ############################################################################################################ x_loss = F.cross_entropy(input=x_w_logits, target=x_targets, reduction="none") metrics["x_loss"].update(x_loss.data.cpu().numpy()) # u ############################################################################################################ u_loss_mask, u_targets = F.softmax(u_w_logits.detach(), 1).max(1) u_loss_mask = (u_loss_mask >= config.train.tau).float() u_loss = u_loss_mask * F.cross_entropy( input=u_s_logits, target=u_targets, reduction="none" ) metrics["u_loss"].update(u_loss.data.cpu().numpy()) metrics["u_loss_mask"].update(u_loss_mask.data.cpu().numpy()) # opt step ##################################################################################################### metrics["lr"].update(np.squeeze(scheduler.get_lr())) optimizer.zero_grad() (x_loss.mean() + config.train.u_weight * u_loss.mean()).backward() optimizer.step() scheduler.step() if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "x_w_images", torchvision.utils.make_grid( denormalize(x_w_images), nrow=compute_nrow(x_w_images), normalize=True ), global_step=epoch, ) writer.add_image( "u_w_images", torchvision.utils.make_grid( denormalize(u_w_images), nrow=compute_nrow(u_w_images), normalize=True ), global_step=epoch, ) writer.add_image( "u_s_images", torchvision.utils.make_grid( denormalize(u_s_images), nrow=compute_nrow(u_s_images), normalize=True ), global_step=epoch, ) writer.flush() writer.close()
def main(config_path, **kwargs): config = load_config(config_path, **kwargs) gen = Gen( image_size=config.image_size, base_channels=config.gen.base_channels, max_channels=config.gen.max_channels, z_channels=config.noise_size, ).to(DEVICE) dsc = Dsc( image_size=config.image_size, base_channels=config.dsc.base_channels, max_channels=config.dsc.max_channels, batch_std=config.dsc.batch_std, ).to(DEVICE) gen_ema = copy.deepcopy(gen) ema = ModuleEMA(gen_ema, config.gen.ema) pl_ema = torch.zeros([], device=DEVICE) opt_gen = build_optimizer(gen.parameters(), config) opt_dsc = build_optimizer(dsc.parameters(), config) z_dist = ZDist(config.noise_size, DEVICE) z_fixed = z_dist(8 ** 2, truncation=1) if os.path.exists(os.path.join(config.experiment_path, "checkpoint.pth")): state = torch.load(os.path.join(config.experiment_path, "checkpoint.pth")) dsc.load_state_dict(state["dsc"]) gen.load_state_dict(state["gen"]) gen_ema.load_state_dict(state["gen_ema"]) opt_gen.load_state_dict(state["opt_gen"]) opt_dsc.load_state_dict(state["opt_dsc"]) pl_ema.copy_(state["pl_ema"]) z_fixed.copy_(state["z_fixed"]) print("restored from checkpoint") dataset = build_dataset(config) print("dataset size: {}".format(len(dataset))) data_loader = torch.utils.data.DataLoader( dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True, ) data_loader = ChunkedDataLoader(data_loader, config.batches_in_epoch) dsc_compute_loss, gen_compute_loss = build_loss(config) writer = SummaryWriter(config.experiment_path) for epoch in range(1, config.num_epochs + 1): metrics = { "dsc/loss": Mean(), "gen/loss": Mean(), } dsc_logits = Concat() dsc_targets = Concat() gen.train() dsc.train() gen_ema.train() for batch_i, real in enumerate( tqdm( data_loader, desc="{}/{}".format(epoch, config.num_epochs), disable=config.debug, smoothing=0.1, ), 1, ): real = real.to(DEVICE) # generator: train with zero_grad_and_step(opt_gen): if config.debug: print("gen") fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size)) assert ( fake.size() == real.size() ), "fake size {} does not match real size {}".format(fake.size(), real.size()) # gen fake logits = dsc(fake) loss = gen_compute_loss(logits, True) loss.mean().backward() metrics["gen/loss"].update(loss.detach()) # generator: regularize if batch_i % config.gen.reg_interval == 0: with zero_grad_and_step(opt_gen): # path length regularization fake, w = gen(z_dist(config.batch_size), z_dist(config.batch_size)) validate_shape(w, (None, config.batch_size, config.noise_size)) pl_noise = torch.randn_like(fake) / math.sqrt(fake.size(2) * fake.size(3)) (pl_grads,) = torch.autograd.grad( outputs=[(fake * pl_noise).sum()], inputs=[w], create_graph=True, only_inputs=True, ) pl_lengths = pl_grads.square().sum(2).mean(0).sqrt() pl_mean = pl_ema.lerp(pl_lengths.mean(), config.gen.pl_decay) pl_ema.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() loss_pl = pl_penalty * config.gen.pl_weight * config.gen.reg_interval loss_pl.mean().backward() # generator: update moving average ema.update(gen) # discriminator: train with zero_grad_and_step(opt_dsc): if config.debug: print("dsc") with torch.no_grad(): fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size)) assert ( fake.size() == real.size() ), "fake size {} does not match real size {}".format(fake.size(), real.size()) # dsc real logits = dsc(real) loss = dsc_compute_loss(logits, True) loss.mean().backward() metrics["dsc/loss"].update(loss.detach()) dsc_logits.update(logits.detach()) dsc_targets.update(torch.ones_like(logits)) # dsc fake logits = dsc(fake.detach()) loss = dsc_compute_loss(logits, False) loss.mean().backward() metrics["dsc/loss"].update(loss.detach()) dsc_logits.update(logits.detach()) dsc_targets.update(torch.zeros_like(logits)) # discriminator: regularize if batch_i % config.dsc.reg_interval == 0: with zero_grad_and_step(opt_dsc): # R1 regularization real = real.detach().requires_grad_(True) logits = dsc(real) (r1_grads,) = torch.autograd.grad( outputs=[logits.sum()], inputs=[real], create_graph=True, only_inputs=True, ) r1_penalty = r1_grads.square().sum([1, 2, 3]) loss_r1 = r1_penalty * (config.dsc.r1_gamma * 0.5) * config.dsc.reg_interval loss_r1.mean().backward() # break dsc.eval() gen.eval() gen_ema.eval() with torch.no_grad(), log_duration("visualization took {:.2f} seconds"): infer = Infer(gen) infer_ema = Infer(gen_ema) real = infer.postprocess(real) fake = infer(z_fixed) fake_ema = infer_ema(z_fixed) fake_ema_mix, fake_ema_mix_nrow = visualize_style_mixing( infer_ema, z_fixed[0 : 8 * 2 : 2], z_fixed[1 : 8 * 2 : 2] ) fake_ema_noise, fake_ema_noise_nrow = stack_images( [ fake_ema[:8], visualize_noise(infer_ema, z_fixed[:8], 128), ] ) dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy() dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy() metrics = {k: metrics[k].compute_and_reset() for k in metrics} metrics["dsc/ap"] = precision_recall_auc(input=dsc_logits, target=dsc_targets) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_figure( "dsc/pr_curve", plot_pr_curve(input=dsc_logits, target=dsc_targets), global_step=epoch, ) writer.add_image( "real", torchvision.utils.make_grid(real, nrow=compute_nrow(real)), global_step=epoch, ) writer.add_image( "fake", torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)), global_step=epoch, ) writer.add_image( "fake_ema", torchvision.utils.make_grid(fake_ema, nrow=compute_nrow(fake_ema)), global_step=epoch, ) writer.add_image( "fake_ema_mix", torchvision.utils.make_grid(fake_ema_mix, nrow=fake_ema_mix_nrow), global_step=epoch, ) writer.add_image( "fake_ema_noise", torchvision.utils.make_grid(fake_ema_noise, nrow=fake_ema_noise_nrow * 2), global_step=epoch, ) # break torch.save( { "gen": gen.state_dict(), "gen_ema": gen_ema.state_dict(), "dsc": dsc.state_dict(), "opt_gen": opt_gen.state_dict(), "opt_dsc": opt_dsc.state_dict(), "pl_ema": pl_ema, "z_fixed": z_fixed, }, os.path.join(config.experiment_path, "checkpoint.pth"), ) # break writer.flush() writer.close()
def eval_epoch(model, data_loader, epoch, config, suffix=""): metrics = { "images": Concat(), "targets": Concat(), "logits": Concat(), "loss": Concat(), } # loop over batches ################################################################################################ model.eval() with torch.no_grad(): for images, meta, targets in tqdm( data_loader, desc="fold {}, epoch {}/{}, eval".format(config.fold, epoch, config.train.epochs), ): images, meta, targets = ( images.to(DEVICE), {k: meta[k].to(DEVICE) for k in meta}, targets.to(DEVICE), ) logits = model(images, meta) loss = compute_loss(input=logits, target=targets, config=config) metrics["images"].update(images.data.cpu()) metrics["targets"].update(targets.data.cpu()) metrics["logits"].update(logits.data.cpu()) metrics["loss"].update(loss.data.cpu()) # compute metrics ################################################################################################## with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} metrics.update(compute_metric(input=metrics["logits"], target=metrics["targets"])) images_hard_pos = topk_hardest( metrics["images"], metrics["loss"], metrics["targets"] > 0.5, topk=config.eval.batch_size, ) images_hard_neg = topk_hardest( metrics["images"], metrics["loss"], metrics["targets"] <= 0.5, topk=config.eval.batch_size, ) roc_curve = plot_roc_curve(input=metrics["logits"], target=metrics["targets"]) metrics["loss"] = metrics["loss"].mean() writer = SummaryWriter(os.path.join(config.experiment_path, "eval", suffix)) writer.add_image( "images/hard/pos", torchvision.utils.make_grid( images_hard_pos, nrow=compute_nrow(images_hard_pos), normalize=True ), global_step=epoch, ) writer.add_image( "images/hard/neg", torchvision.utils.make_grid( images_hard_neg, nrow=compute_nrow(images_hard_neg), normalize=True ), global_step=epoch, ) writer.add_scalar("loss", metrics["loss"], global_step=epoch) writer.add_scalar("roc_auc", metrics["roc_auc"], global_step=epoch) writer.add_figure("roc_curve", roc_curve, global_step=epoch) writer.flush() writer.close() return metrics["roc_auc"]
def train_epoch(model, data_loader, opt_teacher, opt_student, sched_teacher, sched_student, epoch, config): metrics = { "teacher/loss": Mean(), "teacher/grad_norm": Mean(), "teacher/lr": Last(), "student/loss": Mean(), "student/grad_norm": Mean(), "student/lr": Last(), } model.train() for (x_image, x_target), (u_image, ) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)): x_image, x_target, u_image = x_image.to(DEVICE), x_target.to( DEVICE), u_image.to(DEVICE) with higher.innerloop_ctx(model.student, opt_student) as (h_model_student, h_opt_student): # student ################################################################################################## loss_student = cross_entropy(input=h_model_student(u_image), target=model.teacher(u_image)).mean() metrics["student/loss"].update(loss_student.data.cpu().numpy()) metrics["student/lr"].update( np.squeeze(sched_student.get_last_lr())) def grad_callback(grads): metrics["student/grad_norm"].update( grad_norm(grads).data.cpu().numpy()) return grads h_opt_student.step(loss_student.mean(), grad_callback=grad_callback) sched_student.step() # teacher ################################################################################################## loss_teacher = ( cross_entropy(input=model.teacher(x_image), target=one_hot(x_target, NUM_CLASSES)).mean() + cross_entropy(input=h_model_student(x_image), target=one_hot(x_target, NUM_CLASSES)).mean()) metrics["teacher/loss"].update(loss_teacher.data.cpu().numpy()) metrics["teacher/lr"].update( np.squeeze(sched_teacher.get_last_lr())) opt_teacher.zero_grad() loss_teacher.mean().backward() opt_teacher.step() metrics["teacher/grad_norm"].update( grad_norm( p.grad for p in model.teacher.parameters()).data.cpu().numpy()) sched_teacher.step() # copy student weights ##################################################################################### with torch.no_grad(): for p, p_prime in zip(model.student.parameters(), h_model_student.parameters()): p.copy_(p_prime) if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "x_image", torchvision.utils.make_grid(denormalize(x_image), nrow=compute_nrow(x_image), normalize=True), global_step=epoch, ) writer.add_image( "u_image", torchvision.utils.make_grid(denormalize(u_image), nrow=compute_nrow(u_image), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def main(config_path, **kwargs): config = load_config(config_path, **kwargs) gen = Gen( image_size=config.image_size, image_channels=3, base_channels=config.gen.base_channels, z_channels=config.noise_size, ).to(DEVICE) dsc = Dsc( image_size=config.image_size, image_channels=3, base_channels=config.dsc.base_channels, ).to(DEVICE) # gen_ema = gen # gen_ema = copy.deepcopy(gen) # ema = EMA(gen_ema, 0.99) gen.train() dsc.train() # gen_ema.train() opt_gen = build_optimizer(gen.parameters(), config) opt_dsc = build_optimizer(dsc.parameters(), config) transform = T.Compose([ T.Resize(config.image_size), T.RandomCrop(config.image_size), T.ToTensor(), T.Normalize([0.5], [0.5]), ]) # dataset = torchvision.datasets.MNIST( # "./data/mnist", train=True, transform=transform, download=True # ) # dataset = torchvision.datasets.CelebA( # "./data/celeba", split="all", transform=transform, download=True # ) dataset = ImageFolderDataset("./data/wikiart/resized/landscape", transform=transform) data_loader = torch.utils.data.DataLoader( dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True, ) noise_dist = torch.distributions.Normal(0, 1) dsc_compute_loss, gen_compute_loss = build_loss(config) writer = SummaryWriter("./log") for epoch in range(1, config.num_epochs + 1): metrics = { "dsc/loss": Mean(), "gen/loss": Mean(), } dsc_logits = Concat() dsc_targets = Concat() for batch_index, real in enumerate( tqdm(data_loader, desc="{}/{}".format(epoch, config.num_epochs), disable=config.debug)): real = real.to(DEVICE) # train discriminator with zero_grad_and_step(opt_dsc): if config.debug: print("dsc") noise = noise_dist.sample( (real.size(0), config.noise_size)).to(DEVICE) with torch.no_grad(): fake = gen(noise) assert (fake.size() == real.size( )), "fake size {} does not match real size {}".format( fake.size(), real.size()) # dsc real logits = dsc(real) loss = dsc_compute_loss(logits, True) loss.mean().backward() metrics["dsc/loss"].update(loss.detach()) dsc_logits.update(logits.detach()) dsc_targets.update(torch.ones_like(logits)) # dsc fake logits = dsc(fake.detach()) loss = dsc_compute_loss(logits, False) loss.mean().backward() metrics["dsc/loss"].update(loss.detach()) dsc_logits.update(logits.detach()) dsc_targets.update(torch.zeros_like(logits)) if (batch_index + 1) % 8 != 0: # r1 r1_gamma = 10 real = real.detach().requires_grad_(True) logits = dsc(real) (r1_grads, ) = torch.autograd.grad(outputs=[logits.sum()], inputs=[real], create_graph=True, only_inputs=True) r1_penalty = r1_grads.square().sum([1, 2, 3]) loss_r1 = r1_penalty * (r1_gamma / 2) * 8 loss_r1.mean().backward() if config.dsc.weight_clip is not None: clip_parameters(dsc, config.dsc.weight_clip) if (batch_index + 1) % config.dsc.num_steps != 0: continue # train generator with zero_grad_and_step(opt_gen): if config.debug: print("gen") noise = noise_dist.sample( (real.size(0), config.noise_size)).to(DEVICE) fake = gen(noise) assert (fake.size() == real.size() ), "fake size {} does not match real size {}".format( fake.size(), real.size()) # gen fake logits = dsc(fake) loss = gen_compute_loss(logits, True) loss.mean().backward() metrics["gen/loss"].update(loss.detach()) # update moving average # ema.update(gen) with torch.no_grad(): real, fake = [(x[:4**2] * 0.5 + 0.5).clamp(0, 1) for x in [real, fake]] dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy() dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy() metrics = {k: metrics[k].compute_and_reset() for k in metrics} metrics["ap"] = precision_recall_auc(input=dsc_logits, target=dsc_targets) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_figure( "dsc/pr_curve", plot_pr_curve(input=dsc_logits, target=dsc_targets), global_step=epoch, ) writer.add_image( "real", torchvision.utils.make_grid(real, nrow=compute_nrow(real)), global_step=epoch, ) writer.add_image( "fake", torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)), global_step=epoch, ) # writer.add_image( # "fake_ema", # torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)), # global_step=epoch, # ) writer.flush() writer.close()