def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) # If args.single_gpu is set to True, # we will disable distributed data parallel if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) make_logging_dir(cfg.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) current_epoch, current_iteration = trainer.load_checkpoint( cfg, args.checkpoint,resume=args.resume) # Start training. for epoch in range(current_epoch, cfg.max_epoch): print('Epoch {} ...'.format(epoch)) if not args.single_gpu: train_data_loader.sampler.set_epoch(current_epoch) trainer.start_of_epoch(current_epoch) for it, data in enumerate(train_data_loader): data = trainer.start_of_iteration(data, current_iteration) for _ in range(cfg.trainer.dis_step): trainer.dis_update(data) for _ in range(cfg.trainer.gen_step): trainer.gen_update(data) current_iteration += 1 trainer.end_of_iteration(data, current_epoch, current_iteration) if current_iteration >= cfg.max_iter: print('Done with training!!!') return current_epoch += 1 trainer.end_of_epoch(data, current_epoch, current_iteration) print('Done with training!!!') return
def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) if not hasattr(cfg, 'inference_args'): cfg.inference_args = None # If args.single_gpu is set to True, # we will disable distributed data parallel. if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. test_data_loader = get_test_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, None, test_data_loader) # if args.checkpoint == '': # # Download pretrained weights. # pretrained_weight_url = cfg.pretrained_weight # if pretrained_weight_url == '': # print('google link to the pretrained weight is not specified.') # raise # default_checkpoint_path = args.config.replace('.yaml', '.pt') # args.checkpoint = get_checkpoint( # default_checkpoint_path, pretrained_weight_url) # print('Checkpoint downloaded to', args.checkpoint) # Load checkpoint. trainer.load_checkpoint(cfg, args.checkpoint) # Do inference. trainer.current_epoch = -1 trainer.current_iteration = -1 trainer.test(test_data_loader, args.output_dir, cfg.inference_args)
def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) # If args.single_gpu is set to True, # we will disable distributed data parallel if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) make_logging_dir(cfg.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) # Start evaluation. checkpoints = \ sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir))) for checkpoint in checkpoints: current_epoch, current_iteration = \ trainer.load_checkpoint(cfg, checkpoint, resume=True) trainer.current_epoch = current_epoch trainer.current_iteration = current_iteration trainer.write_metrics() print('Done with evaluation!!!') return
def _init_single_image_model(self, load_weights=True): r"""Load single image model, if any.""" if self.single_image_model is None and \ hasattr(self.gen_cfg, 'single_image_model'): print('Using single image model...') single_image_cfg = Config(self.gen_cfg.single_image_model.config) # Init model. net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(single_image_cfg) # Init trainer and load checkpoint. trainer = get_trainer(single_image_cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, None, None) if load_weights: print('Loading single image model checkpoint') single_image_ckpt = self.gen_cfg.single_image_model.checkpoint trainer.load_checkpoint(single_image_cfg, single_image_ckpt) print('Loaded single image model checkpoint') self.single_image_model = net_G.module self.single_image_model_z = None
# encoder x_en2 = self.layer1(x) x_en2 = torch.cat([x_en2, x_d02], dim=1) x_en4 = self.layer2(x_en2) x_en4 = torch.cat([x_en4, x_d04], dim=1) x_en8 = self.layer3(x_en4) x_en8 = torch.cat([x_en8, x_d08], dim=1) x_en16 = self.layer4(x_en8) x_en16 = torch.cat([x_en16, x_d16], dim=1) # decoder x_de8 = self.layer5(x_en16, x_en8) # x_de8 = torch.cat([x_de8, x_en8], dim=1) x_de4 = self.layer6(x_de8, x_en4) # x_de4 = torch.cat([x_de4, x_en4], dim=1) x_de2 = self.layer7(x_de4, x_en2) # x_de2 = torch.cat([x_de2, x_en2], dim=1) out = self.layer8(x_de2, xi_yj) out = self.outlayer(out) return out if __name__ == "__main__": from imaginaire.config import Config cfg = Config("/configs/projects/cagan/LipMPV/base_dis2_gen1.yaml") gen = Generator(cfg.gen, cfg.data) batch = torch.randn((8, 9, 256, 192)) y = gen(batch) print(y.shape)
""" batch_size = images.size(0) features = self.model(images) # outputs = self.classifier(features.view(batch_size, -1)) # return outputs, features, images return features class Discriminator(nn.Module): def __init__(self, gen_cfg, data_cfg): super(Discriminator, self).__init__() self.dis = ResDiscriminator(image_channels=6) self.sigmoid = nn.Sigmoid() def forward(self, x): out = self.dis(x) out = self.sigmoid(out) return out if __name__ == "__main__": from imaginaire.config import Config cfg = Config( "D:/workspace/develop/imaginaire/configs/projects/cagan/LipMPV/base.yaml" ) dis = Discriminator(cfg.dis, cfg.data) batch = torch.randn((8, 6, 256, 192)) features, images = dis(batch) print(features.shape)
def main(): r""" Build lmdb for training/testing. Usage: python scripts/build_lmdb.py \ --config configs/data_image.yaml \ --data_root /mnt/bigdata01/datasets/test_image \ --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \ --overwrite """ args = parse_args() cfg = Config(args.config) # Check if output file already exists. if os.path.exists(args.output_root): if args.overwrite: print('Deleting existing output LMDB.') shutil.rmtree(args.output_root) else: print('Output root LMDB already exists. Use --overwrite. ' + 'Exiting...') return all_filenames, extensions = \ create_metadata(data_root=args.data_root, cfg=cfg, paired=args.paired, input_list=args.input_list) required_data_types = cfg.data.data_types # Build LMDB. os.makedirs(args.output_root) for data_type in required_data_types: data_size = 0 print('Data type:', data_type) filepaths, keys = [], [] print('>> Building file list.') # Get appropriate list of files. if args.paired: filenames = all_filenames else: filenames = all_filenames[data_type] for sequence in tqdm(filenames): for filename in copy.deepcopy(filenames[sequence]): filepath = construct_file_path(args.data_root, data_type, sequence, filename, extensions[data_type]) key = '%s/%s' % (sequence, filename) filesize = check_and_add(filepath, key, filepaths, keys, remove_missing=args.remove_missing) # Remove file from list, if missing. if filesize == -1 and args.paired and args.remove_missing: print('Removing %s from list' % (filename)) filenames[sequence].remove(filename) data_size += filesize # Remove empty sequences. if args.paired and args.remove_missing: for sequence in copy.deepcopy(all_filenames): if not all_filenames[sequence]: all_filenames.pop(sequence) # Allocate size. data_size = max(int((1 + args.metadata_factor) * data_size), 1e9) print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9)) # Write LMDB to file. output_filepath = os.path.join(args.output_root, data_type) build_lmdb(filepaths, keys, output_filepath, data_size, args.large) # Output list of all filenames. if args.output_root: with open(args.output_root + '/all_filenames.json', 'w') as fout: json.dump(all_filenames, fout, indent=4) # Output metadata. with open(args.output_root + '/metadata.json', 'w') as fout: json.dump(extensions, fout, indent=4) else: return all_filenames, extensions
def main(): args = parse_args() cfg = Config(args.config) # Check if output file already exists. if os.path.exists(args.output_root): if args.overwrite: print('Deleting existing output LMDB.') shutil.rmtree(args.output_root) else: print('Output root LMDB already exists. Use --overwrite. ' + 'Exiting...') return # all_filenames: dictionary # "images_content" -> "class01": ["image01.jpg",...], "class02": ["image01.jpg",...] # "images_style" -> "class01": ["image01.jpg",...], "class02": ["image01.jpg",...] all_filenames, extensions = \ create_metadata(data_root=args.data_root, cfg=cfg, paired=args.paired, input_list=args.input_list) required_data_types = cfg.data.data_types # Build LMDB. os.makedirs(args.output_root) for data_type in required_data_types: # required_data_types = ['images_content', 'images_style'] data_size = 0 print('Data type:', data_type) filepaths, keys = [], [] print('>> Building file list.') # Get appropriate list of files. if args.paired: filenames = all_filenames else: filenames = all_filenames[data_type] for sequence in tqdm(filenames): # each class # append key and filepath to keys and filepaths in each class for filename in copy.deepcopy(filenames[sequence]): filepath = construct_file_path(args.data_root, data_type, sequence, filename, extensions[data_type]) # key = '%s/%s' % (sequence, filename) key = os.path.join(sequence, filename) filesize = check_and_add(filepath, key, filepaths, keys, remove_missing=args.remove_missing) # Remove file from list, if missing. if filesize == -1 and args.paired and args.remove_missing: print('Removing %s from list' % (filename)) filenames[sequence].remove(filename) data_size += filesize # Remove empty sequences. if args.paired and args.remove_missing: for sequence in copy.deepcopy(all_filenames): if not all_filenames[sequence]: all_filenames.pop(sequence) # Allocate size. data_size = max(int((1 + args.metadata_factor) * data_size), 1e9) print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9)) # Write LMDB to file. output_filepath = os.path.join(args.output_root, data_type) build_lmdb(filepaths, keys, output_filepath, data_size, args.large) # Output list of all filenames. if args.output_root: with open(os.path.join(args.output_root, 'all_filenames.json'), 'w') as fout: json.dump(all_filenames, fout, indent=4) # Output metadata. with open(os.path.join(args.output_root, 'metadata.json'), 'w') as fout: json.dump(extensions, fout, indent=4) else: return all_filenames, extensions