def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir if args.resume_from is not None: cfg.resume_from = args.resume_from cfg.gpus = args.gpus if cfg.checkpoint_config is not None: # save mmaction version in checkpoints as meta data cfg.checkpoint_config.meta = dict(mmact_version=__version__, config=cfg.text) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # init logger before other steps logger = get_root_logger(cfg.log_level) logger.info('Distributed training: {}'.format(distributed)) # set random seeds if args.seed is not None: logger.info('Set random seed to {}'.format(args.seed)) set_random_seed(args.seed) model = build_recognizer(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) train_dataset = get_trimmed_dataset(cfg.data.train) val_dataset = get_trimmed_dataset(cfg.data.val) datasets = [] for flow in cfg.workflow: assert flow[0] in ['train', 'val'] if flow[0] == 'train': datasets.append(train_dataset) else: datasets.append(val_dataset) train_network(model, datasets, cfg, distributed=distributed, validate=args.validate, logger=logger)
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir if is_valid(args.resume_from): cfg.resume_from = args.resume_from if is_valid(args.load_from): cfg.load_from = args.load_from if is_valid(args.load2d_from): cfg.model.backbone.pretrained = args.load2d_from cfg.model.backbone.pretrained2d = True if args.num_videos is not None: assert args.num_videos > 0 cfg.data.videos_per_gpu = args.num_videos if cfg.checkpoint_config is not None: cfg.checkpoint_config.meta = dict(mmact_version=__version__, config=cfg.text) if args.data_dir is not None: cfg = update_data_paths(cfg, args.data_dir) if hasattr( cfg.model, 'masked_num' ) and cfg.model.masked_num is not None and cfg.model.masked_num > 0: assert cfg.data.videos_per_gpu > cfg.model.masked_num cfg.data.videos_per_gpu -= cfg.model.masked_num # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # init logger before other steps logger = get_root_logger(cfg.log_level) logger.info('Distributed training: {}'.format(distributed)) # set random seeds if args.seed is not None: logger.info('Set random seed to {}'.format(args.seed)) set_random_seed(args.seed) train_dataset = get_trimmed_dataset(cfg.data.train) ignores = ['num_batches_tracked'] if args.ignores is not None and len(args.ignores) > 0: ignores += args.ignores model = build_recognizer(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) if args.sync_bn: logger.info('Enabled SyncBatchNorm') model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) train_network(model, train_dataset, cfg, distributed=distributed, validate=args.validate, logger=logger, ignores=tuple(ignores))