def get_sample_names(split, split_num): """ Args: split (str): split ("train" or "test") split_num (int): split plan num Return: list of strings: a list of sample names of all samples of then given split plan. """ assert isinstance(split, str), TypeError assert split in ["train", "test"], ValueError("Invalid split name") assert isinstance(split_num, int), TypeError assert split_num in (1, 2, 3), ValueError("Invalid split numer") split_file = "{}list{:02d}.txt".format(split, split_num) split_path = os.path.join(metadata_dir_path, split_file) # download metadata if there is no local cache if not os.path.exists(split_path): split_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, metadata_dir, split_file) download(split_src, split_path) results = [] with open(split_path, "r") as fin: for _line in fin: _text = _line.split('\n')[0] # remove \n _path = _text.split(' ')[0] # remove class id _file = _path.split('/')[1] # remove directory _name = _file.split('.')[0] # remove file extension results.append(_name) return results
def __init__(self, root, train, split=1, transform=None, target_transform=None): root = os.path.expanduser(root) super(HMDB51, self).__init__(root=root, transform=transform, target_transform=target_transform) # -------------------- # # load datapoints # # -------------------- # # assemble paths if not (os.path.exists(CACHE_DIR) and os.path.isdir(CACHE_DIR)): os.makedirs(CACHE_DIR, exist_ok=True) if train: datapoint_file_name = "hmdb51_training_split{}.pkl".format(split) else: datapoint_file_name = "hmdb51_testing_split{}.pkl".format(split) datapoint_file_path = os.path.join(CACHE_DIR, datapoint_file_name) # download when missing if not os.path.exists(datapoint_file_path): print("downloading HMDB51 datapoints...") download(src=os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, datapoint_file_name), dst=datapoint_file_path, backend="rsync") # real load with open(datapoint_file_path, "rb") as fin: self.datapoints = pickle.load(fin) assert isinstance(self.datapoints, list), TypeError assert isinstance(self.datapoints[0], DataPoint), TypeError # replace dataset root for dp in self.datapoints: dp.root = root dp._path = dp.path # ------------------ # # load class_to_idx # # ------------------ # # download labels label_file = "hmdb51_labels.txt" label_path = os.path.join(CACHE_DIR, label_file) if not os.path.exists(label_path): print("downloading HMDB51 label_path...") label_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, label_file) download(label_src, label_path, backend="rsync") # build class label to class id mapping (a dictionary) self.class_to_idx = collections.OrderedDict() with open(label_path, "r") as fin: for _line in fin: text = _line.split('\n')[0] text = text.split(' ') self.class_to_idx[text[1]] = int(text[0]) - 1
def test_video2ndarray_mp4(): vpath = os.path.join(DIR_PATH, "test.mp4") if not os.path.exists(vpath): mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, "test.mp4") download(mp4_src, vpath) # read video to varray varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB") print(varray.shape)
def test_ndarray2video_mp4(): vpath = os.path.join(DIR_PATH, "test.mp4") if not os.path.exists(vpath): mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, "test.mp4") download(mp4_src, vpath) # read video to varray varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB") # dump varray to video vpath = os.path.join(DIR_PATH, "ndarray2video.mp4") backend.ndarray2video(varray, vpath)
def test_ndarray2frames_mp4(): vpath = os.path.join(DIR_PATH, "test.mp4") if not os.path.exists(vpath): mp4_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, "test.mp4") download(mp4_src, vpath) # read video to varray varray = backend.video2ndarray(vpath, cin="BGR", cout="RGB") # dump varray to frames fdir = os.path.join(DIR_PATH, "ndarray2frames.frames.mp4.d") ret, f_n = backend.ndarray2frames(varray, fdir, cin="RGB", cout="BGR") print('Dumping frames from varray finished, {} frames'.format(f_n))
if not os.path.exists(metadata_dir_path): os.makedirs(metadata_dir_path, exist_ok=True) else: assert os.path.isdir(metadata_dir_path) # download metadata if there is no local cache label_file = "classInd.txt" label_path = os.path.join(metadata_dir_path, label_file) if not os.path.exists(label_path): label_src = os.path.join( DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, metadata_dir, label_file ) download(label_src, label_path) # build class label to class id mapping (a dictionary) class_to_idx = collections.OrderedDict() with open(label_path, "r") as fin: for _line in fin: text = _line.split('\n')[0] text = text.split(' ') class_to_idx[text[1]] = int(text[0]) - 1 if __name__ == "__main__": print(class_to_idx) cls2idx_path = os.path.join(DIR_PATH, "ucf101_class_to_idx.pkl") with open(cls2idx_path, "wb") as fout:
from torchstream.utils.download import download DOWNLOAD_SERVER_PREFIX = ("[email protected]:" "/home/eecs/zhen/video-acc/download") DOWNLOAD_SRC_DIR = "tools/datasets/metadata/hmdb51" FILE_PATH = os.path.realpath(__file__) DIR_PATH = os.path.dirname(FILE_PATH) # download split annotations split_annot_dir = "testTrainMulti_7030_splits" split_dir_path = os.path.join(DIR_PATH, split_annot_dir) if not os.path.exists(split_dir_path): split_dir_src = os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, split_annot_dir) download(split_dir_src, split_dir_path) def test_sample_names(split_num): """ Args: split_num (int): split plan num Return: list of strings: a list of sample names of all testing samples of then given split plan. """ assert isinstance(split_num, int), TypeError assert split_num in (1, 2, 3), ValueError("Invalid split numer") glob_str = os.path.join(split_dir_path, "*_test_split{}.txt".format(split_num))
def __init__(self, root, train, split=1, class_to_idx=None, ext="avi", transform=None, target_transform=None): root = os.path.expanduser(root) super(UCF101, self).__init__(root=root, transform=transform, target_transform=target_transform) # -------------------- # # load datapoints # # -------------------- # # assemble paths if not (os.path.exists(CACHE_DIR) and os.path.isdir(CACHE_DIR)): os.makedirs(CACHE_DIR, exist_ok=True) if train: datapoint_filename = "ucf101_training_split{}.pkl".format(split) else: datapoint_filename = "ucf101_testing_split{}.pkl".format(split) datapoint_filepath = os.path.join(CACHE_DIR, datapoint_filename) # download when missing if not os.path.exists(datapoint_filepath): print("downloading UCF101 datapoints...") download(src=os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, datapoint_filename), dst=datapoint_filepath, backend="rsync") # real load with open(datapoint_filepath, "rb") as fin: self.datapoints = pickle.load(fin) assert isinstance(self.datapoints, list), TypeError assert isinstance(self.datapoints[0], DataPoint), TypeError # replace dataset root for dp in self.datapoints: dp.root = root dp.ext = ext dp._path = dp.path # ------------------ # # load class_to_idx # # ------------------ # if class_to_idx is not None: self.class_to_idx = class_to_idx else: class_to_idx_filename = "ucf101_class_to_idx.pkl" class_to_idx_filepath = os.path.join(CACHE_DIR, class_to_idx_filename) # download when missing if not os.path.exists(class_to_idx_filepath): print("downloading UCF101 class_to_idx...") download(src=os.path.join(DOWNLOAD_SERVER_PREFIX, DOWNLOAD_SRC_DIR, class_to_idx_filename), dst=class_to_idx_filepath, backend="rsync") # load class_to_idx with open(class_to_idx_filepath, "rb") as fin: self.class_to_idx = pickle.load(fin) # sanity check # print(self.class_to_idx) assert isinstance(self.class_to_idx, dict), TypeError