예제 #1
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus
    if cfg.checkpoint_config is not None:
        # save mmaction version in checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmact_version=__version__,
                                          config=cfg.text)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_recognizer(cfg.model,
                             train_cfg=cfg.train_cfg,
                             test_cfg=cfg.test_cfg)

    train_dataset = get_trimmed_dataset(cfg.data.train)
    val_dataset = get_trimmed_dataset(cfg.data.val)
    datasets = []
    for flow in cfg.workflow:
        assert flow[0] in ['train', 'val']
        if flow[0] == 'train':
            datasets.append(train_dataset)
        else:
            datasets.append(val_dataset)
    train_network(model,
                  datasets,
                  cfg,
                  distributed=distributed,
                  validate=args.validate,
                  logger=logger)
예제 #2
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    if is_valid(args.resume_from):
        cfg.resume_from = args.resume_from

    if is_valid(args.load_from):
        cfg.load_from = args.load_from

    if is_valid(args.load2d_from):
        cfg.model.backbone.pretrained = args.load2d_from
        cfg.model.backbone.pretrained2d = True

    if args.num_videos is not None:
        assert args.num_videos > 0
        cfg.data.videos_per_gpu = args.num_videos

    if cfg.checkpoint_config is not None:
        cfg.checkpoint_config.meta = dict(mmact_version=__version__,
                                          config=cfg.text)

    if args.data_dir is not None:
        cfg = update_data_paths(cfg, args.data_dir)

    if hasattr(
            cfg.model, 'masked_num'
    ) and cfg.model.masked_num is not None and cfg.model.masked_num > 0:
        assert cfg.data.videos_per_gpu > cfg.model.masked_num

        cfg.data.videos_per_gpu -= cfg.model.masked_num

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    train_dataset = get_trimmed_dataset(cfg.data.train)
    ignores = ['num_batches_tracked']
    if args.ignores is not None and len(args.ignores) > 0:
        ignores += args.ignores

    model = build_recognizer(cfg.model,
                             train_cfg=cfg.train_cfg,
                             test_cfg=cfg.test_cfg)
    if args.sync_bn:
        logger.info('Enabled SyncBatchNorm')
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    train_network(model,
                  train_dataset,
                  cfg,
                  distributed=distributed,
                  validate=args.validate,
                  logger=logger,
                  ignores=tuple(ignores))