Ejemplo n.º 1
0
def balance_dataset_by_repeating(dataset,
                                 num_classes,
                                 target_size,
                                 upsample=True):
    balanced_samples_indices = get_balanced_sample_indices(
        get_targets(dataset), num_classes, len(dataset)).values()

    if upsample:
        num_samples_per_class = max(
            max(
                len(samples_per_class)
                for samples_per_class in balanced_samples_indices),
            target_size // num_classes)
    else:
        num_samples_per_class = min(
            max(
                len(samples_per_class)
                for samples_per_class in balanced_samples_indices),
            target_size // num_classes)

    def sample_indices(indices, total_length):
        return (torch.randperm(total_length) % len(indices)).tolist()

    balanced_samples_indices = list(
        itertools.chain.from_iterable([[
            samples_per_class[i]
            for i in sample_indices(samples_per_class, num_samples_per_class)
        ] for samples_per_class in balanced_samples_indices]))

    print(
        f"Resampled dataset ({len(dataset)} samples) to a balanced set of {len(balanced_samples_indices)} samples!"
    )

    return balanced_samples_indices
Ejemplo n.º 2
0
def test_get_balanced_samples():
    labels = torch.randint(0, 47, (1000000, ))
    ranges = torch_utils.get_balanced_sample_indices(labels, 47, 2)

    for digit, samples in ranges.items():
        assert len(samples) == 2, f"Failed for digit class {digit}"
        assert all(labels[samples] == digit), f"Failed for digit class {digit}"
Ejemplo n.º 3
0
def get_experiment_data(
    data_source,
    num_classes,
    initial_samples,
    reduced_dataset,
    samples_per_class,
    validation_set_size,
    balanced_test_set,
    balanced_validation_set,
):
    train_dataset, test_dataset, validation_dataset = (
        data_source.train_dataset,
        data_source.test_dataset,
        data_source.validation_dataset,
    )

    active_learning_data = ActiveLearningData(train_dataset)
    if initial_samples is None:
        initial_samples = list(
            itertools.chain.from_iterable(
                get_balanced_sample_indices(
                    get_targets(train_dataset),
                    num_classes=num_classes,
                    n_per_digit=samples_per_class).values()))

    # Split off the validation dataset after acquiring the initial samples.
    active_learning_data.acquire(initial_samples)

    if validation_dataset is None:
        print("Acquiring validation set from training set.")
        if not validation_set_size:
            validation_set_size = len(test_dataset)

        if not balanced_validation_set:
            validation_dataset = active_learning_data.extract_dataset(
                validation_set_size)
        else:
            print("Using a balanced validation set")
            validation_dataset = active_learning_data.extract_dataset_from_indices(
                balance_dataset_by_repeating(
                    active_learning_data.available_dataset,
                    num_classes,
                    validation_set_size,
                    upsample=False))
    else:
        if validation_set_size == 0:
            print("Using provided validation set.")
            validation_set_size = len(validation_dataset)
        if validation_set_size < len(validation_dataset):
            print("Shrinking provided validation set.")
            if not balanced_validation_set:
                validation_dataset = data.Subset(
                    validation_dataset,
                    torch.randperm(len(validation_dataset))
                    [:validation_set_size].tolist())
            else:
                print("Using a balanced validation set")
                validation_dataset = data.Subset(
                    validation_dataset,
                    balance_dataset_by_repeating(validation_dataset,
                                                 num_classes,
                                                 validation_set_size),
                )

    if balanced_test_set:
        print("Using a balanced test set")
        print("Distribution of original test set classes:")
        classes = get_target_bins(test_dataset)
        print(classes)

        test_dataset = data.Subset(
            test_dataset,
            balance_dataset_by_repeating(test_dataset, num_classes,
                                         len(test_dataset)))

    if reduced_dataset:
        # Let's assume we won't use more than 1000 elements for our validation set.
        active_learning_data.extract_dataset(
            len(train_dataset) - max(len(train_dataset) // 20, 5000))
        test_dataset = subrange_dataset.SubrangeDataset(
            test_dataset, 0, max(len(test_dataset) // 10, 5000))
        if validation_dataset:
            validation_dataset = subrange_dataset.SubrangeDataset(
                validation_dataset, 0,
                len(validation_dataset) // 10)
        print("USING REDUCED DATASET!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

    show_class_frequencies = True
    if show_class_frequencies:
        print("Distribution of training set classes:")
        classes = get_target_bins(train_dataset)
        print(classes)

        print("Distribution of validation set classes:")
        classes = get_target_bins(validation_dataset)
        print(classes)

        print("Distribution of test set classes:")
        classes = get_target_bins(test_dataset)
        print(classes)

        print("Distribution of pool classes:")
        classes = get_target_bins(active_learning_data.available_dataset)
        print(classes)

        print("Distribution of active set classes:")
        classes = get_target_bins(active_learning_data.active_dataset)
        print(classes)

    print(f"Dataset info:")
    print(f"\t{len(active_learning_data.active_dataset)} active samples")
    print(f"\t{len(active_learning_data.available_dataset)} available samples")
    print(f"\t{len(validation_dataset)} validation samples")
    print(f"\t{len(test_dataset)} test samples")

    if data_source.shared_transform is not None or data_source.train_transform is not None:
        train_dataset = TransformedDataset(
            active_learning_data.active_dataset,
            vision_transformer=compose_transformers(
                [data_source.train_transform, data_source.shared_transform]),
        )
    else:
        train_dataset = active_learning_data.active_dataset

    if data_source.shared_transform is not None or data_source.scoring_transform is not None:
        available_dataset = TransformedDataset(
            active_learning_data.available_dataset,
            vision_transformer=compose_transformers(
                [data_source.scoring_transform, data_source.shared_transform]),
        )
    else:
        available_dataset = active_learning_data.available_dataset

    if data_source.shared_transform is not None:
        test_dataset = TransformedDataset(
            test_dataset, vision_transformer=data_source.shared_transform)
        validation_dataset = TransformedDataset(
            validation_dataset,
            vision_transformer=data_source.shared_transform)

    return ExperimentData(
        active_learning_data=active_learning_data,
        train_dataset=train_dataset,
        available_dataset=available_dataset,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        initial_samples=initial_samples,
    )
Ejemplo n.º 4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description="Pure training loop without AL",
        formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120),
    )
    parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training")
    parser.add_argument("--scoring_batch_size", type=int, default=256, help="input batch size for scoring")
    parser.add_argument("--test_batch_size", type=int, default=256, help="input batch size for testing")
    parser.add_argument("--validation_set_size", type=int, default=128, help="validation set size")
    parser.add_argument(
        "--early_stopping_patience", type=int, default=1, help="# patience epochs for early stopping per iteration"
    )
    parser.add_argument("--epochs", type=int, default=30, help="number of epochs to train")
    parser.add_argument("--epoch_samples", type=int, default=5056, help="number of epochs to train")
    parser.add_argument("--quickquick", action="store_true", default=False, help="uses a very reduced dataset")
    parser.add_argument(
        "--balanced_validation_set",
        action="store_true",
        default=False,
        help="uses a balanced validation set (instead of randomly picked)"
        "(and if no validation set is provided by the dataset)",
    )
    parser.add_argument("--num_inference_samples", type=int, default=5, help="number of samples for inference")
    parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
    parser.add_argument(
        "--name", type=str, default="results", help="name for the results file (name of the experiment)"
    )
    parser.add_argument("--seed", type=int, default=1, help="random seed")
    parser.add_argument(
        "--train_dataset_limit",
        type=int,
        default=0,
        help="how much of the training set to use for training after splitting off the validation set (0 for all)",
    )
    parser.add_argument(
        "--balanced_training_set",
        action="store_true",
        default=False,
        help="uses a balanced training set (instead of randomly picked)"
        "(and if no validation set is provided by the dataset)",
    )
    parser.add_argument(
        "--balanced_test_set",
        action="store_true",
        default=False,
        help="force balances the test set---use with CAUTION!",
    )
    parser.add_argument(
        "--log_interval", type=int, default=10, help="how many batches to wait before logging training status"
    )
    parser.add_argument(
        "--dataset",
        type=DatasetEnum,
        default=DatasetEnum.mnist,
        help=f"dataset to use (options: {[f.name for f in DatasetEnum]})",
    )
    args = parser.parse_args()

    store = laaos.create_file_store(
        args.name,
        suffix="",
        truncate=False,
        type_handlers=(blackhc.laaos.StrEnumHandler(), blackhc.laaos.ToReprHandler()),
    )
    store["args"] = args.__dict__

    print(args.__dict__)

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using {device} for computations")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    dataset: DatasetEnum = args.dataset

    data_source = dataset.get_data_source()

    reduced_train_length = args.train_dataset_limit

    experiment_data = get_experiment_data(
        data_source,
        dataset.num_classes,
        None,
        False,
        0,
        args.validation_set_size,
        args.balanced_test_set,
        args.balanced_validation_set,
    )

    if not reduced_train_length:
        reduced_train_length = len(experiment_data.available_dataset)

    print(f"Training with reduced dataset of {reduced_train_length} data points")
    if not args.balanced_training_set:
        experiment_data.active_learning_data.acquire(
            experiment_data.active_learning_data.get_random_available_indices(reduced_train_length)
        )
    else:
        print("Using a balanced training set.")
        num_samples_per_class = reduced_train_length // dataset.num_classes
        experiment_data.active_learning_data.acquire(
            list(
                itertools.chain.from_iterable(
                    torch_utils.get_balanced_sample_indices(
                        get_targets(experiment_data.available_dataset), dataset.num_classes, num_samples_per_class
                    ).values()
                )
            )
        )

    if len(experiment_data.train_dataset) < args.epoch_samples:
        sampler = RandomFixedLengthSampler(experiment_data.train_dataset, args.epoch_samples)
    else:
        sampler = data.RandomSampler(experiment_data.train_dataset)

    test_loader = torch.utils.data.DataLoader(
        experiment_data.test_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs
    )
    train_loader = torch.utils.data.DataLoader(
        experiment_data.train_dataset, sampler=sampler, batch_size=args.batch_size, **kwargs
    )

    validation_loader = torch.utils.data.DataLoader(
        experiment_data.validation_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs
    )

    def desc(name):
        return lambda engine: "%s" % name

    dataset.train_model(
        train_loader=train_loader,
        test_loader=test_loader,
        validation_loader=validation_loader,
        num_inference_samples=args.num_inference_samples,
        max_epochs=args.epochs,
        early_stopping_patience=args.early_stopping_patience,
        desc=desc,
        log_interval=args.log_interval,
        device=device,
        epoch_results_store=store,
    )