Ejemplo n.º 1
0
def train(epoch, loader, model, optimizer, scheduler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0
    lr = optimizer.param_groups[0]["lr"]

    for i, (img, _, _) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        loader.set_postfix_str(
            f'Step: {i + 1}: MSE: {recon_loss.item():.5f}; Latent: {latent_loss.item():.3f}; Total: {loss.item():.5f}'
        )
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = 'datd/train.txt'
    images_path = '/data'
    labels_path = '/data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 3
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transforms = video_transforms.Compose([
        RandomSelectFrames(16),
        video_transforms.Resize(args.size),
        video_transforms.CenterCrop(args.size),
        volume_transforms.ClipToTensor(),
        tensor_transforms.Normalize(0.5, 0.5)
    ])

    f = open(
        '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt',
        'rb')
    train_file_list = pickle.load(f)
    print(len(train_file_list))

    dataset = MITDataset(train_file_list, transform=transforms)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    #loader = DataLoader(
    #    dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    #)
    loader = DataLoader(dataset,
                        batch_size=32 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 4
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=128 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.load_path:
        load_state_dict = torch.load(args.load_path, map_location=device)
        model.load_state_dict(load_state_dict)
        print('successfully loaded model')

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 5
0
def evaluate(loader, model, out_path, sample_size):
    if dist.is_primary():
        loader = tqdm(loader)

    model.eval()

    i, (img, label) = next(enumerate(loader))

    sample = img[:sample_size]

    with torch.no_grad():
        out, _ = model(sample)

    utils.save_image(
        torch.cat([sample, out], 0),
        out_path,
        nrow=sample_size,
        normalize=True,
        range=(-1, 1),
    )
Ejemplo n.º 6
0
def train(epoch, loader, model, optimizer, scheduler, scaler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        with torch.cuda.amp.autocast(scaler.is_enabled()):
            out, latent_loss = model(img)
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss
        scaler.scale(loss).backward()

        if scheduler is not None:
            scheduler.step()
        scaler.step(optimizer)
        scaler.update()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                with torch.no_grad(), torch.cuda.amp.autocast(
                        scaler.is_enabled()):
                    out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
Ejemplo n.º 7
0
def train(epoch, loader, model, optimizer, scheduler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        wandb.log({'train loss': loss.item()})

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                with torch.no_grad():
                    out, _ = model(sample)

                # utils.save_image(
                #     torch.cat([sample, out], 0),
                #     f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                #     nrow=sample_size,
                #     normalize=True,
                #     range=(-1, 1),
                # )

                example_images = [
                    wandb.Image(image, caption=f"{epoch}_{i}") for image in out
                ]
                wandb.log({"Examples": example_images})

                model.train()
def train(epoch, loader, discriminator, generator, scheduler_D, scheduler_G,
          optimizer_D, optimizer_G, device):
    loader_d = tqdm(loader)
    if (epoch + 1) % n_critic == 0:
        loader_g = tqdm(loader)

    adversarial_loss = nn.BCEWithLogitsLoss()  # sigmoid
    pixelwise_loss = nn.L1Loss()
    gdloss = GDLoss()

    recon_loss_weight = 0.4
    latent_loss_weight = 0.2
    gradient_loss_weight = 0.4
    sample_size = batch_size

    mse_sum = 0
    mse_n = 0
    g_sum = 0
    g_n = 0

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    # ---------------------
    #  Train Discriminator
    # ---------------------
    for i, (img, label, label_path, class_name) in enumerate(loader_d):
        discriminator.zero_grad()

        valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0),
                         requires_grad=False)
        fake = Variable(torch.Tensor(img.shape[0], 1).fill_(0.0),
                        requires_grad=False)

        img = img.to(device)
        valid = valid.to(device)
        fake = fake.to(device)
        label = label.to(device)

        gdloss.conv_x = gdloss.conv_x.to(device)
        gdloss.conv_y = gdloss.conv_y.to(device)

        vqvae2_out, latent_loss = generator(img)

        real_loss = adversarial_loss(discriminator(label), valid)
        fake_loss = adversarial_loss(discriminator(vqvae2_out), fake)

        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()

        if scheduler_D is not None:
            scheduler_D.step()
        optimizer_D.step()

        if dist.is_primary():
            lr = optimizer_D.param_groups[0]["lr"]

            loader_d.set_description((
                f"Discriminator epoch: {epoch + 1}; class loss: {d_loss.item():.5f};"
                f"lr: {lr:.5f}"))

    # ---------------------
    #  Train Generator
    # ---------------------
    if (epoch + 1) % n_critic == 0:
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        for i, (img, label, label_path, class_name) in enumerate(loader_g):
            generator.zero_grad()

            valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0),
                             requires_grad=False)

            img = img.to(device)
            valid = valid.to(device)
            label = label.to(device)

            gdloss.conv_x = gdloss.conv_x.to(device)
            gdloss.conv_y = gdloss.conv_y.to(device)

            vqvae2_out, latent_loss = generator(img)

            recon_loss = pixelwise_loss(vqvae2_out, label)
            gradient_loss = gdloss(vqvae2_out, label)
            gradient_loss = gradient_loss.mean()
            latent_loss = latent_loss.mean()
            g_loss = 0.1 * adversarial_loss(discriminator(vqvae2_out), valid) + \
                     0.9 * (recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss)

            g_loss.backward()

            if scheduler_G is not None:
                scheduler_G.step()
            optimizer_G.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = dist.all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            part_g_sum = gradient_loss.item() * img.shape[0]
            part_g_n = img.shape[0]
            g_comm = {"g_sum": part_g_sum, "g_n": part_g_n}
            g_comm = dist.all_gather(g_comm)

            for part in g_comm:
                g_sum += part["g_sum"]
                g_n += part["g_n"]

            if dist.is_primary():
                lr = optimizer_G.param_groups[0]["lr"]

                loader_g.set_description((
                    f"Denerator epoch: {(epoch + 1) // n_critic + 1}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"))

            if i % 100 == 0:
                generator.eval()

                sample = img[:sample_size]
                label_sample = label[:sample_size]
                sample0 = sample[:, 0, :, :].unsqueeze(dim=1)
                sample1 = sample[:, 1, :, :].unsqueeze(dim=1)
                a = (sample1.data.cpu()).numpy()
                with torch.no_grad():
                    out, _ = generator(sample)

                utils.save_image(
                    torch.cat([sample0, sample1, label_sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                generator.train()
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = './data/train.txt'
    images_path = './data'
    labels_path = './data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    # Initialize generator and discriminator
    DpretrainedPath = './checkpoint/vqvae2GAN_040.pt'
    GpretrainedPath = './checkpoint/vqvae_040.pt'

    discriminator = Discriminator()
    generator = Generator()
    if os.path.exists(DpretrainedPath):
        print('Loading model weights...')
        discriminator.load_state_dict(
            torch.load(DpretrainedPath)['discriminator'])
        print('done')
    if os.path.exists(GpretrainedPath):
        print('Loading model weights...')
        generator.load_state_dict(torch.load(GpretrainedPath))
        print('done')

    discriminator = discriminator.to(device)
    generator = generator.to(device)

    if args.distributed:
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr)
    optimizer_G = optim.Adam(generator.parameters(), lr=args.lr)
    scheduler_D = None
    scheduler_G = None
    if args.sched == "cycle":
        scheduler_D = CycleScheduler(
            optimizer_D,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

        scheduler_G = CycleScheduler(
            optimizer_G,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(41, args.epoch):
        train(i, loader, discriminator, generator, scheduler_D, scheduler_G,
              optimizer_D, optimizer_G, device)

        if dist.is_primary():
            torch.save(
                {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'g_optimizer': optimizer_G.state_dict(),
                    'd_optimizer': optimizer_D.state_dict(),
                },
                f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt',
            )
            if (i + 1) % n_critic == 0:
                torch.save(generator.state_dict(),
                           f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = OffsetDataset(args.path, transform=transform, offset=args.offset)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=args.bsize // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    # Load pre-trained VQVAE
    vqvae = VQVAE().to(device)
    try:
        vqvae.load_state_dict(torch.load(args.ckpt))
    except:
        print(
            "Seems the checkpoint was trained with data parallel, try loading it that way"
        )
        weights = torch.load(args.ckpt)
        renamed_weights = {}
        for key, value in weights.items():
            renamed_weights[key.replace('module.', '')] = value
        weights = renamed_weights
        vqvae.load_state_dict(weights)

    # Init offset encoder
    model = OffsetNetwork(vqvae).to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
def train(epoch, loader, model, optimizer, scheduler, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (frames, next_frames) in enumerate(loader):
        model.zero_grad()

        frames = frames.to(device)
        next_frames = next_frames.to(device)

        out, latent_loss = model(frames, next_frames)
        recon_loss = criterion(out, next_frames)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * frames.shape[0]
        part_mse_n = frames.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample_frames = frames[:sample_size]
                sample_next = next_frames[:sample_size]

                with torch.no_grad():
                    out, _ = model(sample_frames, sample_next)

                utils.save_image(
                    torch.cat([sample_frames, out], 0),
                    f"offset_sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
def train(epoch, loader, model, optimizer, scheduler, device):
    # if dist.is_primary():
    #     loader = tqdm(loader)
    loader = tqdm(loader)

    # criterion = nn.MSELoss()
    criterion = nn.L1Loss()
    gdloss = GDLoss()

    recon_loss_weight = 0.4
    latent_loss_weight = 0.2
    gradient_loss_weight = 0.4
    sample_size = batch_size

    mse_sum = 0
    mse_n = 0
    g_sum = 0
    g_n = 0

    for i, (img, label, label_path, class_name) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)
        label = label.to(device)
        gdloss.conv_x = gdloss.conv_x.to(device)
        gdloss.conv_y = gdloss.conv_y.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, label)
        gradient_loss = gdloss(out, label)
        gradient_loss = gradient_loss.mean()
        latent_loss = latent_loss.mean()
        loss = recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        part_g_sum = gradient_loss.item() * img.shape[0]
        part_g_n = img.shape[0]
        g_comm = {"g_sum": part_g_sum, "g_n": part_g_n}
        g_comm = dist.all_gather(g_comm)

        for part in g_comm:
            g_sum += part["g_sum"]
            g_n += part["g_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description((
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"))

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]
                label_sample = label[:sample_size]
                sample0 = sample[:, 0, :, :].unsqueeze(dim=1)
                sample1 = sample[:, 1, :, :].unsqueeze(dim=1)
                a = (sample1.data.cpu()).numpy()
                with torch.no_grad():
                    out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample0, sample1, label_sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
Ejemplo n.º 13
0
    def train_epoch(self, epoch):
        if dist.is_primary():
            loader = tqdm(self.dataloader)
        else:
            loader = self.dataloader

        criterion = nn.MSELoss()

        latent_loss_weight = 0.25
        sample_size = 25

        mse_sum = 0
        mse_n = 0

        for i, img in enumerate(loader):
            self.model.zero_grad()
            img = img.to(self.device)

            outputs = self.model(img)
            out, latent_loss = outputs[:2]
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss
            if self.args.fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backword()
            else:
                loss.backward()

            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = dist.all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            self.global_step += 1

            if dist.is_primary(
            ) and self.global_step % self.args.logging_steps == 0:
                print("global_step",
                      self.global_step,
                      "mse",
                      "{:.4g}".format(recon_loss.item()),
                      "latent",
                      "{:.4g}".format(latent_loss.item()),
                      "avg_mse",
                      "{:.4g}".format(mse_sum / mse_n),
                      "lr",
                      "{:.4g}".format(self.optimizer.param_groups[0]["lr"]),
                      file=sys.stderr,
                      flush=True)

            if dist.is_primary(
            ) and self.global_step % self.args.save_steps == 0:
                self.save_checkpoint()

            if dist.is_primary(
            ) and self.global_step % self.args.eval_steps == 0:
                self.model.eval()
                sample = img[:sample_size]
                with torch.no_grad():
                    out = self.model(sample)[0]
                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"{self.args.eval_path}/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )
                self.model.train()
Ejemplo n.º 14
0
    def run(self, args):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        transform = [transforms.ToTensor()]

        if args.normalize:
            transform.append(
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))

        transform = transforms.Compose(transform)

        dataset = datasets.ImageFolder(args.path, transform=transform)
        sampler = dist_fn.data_sampler(dataset,
                                       shuffle=True,
                                       distributed=args.distributed)
        loader = DataLoader(dataset,
                            batch_size=args.batch_size // args.n_gpu,
                            sampler=sampler,
                            num_workers=args.num_workers)

        self = self.to(device)

        if args.distributed:
            self = nn.parallel.DistributedDataParallel(
                self,
                device_ids=[dist_fn.get_local_rank()],
                output_device=dist_fn.get_local_rank())

        optimizer = args.optimizer(self.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == 'cycle':
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch,
                momentum=None,
                warmup_proportion=0.05,
            )

        start = str(time())
        run_path = os.path.join('runs', start)
        sample_path = os.path.join(run_path, 'sample')
        checkpoint_path = os.path.join(run_path, 'checkpoint')
        os.mkdir(run_path)
        os.mkdir(sample_path)
        os.mkdir(checkpoint_path)

        with Progress() as progress:
            train = progress.add_task(f'epoch 1/{args.epoch}',
                                      total=args.epoch,
                                      columns='epochs')
            steps = progress.add_task('',
                                      total=len(dataset) // args.batch_size)

            for epoch in range(args.epoch):
                progress.update(steps, completed=0, refresh=True)

                for recon_loss, latent_loss, avg_mse, lr in self.train_epoch(
                        epoch, loader, optimizer, scheduler, device,
                        sample_path):
                    progress.update(
                        steps,
                        description=
                        f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}'
                    )
                    progress.advance(steps)

                if dist_fn.is_primary():
                    torch.save(
                        self.state_dict(),
                        os.path.join(checkpoint_path,
                                     f'vqvae_{str(epoch + 1).zfill(3)}.pt'))

                progress.update(train,
                                description=f'epoch {epoch + 1}/{args.epoch}')
                progress.advance(train)