コード例 #1
0
def load_train_config(args):
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    if args.log_step != -1:
        cfg.TRAIN.LOG_STEP = args.log_step
    if args.save_step != -1:
        cfg.TRAIN.SAVE_STEP = args.save_step
    if args.eval_step != -1:
        cfg.TRAIN.EVAL_STEP = args.eval_step

    if args.resume:
        cfg.TRAIN.RESUME = True
    if not args.use_tensorboard:
        cfg.TRAIN.USE_TENSORBOARD = False

    if args.gpus != -1:
        cfg.NUM_GPUS = args.gpus
    if args.nodes != -1:
        cfg.NUM_NODES = args.nodes
    if args.nr != -1:
        cfg.RANK_ID = args.nr

    num_gpus = cfg.NUM_GPUS
    if num_gpus > 1:
        cfg.OPTIMIZER.LR *= num_gpus

    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT_DIR):
        os.makedirs(cfg.OUTPUT_DIR)

    return cfg
コード例 #2
0
ファイル: test_backbone.py プロジェクト: ZJCV/TSM
def test_shufflenet():
    cfg.merge_from_file('configs/tsn_sfv2_ucf101_rgb_raw_dense_1x16x4.yaml')

    model = build_shufflenet_v2(cfg)
    data = torch.randn((1, 3, 224, 224))
    outputs = model(data)

    print(outputs.shape)
    assert outputs.shape == (1, 1024, 7, 7)
コード例 #3
0
ファイル: test_backbone.py プロジェクト: ZJCV/TSM
def test_resnet50():
    cfg.merge_from_file('configs/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml')

    model = build_resnet50(cfg)
    data = torch.randn((1, 3, 224, 224))
    outputs = model(data)

    print(outputs.shape)
    assert outputs.shape == (1, 2048, 7, 7)
コード例 #4
0
def test_jester_rgbdiff():
    cfg.merge_from_file('configs/tsn_r50_jester_rgbdiff_224x3_seg.yaml')

    transform = build_transform(cfg, is_train=True)
    dataset = build_dataset(cfg, transform=transform, is_train=True)
    image, target = dataset.__getitem__(20)
    print(image.shape)
    print(target)

    assert image.shape == (3, 15, 224, 224)
コード例 #5
0
def test_hmdb51_rgb():
    cfg.merge_from_file('configs/tsn_r50_hmdb51_rgb_224x3_seg.yaml')
    cfg.DATASETS.NUM_CLIPS = 8

    transform = build_transform(cfg, is_train=True)
    dataset = build_dataset(cfg, transform=transform, is_train=True)
    image, target = dataset.__getitem__(20)
    print(image.shape)
    print(target)

    assert image.shape == (3, 8, 224, 224)
コード例 #6
0
ファイル: test_slowfast.py プロジェクト: ZJCV/SlowFast
def test_slowonly():
    cfg.merge_from_file('configs/slowonly_r3d50_ucf101_rgb_224x4_dense.yaml')
    cfg.freeze()

    model = resnet3d_50_slowonly(cfg)
    print(model)

    data = torch.randn(1, 3, 4, 224, 224)
    outputs = model(data)
    print(outputs.shape)

    assert outputs.shape == (1, 2048, 4, 7, 7)
コード例 #7
0
def test_i3d():
    data = torch.randn(1, 3, 32, 224, 224)

    cfg.merge_from_file(
        'configs/i3d-3x1-nl_r3d50_ucf101_rgb_224x32_dense.yaml')
    model = resnet3d_50(cfg)
    outputs = model(data)
    print(outputs.shape)
    assert outputs.shape == (1, 2048, 4, 7, 7)

    cfg.merge_from_file('configs/i3d-3x1_r3d50_ucf101_rgb_224x32_dense.yaml')
    model = resnet3d_50(cfg)
    outputs = model(data)
    print(outputs.shape)
    assert outputs.shape == (1, 2048, 4, 7, 7)

    cfg.merge_from_file(
        'configs/i3d-3x3-nl_r3d50_ucf101_rgb_224x32_dense.yaml')
    model = resnet3d_50(cfg)
    outputs = model(data)
    print(outputs.shape)
    assert outputs.shape == (1, 2048, 4, 7, 7)

    cfg.merge_from_file('configs/i3d-3x3_r3d50_ucf101_rgb_224x32_dense.yaml')
    model = resnet3d_50(cfg)
    outputs = model(data)
    print(outputs.shape)
    assert outputs.shape == (1, 2048, 4, 7, 7)
コード例 #8
0
def load_config(args):
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    if args.gpus > 1:
        cfg.OPTIMIZER.LR *= args.gpus
        cfg.OPTIMIZER.WEIGHT_DECAY *= args.gpus
        cfg.LR_SCHEDULER.COSINE_ANNEALING_LR.MINIMAL_LR *= args.gpus

    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)

    return cfg
コード例 #9
0
ファイル: test_slowfast.py プロジェクト: ZJCV/SlowFast
def test_slowfast():
    cfg.merge_from_file('configs/slowfast_r3d50_ucf101_rgb_224x32_dense.yaml')
    cfg.freeze()

    model = resnet3d_50_slowfast(cfg)
    print(model)

    data = torch.randn(1, 3, 32, 224, 224)
    outputs = model(data)
    print(len(outputs))
    print(outputs[0].shape)
    print(outputs[1].shape)

    assert len(outputs) == 2
    assert outputs[0].shape == (1, 2048, 4, 7, 7)
    assert outputs[1].shape == (1, 256, 32, 7, 7)
コード例 #10
0
def load_test_config(args):
    if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained):
        raise ValueError('需要输入配置文件和预训练模型路径')

    cfg.merge_from_file(args.config_file)
    cfg.MODEL.PRETRAINED = args.pretrained
    cfg.OUTPUT_DIR = args.output

    if args.gpus != -1:
        cfg.NUM_GPUS = args.gpus
    if args.nodes != -1:
        cfg.NODES = args.nodes
    if args.nr != -1:
        cfg.RANK = args.nr
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT_DIR):
        os.makedirs(cfg.OUTPUT_DIR)

    return cfg
コード例 #11
0
ファイル: compute_flops.py プロジェクト: ZJCV/TSM
def main(data_shape, config_file, mobile_name):
    cfg.merge_from_file(config_file)

    device = get_device(local_rank=get_local_rank())
    model = build_recognizer(cfg, device)
    data = torch.randn(data_shape).to(device=device, non_blocking=True)

    GFlops, params_size = compute_num_flops(model, data)
    print(f'{mobile_name} ' + '*' * 10)
    print(f'device: {device}')
    print(f'GFlops: {GFlops}')
    print(f'Params Size: {params_size}')

    total_time = 0.0
    num = 100
    for i in range(num):
        data = torch.randn(data_shape).to(device=device, non_blocking=True)
        start = time.time()
        model(data)
        total_time = time.time() - start
    print(f'one process need {total_time / num}')
コード例 #12
0
ファイル: test.py プロジェクト: ZJCV/SlowFast
def main():
    parser = argparse.ArgumentParser(description='TSN Test With PyTorch')
    parser.add_argument("config_file", default="", metavar="CONFIG_FILE",
                        help="path to config file", type=str)
    parser.add_argument('pretrained', default="", metavar='PRETRAINED_FILE',
                        help="path to pretrained model", type=str)
    parser.add_argument('--output', default="./outputs/test", type=str)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    if not os.path.isfile(args.config_file) or not os.path.isfile(args.pretrained):
        raise ValueError('需要输入配置文件和预训练模型路径')

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.PRETRAINED = args.pretrained
    cfg.OUTPUT.DIR = args.output
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    test(cfg)
コード例 #13
0
def main(data_shape, config_file, mobile_name):
    cfg.merge_from_file(config_file)

    device = get_device(local_rank=get_local_rank())
    model = build_recognizer(cfg, device)
    model.eval()
    data = torch.randn(data_shape).to(device=device, non_blocking=True)

    GFlops, params_size = compute_num_flops(model, data)
    print(f'{mobile_name} ' + '*' * 10)
    print(f'device: {device}')
    print(f'GFlops: {GFlops}')
    print(f'Params Size: {params_size}')

    data = torch.randn(data_shape)
    t1 = 0.0
    num = 100
    begin = time.time()
    for i in range(num):
        start = time.time()
        model(data.to(device=device, non_blocking=True))
        t1 += time.time() - start
    t2 = time.time() - begin
    print(f'one process need {t2 / num}, model compute need: {t1 / num}')
コード例 #14
0
ファイル: train.py プロジェクト: ZJCV/TRN
def main():
    parser = argparse.ArgumentParser(description='TSN Training With PyTorch')
    parser.add_argument("--config_file",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Print logs every log_step')
    parser.add_argument('--save_step',
                        default=2500,
                        type=int,
                        help='Save checkpoint every save_step')
    parser.add_argument('--stop_save', default=False, action='store_true')
    parser.add_argument(
        '--eval_step',
        default=2500,
        type=int,
        help='Evaluate dataset every eval_step, disabled when eval_step < 0')
    parser.add_argument('--stop_eval', default=False, action='store_true')
    parser.add_argument('--resume',
                        default=False,
                        action='store_true',
                        help='Resume training')
    parser.add_argument('--use_tensorboard', default=1, type=int)

    parser.add_argument('-n',
                        '--nodes',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of machines (default: 1)')
    parser.add_argument('-g',
                        '--gpus',
                        default=1,
                        type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr',
                        '--nr',
                        default=0,
                        type=int,
                        help='ranking within the nodes')

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if not os.path.exists(cfg.OUTPUT.DIR):
        os.makedirs(cfg.OUTPUT.DIR)
    logger = setup_logger("TSN", save_dir=cfg.OUTPUT.DIR)
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(args.config_file))
    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '14028'
    mp.spawn(train, nprocs=args.gpus, args=(args, cfg))