def __init__(self, config): super(ImGANTrainer, self).__init__() self.netG_ab = Generator(config) self.netG_ba = Generator(config) self.netD_ab = PatchD(config['input_nc'], config['ndf']) self.netD_ba = PatchD(config['input_nc'], config['ndf']) self.optimizer_g = torch.optim.Adam(itertools.chain( self.netG_ab.parameters(), self.netG_ba.parameters()), lr=config['lr'], betas=(config['beta1'], 0.999)) self.optimizer_d = torch.optim.Adam(itertools.chain( self.netD_ab.parameters(), self.netD_ba.parameters()), lr=config['lr'], betas=(config['beta1'], 0.999)) # criterion self.criteritionGAN = nn.BCELoss() self.criteritioL1 = nn.L1Loss() self.criteritiommd = MMDLoss() # labels self.real_label = 1. self.fake_label = 0. # losses self.loss_names = [ 'loss_D', 'loss_G', 'loss_cycle_aba', 'loss_cycle_bab', 'loss_mmd' ]
def __init__(self, config): super(Trainer, self).__init__() self.config = config self.use_cuda = self.config['cuda'] self.device_ids = self.config['gpu_ids'] self.netG = Generator(self.config['netG'], self.use_cuda, self.device_ids) self.localD = LocalDis(self.config['netD'], self.use_cuda, self.device_ids) self.globalD = GlobalDis(self.config['netD'], self.use_cuda, self.device_ids) self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=self.config['lr'], betas=(self.config['beta1'], self.config['beta2'])) d_params = list(self.localD.parameters()) + list( self.globalD.parameters()) self.optimizer_d = torch.optim.Adam(d_params, lr=config['lr'], betas=(self.config['beta1'], self.config['beta2'])) if self.use_cuda: self.netG.to(self.device_ids[0]) self.localD.to(self.device_ids[0]) self.globalD.to(self.device_ids[0])
def main(args): device = torch.device("cuda:0") G = Generator().to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) with torch.no_grad(): G.eval() batch_size = args.batch_size n_epoch = args.n // batch_size + 1 for epoch in tqdm(range(n_epoch)): bs = min(batch_size, args.n - epoch * batch_size) za = torch.randn(bs, args.d_za, 1, 1, 1).to(device) zm = torch.randn(bs, args.d_zm, 1, 1, 1).to(device) vid_fake = G(za, zm) vid_fake = vid_fake.transpose(2, 1) # bs x 16 x 3 x 64 x 64 vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data # save into videos save_videos(args.gen_path, vid_fake, epoch, bs) return
def __init__(self, opts, nc_in=5, nc_out=3, d_s_args={}, d_t_args={}): super().__init__() self.d_t_args = { "nf": 32, "use_sigmoid": True, "norm": 'SN' } # default values for key, value in d_t_args.items(): # overwrite default values if provided self.d_t_args[key] = value self.d_s_args = { "nf": 32, "use_sigmoid": True, "norm": 'SN' } # default values for key, value in d_s_args.items(): # overwrite default values if provided self.d_s_args[key] = value nf = opts['nf'] norm = opts['norm'] use_bias = opts['bias'] # warning: if 2d convolution is used in generator, settings (e.g. stride, # kernal_size, padding) on the temporal axis will be discarded self.conv_by = opts['conv_by'] if 'conv_by' in opts else '3d' self.conv_type = opts['conv_type'] if 'conv_type' in opts else 'gated' self.use_refine = opts['use_refine'] if 'use_refine' in opts else False use_skip_connection = opts.get('use_skip_connection', False) self.opts = opts ###################### # Convolution layers # ###################### self.generator = Generator(nc_in, nc_out, nf, use_bias, norm, self.conv_by, self.conv_type, use_refine=self.use_refine, use_skip_connection=use_skip_connection) ################# # Discriminator # ################# if 'spatial_discriminator' not in opts or opts['spatial_discriminator']: self.spatial_discriminator = SNTemporalPatchGANDiscriminator( nc_in=5, conv_type='2d', **self.d_s_args) if 'temporal_discriminator' not in opts or opts[ 'temporal_discriminator']: self.temporal_discriminator = SNTemporalPatchGANDiscriminator( nc_in=5, **self.d_t_args)
def main(args): # write into tensorboard log_path = os.path.join('demos', args.dataset + '/log') vid_path = os.path.join('demos', args.dataset + '/vids') os.makedirs(log_path, exist_ok=True) os.makedirs(vid_path, exist_ok=True) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) transform = torchvision.transforms.Compose([ transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset = MUG_test(args.data_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) with torch.no_grad(): G.eval() img = next(iter(dataloader)) bs = img.size(0) nclasses = args.nclasses z = torch.randn(bs, args.dim_z).to(device) for i in range(nclasses): y = torch.zeros(bs, nclasses).to(device) y[:, i] = 1.0 vid_gen = G(img, z, y) vid_gen = vid_gen.transpose(2, 1) vid_gen = ((vid_gen - vid_gen.min()) / (vid_gen.max() - vid_gen.min())).data writer.add_video(tag='vid_cat_%d' % i, vid_tensor=vid_gen) writer.flush() # save videos print('==> saving videos') save_videos(vid_path, vid_gen, bs, i)
def main(): args = cfg.parse_args() # write into tensorboard log_path = os.path.join(args.demo_path, args.demo_name + '/log') vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') if not os.path.exists(log_path) and not os.path.exists(vid_path): os.makedirs(log_path) os.makedirs(vid_path) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator().to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) with torch.no_grad(): G.eval() za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device) n_za = za.size(0) n_zm = zm.size(0) za = za.unsqueeze(1).repeat(1, n_zm, 1, 1, 1, 1).contiguous().view( n_za * n_zm, -1, 1, 1, 1) zm = zm.repeat(n_za, 1, 1, 1, 1) vid_fake = G(za, zm) vid_fake = vid_fake.transpose(2, 1) # bs x 16 x 3 x 64 x 64 vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data writer.add_video(tag='generated_videos', global_step=1, vid_tensor=vid_fake) writer.flush() # save into videos print('==> saving videos...') save_videos(vid_path, vid_fake, n_za, n_zm) return
def loadGenerator(args): config = get_config(args.g_config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join( 'checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids).cuda() # Resume weight last_model_name = get_model_list(checkpoint_path, "gen", iteration=args.iter) model_iteration = int(last_model_name[-11:-3]) netG.load_state_dict(torch.load(last_model_name)) print("Configuration: {}".format(config)) print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration)) if cuda: netG = nn.parallel.DataParallel(netG, device_ids=device_ids) return netG
def main(): args = cfg.parse_args() # write into tensorboard log_path = os.path.join(args.demo_path, args.demo_name + '/log') vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') if not os.path.exists(log_path) and not os.path.exists(vid_path): os.makedirs(log_path) os.makedirs(vid_path) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator().to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) with torch.no_grad(): G.eval() za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) # appearance # generating frames from [16, 20, 24, 28, 32, 36, 40, 44, 48] for i in range(9): zm = torch.randn(args.n_zm_test, args.d_zm, (i+1), 1, 1).to(device) # 16+i*4 vid_fake = G(za, zm) vid_fake = vid_fake.transpose(2,1) vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data writer.add_video(tag='generated_videos_%dframes'%(16+i*4), global_step=1, vid_tensor=vid_fake) writer.flush() print('saving videos') save_videos(vid_path, vid_fake, args.n_za_test, (16+i*4)) return
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) print(args) # create logging folder log_path = os.path.join(args.save_path, args.exp_name + '/log') model_path = os.path.join(args.save_path, args.exp_name + '/models') os.makedirs(log_path, exist_ok=True) os.makedirs(model_path, exist_ok=True) writer = SummaryWriter(log_path) # tensorboard # load model print('==> loading models') device = torch.device("cuda:0") G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device) VD = VideoDiscriminator(args.nclasses, args.ch).to(device) ID = ImageDiscriminator(args.ch).to(device) G = nn.DataParallel(G) VD = nn.DataParallel(VD) ID = nn.DataParallel(ID) # optimizer optimizer_G = torch.optim.Adam(G.parameters(), args.g_lr, (0.5, 0.999)) optimizer_VD = torch.optim.Adam(VD.parameters(), args.d_lr, (0.5, 0.999)) optimizer_ID = torch.optim.Adam(ID.parameters(), args.d_lr, (0.5, 0.999)) # loss criterion_gan = nn.BCEWithLogitsLoss().to(device) criterion_l1 = nn.L1Loss().to(device) # prepare dataset print('==> preparing dataset') transform = torchvision.transforms.Compose([ transforms_vid.ClipResize((args.img_size, args.img_size)), transforms_vid.ClipToTensor(), transforms_vid.ClipNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) transform_test = torchvision.transforms.Compose([ transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if args.dataset == 'mug': dataset_train = MUG('train', args.data_path, transform=transform) dataset_val = MUG('val', args.data_path, transform=transform) dataset_test = MUG_test(args.data_path, transform=transform_test) else: raise NotImplementedError dataloader_train = torch.utils.data.DataLoader( dataset=dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True) dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) dataloader_test = torch.utils.data.DataLoader( dataset=dataset_test, batch_size=args.batch_size_test, num_workers=args.num_workers, shuffle=False, pin_memory=True) print('==> start training') for epoch in range(args.max_epoch): train(args, epoch, G, VD, ID, optimizer_G, optimizer_VD, optimizer_ID, criterion_gan, criterion_l1, dataloader_train, writer, device) if epoch % args.val_freq == 0: val(args, epoch, G, criterion_l1, dataloader_val, device, writer) test(args, epoch, G, dataloader_test, device, writer) if epoch % args.save_freq == 0: torch.save(G.state_dict(), os.path.join(model_path, 'G_%d.pth' % (epoch))) torch.save(VD.state_dict(), os.path.join(model_path, 'VD_%d.pth' % (epoch))) torch.save(ID.state_dict(), os.path.join(model_path, 'ID_%d.pth' % (epoch))) return
def main(): args = parser.parse_args() config = get_config(args.config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True print("Arguments: {}".format(args)) # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) print("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) print("Configuration: {}".format(config)) try: # for unexpected error logging with torch.no_grad(): # enter no grad context if is_image_file(args.image): if args.mask and is_image_file(args.mask): # Test a single masked image with a given mask x = default_loader(args.image) mask = default_loader(args.mask) x = transforms.Resize(config['image_shape'][:-1])(x) x = transforms.CenterCrop(config['image_shape'][:-1])(x) mask = transforms.Resize(config['image_shape'][:-1])(mask) mask = transforms.CenterCrop( config['image_shape'][:-1])(mask) x = transforms.ToTensor()(x) mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0) x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) elif args.mask: raise TypeError( "{} is not an image file.".format(args.mask)) else: # Test a single ground-truth image with a random mask ground_truth = default_loader(args.image) ground_truth = transforms.Resize( config['image_shape'][:-1])(ground_truth) ground_truth = transforms.CenterCrop( config['image_shape'][:-1])(ground_truth) ground_truth = transforms.ToTensor()(ground_truth) ground_truth = normalize(ground_truth) ground_truth = ground_truth.unsqueeze(dim=0) bboxes = random_bbox( config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join('checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list( checkpoint_path, "gen", iteration=args.iter) netG.load_state_dict(torch.load(last_model_name)) model_iteration = int(last_model_name[-11:-3]) print("Resume from {} at iteration {}".format( checkpoint_path, model_iteration)) if cuda: netG = nn.parallel.DataParallel( netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) vutils.save_image(inpainted_result, args.output, padding=0, normalize=True) print("Saved the inpainted result to {}".format(args.output)) if args.flow: vutils.save_image(offset_flow, args.flow, padding=0, normalize=True) print("Saved offset flow to {}".format(args.flow)) else: raise TypeError("{} is not an image file.".format) # exit no grad context except Exception as e: # for unexpected error logging print("Error: {}".format(e)) raise e
def main(): args = parser.parse_args() config = get_config(args.config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) print("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) chunker = ImageChunker(config['image_shape'][0], config['image_shape'][1], args.overlap) try: # for unexpected error logging with torch.no_grad(): # enter no grad context if is_image_file(args.image): print("Loading image...") imgs, masks = [], [] img_ori = default_loader(args.image) img_w, img_h = img_ori.size # Load mask txt file fname = args.image.replace('.jpg', '.txt') bboxes, _ = load_bbox_txt(fname, img_w, img_h) mask_ori = create_mask(bboxes, img_w, img_h) chunked_images = chunker.dimension_preprocess( np.array(deepcopy(img_ori))) chunked_masks = chunker.dimension_preprocess( np.array(deepcopy(mask_ori))) for (x, msk) in zip(chunked_images, chunked_masks): x = transforms.ToTensor()(x) mask = transforms.ToTensor()(msk)[0].unsqueeze(dim=0) # x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) imgs.append(x) masks.append(mask) # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join( 'checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list(checkpoint_path, "gen", iteration=args.iter) netG.load_state_dict(torch.load(last_model_name)) model_iteration = int(last_model_name[-11:-3]) print("Resume from {} at iteration {}".format( checkpoint_path, model_iteration)) pred_imgs = [] for (x, mask) in zip(imgs, masks): if torch.max(mask) == 1: if cuda: netG = nn.parallel.DataParallel( netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) inpainted_result = inpainted_result.squeeze( dim=0).permute(1, 2, 0).cpu() pred_imgs.append(inpainted_result.numpy()) else: pred_imgs.append( x.squeeze(dim=0).permute(1, 2, 0).numpy()) pred_imgs = np.asarray(pred_imgs, dtype=np.float32) reconstructed_image = chunker.dimension_postprocess( pred_imgs, np.array(img_ori)) # plt.imshow(reconstructed_image); plt.show() reconstructed_image = torch.tensor( reconstructed_image).permute(2, 0, 1).unsqueeze(dim=0) vutils.save_image(reconstructed_image, args.output, padding=0, normalize=True) print("Saved the inpainted result to {}".format(args.output)) if args.flow: vutils.save_image(offset_flow, args.flow, padding=0, normalize=True) print("Saved offset flow to {}".format(args.flow)) else: raise TypeError("{} is not an image file.".format) # exit no grad context except Exception as e: # for unexpected error logging print("Error: {}".format(e)) raise e
def main(args): # Set random seed for reproducibility seed = args.seed if(seed is None): seed = random.randint(1, 10000) # use if you want new results print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # directories saveloc = os.path.join(args.saveloc, args.expname) modelpath = os.path.join(args.modelpath, args.modelname) if(not os.path.exists(saveloc)): os.makedirs(saveloc) num_batches = 1 # no. of image batches to generate batch_size = 200 # no. of images to generate nc = 1 # Number of channels in the training images. For color images this is 3 nz = 62 # Size of z latent vector (i.e. size of generator input) ndc = 10 # latent categorical code ncc = 3 # continuous categorical code ngf = 64 fixed_exp = False # Number of GPUs available. Use 0 for CPU mode. ngpu = 1 if(ngpu > 0): torch.cuda.set_device(0) # load model weights netG = Generator(ndc+ncc+nz, nc, ngf) print('********* Generator **********\n', netG) netG.load_state_dict(torch.load(modelpath)) netG.eval() if(ngpu > 0): # assign to GPU netG = netG.cuda() print("Starting Testing Loop...") if(fixed_exp): z_rand = torch.randn((batch_size, nz, 1, 1)) z_disc = torch.LongTensor(np.random.randint(ndc, size=(batch_size, 1))) z_cont = torch.rand((batch_size, ncc, 1, 1)) * 2 - 1 # multiple digits plot # z_cont2 = torch.tensor(np.tile(np.linspace(-1, 1, 20).reshape(1, -1), reps=(10, 1))).view(batch_size, -1, 1, 1) # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # pdb.set_trace() # z_disc = torch.LongTensor(np.tile(np.arange(0, 10).reshape(-1,1), reps=[1, batch_size // 10])) # z_disc = torch.LongTensor(np.repeat(np.arange(0, 10), repeats=batch_size // 10)).reshape(-1,1) # z_cont = torch.tensor(np.tile(np.linspace(-1, 1, 7).reshape(1,-1), reps=[10, 1])) # z_disc = 3 * torch.ones((batch_size, 1), dtype=torch.long) # z_cont2 = torch.linspace(-6, 6, batch_size).view(batch_size, -1, 1, 1) # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # z_cont12 = torch.rand((batch_size, 2, 1, 1)) * 2 - 1 # z_cont3 = torch.linspace(-5, 5, batch_size).view(batch_size, -1, 1, 1) # z_cont4 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # z_cont2 = torch.linspace(-4, 4, batch_size).view(batch_size, -1, 1, 1) # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # z_cont2 = torch.tensor(np.tile(np.linspace(-2.5, 2.5, 20).reshape(1, -1), reps=(10, 1))).view(batch_size, -1, 1, 1) # z_cont3 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # z_cont3 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # torch.linspace(-2, 2, batch_size).view(batch_size, -1, 1, 1) # z_cont = torch.cat([ # z_cont1.type(torch.float32), # z_cont2.type(torch.float32), # z_cont3.type(torch.float32)], # z_cont4.type(torch.float32)], # dim=1 # ) for iters in range(num_batches): # fake batch if(fixed_exp): noise, idx = controlled_noise_sample( batch_size, ndc, z_random = z_rand, # nz=nz, z_categorical=z_disc, # num_discrete = 1, z_continuous=z_cont # num_continuous=ncc, ) else: noise, idx = noise_sample(1, ndc, ncc, nz, 1) if(ngpu > 0): noise = noise.cuda() fake = netG(noise) # Check how the generator is doing by saving G's output on fixed_noise with torch.no_grad(): fake = netG(noise).detach() vutils.save_image( fake, os.path.join(saveloc, str(iters)+'.jpg'), nrow=20, normalize=True, range=(0.0, 1.0) ) with open(os.path.join(saveloc, 'metadata.txt'), 'a') as f: for lineno in range(batch_size): if(batch_size == 1): f.write('C1: {:1.0f}, '.format(idx.item())) else: f.write('C1: {:1.0f}, '.format(idx[lineno].item())) for i, item in enumerate(noise[lineno, nz+ndc:].squeeze()): f.write('C'+str(2+i)+': {:1.4f}, '.format(item.item())) f.write('\n') print('Generated file {}'.format(iters))
def __init__(self, opts, nc_in=5, nc_out=3, d_s_args={}, d_t_args={}, losses=None): super().__init__() self.d_t_args = { "nf": 32, "use_sigmoid": True, "norm": 'SN' } # default values for key, value in d_t_args.items(): # overwrite default values if provided self.d_t_args[key] = value self.d_s_args = { "nf": 32, "use_sigmoid": True, "norm": 'SN' } # default values for key, value in d_s_args.items(): # overwrite default values if provided self.d_s_args[key] = value nf = opts['nf'] norm = opts['norm'] use_bias = opts['bias'] # warning: if 2d convolution is used in generator, settings (e.g. stride, # kernal_size, padding) on the temporal axis will be discarded self.conv_by = opts['conv_by'] if 'conv_by' in opts else '3d' self.conv_type = opts['conv_type'] if 'conv_type' in opts else 'gated' self.flow_tsm = opts['flow_tsm'] self.use_refine = opts['use_refine'] if 'use_refine' in opts else False use_skip_connection = opts.get('use_skip_connection', False) self.backbone = opts['backbone'] if 'backbone' in opts else 'unet' self.opts = opts ###################### # Convolution layers # ###################### self.generator = Generator(self.backbone, nc_in, nc_out, nf, use_bias, norm, self.conv_by, self.conv_type, use_refine=self.use_refine, use_skip_connection=use_skip_connection, use_flow_tsm=self.flow_tsm) ################# # Discriminator # ################# if 'spatial_discriminator' not in opts or opts['spatial_discriminator']: self.spatial_discriminator = SNTemporalPatchGANDiscriminator( nc_in=5, conv_type='2d', **self.d_s_args) self.advloss = AdversarialLoss() if 'temporal_discriminator' not in opts or opts[ 'temporal_discriminator']: self.temporal_discriminator = SNTemporalPatchGANDiscriminator( nc_in=5, **self.d_t_args) self.advloss = AdversarialLoss() ####### # Vgg # ####### self.vgg = Vgg16(requires_grad=False) ######## # Loss # ######## self.losses = losses for key, value in losses.items(): if value > 0: setattr(self, key, loss_nickname_to_module[key]())
def train_distributed(config, logger, writer, checkpoint_path): dist.init_process_group( backend='nccl', # backend='gloo', init_method='env://' ) # Find out what GPU on this compute node. # local_rank = torch.distributed.get_rank() # this is the total # of GPUs across all nodes # if using 2 nodes with 4 GPUs each, world size is 8 # world_size = torch.distributed.get_world_size() print("### global rank of curr node: {} of {}".format(local_rank, world_size)) # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. # print("local_rank: ", local_rank) # dist.barrier() torch.cuda.set_device(local_rank) # Define the trainer print("Creating models on device: ", local_rank) input_dim = config['netG']['input_dim'] cnum = config['netG']['ngf'] use_cuda = True gated = config['netG']['gated'] # Models # netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda() netG = torch.nn.parallel.DistributedDataParallel( netG, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) localD = LocalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda() localD = torch.nn.parallel.DistributedDataParallel( localD, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) globalD = GlobalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda() globalD = torch.nn.parallel.DistributedDataParallel( globalD, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) if local_rank == 0: logger.info("\n{}".format(netG)) logger.info("\n{}".format(localD)) logger.info("\n{}".format(globalD)) # Optimizers # optimizer_g = torch.optim.Adam( netG.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']) ) d_params = list(localD.parameters()) + list(globalD.parameters()) optimizer_d = torch.optim.Adam( d_params, lr=config['lr'], betas=(config['beta1'], config['beta2']) ) # Data # sampler = None train_dataset = Dataset( data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop'] ) sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, # num_replicas=torch.cuda.device_count(), num_replicas=len(config['gpu_ids']), # rank = local_rank ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=config['batch_size'], shuffle=(sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=sampler, drop_last=True ) # Get the resume iteration to restart training # # start_iteration = trainer.resume(config['resume']) if config['resume'] else 1 start_iteration = 1 print("\n\nStarting epoch: ", start_iteration) iterable_train_loader = iter(train_loader) if local_rank == 0: time_count = time.time() epochs = config['niter'] + 1 pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01) for iteration in pbar: sampler.set_epoch(iteration) try: ground_truth = next(iterable_train_loader) except StopIteration: iterable_train_loader = iter(train_loader) ground_truth = next(iterable_train_loader) # Prepare the inputs bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Move to proper device. # bboxes = bboxes.cuda(local_rank) x = x.cuda(local_rank) mask = mask.cuda(local_rank) ground_truth = ground_truth.cuda(local_rank) ###### Forward pass ###### compute_g_loss = iteration % config['n_critic'] == 0 # losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth, # localD=localD, globalD=globalD, # coarse_gen=coarse_generator, fine_gen=fine_generator, # local_rank=local_rank, compute_loss_g=compute_g_loss) losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth, netG=netG, localD=localD, globalD=globalD, local_rank=local_rank, compute_loss_g=compute_g_loss) # Scalars from different devices are gathered into vectors # for k in losses.keys(): if not losses[k].dim() == 0: losses[k] = torch.mean(losses[k]) ###### Backward pass ###### # Update D if not compute_g_loss: optimizer_d.zero_grad() losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda'] losses['d'].backward() optimizer_d.step() # Update G if compute_g_loss: optimizer_g.zero_grad() losses['g'] = losses['ae'] * config['ae_loss_alpha'] losses['g'] += losses['l1'] * config['l1_loss_alpha'] losses['g'] += losses['wgan_g'] * config['gan_loss_alpha'] losses['g'].backward() optimizer_g.step() # Set tqdm description # if local_rank == 0: log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] message = ' ' for k in log_losses: v = losses.get(k, 0.) writer.add_scalar(k, v, iteration) message += '%s: %.4f ' % (k, v) pbar.set_description( ( f" {message}" ) ) if local_rank == 0: if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if x.size(0) > viz_max_out: viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out], offset_flow[:viz_max_out]], dim=1) else: viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1) viz_images = viz_images.view(-1, *list(x.size())[1:]) vutils.save_image(viz_images, '%s/niter_%08d.png' % (checkpoint_path, iteration), nrow=3 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: save_model( netG, globalD, localD, optimizer_g, optimizer_d, checkpoint_path, iteration )
def train_distributed_v2(config, logger, writer, checkpoint_path): dist.init_process_group( backend='nccl', # backend='gloo', init_method='env://' ) # Find out what GPU on this compute node. # local_rank = torch.distributed.get_rank() # this is the total # of GPUs across all nodes # if using 2 nodes with 4 GPUs each, world size is 8 # world_size = torch.distributed.get_world_size() print("### global rank of curr node: {} of {}".format(local_rank, world_size)) # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. # print("local_rank: ", local_rank) # dist.barrier() torch.cuda.set_device(local_rank) print("Creating models on device: ", local_rank) # Various definitions for models, etc. # input_dim = config['netG']['input_dim'] cnum = config['netG']['ngf'] use_cuda = True gated = config['netG']['gated'] batch_size = config['batch_size'] # L1 loss used on outputs from course and fine networks in generator. # loss_l1 = nn.L1Loss(reduction='mean').cuda() # Models # netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda() netG = torch.nn.parallel.DistributedDataParallel( netG, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) patchD = PatchDis(config['netD'], use_cuda=True, device=local_rank).cuda() patchD = torch.nn.parallel.DistributedDataParallel( patchD, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) if local_rank == 0: logger.info("\n{}".format(netG)) logger.info("\n{}".format(patchD)) # Optimizers # optimizer_g = torch.optim.Adam( netG.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']) ) optimizer_d = torch.optim.Adam( patchD.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']) ) if local_rank == 0: logger.info("\n{}".format(netG)) logger.info("\n{}".format(patchD)) # Data # sampler = None train_dataset = Dataset( data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop'] ) sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, # num_replicas=torch.cuda.device_count(), num_replicas=len(config['gpu_ids']), # rank = local_rank ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=(sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=sampler, drop_last=True ) losses = { 'coarse': 0.0, 'fine': 0.0, 'ae': 0.0, 'g_loss': 0.0, 'd_loss': 0.0 } # Get the resume iteration to restart training # ### TODO: ### - allow resuming from checkpoint. ### # start_iteration = trainer.resume(config['resume']) if config['resume'] else 1 start_iteration = 1 print("\n\nStarting epoch: ", start_iteration) iterable_train_loader = iter(train_loader) if local_rank == 0: time_count = time.time() epochs = config['niter'] + 1 pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01) for iteration in pbar: sampler.set_epoch(iteration) try: ground_truth = next(iterable_train_loader) except StopIteration: iterable_train_loader = iter(train_loader) ground_truth = next(iterable_train_loader) ground_truth = ground_truth.cuda(local_rank) mask_ff = random_ff_mask(config['random_ff_settings'], batch_size=batch_size).cuda(local_rank) # netG.zero_grad() imgs_incomplete = ground_truth * (1. - mask_ff) # just background x1, x2, offset_flow = netG(imgs_incomplete, mask_ff) imgs_complete = (x2 * mask_ff) + imgs_incomplete # Losses # coarse_loss = config['l1_loss_alpha'] * loss_l1(ground_truth, x1) fine_loss = config['l1_loss_alpha'] * loss_l1(ground_truth, x2) ae_loss = coarse_loss + fine_loss losses['coarse'] = coarse_loss.item() losses['fine'] = fine_loss.item() losses['ae'] = ae_loss.item() # Discriminate # batch_pos_neg = torch.cat([ground_truth, imgs_complete], dim=0) # [N3HW] # Add in mask and repeat for ground truth and generated completion. # Will be split later to produce "real" and "fake" patch features in discriminator # for use with hinge loss. # batch_pos_neg= torch.cat([batch_pos_neg, mask_ff.repeat(2, 1, 1, 1)], dim=1) # [N4HW] # patchD.zero_grad() pos_neg = patchD(batch_pos_neg) # Losses # pos, neg = torch.chunk(pos_neg, 2, dim=0) g_loss, d_loss = gan_hinge_loss(pos, neg) g_loss += ae_loss losses['g_loss'] = g_loss.item() losses['d_loss'] = d_loss.item() compute_g_loss = iteration % config['n_critic'] == 0 # # Optimize # # # if not compute_g_loss: optimizer_d.zero_grad() d_loss.backward(retain_graph=True) optimizer_d.step() pos_neg = patchD(batch_pos_neg) pos, neg = torch.chunk(pos_neg, 2, dim=0) g_loss, d_loss = gan_hinge_loss(pos, neg) g_loss += ae_loss # if compute_g_loss: optimizer_g.zero_grad() g_loss.backward() optimizer_g.step() # print("ae_loss: ", ae_loss, " g_loss: ", g_loss, " d_loss: ", d_loss) # Set tqdm description # if local_rank == 0: message = ' ' for k in losses: # v = losses.get(k, 0.) v = losses[k] # writer.add_scalar(k, v, iteration) message += '%s: %.4f ' % (k, v) pbar.set_description( ( f" {message}" ) ) # Save output from current iteration. # if local_rank == 0: if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if ground_truth.size(0) > viz_max_out: viz_images = torch.stack( [ground_truth[:viz_max_out], imgs_incomplete[:viz_max_out], imgs_complete[:viz_max_out], offset_flow[:viz_max_out]], dim=1 ) else: viz_images = torch.stack( [ground_truth, imgs_incomplete, imgs_complete, offset_flow], dim=1 ) viz_images = viz_images.view(-1, *list(ground_truth.size())[1:]) vutils.save_image(viz_images, '%s/niter_%08d.png' % (checkpoint_path, iteration), nrow=2 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: save_model_v2(netG, patchD, optimizer_g, optimizer_d, checkpoint_path, iteration)
cudnn.benchmark = True # print("Arguments: {}".format(args)) print("Use cuda: {}, use gpu_ids: {}".format(cuda, device_ids)) # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) print("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) # print("Configuration: {}".format(config)) # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight # if cuda: # netG.cuda() last_model_name = get_model_list(args.checkpoint_path, "gen", iteration=args.iter) last_model_name = args.which_model # last_model_name = args.which_model # print("loading model from here --------------> {}".format(last_model_name)) # if not cuda: # netG.load_state_dict(torch.load(last_model_name, map_location='cpu')) # else:
def main(): args = parser.parse_args() config = get_config(args.config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) # print("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) t0 = time.time() dataset = datasets.LoadImages(args.image) chunker = ImageChunker(config['image_shape'][0], config['image_shape'][1], args.overlap) try: # for unexpected error logging with torch.no_grad(): # enter no grad context # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join('checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path last_model_name = get_model_list(checkpoint_path, "gen", iteration=args.iter) prev_fname = '' vid_writer = None for fpath, img_ori, vid_cap in dataset : imgs, masks = [], [] if prev_fname == fpath : frame += 1 # increase frame number if still on the same file else : frame = 0 # start frame number _, img_h, img_w = img_ori.shape txtfile = pathlib.Path(fpath).with_suffix('.txt') # Load mask txt file txtfile = os.path.join(args.output, str(txtfile).split('/')[-1]) if os.path.exists(txtfile) : bboxes, bframes = load_bbox_txt(txtfile, img_w, img_h) assert len(bboxes) == len(bframes) idx = [ii for ii, val in enumerate(bframes) if val==frame] bndbxs = [bboxes[ii] for ii in idx] img_ori = np.moveaxis(img_ori, 0, -1) if len(bndbxs) > 0 : # if any logo detected mask_ori = create_mask(bndbxs, img_w, img_h) # fig, axes = plt.subplots(1,2); axes[0].imshow(img_ori[0]); axes[1].imshow(mask_ori); plt.show() chunked_images = chunker.dimension_preprocess(np.array(deepcopy(img_ori))) chunked_masks = chunker.dimension_preprocess(np.array(deepcopy(mask_ori))) for (x, msk) in zip(chunked_images, chunked_masks) : x = transforms.ToTensor()(x) mask = transforms.ToTensor()(msk)[0].unsqueeze(dim=0) # x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) imgs.append(x) masks.append(mask) # Define the trainer netG = Generator(config['netG'], cuda, device_ids) netG.load_state_dict(torch.load(last_model_name)) model_iteration = int(last_model_name[-11:-3]) # print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration)) pred_imgs = [] for (x, mask) in zip(imgs, masks) : if torch.max(mask) == 1 : if cuda: netG = nn.parallel.DataParallel(netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) inpainted_result = inpainted_result.squeeze(dim=0).permute(1,2,0).cpu() pred_imgs.append(inpainted_result.numpy()) else : pred_imgs.append(x.squeeze(dim=0).permute(1,2,0).numpy()) pred_imgs = np.asarray(pred_imgs, dtype=np.float32) reconstructed_image = chunker.dimension_postprocess(pred_imgs, np.array(img_ori)) reconstructed_image = np.uint8(reconstructed_image[:, :, ::-1]*255) # BGR to RGB, and rescaling else : # no logo detected reconstructed_image = img_ori[:, :, ::-1] # Save results (image with detections) outname = fpath.split('/')[-1] outname = outname.split('.')[0] + '-inp.' + outname.split('.')[-1] outpath = os.path.join(args.output, outname) if dataset.mode == 'images': cv2.imwrite(outpath, reconstructed_image) print("Saved the inpainted image to {}".format(outpath)) else : if fpath != prev_fname: # new video if isinstance(vid_writer, cv2.VideoWriter): vid_writer.release() # release previous video writer print("Saved the inpainted video to {}".format(outpath)) fps = vid_cap.get(cv2.CAP_PROP_FPS) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) vid_writer = cv2.VideoWriter(outpath, cv2.VideoWriter_fourcc(*args.fourcc), fps, (w, h)) vid_writer.write(reconstructed_image) prev_fname = fpath # exit no grad context except Exception as err: # for unexpected error logging print("Error: {}".format(err)) pass print('Inpainting: (%.3fs)' % (time.time() - t0))
def generate(img, img_mask_path, model_path): with torch.no_grad(): # enter no grad context if img_mask_path and is_image_file(img_mask_path): # Test a single masked image with a given mask x = Image.fromarray(img) mask = default_loader(img_mask_path) x = transforms.Resize(config['image_shape'][:-1])(x) x = transforms.CenterCrop(config['image_shape'][:-1])(x) mask = transforms.Resize(config['image_shape'][:-1])(mask) mask = transforms.CenterCrop(config['image_shape'][:-1])(mask) x = transforms.ToTensor()(x) mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0) x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) elif img_mask_path: raise TypeError("{} is not an image file.".format(img_mask_path)) else: # Test a single ground-truth image with a random mask #ground_truth = default_loader(img_path) ground_truth = img ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth) ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth) ground_truth = transforms.ToTensor()(ground_truth) ground_truth = normalize(ground_truth) ground_truth = ground_truth.unsqueeze(dim=0) bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Set checkpoint path if not model_path: checkpoint_path = os.path.join('checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = model_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list(checkpoint_path, "gen", iteration=0) if cuda: netG.load_state_dict(torch.load(last_model_name)) else: netG.load_state_dict(torch.load(last_model_name, map_location='cpu')) model_iteration = int(last_model_name[-11:-3]) print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration)) if cuda: netG = nn.parallel.DataParallel(netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) inpainted_result = from_torch_img_to_numpy(inpainted_result, 'output.png', padding=0, normalize=True) return inpainted_result
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) print(args) # create logging folder log_path = os.path.join(args.save_path, args.exp_name + '/log') model_path = os.path.join(args.save_path, args.exp_name + '/models') if not os.path.exists(log_path) and not os.path.exists(model_path): os.makedirs(log_path) os.makedirs(model_path) writer = SummaryWriter(log_path) # tensorboard # load model device = torch.device("cuda:0") G = Generator(args.d_za, args.d_zm, args.ch_g, args.g_mode, args.use_attention).to(device) VD = VideoDiscriminator(args.ch_d).to(device) ID = ImageDiscriminator(args.ch_d).to(device) G = nn.DataParallel(G) VD = nn.DataParallel(VD) ID = nn.DataParallel(ID) # optimizer optimizer_G = torch.optim.Adam(G.parameters(), args.g_lr, (0.5, 0.999)) optimizer_VD = torch.optim.Adam(VD.parameters(), args.d_lr, (0.5, 0.999)) optimizer_ID = torch.optim.Adam(ID.parameters(), args.d_lr, (0.5, 0.999)) # loss criterion = nn.BCEWithLogitsLoss().to(device) # prepare dataset print('==> preparing dataset') transform = torchvision.transforms.Compose([ transforms_vid.ClipResize((args.img_size, args.img_size)), transforms_vid.ClipToTensor(), transforms_vid.ClipNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset = UVA(args.data_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True) # for validation fixed_za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) fixed_zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device) print('==> start training') for epoch in range(args.max_epoch): train(args, epoch, G, VD, ID, optimizer_G, optimizer_VD, optimizer_ID, criterion, dataloader, writer, device) if epoch % args.val_freq == 0: vis(epoch, G, fixed_za, fixed_zm, device, writer) if epoch % args.save_freq == 0: torch.save(G.state_dict(), os.path.join(model_path, 'G_%d.pth' % (epoch))) torch.save(VD.state_dict(), os.path.join(model_path, 'VD_%d.pth' % (epoch))) torch.save(ID.state_dict(), os.path.join(model_path, 'ID_%d.pth' % (epoch))) return