def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) # If args.single_gpu is set to True, # we will disable distributed data parallel if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) make_logging_dir(cfg.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) current_epoch, current_iteration = trainer.load_checkpoint( cfg, args.checkpoint,resume=args.resume) # Start training. for epoch in range(current_epoch, cfg.max_epoch): print('Epoch {} ...'.format(epoch)) if not args.single_gpu: train_data_loader.sampler.set_epoch(current_epoch) trainer.start_of_epoch(current_epoch) for it, data in enumerate(train_data_loader): data = trainer.start_of_iteration(data, current_iteration) for _ in range(cfg.trainer.dis_step): trainer.dis_update(data) for _ in range(cfg.trainer.gen_step): trainer.gen_update(data) current_iteration += 1 trainer.end_of_iteration(data, current_epoch, current_iteration) if current_iteration >= cfg.max_iter: print('Done with training!!!') return current_epoch += 1 trainer.end_of_epoch(data, current_epoch, current_iteration) print('Done with training!!!') return
def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) if not hasattr(cfg, 'inference_args'): cfg.inference_args = None # If args.single_gpu is set to True, # we will disable distributed data parallel. if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. test_data_loader = get_test_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, None, test_data_loader) # if args.checkpoint == '': # # Download pretrained weights. # pretrained_weight_url = cfg.pretrained_weight # if pretrained_weight_url == '': # print('google link to the pretrained weight is not specified.') # raise # default_checkpoint_path = args.config.replace('.yaml', '.pt') # args.checkpoint = get_checkpoint( # default_checkpoint_path, pretrained_weight_url) # print('Checkpoint downloaded to', args.checkpoint) # Load checkpoint. trainer.load_checkpoint(cfg, args.checkpoint) # Do inference. trainer.current_epoch = -1 trainer.current_iteration = -1 trainer.test(test_data_loader, args.output_dir, cfg.inference_args)
def main(): args = parse_args() set_affinity(args.local_rank) set_random_seed(args.seed, by_rank=True) cfg = Config(args.config) # If args.single_gpu is set to True, # we will disable distributed data parallel if not args.single_gpu: cfg.local_rank = args.local_rank init_dist(cfg.local_rank) # Override the number of data loading workers if necessary if args.num_workers is not None: cfg.data.num_workers = args.num_workers # Create log directory for storing training results. cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) make_logging_dir(cfg.logdir) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg) net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ get_model_optimizer_and_scheduler(cfg, seed=args.seed) trainer = get_trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) # Start evaluation. checkpoints = \ sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir))) for checkpoint in checkpoints: current_epoch, current_iteration = \ trainer.load_checkpoint(cfg, checkpoint, resume=True) trainer.current_epoch = current_epoch trainer.current_iteration = current_iteration trainer.write_metrics() print('Done with evaluation!!!') return