예제 #1
0
def get_sample_names(split, split_num):
    """
    Args:
        split (str): split ("train" or "test")
        split_num (int): split plan num
    Return:
        list of strings: a list of sample names of all samples of
        then given split plan.
    """
    assert isinstance(split, str), TypeError
    assert split in ["train", "test"], ValueError("Invalid split name")
    assert isinstance(split_num, int), TypeError
    assert split_num in (1, 2, 3), ValueError("Invalid split numer")

    split_file = "{}list{:02d}.txt".format(split, split_num)
    split_path = os.path.join(metadata_dir_path, split_file)

    # download metadata if there is no local cache
    if not os.path.exists(split_path):
        split_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                                 metadata_dir, split_file)
        download(split_src, split_path)

    results = []
    with open(split_path, "r") as fin:
        for _line in fin:
            _text = _line.split('\n')[0]  # remove \n
            _path = _text.split(' ')[0]  # remove class id
            _file = _path.split('/')[1]  # remove directory
            _name = _file.split('.')[0]  # remove file extension
            results.append(_name)

    return results
예제 #2
0
    def __init__(self,
                 root,
                 train,
                 split=1,
                 transform=None,
                 target_transform=None):
        root = os.path.expanduser(root)

        super(HMDB51, self).__init__(root=root,
                                     transform=transform,
                                     target_transform=target_transform)
        # -------------------- #
        #   load datapoints    #
        # -------------------- #

        # assemble paths
        if not (os.path.exists(CACHE_DIR) and os.path.isdir(CACHE_DIR)):
            os.makedirs(CACHE_DIR, exist_ok=True)
        if train:
            datapoint_file_name = "hmdb51_training_split{}.pkl".format(split)
        else:
            datapoint_file_name = "hmdb51_testing_split{}.pkl".format(split)
        datapoint_file_path = os.path.join(CACHE_DIR, datapoint_file_name)
        # download when missing
        if not os.path.exists(datapoint_file_path):
            print("downloading HMDB51 datapoints...")
            download(src=os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                                      datapoint_file_name),
                     dst=datapoint_file_path,
                     backend="rsync")
        # real load
        with open(datapoint_file_path, "rb") as fin:
            self.datapoints = pickle.load(fin)
            assert isinstance(self.datapoints, list), TypeError
            assert isinstance(self.datapoints[0], DataPoint), TypeError
        # replace dataset root
        for dp in self.datapoints:
            dp.root = root
            dp._path = dp.path

        # ------------------ #
        #  load class_to_idx #
        # ------------------ #
        # download labels
        label_file = "hmdb51_labels.txt"
        label_path = os.path.join(CACHE_DIR, label_file)
        if not os.path.exists(label_path):
            print("downloading HMDB51 label_path...")
            label_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                                     label_file)
            download(label_src, label_path, backend="rsync")
        # build class label to class id mapping (a dictionary)
        self.class_to_idx = collections.OrderedDict()
        with open(label_path, "r") as fin:
            for _line in fin:
                text = _line.split('\n')[0]
                text = text.split(' ')
                self.class_to_idx[text[1]] = int(text[0]) - 1
예제 #3
0
def test_video2ndarray_mp4():
    vpath = os.path.join(DIR_PATH, "test.mp4")
    if not os.path.exists(vpath):
        mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX,
                               DOWNLOAD_SRC_DIR,
                               "test.mp4")
        download(mp4_src, vpath)

    # read video to varray
    varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB")
    print(varray.shape)
예제 #4
0
def test_ndarray2video_mp4():
    vpath = os.path.join(DIR_PATH, "test.mp4")
    if not os.path.exists(vpath):
        mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                               "test.mp4")
        download(mp4_src, vpath)

    # read video to varray
    varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB")
    # dump varray to video
    vpath = os.path.join(DIR_PATH, "ndarray2video.mp4")
    backend.ndarray2video(varray, vpath)
예제 #5
0
def test_ndarray2frames_mp4():
    vpath = os.path.join(DIR_PATH, "test.mp4")
    if not os.path.exists(vpath):
        mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                               "test.mp4")
        download(mp4_src, vpath)

    # read video to varray
    varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB")
    # dump varray to frames
    fdir = os.path.join(DIR_PATH, "ndarray2frames.frames.mp4.d")
    ret, f_n = backend.ndarray2frames(varray, fdir, cin="RGB", cout="BGR")
    print('Dumping frames from varray finished, {} frames'.format(f_n))
예제 #6
0
if not os.path.exists(metadata_dir_path):
    os.makedirs(metadata_dir_path, exist_ok=True)
else:
    assert os.path.isdir(metadata_dir_path)

# download metadata if there is no local cache
label_file = "classInd.txt"
label_path = os.path.join(metadata_dir_path, label_file)
if not os.path.exists(label_path):
    label_src = os.path.join(
        DOWNLOAD_SERVER_PREFIX,
        DOWNLOAD_SRC_DIR,
        metadata_dir,
        label_file
    )
    download(label_src, label_path)

# build class label to class id mapping (a dictionary)
class_to_idx = collections.OrderedDict()
with open(label_path, "r") as fin:
    for _line in fin:
        text = _line.split('\n')[0]
        text = text.split(' ')
        class_to_idx[text[1]] = int(text[0]) - 1


if __name__ == "__main__":
    print(class_to_idx)

    cls2idx_path = os.path.join(DIR_PATH, "ucf101_class_to_idx.pkl")    
    with open(cls2idx_path, "wb") as fout:
예제 #7
0
from torchstream.utils.download import download

DOWNLOAD_SERVER_PREFIX = ("[email protected]:"
                          "/home/eecs/zhen/video-acc/download")
DOWNLOAD_SRC_DIR = "tools/datasets/metadata/hmdb51"

FILE_PATH = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(FILE_PATH)

# download split annotations
split_annot_dir = "testTrainMulti_7030_splits"
split_dir_path = os.path.join(DIR_PATH, split_annot_dir)
if not os.path.exists(split_dir_path):
    split_dir_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                                 split_annot_dir)
    download(split_dir_src, split_dir_path)


def test_sample_names(split_num):
    """
    Args:
        split_num (int): split plan num
    Return:
        list of strings: a list of sample names of all testing samples of
        then given split plan.
    """
    assert isinstance(split_num, int), TypeError
    assert split_num in (1, 2, 3), ValueError("Invalid split numer")

    glob_str = os.path.join(split_dir_path,
                            "*_test_split{}.txt".format(split_num))
예제 #8
0
    def __init__(self,
                 root,
                 train,
                 split=1,
                 class_to_idx=None,
                 ext="avi",
                 transform=None,
                 target_transform=None):
        root = os.path.expanduser(root)

        super(UCF101, self).__init__(root=root,
                                     transform=transform,
                                     target_transform=target_transform)
        # -------------------- #
        #   load datapoints    #
        # -------------------- #

        # assemble paths
        if not (os.path.exists(CACHE_DIR) and os.path.isdir(CACHE_DIR)):
            os.makedirs(CACHE_DIR, exist_ok=True)
        if train:
            datapoint_filename = "ucf101_training_split{}.pkl".format(split)
        else:
            datapoint_filename = "ucf101_testing_split{}.pkl".format(split)
        datapoint_filepath = os.path.join(CACHE_DIR, datapoint_filename)
        # download when missing
        if not os.path.exists(datapoint_filepath):
            print("downloading UCF101 datapoints...")
            download(src=os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR,
                                      datapoint_filename),
                     dst=datapoint_filepath,
                     backend="rsync")
        # real load
        with open(datapoint_filepath, "rb") as fin:
            self.datapoints = pickle.load(fin)
            assert isinstance(self.datapoints, list), TypeError
            assert isinstance(self.datapoints[0], DataPoint), TypeError
        # replace dataset root
        for dp in self.datapoints:
            dp.root = root
            dp.ext = ext
            dp._path = dp.path

        # ------------------ #
        #  load class_to_idx #
        # ------------------ #
        if class_to_idx is not None:
            self.class_to_idx = class_to_idx
        else:
            class_to_idx_filename = "ucf101_class_to_idx.pkl"
            class_to_idx_filepath = os.path.join(CACHE_DIR,
                                                 class_to_idx_filename)
            # download when missing
            if not os.path.exists(class_to_idx_filepath):
                print("downloading UCF101 class_to_idx...")
                download(src=os.path.join(DOWNLOAD_SERVER_PREFIX,
                                          DOWNLOAD_SRC_DIR,
                                          class_to_idx_filename),
                         dst=class_to_idx_filepath,
                         backend="rsync")
            # load class_to_idx
            with open(class_to_idx_filepath, "rb") as fin:
                self.class_to_idx = pickle.load(fin)
        # sanity check
        # print(self.class_to_idx)
        assert isinstance(self.class_to_idx, dict), TypeError