Exemplo n.º 1
0
def eval_epoch(model, data_loader, epoch, config):
    metrics = {
        "accuracy": Mean(),
        "entropy": Mean(),
    }

    with torch.no_grad():
        model.eval()
        for images_x, targets_x in tqdm(data_loader,
                                        desc="epoch {}/{}, eval".format(
                                            epoch, config.epochs)):
            images_x, targets_x = images_x.to(DEVICE), targets_x.to(DEVICE)

            probs_x = model(images_x)

            metrics["entropy"].update(entropy(probs_x).data.cpu().numpy())
            metrics["accuracy"].update(
                (probs_x.argmax(-1) == targets_x).float().data.cpu().numpy())

    writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "images_x",
            torchvision.utils.make_grid(images_x,
                                        nrow=compute_nrow(images_x),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 2
0
def eval_epoch(model, data_loader, epoch, config):
    metrics = {
        "loss": Mean(),
        "accuracy": Mean(),
    }

    with torch.no_grad():
        model.eval()
        for images, targets in tqdm(
            data_loader, desc="epoch {}/{}, eval".format(epoch, config.epochs)
        ):
            images, targets = images.to(DEVICE), targets.to(DEVICE)

            logits = model(images)
            loss = F.cross_entropy(input=logits, target=targets, reduction="none")

            metrics["loss"].update(loss.data.cpu().numpy())
            metrics["accuracy"].update((logits.argmax(-1) == targets).float().data.cpu().numpy())

    writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(images, nrow=compute_nrow(images), normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 3
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "images": Last(),
        "loss": Mean(),
        "lr": Last(),
    }

    # loop over batches ################################################################################################
    model.train()
    for images, meta, targets in tqdm(
        data_loader,
        desc="fold {}, epoch {}/{}, train".format(config.fold, epoch, config.train.epochs),
    ):
        images, meta, targets = (
            images.to(DEVICE),
            {k: meta[k].to(DEVICE) for k in meta},
            targets.to(DEVICE),
        )
        # images, targets = mix_up(images, targets, alpha=1.)

        logits = model(images, meta)
        loss = compute_loss(input=logits, target=targets, config=config)

        metrics["images"].update(images.data.cpu())
        metrics["loss"].update(loss.data.cpu())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    # compute metrics ##################################################################################################
    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}

        writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
        writer.add_image(
            "images",
            torchvision.utils.make_grid(
                metrics["images"], nrow=compute_nrow(metrics["images"]), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_scalar("loss", metrics["loss"], global_step=epoch)
        writer.add_scalar("lr", metrics["lr"], global_step=epoch)

        writer.flush()
        writer.close()
Exemplo n.º 4
0
def eval_epoch(model, data_loader, epoch, config):
    metrics = {
        "teacher/accuracy": Mean(),
        "teacher/entropy": Mean(),
        "student/accuracy": Mean(),
        "student/entropy": Mean(),
    }

    with torch.no_grad():
        model.eval()
        for x_image, x_target in tqdm(data_loader,
                                      desc="epoch {}/{}, eval".format(
                                          epoch, config.epochs)):
            x_image, x_target = x_image.to(DEVICE), x_target.to(DEVICE)

            probs_teacher = model.teacher(x_image)
            probs_student = model.student(x_image)

            metrics["teacher/entropy"].update(
                entropy(probs_teacher).data.cpu().numpy())
            metrics["student/entropy"].update(
                entropy(probs_student).data.cpu().numpy())

            metrics["teacher/accuracy"].update(
                (probs_teacher.argmax(-1) == x_target
                 ).float().data.cpu().numpy())
            metrics["student/accuracy"].update(
                (probs_student.argmax(-1) == x_target
                 ).float().data.cpu().numpy())

    writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "x_image",
            torchvision.utils.make_grid(denormalize(x_image),
                                        nrow=compute_nrow(x_image),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 5
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss": Mean(),
        "lr": Last(),
    }

    model.train()
    for images, targets in tqdm(data_loader,
                                desc="epoch {}/{}, train".format(
                                    epoch, config.epochs)):
        images, targets = images.to(DEVICE), targets.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(input=logits, target=targets, reduction="none")

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(denormalize(images),
                                        nrow=compute_nrow(images),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_histogram("params",
                             flatten_weights(model.parameters()),
                             global_step=epoch)

    writer.flush()
    writer.close()
Exemplo n.º 6
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss/x": Mean(),
        "loss/u": Mean(),
        "weight/u": Last(),
        "lr": Last(),
    }

    model.train()
    for (images_x, targets_x), (images_u_0, images_u_1) in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.epochs)):
        # prepare data #################################################################################################
        images_x, targets_x, images_u_0, images_u_1 = (
            images_x.to(DEVICE),
            targets_x.to(DEVICE),
            images_u_0.to(DEVICE),
            images_u_1.to(DEVICE),
        )
        targets_x = one_hot(targets_x, NUM_CLASSES)

        # mix-match ####################################################################################################
        with torch.no_grad():
            (images_x,
             targets_x), (images_u,
                          targets_u) = mix_match(x=(images_x, targets_x),
                                                 u=(images_u_0, images_u_1),
                                                 model=model,
                                                 config=config)

        probs_x, probs_u = model(torch.cat([images_x, images_u])).split(
            [images_x.size(0), images_u.size(0)])

        # x ############################################################################################################
        loss_x = compute_loss_x(input=probs_x, target=targets_x)
        metrics["loss/x"].update(loss_x.data.cpu().numpy())

        # u ############################################################################################################
        loss_u = compute_loss_u(input=probs_u, target=targets_u)
        metrics["loss/u"].update(loss_u.data.cpu().numpy())

        # opt step #####################################################################################################
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))
        weight_u = config.train.mix_match.weight_u * min(
            (epoch - 1) / config.epochs_warmup, 1.0)
        metrics["weight/u"].update(weight_u)

        optimizer.zero_grad()
        (loss_x.mean() + weight_u * loss_u.mean()).backward()
        optimizer.step()
        scheduler.step()

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "images_x",
            torchvision.utils.make_grid(images_x,
                                        nrow=compute_nrow(images_x),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "images_u",
            torchvision.utils.make_grid(images_u,
                                        nrow=compute_nrow(images_u),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 7
0
def eval_epoch(model, data_loader, epoch, config):
    metrics = {
        "loss": Mean(),
    }

    with torch.no_grad():
        model.eval()
        for (text, text_mask), (audio, audio_mask) in tqdm(
                data_loader,
                desc="epoch {}/{}, eval".format(epoch, config.train.epochs)):
            text, audio, text_mask, audio_mask = [
                x.to(DEVICE) for x in [text, audio, text_mask, audio_mask]
            ]

            output, pre_output, target, target_mask, weight = model(
                text, text_mask, audio, audio_mask)

            loss = masked_mse(output, target, target_mask) + masked_mse(
                pre_output, target, target_mask)

            metrics["loss"].update(loss.data.cpu())

    writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))
    with torch.no_grad():
        gl_true = griffin_lim(target, model.spectra)
        gl_pred = griffin_lim(output, model.spectra)
        output, pre_output, target, weight = [
            x.unsqueeze(1) for x in [output, pre_output, target, weight]
        ]
        nrow = compute_nrow(target)

        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "target",
            torchvision.utils.make_grid(target, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "output",
            torchvision.utils.make_grid(output, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "pre_output",
            torchvision.utils.make_grid(pre_output, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "weight",
            torchvision.utils.make_grid(weight, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        for i in tqdm(range(min(text.size(0), 4)), desc="writing audio"):
            writer.add_audio("audio/{}".format(i),
                             audio[i],
                             sample_rate=config.sample_rate,
                             global_step=epoch)
            writer.add_audio(
                "griffin-lim-true/{}".format(i),
                gl_true[i],
                sample_rate=config.sample_rate,
                global_step=epoch,
            )
            writer.add_audio(
                "griffin-lim-pred/{}".format(i),
                gl_pred[i],
                sample_rate=config.sample_rate,
                global_step=epoch,
            )

    writer.flush()
    writer.close()
Exemplo n.º 8
0
def eval_epoch(model, data_loader, epoch, config):
    writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))
    metrics = {
        "loss": Mean(),
        "iou": IoU(),
    }

    with torch.no_grad():
        model.eval()
        for images, targets in tqdm(
            data_loader, desc="epoch {}/{}, eval".format(epoch, config.epochs)
        ):
            images, targets = images.to(DEVICE), targets.to(DEVICE)
            targets = image_one_hot(targets, NUM_CLASSES)

            logits = model(images)
            loss = compute_loss(input=logits, target=targets)

            metrics["loss"].update(loss.data.cpu().numpy())
            metrics["iou"].update(input=logits.argmax(1), target=targets.argmax(1))

    with torch.no_grad():
        mask_true = F.interpolate(draw_masks(targets.argmax(1, keepdim=True)), scale_factor=1)
        mask_pred = F.interpolate(draw_masks(logits.argmax(1, keepdim=True)), scale_factor=1)

        for k in metrics:
            writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(
                denormalize(images), nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "mask_true",
            torchvision.utils.make_grid(mask_true, nrow=compute_nrow(mask_true), normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "mask_pred",
            torchvision.utils.make_grid(mask_pred, nrow=compute_nrow(mask_pred), normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "images_true",
            torchvision.utils.make_grid(
                denormalize(images) + mask_true, nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "images_pred",
            torchvision.utils.make_grid(
                denormalize(images) + mask_pred, nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 9
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "x_loss": Mean(),
        "u_loss": Mean(),
        "u_loss_mask": Mean(),
        "lr": Last(),
    }

    model.train()
    for (x_w_images, x_targets), (u_w_images, u_s_images) in tqdm(
        data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)
    ):
        x_w_images, x_targets, u_w_images, u_s_images = (
            x_w_images.to(DEVICE),
            x_targets.to(DEVICE),
            u_w_images.to(DEVICE),
            u_s_images.to(DEVICE),
        )

        x_w_logits, u_w_logits, u_s_logits = model(
            torch.cat([x_w_images, u_w_images, u_s_images], 0)
        ).split([x_w_images.size(0), u_w_images.size(0), u_s_images.size(0)])

        # x ############################################################################################################
        x_loss = F.cross_entropy(input=x_w_logits, target=x_targets, reduction="none")
        metrics["x_loss"].update(x_loss.data.cpu().numpy())

        # u ############################################################################################################
        u_loss_mask, u_targets = F.softmax(u_w_logits.detach(), 1).max(1)
        u_loss_mask = (u_loss_mask >= config.train.tau).float()

        u_loss = u_loss_mask * F.cross_entropy(
            input=u_s_logits, target=u_targets, reduction="none"
        )
        metrics["u_loss"].update(u_loss.data.cpu().numpy())
        metrics["u_loss_mask"].update(u_loss_mask.data.cpu().numpy())

        # opt step #####################################################################################################
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        optimizer.zero_grad()
        (x_loss.mean() + config.train.u_weight * u_loss.mean()).backward()
        optimizer.step()
        scheduler.step()

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch)
        writer.add_image(
            "x_w_images",
            torchvision.utils.make_grid(
                denormalize(x_w_images), nrow=compute_nrow(x_w_images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "u_w_images",
            torchvision.utils.make_grid(
                denormalize(u_w_images), nrow=compute_nrow(u_w_images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "u_s_images",
            torchvision.utils.make_grid(
                denormalize(u_s_images), nrow=compute_nrow(u_s_images), normalize=True
            ),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 10
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    gen = Gen(
        image_size=config.image_size,
        base_channels=config.gen.base_channels,
        max_channels=config.gen.max_channels,
        z_channels=config.noise_size,
    ).to(DEVICE)
    dsc = Dsc(
        image_size=config.image_size,
        base_channels=config.dsc.base_channels,
        max_channels=config.dsc.max_channels,
        batch_std=config.dsc.batch_std,
    ).to(DEVICE)
    gen_ema = copy.deepcopy(gen)
    ema = ModuleEMA(gen_ema, config.gen.ema)
    pl_ema = torch.zeros([], device=DEVICE)

    opt_gen = build_optimizer(gen.parameters(), config)
    opt_dsc = build_optimizer(dsc.parameters(), config)

    z_dist = ZDist(config.noise_size, DEVICE)
    z_fixed = z_dist(8 ** 2, truncation=1)

    if os.path.exists(os.path.join(config.experiment_path, "checkpoint.pth")):
        state = torch.load(os.path.join(config.experiment_path, "checkpoint.pth"))
        dsc.load_state_dict(state["dsc"])
        gen.load_state_dict(state["gen"])
        gen_ema.load_state_dict(state["gen_ema"])
        opt_gen.load_state_dict(state["opt_gen"])
        opt_dsc.load_state_dict(state["opt_dsc"])
        pl_ema.copy_(state["pl_ema"])
        z_fixed.copy_(state["z_fixed"])
        print("restored from checkpoint")

    dataset = build_dataset(config)
    print("dataset size: {}".format(len(dataset)))
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    data_loader = ChunkedDataLoader(data_loader, config.batches_in_epoch)

    dsc_compute_loss, gen_compute_loss = build_loss(config)

    writer = SummaryWriter(config.experiment_path)
    for epoch in range(1, config.num_epochs + 1):
        metrics = {
            "dsc/loss": Mean(),
            "gen/loss": Mean(),
        }
        dsc_logits = Concat()
        dsc_targets = Concat()

        gen.train()
        dsc.train()
        gen_ema.train()
        for batch_i, real in enumerate(
            tqdm(
                data_loader,
                desc="{}/{}".format(epoch, config.num_epochs),
                disable=config.debug,
                smoothing=0.1,
            ),
            1,
        ):
            real = real.to(DEVICE)

            # generator: train
            with zero_grad_and_step(opt_gen):
                if config.debug:
                    print("gen")
                fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                assert (
                    fake.size() == real.size()
                ), "fake size {} does not match real size {}".format(fake.size(), real.size())

                # gen fake
                logits = dsc(fake)
                loss = gen_compute_loss(logits, True)
                loss.mean().backward()
                metrics["gen/loss"].update(loss.detach())

            # generator: regularize
            if batch_i % config.gen.reg_interval == 0:
                with zero_grad_and_step(opt_gen):
                    # path length regularization
                    fake, w = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                    validate_shape(w, (None, config.batch_size, config.noise_size))
                    pl_noise = torch.randn_like(fake) / math.sqrt(fake.size(2) * fake.size(3))
                    (pl_grads,) = torch.autograd.grad(
                        outputs=[(fake * pl_noise).sum()],
                        inputs=[w],
                        create_graph=True,
                        only_inputs=True,
                    )
                    pl_lengths = pl_grads.square().sum(2).mean(0).sqrt()
                    pl_mean = pl_ema.lerp(pl_lengths.mean(), config.gen.pl_decay)
                    pl_ema.copy_(pl_mean.detach())
                    pl_penalty = (pl_lengths - pl_mean).square()
                    loss_pl = pl_penalty * config.gen.pl_weight * config.gen.reg_interval
                    loss_pl.mean().backward()

            # generator: update moving average
            ema.update(gen)

            # discriminator: train
            with zero_grad_and_step(opt_dsc):
                if config.debug:
                    print("dsc")
                with torch.no_grad():
                    fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                    assert (
                        fake.size() == real.size()
                    ), "fake size {} does not match real size {}".format(fake.size(), real.size())

                # dsc real
                logits = dsc(real)
                loss = dsc_compute_loss(logits, True)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.ones_like(logits))

                # dsc fake
                logits = dsc(fake.detach())
                loss = dsc_compute_loss(logits, False)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.zeros_like(logits))

            # discriminator: regularize
            if batch_i % config.dsc.reg_interval == 0:
                with zero_grad_and_step(opt_dsc):
                    # R1 regularization
                    real = real.detach().requires_grad_(True)
                    logits = dsc(real)
                    (r1_grads,) = torch.autograd.grad(
                        outputs=[logits.sum()],
                        inputs=[real],
                        create_graph=True,
                        only_inputs=True,
                    )
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_r1 = r1_penalty * (config.dsc.r1_gamma * 0.5) * config.dsc.reg_interval
                    loss_r1.mean().backward()

            # break

        dsc.eval()
        gen.eval()
        gen_ema.eval()
        with torch.no_grad(), log_duration("visualization took {:.2f} seconds"):
            infer = Infer(gen)
            infer_ema = Infer(gen_ema)

            real = infer.postprocess(real)
            fake = infer(z_fixed)
            fake_ema = infer_ema(z_fixed)
            fake_ema_mix, fake_ema_mix_nrow = visualize_style_mixing(
                infer_ema, z_fixed[0 : 8 * 2 : 2], z_fixed[1 : 8 * 2 : 2]
            )

            fake_ema_noise, fake_ema_noise_nrow = stack_images(
                [
                    fake_ema[:8],
                    visualize_noise(infer_ema, z_fixed[:8], 128),
                ]
            )

            dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy()
            dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy()

            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            metrics["dsc/ap"] = precision_recall_auc(input=dsc_logits, target=dsc_targets)
            for k in metrics:
                writer.add_scalar(k, metrics[k], global_step=epoch)
            writer.add_figure(
                "dsc/pr_curve",
                plot_pr_curve(input=dsc_logits, target=dsc_targets),
                global_step=epoch,
            )
            writer.add_image(
                "real",
                torchvision.utils.make_grid(real, nrow=compute_nrow(real)),
                global_step=epoch,
            )
            writer.add_image(
                "fake",
                torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema",
                torchvision.utils.make_grid(fake_ema, nrow=compute_nrow(fake_ema)),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema_mix",
                torchvision.utils.make_grid(fake_ema_mix, nrow=fake_ema_mix_nrow),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema_noise",
                torchvision.utils.make_grid(fake_ema_noise, nrow=fake_ema_noise_nrow * 2),
                global_step=epoch,
            )
            # break
            torch.save(
                {
                    "gen": gen.state_dict(),
                    "gen_ema": gen_ema.state_dict(),
                    "dsc": dsc.state_dict(),
                    "opt_gen": opt_gen.state_dict(),
                    "opt_dsc": opt_dsc.state_dict(),
                    "pl_ema": pl_ema,
                    "z_fixed": z_fixed,
                },
                os.path.join(config.experiment_path, "checkpoint.pth"),
            )
        # break

    writer.flush()
    writer.close()
Exemplo n.º 11
0
def eval_epoch(model, data_loader, epoch, config, suffix=""):
    metrics = {
        "images": Concat(),
        "targets": Concat(),
        "logits": Concat(),
        "loss": Concat(),
    }

    # loop over batches ################################################################################################
    model.eval()
    with torch.no_grad():
        for images, meta, targets in tqdm(
            data_loader,
            desc="fold {}, epoch {}/{}, eval".format(config.fold, epoch, config.train.epochs),
        ):
            images, meta, targets = (
                images.to(DEVICE),
                {k: meta[k].to(DEVICE) for k in meta},
                targets.to(DEVICE),
            )

            logits = model(images, meta)
            loss = compute_loss(input=logits, target=targets, config=config)

            metrics["images"].update(images.data.cpu())
            metrics["targets"].update(targets.data.cpu())
            metrics["logits"].update(logits.data.cpu())
            metrics["loss"].update(loss.data.cpu())

    # compute metrics ##################################################################################################
    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        metrics.update(compute_metric(input=metrics["logits"], target=metrics["targets"]))
        images_hard_pos = topk_hardest(
            metrics["images"],
            metrics["loss"],
            metrics["targets"] > 0.5,
            topk=config.eval.batch_size,
        )
        images_hard_neg = topk_hardest(
            metrics["images"],
            metrics["loss"],
            metrics["targets"] <= 0.5,
            topk=config.eval.batch_size,
        )
        roc_curve = plot_roc_curve(input=metrics["logits"], target=metrics["targets"])
        metrics["loss"] = metrics["loss"].mean()

        writer = SummaryWriter(os.path.join(config.experiment_path, "eval", suffix))
        writer.add_image(
            "images/hard/pos",
            torchvision.utils.make_grid(
                images_hard_pos, nrow=compute_nrow(images_hard_pos), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "images/hard/neg",
            torchvision.utils.make_grid(
                images_hard_neg, nrow=compute_nrow(images_hard_neg), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_scalar("loss", metrics["loss"], global_step=epoch)
        writer.add_scalar("roc_auc", metrics["roc_auc"], global_step=epoch)
        writer.add_figure("roc_curve", roc_curve, global_step=epoch)

        writer.flush()
        writer.close()

    return metrics["roc_auc"]
Exemplo n.º 12
0
def train_epoch(model, data_loader, opt_teacher, opt_student, sched_teacher,
                sched_student, epoch, config):
    metrics = {
        "teacher/loss": Mean(),
        "teacher/grad_norm": Mean(),
        "teacher/lr": Last(),
        "student/loss": Mean(),
        "student/grad_norm": Mean(),
        "student/lr": Last(),
    }

    model.train()
    for (x_image, x_target), (u_image, ) in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.epochs)):
        x_image, x_target, u_image = x_image.to(DEVICE), x_target.to(
            DEVICE), u_image.to(DEVICE)

        with higher.innerloop_ctx(model.student,
                                  opt_student) as (h_model_student,
                                                   h_opt_student):
            # student ##################################################################################################

            loss_student = cross_entropy(input=h_model_student(u_image),
                                         target=model.teacher(u_image)).mean()
            metrics["student/loss"].update(loss_student.data.cpu().numpy())
            metrics["student/lr"].update(
                np.squeeze(sched_student.get_last_lr()))

            def grad_callback(grads):
                metrics["student/grad_norm"].update(
                    grad_norm(grads).data.cpu().numpy())

                return grads

            h_opt_student.step(loss_student.mean(),
                               grad_callback=grad_callback)
            sched_student.step()

            # teacher ##################################################################################################

            loss_teacher = (
                cross_entropy(input=model.teacher(x_image),
                              target=one_hot(x_target, NUM_CLASSES)).mean() +
                cross_entropy(input=h_model_student(x_image),
                              target=one_hot(x_target, NUM_CLASSES)).mean())

            metrics["teacher/loss"].update(loss_teacher.data.cpu().numpy())
            metrics["teacher/lr"].update(
                np.squeeze(sched_teacher.get_last_lr()))

            opt_teacher.zero_grad()
            loss_teacher.mean().backward()
            opt_teacher.step()
            metrics["teacher/grad_norm"].update(
                grad_norm(
                    p.grad
                    for p in model.teacher.parameters()).data.cpu().numpy())
            sched_teacher.step()

            # copy student weights #####################################################################################

            with torch.no_grad():
                for p, p_prime in zip(model.student.parameters(),
                                      h_model_student.parameters()):
                    p.copy_(p_prime)

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "x_image",
            torchvision.utils.make_grid(denormalize(x_image),
                                        nrow=compute_nrow(x_image),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "u_image",
            torchvision.utils.make_grid(denormalize(u_image),
                                        nrow=compute_nrow(u_image),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Exemplo n.º 13
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    gen = Gen(
        image_size=config.image_size,
        image_channels=3,
        base_channels=config.gen.base_channels,
        z_channels=config.noise_size,
    ).to(DEVICE)
    dsc = Dsc(
        image_size=config.image_size,
        image_channels=3,
        base_channels=config.dsc.base_channels,
    ).to(DEVICE)
    # gen_ema = gen
    # gen_ema = copy.deepcopy(gen)
    # ema = EMA(gen_ema, 0.99)

    gen.train()
    dsc.train()
    # gen_ema.train()

    opt_gen = build_optimizer(gen.parameters(), config)
    opt_dsc = build_optimizer(dsc.parameters(), config)

    transform = T.Compose([
        T.Resize(config.image_size),
        T.RandomCrop(config.image_size),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
    ])

    # dataset = torchvision.datasets.MNIST(
    #     "./data/mnist", train=True, transform=transform, download=True
    # )
    # dataset = torchvision.datasets.CelebA(
    #     "./data/celeba", split="all", transform=transform, download=True
    # )
    dataset = ImageFolderDataset("./data/wikiart/resized/landscape",
                                 transform=transform)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    noise_dist = torch.distributions.Normal(0, 1)
    dsc_compute_loss, gen_compute_loss = build_loss(config)

    writer = SummaryWriter("./log")
    for epoch in range(1, config.num_epochs + 1):
        metrics = {
            "dsc/loss": Mean(),
            "gen/loss": Mean(),
        }
        dsc_logits = Concat()
        dsc_targets = Concat()

        for batch_index, real in enumerate(
                tqdm(data_loader,
                     desc="{}/{}".format(epoch, config.num_epochs),
                     disable=config.debug)):
            real = real.to(DEVICE)

            # train discriminator
            with zero_grad_and_step(opt_dsc):
                if config.debug:
                    print("dsc")
                noise = noise_dist.sample(
                    (real.size(0), config.noise_size)).to(DEVICE)
                with torch.no_grad():
                    fake = gen(noise)
                    assert (fake.size() == real.size(
                    )), "fake size {} does not match real size {}".format(
                        fake.size(), real.size())

                # dsc real
                logits = dsc(real)
                loss = dsc_compute_loss(logits, True)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.ones_like(logits))

                # dsc fake
                logits = dsc(fake.detach())
                loss = dsc_compute_loss(logits, False)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.zeros_like(logits))

                if (batch_index + 1) % 8 != 0:
                    # r1
                    r1_gamma = 10
                    real = real.detach().requires_grad_(True)
                    logits = dsc(real)
                    (r1_grads, ) = torch.autograd.grad(outputs=[logits.sum()],
                                                       inputs=[real],
                                                       create_graph=True,
                                                       only_inputs=True)
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_r1 = r1_penalty * (r1_gamma / 2) * 8
                    loss_r1.mean().backward()

            if config.dsc.weight_clip is not None:
                clip_parameters(dsc, config.dsc.weight_clip)

            if (batch_index + 1) % config.dsc.num_steps != 0:
                continue

            # train generator
            with zero_grad_and_step(opt_gen):
                if config.debug:
                    print("gen")
                noise = noise_dist.sample(
                    (real.size(0), config.noise_size)).to(DEVICE)
                fake = gen(noise)
                assert (fake.size() == real.size()
                        ), "fake size {} does not match real size {}".format(
                            fake.size(), real.size())

                # gen fake
                logits = dsc(fake)
                loss = gen_compute_loss(logits, True)
                loss.mean().backward()
                metrics["gen/loss"].update(loss.detach())

                # update moving average
                # ema.update(gen)

        with torch.no_grad():
            real, fake = [(x[:4**2] * 0.5 + 0.5).clamp(0, 1)
                          for x in [real, fake]]

            dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy()
            dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy()

            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            metrics["ap"] = precision_recall_auc(input=dsc_logits,
                                                 target=dsc_targets)
            for k in metrics:
                writer.add_scalar(k, metrics[k], global_step=epoch)
            writer.add_figure(
                "dsc/pr_curve",
                plot_pr_curve(input=dsc_logits, target=dsc_targets),
                global_step=epoch,
            )
            writer.add_image(
                "real",
                torchvision.utils.make_grid(real, nrow=compute_nrow(real)),
                global_step=epoch,
            )
            writer.add_image(
                "fake",
                torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
                global_step=epoch,
            )
            # writer.add_image(
            #     "fake_ema",
            #     torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
            #     global_step=epoch,
            # )

    writer.flush()
    writer.close()