Example #1
0
def valid(args, epoch, loader, model, show):
    torch.backends.cudnn.benchmark = False

    model.eval()

    if get_rank() == 0:
        pbar = tqdm(loader, dynamic_ncols=True)

    else:
        pbar = loader

    intersect_sum = None
    union_sum = None
    correct_sum = 0
    total_sum = 0

    for i, (img, annot) in enumerate(pbar):
        img = img.to('cuda')
        annot = annot.to('cuda')
        _, out = model(img)
        _, pred = out.max(1)

        if get_rank() == 0 and i % 10 == 0:
            result = show(img[0], annot[0], pred[0])
            result.save(f'sample/{str(epoch + 1).zfill(3)}-{str(i).zfill(4)}.png')

        pred = (annot > 0) * pred
        correct = (pred > 0) * (pred == annot)
        correct_sum += correct.sum().float().item()
        total_sum += (annot > 0).sum().float()

        for g, p, c in zip(annot, pred, correct):
            intersect, union = intersection_union(g, p, c, args.n_class)

            if intersect_sum is None:
                intersect_sum = intersect

            else:
                intersect_sum += intersect

            if union_sum is None:
                union_sum = union

            else:
                union_sum += union

        all_intersect = sum(all_gather(intersect_sum.to('cpu')))
        all_union = sum(all_gather(union_sum.to('cpu')))

        if get_rank() == 0:
            iou = all_intersect / (all_union + 1e-10)
            m_iou = iou.mean().item()

            pbar.set_description(
                f'acc: {correct_sum / total_sum:.5f}; mIoU: {m_iou:.5f}'
            )
Example #2
0
    def train_epoch(self, epoch, loader, optimizer, scheduler, device,
                    sample_path):
        criterion = nn.MSELoss()

        latent_loss_weight = 0.25
        sample_size = 25

        mse_sum = 0
        mse_n = 0

        for i, (img, label) in enumerate(loader):
            recon_loss, latent_loss = self.train_step(img, optimizer,
                                                      criterion, device,
                                                      scheduler,
                                                      latent_loss_weight)

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {'mse_sum': part_mse_sum, 'mse_n': part_mse_n}
            comm = dist_fn.all_gather(comm)

            for part in comm:
                mse_sum += part['mse_sum']
                mse_n += part['mse_n']

            lr = optimizer.param_groups[0]['lr']

            if i % 100 == 0:
                self.eval()

                sample = img[:sample_size]

                with torch.no_grad():
                    out, _ = self(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    os.path.join(
                        sample_path,
                        f'{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png'),
                    nrow=sample_size,
                    normalize=True,
                    value_range=(-1, 1),
                )

                self.train()

            yield recon_loss.item(), latent_loss.item(), mse_sum / mse_n, lr
Example #3
0
def accumulate_predictions(predictions):
    all_predictions = all_gather(predictions)

    if get_rank() != 0:
        return

    predictions = {}

    for p in all_predictions:
        predictions.update(p)

    ids = list(sorted(predictions.keys()))

    if len(ids) != ids[-1] + 1:
        print('Evaluation results is not contiguous')

    predictions = [predictions[i] for i in ids]

    return predictions
Example #4
0
def train(epoch, loader, model, optimizer, scheduler, scaler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        with torch.cuda.amp.autocast(scaler.is_enabled()):
            out, latent_loss = model(img)
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss
        scaler.scale(loss).backward()

        if scheduler is not None:
            scheduler.step()
        scaler.step(optimizer)
        scaler.update()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                with torch.no_grad(), torch.cuda.amp.autocast(
                        scaler.is_enabled()):
                    out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
Example #5
0
def train(epoch, loader, model, optimizer, scheduler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        wandb.log({'train loss': loss.item()})

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                with torch.no_grad():
                    out, _ = model(sample)

                # utils.save_image(
                #     torch.cat([sample, out], 0),
                #     f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                #     nrow=sample_size,
                #     normalize=True,
                #     range=(-1, 1),
                # )

                example_images = [
                    wandb.Image(image, caption=f"{epoch}_{i}") for image in out
                ]
                wandb.log({"Examples": example_images})

                model.train()
def train(epoch, loader, discriminator, generator, scheduler_D, scheduler_G,
          optimizer_D, optimizer_G, device):
    loader_d = tqdm(loader)
    if (epoch + 1) % n_critic == 0:
        loader_g = tqdm(loader)

    adversarial_loss = nn.BCEWithLogitsLoss()  # sigmoid
    pixelwise_loss = nn.L1Loss()
    gdloss = GDLoss()

    recon_loss_weight = 0.4
    latent_loss_weight = 0.2
    gradient_loss_weight = 0.4
    sample_size = batch_size

    mse_sum = 0
    mse_n = 0
    g_sum = 0
    g_n = 0

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    # ---------------------
    #  Train Discriminator
    # ---------------------
    for i, (img, label, label_path, class_name) in enumerate(loader_d):
        discriminator.zero_grad()

        valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0),
                         requires_grad=False)
        fake = Variable(torch.Tensor(img.shape[0], 1).fill_(0.0),
                        requires_grad=False)

        img = img.to(device)
        valid = valid.to(device)
        fake = fake.to(device)
        label = label.to(device)

        gdloss.conv_x = gdloss.conv_x.to(device)
        gdloss.conv_y = gdloss.conv_y.to(device)

        vqvae2_out, latent_loss = generator(img)

        real_loss = adversarial_loss(discriminator(label), valid)
        fake_loss = adversarial_loss(discriminator(vqvae2_out), fake)

        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()

        if scheduler_D is not None:
            scheduler_D.step()
        optimizer_D.step()

        if dist.is_primary():
            lr = optimizer_D.param_groups[0]["lr"]

            loader_d.set_description((
                f"Discriminator epoch: {epoch + 1}; class loss: {d_loss.item():.5f};"
                f"lr: {lr:.5f}"))

    # ---------------------
    #  Train Generator
    # ---------------------
    if (epoch + 1) % n_critic == 0:
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        for i, (img, label, label_path, class_name) in enumerate(loader_g):
            generator.zero_grad()

            valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0),
                             requires_grad=False)

            img = img.to(device)
            valid = valid.to(device)
            label = label.to(device)

            gdloss.conv_x = gdloss.conv_x.to(device)
            gdloss.conv_y = gdloss.conv_y.to(device)

            vqvae2_out, latent_loss = generator(img)

            recon_loss = pixelwise_loss(vqvae2_out, label)
            gradient_loss = gdloss(vqvae2_out, label)
            gradient_loss = gradient_loss.mean()
            latent_loss = latent_loss.mean()
            g_loss = 0.1 * adversarial_loss(discriminator(vqvae2_out), valid) + \
                     0.9 * (recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss)

            g_loss.backward()

            if scheduler_G is not None:
                scheduler_G.step()
            optimizer_G.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = dist.all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            part_g_sum = gradient_loss.item() * img.shape[0]
            part_g_n = img.shape[0]
            g_comm = {"g_sum": part_g_sum, "g_n": part_g_n}
            g_comm = dist.all_gather(g_comm)

            for part in g_comm:
                g_sum += part["g_sum"]
                g_n += part["g_n"]

            if dist.is_primary():
                lr = optimizer_G.param_groups[0]["lr"]

                loader_g.set_description((
                    f"Denerator epoch: {(epoch + 1) // n_critic + 1}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"))

            if i % 100 == 0:
                generator.eval()

                sample = img[:sample_size]
                label_sample = label[:sample_size]
                sample0 = sample[:, 0, :, :].unsqueeze(dim=1)
                sample1 = sample[:, 1, :, :].unsqueeze(dim=1)
                a = (sample1.data.cpu()).numpy()
                with torch.no_grad():
                    out, _ = generator(sample)

                utils.save_image(
                    torch.cat([sample0, sample1, label_sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                generator.train()
def train(epoch, loader, model, optimizer, scheduler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (frames, next_frames) in enumerate(loader):
        model.zero_grad()

        frames = frames.to(device)
        next_frames = next_frames.to(device)

        out, latent_loss = model(frames, next_frames)
        recon_loss = criterion(out, next_frames)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * frames.shape[0]
        part_mse_n = frames.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample_frames = frames[:sample_size]
                sample_next = next_frames[:sample_size]

                with torch.no_grad():
                    out, _ = model(sample_frames, sample_next)

                utils.save_image(
                    torch.cat([sample_frames, out], 0),
                    f"offset_sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
def train(epoch, loader, model, optimizer, scheduler, device):
    # if dist.is_primary():
    #     loader = tqdm(loader)
    loader = tqdm(loader)

    # criterion = nn.MSELoss()
    criterion = nn.L1Loss()
    gdloss = GDLoss()

    recon_loss_weight = 0.4
    latent_loss_weight = 0.2
    gradient_loss_weight = 0.4
    sample_size = batch_size

    mse_sum = 0
    mse_n = 0
    g_sum = 0
    g_n = 0

    for i, (img, label, label_path, class_name) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)
        label = label.to(device)
        gdloss.conv_x = gdloss.conv_x.to(device)
        gdloss.conv_y = gdloss.conv_y.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, label)
        gradient_loss = gdloss(out, label)
        gradient_loss = gradient_loss.mean()
        latent_loss = latent_loss.mean()
        loss = recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        part_g_sum = gradient_loss.item() * img.shape[0]
        part_g_n = img.shape[0]
        g_comm = {"g_sum": part_g_sum, "g_n": part_g_n}
        g_comm = dist.all_gather(g_comm)

        for part in g_comm:
            g_sum += part["g_sum"]
            g_n += part["g_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]
                label_sample = label[:sample_size]
                sample0 = sample[:, 0, :, :].unsqueeze(dim=1)
                sample1 = sample[:, 1, :, :].unsqueeze(dim=1)
                a = (sample1.data.cpu()).numpy()
                with torch.no_grad():
                    out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample0, sample1, label_sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
    def train_epoch(self, epoch):
        if dist.is_primary():
            loader = tqdm(self.dataloader)
        else:
            loader = self.dataloader

        criterion = nn.MSELoss()

        latent_loss_weight = 0.25
        sample_size = 25

        mse_sum = 0
        mse_n = 0

        for i, img in enumerate(loader):
            self.model.zero_grad()
            img = img.to(self.device)

            outputs = self.model(img)
            out, latent_loss = outputs[:2]
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss
            if self.args.fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backword()
            else:
                loss.backward()

            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = dist.all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            self.global_step += 1

            if dist.is_primary(
            ) and self.global_step % self.args.logging_steps == 0:
                print("global_step",
                      self.global_step,
                      "mse",
                      "{:.4g}".format(recon_loss.item()),
                      "latent",
                      "{:.4g}".format(latent_loss.item()),
                      "avg_mse",
                      "{:.4g}".format(mse_sum / mse_n),
                      "lr",
                      "{:.4g}".format(self.optimizer.param_groups[0]["lr"]),
                      file=sys.stderr,
                      flush=True)

            if dist.is_primary(
            ) and self.global_step % self.args.save_steps == 0:
                self.save_checkpoint()

            if dist.is_primary(
            ) and self.global_step % self.args.eval_steps == 0:
                self.model.eval()
                sample = img[:sample_size]
                with torch.no_grad():
                    out = self.model(sample)[0]
                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"{self.args.eval_path}/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )
                self.model.train()