Exemple #1
0
def split_train_val(dataset: TaskSet,
                    val_split: float = 0.1) -> Tuple[TaskSet, TaskSet]:
    """Split train dataset into two datasets, one for training and one for validation.

    :param dataset: A torch dataset, with .x and .y attributes.
    :param val_split: Percentage to allocate for validation, between [0, 1[.
    :return: A tuple a dataset, respectively for train and validation.
    """
    random_state = np.random.RandomState(seed=1)
    indexes = np.arange(len(dataset))
    random_state.shuffle(indexes)

    train_indexes = indexes[int(val_split * len(indexes)):]
    val_indexes = indexes[:int(val_split * len(indexes))]

    x_train, y_train, t_train = dataset.get_raw_samples(train_indexes)
    train_dataset = TaskSet(x_train,
                            y_train,
                            t_train,
                            trsf=dataset.trsf,
                            data_type=dataset.data_type)

    x_val, y_val, t_val = dataset.get_raw_samples(val_indexes)
    val_dataset = TaskSet(x_val,
                          y_val,
                          t_val,
                          trsf=dataset.trsf,
                          data_type=dataset.data_type)

    return train_dataset, val_dataset
Exemple #2
0
def taskset_subset(taskset: TaskSet, indices: np.ndarray) -> TaskSet:
    # x, y, t = taskset.get_raw_samples(indices)
    x, y, t = taskset.get_raw_samples(indices)
    # TODO: Not sure if/how to handle the `bounding_boxes` attribute here.
    bounding_boxes = taskset.bounding_boxes
    if bounding_boxes is not None:
        bounding_boxes = bounding_boxes[indices]
    return replace_taskset_attributes(
        taskset, x=x, y=y, t=t, bounding_boxes=bounding_boxes
    )
Exemple #3
0
def test_get_raw_samples(nb_samples):
    x = np.ones((10, 2, 2, 3))
    y = np.ones((10, ))
    t = np.ones((10, ))

    base_set = TaskSet(x, y, t, None)

    data, y_, t_ = base_set.get_raw_samples(indexes=range(nb_samples))

    assert (x[:nb_samples] == data).all()
    assert (y[:nb_samples] == y_).all()
    assert (t[:nb_samples] == t_).all()