コード例 #1
0
ファイル: test_transforms.py プロジェクト: ZJCV/Non-local
def test_transforms():
    img = np.random.randn(224, 224, 3).astype(np.uint8)

    test_transform = build_transform(cfg, is_train=False)
    start = time.time()
    res = test_transform(img)
    end = time.time()
    print(res.shape)
    print('process time: {}'.format(end - start))

    train_transform = build_transform(cfg, is_train=True)
    start = time.time()
    res = train_transform(img)
    end = time.time()
    print(res.shape)
    print('process time: {}'.format(end - start))
コード例 #2
0
def test_jester_rgbdiff():
    cfg.merge_from_file('configs/tsn_r50_jester_rgbdiff_224x3_seg.yaml')

    transform = build_transform(cfg, is_train=True)
    dataset = build_dataset(cfg, transform=transform, is_train=True)
    image, target = dataset.__getitem__(20)
    print(image.shape)
    print(target)

    assert image.shape == (3, 15, 224, 224)
コード例 #3
0
def test_hmdb51_rgb():
    cfg.merge_from_file('configs/tsn_r50_hmdb51_rgb_224x3_seg.yaml')
    cfg.DATASETS.NUM_CLIPS = 8

    transform = build_transform(cfg, is_train=True)
    dataset = build_dataset(cfg, transform=transform, is_train=True)
    image, target = dataset.__getitem__(20)
    print(image.shape)
    print(target)

    assert image.shape == (3, 8, 224, 224)
コード例 #4
0
def main():
    is_train = True
    transform = build_transform(cfg, is_train=is_train)
    dataset = build_dataset(cfg, transform=transform, is_train=is_train)

    sampler = SequentialSampler(dataset)
    cfg.SAMPLER.MULTIGRID.DEFAULT_S = cfg.TRANSFORM.TRAIN.TRAIN_CROP_SIZE
    sampler = ShortCycleBatchSampler(sampler, cfg.DATALOADER.TRAIN_BATCH_SIZE, False, cfg)

    print('batch_size:', cfg.DATALOADER.TRAIN_BATCH_SIZE)

    for i, idxs in enumerate(sampler):
        print(idxs)
        print(len(idxs))
        if i > 3:
            break

    print(len(sampler))
コード例 #5
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode): configs. Details can be found in
                tsn/config/defaults.py
            gpu_id (Optional[int]): GPU id.
        """
        if cfg.NUM_GPUS > 0:
            device = get_device(local_rank=get_local_rank())
        else:
            device = get_device()

        # Build the video model and print model statistics.
        self.model = build_model(cfg, device)
        self.model.eval()
        self.transform = build_transform(cfg, is_train=False)

        self.cfg = cfg
        self.device = device
コード例 #6
0
ファイル: webcam_demo.py プロジェクト: ZJCV/X3D
def main():
    global frame_queue, camera, frame, results, threshold, sample_length, \
        data, test_transform, model, device, average_size, label, result_queue, \
        frame_interval

    args = parse_test_args()
    cfg = load_test_config(args)
    average_size = 1
    threshold = 0.5

    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    device = get_device(local_rank=get_local_rank())
    model = build_model(cfg, device)
    model.eval()
    camera = cv2.VideoCapture(cfg.VISUALIZATION.INPUT_VIDEO)

    with open(cfg.VISUALIZATION.LABEL_FILE_PATH, 'r') as f:
        label = [line.strip().split(' ')[1] for line in f]

    # prepare test pipeline from non-camera pipeline
    test_transform = build_transform(cfg, is_train=False)
    sample_length = cfg.DATASETS.CLIP_LEN * cfg.DATASETS.NUM_CLIPS * cfg.DATASETS.FRAME_INTERVAL
    frame_interval = cfg.DATASETS.FRAME_INTERVAL

    assert sample_length > 0

    try:
        frame_queue = deque(maxlen=sample_length)
        result_queue = deque(maxlen=1)
        pw = Thread(target=show_results, args=(), daemon=True)
        pr = Thread(target=inference, args=(), daemon=True)
        pw.start()
        pr.start()
        while True:
            if not pw.is_alive():
                exit(0)
    except KeyboardInterrupt:
        pass