Exemplo n.º 1
0
def main(args):
    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
    loader = DataLoader(
        dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    )

    model = load_model(args.checkpoint).to(DEVICE)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
    evaluate(loader, model, args.out_path, args.sample_size)
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = 'datd/train.txt'
    images_path = '/data'
    labels_path = '/data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 3
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transforms = video_transforms.Compose([
        RandomSelectFrames(16),
        video_transforms.Resize(args.size),
        video_transforms.CenterCrop(args.size),
        volume_transforms.ClipToTensor(),
        tensor_transforms.Normalize(0.5, 0.5)
    ])

    f = open(
        '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt',
        'rb')
    train_file_list = pickle.load(f)
    print(len(train_file_list))

    dataset = MITDataset(train_file_list, transform=transforms)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    #loader = DataLoader(
    #    dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    #)
    loader = DataLoader(dataset,
                        batch_size=32 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 4
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=128 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.load_path:
        load_state_dict = torch.load(args.load_path, map_location=device)
        model.load_state_dict(load_state_dict)
        print('successfully loaded model')

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 5
0
def train(args, dataset, gen, dis, g_ema, device):
    if args.distributed:
        g_module = gen.module
        d_module = dis.module

    else:
        g_module = gen
        d_module = dis

    vgg = VGGFeature("vgg16", [4, 9, 16, 23, 30],
                     use_fc=True).eval().to(device)
    requires_grad(vgg, False)

    g_optim = optim.Adam(gen.parameters(), lr=1e-4, betas=(0, 0.999))
    d_optim = optim.Adam(dis.parameters(), lr=1e-4, betas=(0, 0.999))

    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        num_workers=4,
        sampler=dist.data_sampler(dataset,
                                  shuffle=True,
                                  distributed=args.distributed),
        drop_last=True,
    )

    loader_iter = sample_data(loader)

    pbar = range(args.start_iter, args.iter)

    if dist.get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True)

    eps = 1e-8

    for i in pbar:
        real, class_id = next(loader_iter)

        real = real.to(device)
        class_id = class_id.to(device)

        masks = make_mask(real.shape[0], device, args.crop_prob)
        features, fcs = vgg(real)
        features = features + fcs[1:]

        requires_grad(dis, True)
        requires_grad(gen, False)

        real_pred = dis(real, class_id)

        z = torch.randn(args.batch, args.dim_z, device=device)

        fake = gen(z, class_id, features, masks)

        fake_pred = dis(fake, class_id)

        d_loss = d_ls_loss(real_pred, fake_pred)

        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()

        z1 = torch.randn(args.batch, args.dim_z, device=device)
        z2 = torch.randn(args.batch, args.dim_z, device=device)

        requires_grad(gen, True)
        requires_grad(dis, False)

        masks = make_mask(real.shape[0], device, args.crop_prob)

        if args.distributed:
            gen.broadcast_buffers = True

        fake1 = gen(z1, class_id, features, masks)

        if args.distributed:
            gen.broadcast_buffers = False

        fake2 = gen(z2, class_id, features, masks)

        fake_pred = dis(fake1, class_id)

        a_loss = g_ls_loss(None, fake_pred)

        features_fake, fcs_fake = vgg(fake1)
        features_fake = features_fake + fcs_fake[1:]

        r_loss = recon_loss(features_fake, features, masks)
        div_loss = diversity_loss(z1, z2, fake1, fake2, eps)

        g_loss = a_loss + args.rec_weight * r_loss + args.div_weight * div_loss

        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        accumulate(g_ema, g_module)

        if dist.get_rank() == 0:
            pbar.set_description(
                f"d: {d_loss.item():.4f}; g: {a_loss.item():.4f}; rec: {r_loss.item():.4f}; div: {div_loss.item():.4f}"
            )

            if i % 100 == 0:
                utils.save_image(
                    fake1,
                    f"sample/{str(i).zfill(6)}.png",
                    nrow=int(args.batch**0.5),
                    normalize=True,
                    range=(-1, 1),
                )

            if i % 10000 == 0:
                torch.save(
                    {
                        "args": args,
                        "g_ema": g_ema.state_dict(),
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = './data/train.txt'
    images_path = './data'
    labels_path = './data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    # Initialize generator and discriminator
    DpretrainedPath = './checkpoint/vqvae2GAN_040.pt'
    GpretrainedPath = './checkpoint/vqvae_040.pt'

    discriminator = Discriminator()
    generator = Generator()
    if os.path.exists(DpretrainedPath):
        print('Loading model weights...')
        discriminator.load_state_dict(
            torch.load(DpretrainedPath)['discriminator'])
        print('done')
    if os.path.exists(GpretrainedPath):
        print('Loading model weights...')
        generator.load_state_dict(torch.load(GpretrainedPath))
        print('done')

    discriminator = discriminator.to(device)
    generator = generator.to(device)

    if args.distributed:
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr)
    optimizer_G = optim.Adam(generator.parameters(), lr=args.lr)
    scheduler_D = None
    scheduler_G = None
    if args.sched == "cycle":
        scheduler_D = CycleScheduler(
            optimizer_D,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

        scheduler_G = CycleScheduler(
            optimizer_G,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(41, args.epoch):
        train(i, loader, discriminator, generator, scheduler_D, scheduler_G,
              optimizer_D, optimizer_G, device)

        if dist.is_primary():
            torch.save(
                {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'g_optimizer': optimizer_G.state_dict(),
                    'd_optimizer': optimizer_D.state_dict(),
                },
                f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt',
            )
            if (i + 1) % n_critic == 0:
                torch.save(generator.state_dict(),
                           f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 7
0
def train(cfg, logger):
    # Create save path
    prefix = cfg.DATA.NAME + "-" + cfg.DATA.SOURCE + '2' + cfg.DATA.TARGET
    save_path = os.path.join("results", prefix)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    suffix = "-".join([
        item for item in [
            "ls%d" % (cfg.LANGEVIN.STEP),
            "llr%.2f" % (cfg.LANGEVIN.LR),
            "lr%.4f" % (cfg.EBM.LR),
            "h%d" % (cfg.EBM.HIDDEN),
            "layer%d" % (cfg.EBM.LAYER),
            "opt%s" % (cfg.EBM.OPT),
        ] if item is not None
    ])
    run_dir = _create_run_dir_local(save_path, suffix)
    _copy_dir(['translation'], run_dir)
    sys.stdout = Logger(os.path.join(run_dir, 'log.txt'))

    ae = load_ae(cfg, logger)

    device = 'cuda'
    transform = transforms.Compose([
        transforms.RandomResizedCrop(2**cfg.DATASET.MAX_RESOLUTION_LEVEL,
                                     scale=[0.8, 1.0],
                                     ratio=[0.9, 1.1]),
        transforms.RandomHorizontalFlip(0.5),
        # transforms.Resize(256),
        # transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    data_root = os.path.join(cfg.DATA.ROOT, cfg.DATA.NAME)
    print(data_root)
    source_dataset = ImageFolder(os.path.join(data_root,
                                              'train/' + cfg.DATA.SOURCE),
                                 transform=transform)
    source_sampler = dist.data_sampler(source_dataset,
                                       shuffle=True,
                                       distributed=False)
    source_loader = DataLoader(source_dataset,
                               batch_size=cfg.DATA.BATCH,
                               sampler=source_sampler,
                               num_workers=1,
                               drop_last=True)
    target_dataset = ImageFolder(os.path.join(data_root,
                                              'train/' + cfg.DATA.TARGET),
                                 transform=transform)
    target_sampler = dist.data_sampler(target_dataset,
                                       shuffle=True,
                                       distributed=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=cfg.DATA.BATCH,
                               sampler=target_sampler,
                               num_workers=1,
                               drop_last=True)
    source_iter = iter(source_loader)
    target_iter = iter(target_loader)

    latent_ebm = LatentEBM(latent_dim=512,
                           n_layer=cfg.EBM.LAYER,
                           n_hidden=cfg.EBM.HIDDEN).cuda()

    latent_ema = LatentEBM(latent_dim=512,
                           n_layer=cfg.EBM.LAYER,
                           n_hidden=cfg.EBM.HIDDEN).cuda()
    ema(latent_ema, latent_ebm, decay=0.)

    latent_optimizer = optim.SGD(latent_ebm.parameters(), lr=cfg.EBM.LR)
    if cfg.EBM.OPT == 'adam':
        latent_optimizer = optim.Adam(latent_ebm.parameters(), lr=cfg.EBM.LR)

    layer_count = cfg.MODEL.LAYER_COUNT
    used_sample = 0
    iterations = -1
    nrow = min(cfg.DATA.BATCH, 2)
    batch_size = cfg.DATA.BATCH

    # generate_recon(cfg=cfg, ae=ae, ebm=latent_ema, run_dir=run_dir, iteration=iterations, device=device)

    ebm_param = sum(p.numel() for p in latent_ebm.parameters())
    ae_param = sum(p.numel() for p in ae.parameters())
    print(ebm_param, ae_param)
    while used_sample < 10000000:
        iterations += 1
        latent_ebm.zero_grad()
        latent_optimizer.zero_grad()

        try:
            source_img, target_img = next(source_iter).to(device), next(
                target_iter).to(device)
        except (OSError, StopIteration):
            source_iter = iter(source_loader)
            target_iter = iter(target_loader)
            source_img, target_img = next(source_iter).to(device), next(
                target_iter).to(device)

        source_latent, target_latent = encode(ae, source_img, cfg), encode(
            ae, target_img, cfg)
        source_latent = source_latent.squeeze()
        target_latent = target_latent.squeeze()

        requires_grad(latent_ebm, False)
        source_latent_q = langvin_sampler(
            latent_ebm,
            source_latent.clone().detach(),
            langevin_steps=cfg.LANGEVIN.STEP,
            lr=cfg.LANGEVIN.LR,
        )

        requires_grad(latent_ebm, True)
        source_energy = latent_ebm(source_latent_q)
        target_energy = latent_ebm(target_latent)
        loss = -(target_energy - source_energy).mean()

        if abs(loss.item() > 10000):
            break
        loss.backward()
        latent_optimizer.step()

        ema(latent_ema, latent_ebm, decay=0.999)

        used_sample += batch_size
        #
        if iterations % 1000 == 0:
            test_image_folder(cfg=cfg,
                              ae=ae,
                              ebm=latent_ema,
                              run_dir=run_dir,
                              iteration=iterations,
                              device=device)
            # test_representatives(cfg=cfg, ae=ae, ebm=latent_ema, run_dir=run_dir, iteration=iterations, device=device)
            torch.save(latent_ebm.state_dict(),
                       f"{run_dir}/ebm_{str(iterations).zfill(6)}.pt")

        if iterations % 100 == 0:
            print(f'Iter: {iterations:06}, Loss: {loss:6.3f}')

            latents = langvin_sampler(latent_ema,
                                      source_latent[:nrow].clone().detach(),
                                      langevin_steps=cfg.LANGEVIN.STEP,
                                      lr=cfg.LANGEVIN.LR)

            with torch.no_grad():
                latents = torch.cat((source_latent[:nrow], latents))
                latents = latents.unsqueeze(1).repeat(1,
                                                      ae.mapping_fl.num_layers,
                                                      1)

                out = decode(ae, latents, cfg)

                out = torch.cat((source_img[:nrow], out), dim=0)
                utils.save_image(
                    out,
                    f"{run_dir}/{str(iterations).zfill(6)}.png",
                    nrow=nrow,
                    normalize=True,
                    padding=0,
                    range=(-1, 1),
                )
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = OffsetDataset(args.path, transform=transform, offset=args.offset)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=args.bsize // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    # Load pre-trained VQVAE
    vqvae = VQVAE().to(device)
    try:
        vqvae.load_state_dict(torch.load(args.ckpt))
    except:
        print(
            "Seems the checkpoint was trained with data parallel, try loading it that way"
        )
        weights = torch.load(args.ckpt)
        renamed_weights = {}
        for key, value in weights.items():
            renamed_weights[key.replace('module.', '')] = value
        weights = renamed_weights
        vqvae.load_state_dict(weights)

    # Init offset encoder
    model = OffsetNetwork(vqvae).to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 9
0
    def run(self, args):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        transform = [transforms.ToTensor()]

        if args.normalize:
            transform.append(
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))

        transform = transforms.Compose(transform)

        dataset = datasets.ImageFolder(args.path, transform=transform)
        sampler = dist_fn.data_sampler(dataset,
                                       shuffle=True,
                                       distributed=args.distributed)
        loader = DataLoader(dataset,
                            batch_size=args.batch_size // args.n_gpu,
                            sampler=sampler,
                            num_workers=args.num_workers)

        self = self.to(device)

        if args.distributed:
            self = nn.parallel.DistributedDataParallel(
                self,
                device_ids=[dist_fn.get_local_rank()],
                output_device=dist_fn.get_local_rank())

        optimizer = args.optimizer(self.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == 'cycle':
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch,
                momentum=None,
                warmup_proportion=0.05,
            )

        start = str(time())
        run_path = os.path.join('runs', start)
        sample_path = os.path.join(run_path, 'sample')
        checkpoint_path = os.path.join(run_path, 'checkpoint')
        os.mkdir(run_path)
        os.mkdir(sample_path)
        os.mkdir(checkpoint_path)

        with Progress() as progress:
            train = progress.add_task(f'epoch 1/{args.epoch}',
                                      total=args.epoch,
                                      columns='epochs')
            steps = progress.add_task('',
                                      total=len(dataset) // args.batch_size)

            for epoch in range(args.epoch):
                progress.update(steps, completed=0, refresh=True)

                for recon_loss, latent_loss, avg_mse, lr in self.train_epoch(
                        epoch, loader, optimizer, scheduler, device,
                        sample_path):
                    progress.update(
                        steps,
                        description=
                        f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}'
                    )
                    progress.advance(steps)

                if dist_fn.is_primary():
                    torch.save(
                        self.state_dict(),
                        os.path.join(checkpoint_path,
                                     f'vqvae_{str(epoch + 1).zfill(3)}.pt'))

                progress.update(train,
                                description=f'epoch {epoch + 1}/{args.epoch}')
                progress.advance(train)