def main(args): args.distributed = dist.get_world_size() > 1 transform = transforms.Compose( [ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader( dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2 ) model = load_model(args.checkpoint).to(DEVICE) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) evaluate(loader, model, args.out_path, args.sample_size)
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 normMean = [0.5] normStd = [0.5] normTransform = transforms.Normalize(normMean, normStd) transform = transforms.Compose([ transforms.Resize(args.size), transforms.ToTensor(), normTransform, ]) txt_path = 'datd/train.txt' images_path = '/data' labels_path = '/data' dataset = txtDataset(txt_path, images_path, labels_path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transforms = video_transforms.Compose([ RandomSelectFrames(16), video_transforms.Resize(args.size), video_transforms.CenterCrop(args.size), volume_transforms.ClipToTensor(), tensor_transforms.Normalize(0.5, 0.5) ]) f = open( '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt', 'rb') train_file_list = pickle.load(f) print(len(train_file_list)) dataset = MITDataset(train_file_list, transform=transforms) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) #loader = DataLoader( # dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2 #) loader = DataLoader(dataset, batch_size=32 // args.n_gpu, sampler=sampler, num_workers=2) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2) model = VQVAE().to(device) if args.load_path: load_state_dict = torch.load(args.load_path, map_location=device) model.load_state_dict(load_state_dict) print('successfully loaded model') if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def train(args, dataset, gen, dis, g_ema, device): if args.distributed: g_module = gen.module d_module = dis.module else: g_module = gen d_module = dis vgg = VGGFeature("vgg16", [4, 9, 16, 23, 30], use_fc=True).eval().to(device) requires_grad(vgg, False) g_optim = optim.Adam(gen.parameters(), lr=1e-4, betas=(0, 0.999)) d_optim = optim.Adam(dis.parameters(), lr=1e-4, betas=(0, 0.999)) loader = data.DataLoader( dataset, batch_size=args.batch, num_workers=4, sampler=dist.data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) loader_iter = sample_data(loader) pbar = range(args.start_iter, args.iter) if dist.get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True) eps = 1e-8 for i in pbar: real, class_id = next(loader_iter) real = real.to(device) class_id = class_id.to(device) masks = make_mask(real.shape[0], device, args.crop_prob) features, fcs = vgg(real) features = features + fcs[1:] requires_grad(dis, True) requires_grad(gen, False) real_pred = dis(real, class_id) z = torch.randn(args.batch, args.dim_z, device=device) fake = gen(z, class_id, features, masks) fake_pred = dis(fake, class_id) d_loss = d_ls_loss(real_pred, fake_pred) d_optim.zero_grad() d_loss.backward() d_optim.step() z1 = torch.randn(args.batch, args.dim_z, device=device) z2 = torch.randn(args.batch, args.dim_z, device=device) requires_grad(gen, True) requires_grad(dis, False) masks = make_mask(real.shape[0], device, args.crop_prob) if args.distributed: gen.broadcast_buffers = True fake1 = gen(z1, class_id, features, masks) if args.distributed: gen.broadcast_buffers = False fake2 = gen(z2, class_id, features, masks) fake_pred = dis(fake1, class_id) a_loss = g_ls_loss(None, fake_pred) features_fake, fcs_fake = vgg(fake1) features_fake = features_fake + fcs_fake[1:] r_loss = recon_loss(features_fake, features, masks) div_loss = diversity_loss(z1, z2, fake1, fake2, eps) g_loss = a_loss + args.rec_weight * r_loss + args.div_weight * div_loss g_optim.zero_grad() g_loss.backward() g_optim.step() accumulate(g_ema, g_module) if dist.get_rank() == 0: pbar.set_description( f"d: {d_loss.item():.4f}; g: {a_loss.item():.4f}; rec: {r_loss.item():.4f}; div: {div_loss.item():.4f}" ) if i % 100 == 0: utils.save_image( fake1, f"sample/{str(i).zfill(6)}.png", nrow=int(args.batch**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "args": args, "g_ema": g_ema.state_dict(), "g": g_module.state_dict(), "d": d_module.state_dict(), }, f"checkpoint/{str(i).zfill(6)}.pt", )
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 normMean = [0.5] normStd = [0.5] normTransform = transforms.Normalize(normMean, normStd) transform = transforms.Compose([ transforms.Resize(args.size), transforms.ToTensor(), normTransform, ]) txt_path = './data/train.txt' images_path = './data' labels_path = './data' dataset = txtDataset(txt_path, images_path, labels_path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16) # Initialize generator and discriminator DpretrainedPath = './checkpoint/vqvae2GAN_040.pt' GpretrainedPath = './checkpoint/vqvae_040.pt' discriminator = Discriminator() generator = Generator() if os.path.exists(DpretrainedPath): print('Loading model weights...') discriminator.load_state_dict( torch.load(DpretrainedPath)['discriminator']) print('done') if os.path.exists(GpretrainedPath): print('Loading model weights...') generator.load_state_dict(torch.load(GpretrainedPath)) print('done') discriminator = discriminator.to(device) generator = generator.to(device) if args.distributed: discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr) optimizer_G = optim.Adam(generator.parameters(), lr=args.lr) scheduler_D = None scheduler_G = None if args.sched == "cycle": scheduler_D = CycleScheduler( optimizer_D, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) scheduler_G = CycleScheduler( optimizer_G, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(41, args.epoch): train(i, loader, discriminator, generator, scheduler_D, scheduler_G, optimizer_D, optimizer_G, device) if dist.is_primary(): torch.save( { 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'g_optimizer': optimizer_G.state_dict(), 'd_optimizer': optimizer_D.state_dict(), }, f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt', ) if (i + 1) % n_critic == 0: torch.save(generator.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def train(cfg, logger): # Create save path prefix = cfg.DATA.NAME + "-" + cfg.DATA.SOURCE + '2' + cfg.DATA.TARGET save_path = os.path.join("results", prefix) if not os.path.exists(save_path): os.makedirs(save_path) suffix = "-".join([ item for item in [ "ls%d" % (cfg.LANGEVIN.STEP), "llr%.2f" % (cfg.LANGEVIN.LR), "lr%.4f" % (cfg.EBM.LR), "h%d" % (cfg.EBM.HIDDEN), "layer%d" % (cfg.EBM.LAYER), "opt%s" % (cfg.EBM.OPT), ] if item is not None ]) run_dir = _create_run_dir_local(save_path, suffix) _copy_dir(['translation'], run_dir) sys.stdout = Logger(os.path.join(run_dir, 'log.txt')) ae = load_ae(cfg, logger) device = 'cuda' transform = transforms.Compose([ transforms.RandomResizedCrop(2**cfg.DATASET.MAX_RESOLUTION_LEVEL, scale=[0.8, 1.0], ratio=[0.9, 1.1]), transforms.RandomHorizontalFlip(0.5), # transforms.Resize(256), # transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) data_root = os.path.join(cfg.DATA.ROOT, cfg.DATA.NAME) print(data_root) source_dataset = ImageFolder(os.path.join(data_root, 'train/' + cfg.DATA.SOURCE), transform=transform) source_sampler = dist.data_sampler(source_dataset, shuffle=True, distributed=False) source_loader = DataLoader(source_dataset, batch_size=cfg.DATA.BATCH, sampler=source_sampler, num_workers=1, drop_last=True) target_dataset = ImageFolder(os.path.join(data_root, 'train/' + cfg.DATA.TARGET), transform=transform) target_sampler = dist.data_sampler(target_dataset, shuffle=True, distributed=False) target_loader = DataLoader(target_dataset, batch_size=cfg.DATA.BATCH, sampler=target_sampler, num_workers=1, drop_last=True) source_iter = iter(source_loader) target_iter = iter(target_loader) latent_ebm = LatentEBM(latent_dim=512, n_layer=cfg.EBM.LAYER, n_hidden=cfg.EBM.HIDDEN).cuda() latent_ema = LatentEBM(latent_dim=512, n_layer=cfg.EBM.LAYER, n_hidden=cfg.EBM.HIDDEN).cuda() ema(latent_ema, latent_ebm, decay=0.) latent_optimizer = optim.SGD(latent_ebm.parameters(), lr=cfg.EBM.LR) if cfg.EBM.OPT == 'adam': latent_optimizer = optim.Adam(latent_ebm.parameters(), lr=cfg.EBM.LR) layer_count = cfg.MODEL.LAYER_COUNT used_sample = 0 iterations = -1 nrow = min(cfg.DATA.BATCH, 2) batch_size = cfg.DATA.BATCH # generate_recon(cfg=cfg, ae=ae, ebm=latent_ema, run_dir=run_dir, iteration=iterations, device=device) ebm_param = sum(p.numel() for p in latent_ebm.parameters()) ae_param = sum(p.numel() for p in ae.parameters()) print(ebm_param, ae_param) while used_sample < 10000000: iterations += 1 latent_ebm.zero_grad() latent_optimizer.zero_grad() try: source_img, target_img = next(source_iter).to(device), next( target_iter).to(device) except (OSError, StopIteration): source_iter = iter(source_loader) target_iter = iter(target_loader) source_img, target_img = next(source_iter).to(device), next( target_iter).to(device) source_latent, target_latent = encode(ae, source_img, cfg), encode( ae, target_img, cfg) source_latent = source_latent.squeeze() target_latent = target_latent.squeeze() requires_grad(latent_ebm, False) source_latent_q = langvin_sampler( latent_ebm, source_latent.clone().detach(), langevin_steps=cfg.LANGEVIN.STEP, lr=cfg.LANGEVIN.LR, ) requires_grad(latent_ebm, True) source_energy = latent_ebm(source_latent_q) target_energy = latent_ebm(target_latent) loss = -(target_energy - source_energy).mean() if abs(loss.item() > 10000): break loss.backward() latent_optimizer.step() ema(latent_ema, latent_ebm, decay=0.999) used_sample += batch_size # if iterations % 1000 == 0: test_image_folder(cfg=cfg, ae=ae, ebm=latent_ema, run_dir=run_dir, iteration=iterations, device=device) # test_representatives(cfg=cfg, ae=ae, ebm=latent_ema, run_dir=run_dir, iteration=iterations, device=device) torch.save(latent_ebm.state_dict(), f"{run_dir}/ebm_{str(iterations).zfill(6)}.pt") if iterations % 100 == 0: print(f'Iter: {iterations:06}, Loss: {loss:6.3f}') latents = langvin_sampler(latent_ema, source_latent[:nrow].clone().detach(), langevin_steps=cfg.LANGEVIN.STEP, lr=cfg.LANGEVIN.LR) with torch.no_grad(): latents = torch.cat((source_latent[:nrow], latents)) latents = latents.unsqueeze(1).repeat(1, ae.mapping_fl.num_layers, 1) out = decode(ae, latents, cfg) out = torch.cat((source_img[:nrow], out), dim=0) utils.save_image( out, f"{run_dir}/{str(iterations).zfill(6)}.png", nrow=nrow, normalize=True, padding=0, range=(-1, 1), )
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = OffsetDataset(args.path, transform=transform, offset=args.offset) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=args.bsize // args.n_gpu, sampler=sampler, num_workers=2) # Load pre-trained VQVAE vqvae = VQVAE().to(device) try: vqvae.load_state_dict(torch.load(args.ckpt)) except: print( "Seems the checkpoint was trained with data parallel, try loading it that way" ) weights = torch.load(args.ckpt) renamed_weights = {} for key, value in weights.items(): renamed_weights[key.replace('module.', '')] = value weights = renamed_weights vqvae.load_state_dict(weights) # Init offset encoder model = OffsetNetwork(vqvae).to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), find_unused_parameters=True) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
def run(self, args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = [transforms.ToTensor()] if args.normalize: transform.append( transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])) transform = transforms.Compose(transform) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist_fn.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=args.batch_size // args.n_gpu, sampler=sampler, num_workers=args.num_workers) self = self.to(device) if args.distributed: self = nn.parallel.DistributedDataParallel( self, device_ids=[dist_fn.get_local_rank()], output_device=dist_fn.get_local_rank()) optimizer = args.optimizer(self.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) start = str(time()) run_path = os.path.join('runs', start) sample_path = os.path.join(run_path, 'sample') checkpoint_path = os.path.join(run_path, 'checkpoint') os.mkdir(run_path) os.mkdir(sample_path) os.mkdir(checkpoint_path) with Progress() as progress: train = progress.add_task(f'epoch 1/{args.epoch}', total=args.epoch, columns='epochs') steps = progress.add_task('', total=len(dataset) // args.batch_size) for epoch in range(args.epoch): progress.update(steps, completed=0, refresh=True) for recon_loss, latent_loss, avg_mse, lr in self.train_epoch( epoch, loader, optimizer, scheduler, device, sample_path): progress.update( steps, description= f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}' ) progress.advance(steps) if dist_fn.is_primary(): torch.save( self.state_dict(), os.path.join(checkpoint_path, f'vqvae_{str(epoch + 1).zfill(3)}.pt')) progress.update(train, description=f'epoch {epoch + 1}/{args.epoch}') progress.advance(train)