def webvision_subcls(train=True, cls=50, per_cls=-1):
    dataset_root = os.path.join(root, 'webvision')
    xs = []
    ys = []
    if train:
        with open(os.path.join(dataset_root, 'info',
                               'train_filelist_google.txt'),
                  'r',
                  encoding='utf-8') as r:
            for line in r:
                x, y = line.split(' ')
                y = int(y)
                if y == cls:
                    break
                x = os.path.join(os.path.abspath(dataset_root), x)
                xs.append(x)
                ys.append(y)
    else:
        with open(os.path.join(dataset_root, 'info', 'val_filelist.txt'),
                  'r',
                  encoding='utf-8') as r:
            for line in r:
                x, y = line.split(' ')
                y = int(y)
                if y == cls:
                    break
                x = os.path.join(os.path.abspath(dataset_root),
                                 'val_images_256', x)
                xs.append(x)
                ys.append(y)

    return llist(xs), np.array(ys)
示例#2
0
def test_llist():
    from thexp.base_classes import llist
    import numpy as np

    k = llist([1,2,3,4])
    assert k[np.array([0,1])] == [1,2]
    assert k[np.array(1)] == 2
示例#3
0
def all_memory_cached(device_id: Union[List[int], int] = None, process=False):
    from thexp.base_classes import llist
    match_mem = re.compile('([0-9]+[a-zA-Z]+) \/ ([0-9]+[a-zA-Z]+)')
    proc = subprocess.Popen(['nvidia-smi'], stdout=subprocess.PIPE)

    count = device_count()
    res = llist()
    for i, line in enumerate(proc.stdout.readlines()):
        if i < 7:
            continue

        if i % 3 == 2:
            line = line.decode('utf-8').strip()
            res_ = re.findall(match_mem, line)
            res.append(res_[1][0])
            count -= 1
        if count == 0:
            break

    if device_id is not None:
        res = res[device_id]

    if process:
        # TODO 待验证
        match_num = re.compile('([0-9]+)')
        for i in range(len(res)):
            res[i] = re.findall(match_num, res[i])

    return res
示例#4
0
def train_val_split(target, val_size=10000, train_size=None):
    import numpy as np
    size = len(target)
    idx = np.arange(size)
    if isinstance(target, list):
        from thexp.base_classes import llist
        target = llist(target)

    idx = np.arange(len(target))
    np.random.shuffle(idx)

    if train_size is not None:
        assert size > val_size + train_size, "should less than {}, but {}".format(
            size, train_size + val_size)
        return idx[val_size:val_size + train_size], idx[:val_size]

    return idx[val_size:], idx[:val_size]
示例#5
0
def cifar10(train=True):
    dataset = CIFAR10(root=root, train=train)
    xs = llist(Image.fromarray(i) for i in dataset.data)
    ys = np.array(int(i) for i in dataset.targets)

    return xs, ys
示例#6
0
def fashionmnist(train=True):
    dataset = FashionMNIST(root=root, train=train)
    xs = llist(Image.fromarray(img.numpy(), mode='L') for img in dataset.data)
    ys = llist(int(i) for i in dataset.targets)

    return xs, ys
示例#7
0
def svhn(train=True):
    dataset = SVHN(root=root, split='train' if train else 'test')
    xs = llist(Image.fromarray(np.transpose(img, (1, 2, 0))) for img in dataset.data)
    ys = llist(int(i) for i in dataset.labels)

    return xs, ys