def get_train_length(wordnet_ids): mode = get_mode('train') n = 0 logger.info('Getting number of train examples...') bar = IncrementalBar(max=len(wordnet_ids)) for wordnet_id in wordnet_ids: n += get_tar_length(get_tar_path(mode, wordnet_id)) bar.next() bar.finish() return n
def get_test_records(delete_tar): mode = get_mode('test') with tarfile.TarFile(get_tar_path(mode)) as tar: for member in tar.getmembers(): filename = member.name image = tar.extractfile(member) yield image, filename image.close() maybe_delete(mode, None, delete_tar)
def get_train_records(wordnet_ids, delete_tar=False): class_indices = meta.get_wordnet_indices(zero_based=ZERO_BASED) mode = get_mode('train') for wordnet_id in wordnet_ids: target = class_indices[wordnet_id] with tarfile.TarFile(get_tar_path(mode, wordnet_id)) as tar: for member in tar.getmembers(): filename = member.name image = tar.extractfile(member) yield image, filename, target image.close() maybe_delete(mode, wordnet_id, delete_tar)
def get_val_records(delete_tar=False): from imagenet.meta import load_val_labels mode = get_mode('val') labels = load_val_labels(zero_based=ZERO_BASED) with tarfile.TarFile(get_tar_path(mode)) as tar: for member in tar.getmembers(): filename = member.name index = filename.split('_')[-1][:-5] target = labels[int(index)-1] image = tar.extractfile(member) yield image, filename, target image.close() maybe_delete(mode, None, delete_tar)
def convert_other(mode, delete_tar, overwrite=False): mode = get_mode(mode) check_overwrite(mode, overwrite) n_examples = get_tar_length(get_tar_path(mode)) with get_file(mode, 'w') as fp: if mode == 'val': write_examples( fp, n_examples, include_targets=True, records=get_val_records(delete_tar), shuffle=False) elif mode == 'test': write_examples( fp, n_examples, include_targets=False, records=get_test_records(delete_tar), shuffle=False) else: raise ValueError('Invalid mode: "%s"' % mode)
def maybe_delete(mode, wordnet_id, delete_tar): if delete_tar: tar_path = get_tar_path(mode, wordnet_id) logger.info('Removing tar file "%s"' % tar_path) os.remove(tar_path)