def split_dataset(db, dataset, build_config, prefill_dataset=None):
    """
    Randomly selects tracks to be used as the train, validation, and test sets
    :param prefill_bins: if given will use these bins for the test set
    :return: tuple containing train, validation, and test datasets.

    This method assigns tracks into 'label-camera-day' bins and splits the bins across datasets.
    """

    # pick out groups to use for the various sets
    bins_by_label = {}
    used_bins = {}

    for label in dataset.labels:
        bins_by_label[label] = []
        used_bins[label] = []

    counts = []
    for bin_id, tracks in dataset.tracks_by_bin.items():
        label = tracks[0].label
        bins_by_label[tracks[0].label].append(bin_id)
        counts.append(sum(len(track.segments) for track in tracks))

    train = Dataset(db, "train")
    # 10 cameras
    # 5 tests
    # 5 train
    # then change the names
    validation = Dataset(db, "validation")
    test = Dataset(db, "test")

    bin_segment_mean = np.mean(counts)
    bin_segment_std = np.std(counts)
    max_bin_segments = bin_segment_mean + bin_segment_std * build_config.cap_bin_weight

    print_bin_segment_stats(bin_segment_mean, bin_segment_std, max_bin_segments)

    max_track_duration = build_config.max_validation_set_track_duration
    if prefill_dataset is not None:
        prefill_bins(
            dataset,
            [validation, test],
            prefill_dataset,
            used_bins,
            max_bin_segments,
            max_track_duration,
        )

    required_samples = build_config.test_set_count
    required_bins = max(MIN_BINS, build_config.test_set_bins)

    # assign bins to test and validation sets
    # if we previously added bins from another dataset we are simply filling in the gaps here.
    for label in dataset.labels:
        available_bins = set(bins_by_label[label]) - set(used_bins[label])

        normal_bins, heavy_bins = dataset.split_heavy_bins(
            available_bins, max_bin_segments, max_track_duration
        )

        print_bin_stats(label, normal_bins, heavy_bins, used_bins)

        add_random_samples(
            dataset,
            [validation, test],
            normal_bins,
            used_bins[label],
            label,
            required_samples,
            required_bins,
        )

        normal_bins.extend(heavy_bins)
        for bin_id in normal_bins:
            train.add_tracks(dataset.tracks_by_bin[bin_id])

    # if we have lots of segments on a single day, reduce the weight
    # so we don't overtrain on this specific example.
    train.balance_bins(max_bin_segments)
    validation.balance_bins(max_bin_segments)
    # balance out the classes
    train.balance_weights()
    validation.balance_weights()

    test.balance_resample(required_samples=build_config.test_set_count)

    print_segments(dataset, train, validation, test)

    return train, validation, test
Beispiel #2
0
def split_dataset_days(prefill_bins=None):
    """
    Randomly selects tracks to be used as the train, validation, and test sets
    :param prefill_bins: if given will use these bins for the test set
    :return: tuple containing train, validation, and test datasets.

    This method assigns tracks into 'label-camera-day' bins and splits the bins across datasets.
    """

    # pick out groups to use for the various sets
    bins_by_label = {}
    for label in dataset.labels:
        bins_by_label[label] = []

    for bin_id, tracks in dataset.tracks_by_bin.items():
        label = dataset.track_by_id[tracks[0].track_id].label
        bins_by_label[label].append(bin_id)

    train = Dataset(db, 'train')
    validation = Dataset(db, 'validation')
    test = Dataset(db, 'test')

    train.labels = dataset.labels.copy()
    validation.labels = dataset.labels.copy()
    test.labels = dataset.labels.copy()

    # check bins distribution
    bin_segments = []
    for bin, tracks in dataset.tracks_by_bin.items():
        count = sum(len(track.segments) for track in tracks)
        bin_segments.append((count, bin))
    bin_segments.sort()

    counts = [count for count, bin in bin_segments]
    bin_segment_mean = np.mean(counts)
    bin_segment_std = np.std(counts)
    max_bin_segments = bin_segment_mean + bin_segment_std * CAP_BIN_WEIGHT

    print()
    print("Bin segment mean:{:.1f} std:{:.1f} auto max segments:{:.1f}".format(
        bin_segment_mean, bin_segment_std, max_bin_segments))
    print()

    used_bins = {}

    for label in dataset.labels:
        used_bins[label] = []

    if prefill_bins is not None:
        print("Reusing bins from previous split:")
        for label in dataset.labels:
            available_bins = set(bins_by_label[label])
            if label not in prefill_bins:
                continue
            for sample in prefill_bins[label]:
                # this happens if we have banned/deleted the clip, but it was previously used.
                if sample not in dataset.tracks_by_bin:
                    continue
                # this happens if we changed what a 'heavy' bin is.
                if is_heavy_bin(sample, max_bin_segments):
                    continue

                validation.add_tracks(dataset.tracks_by_bin[sample])
                test.add_tracks(dataset.tracks_by_bin[sample])
                validation.filter_segments(TEST_MIN_MASS,
                                           ignore_labels=['false-positive'])
                test.filter_segments(TEST_MIN_MASS,
                                     ignore_labels=['false-positive'])

                available_bins.remove(sample)
                used_bins[label].append(sample)

            for bin_id in available_bins:
                train.add_tracks(dataset.tracks_by_bin[bin_id])
                train.filter_segments(TRAIN_MIN_MASS,
                                      ignore_labels=['false-positive'])

    # assign bins to test and validation sets
    # if we previously added bins from another dataset we are simply filling in the gaps here.
    for label in dataset.labels:

        available_bins = set(bins_by_label[label])

        # heavy bins are bins with an unsually high number of examples on a day.  We exclude these from the test/validation
        # set as they will be subfiltered down and there is no need to waste that much data.
        heavy_bins = set()
        for bin_id in available_bins:
            if is_heavy_bin(bin_id, max_bin_segments):
                heavy_bins.add(bin_id)

        available_bins -= heavy_bins
        available_bins -= set(used_bins[label])

        # print bin statistics
        print("{}: normal {} heavy {} pre-filled {}".format(
            label, len(available_bins), len(heavy_bins),
            len(used_bins[label])))

        required_samples = TEST_SET_COUNT * LABEL_WEIGHTS.get(label, 1.0)
        required_bins = TEST_SET_BINS * LABEL_WEIGHTS.get(
            label, 1.0)  # make sure there is some diversity
        required_bins = max(4, required_bins)

        # we assign bins to the test and validation sets randomly until we have enough segments + bins
        # the remaining bins can be used for training
        while len(available_bins) > 0 and \
                (validation.get_class_segments_count(label) < required_samples or len(used_bins[label]) < required_bins):

            sample = random.sample(available_bins, 1)[0]

            validation.add_tracks(dataset.tracks_by_bin[sample])
            test.add_tracks(dataset.tracks_by_bin[sample])

            validation.filter_segments(TEST_MIN_MASS,
                                       ignore_labels=['false-positive'])
            test.filter_segments(TEST_MIN_MASS,
                                 ignore_labels=['false-positive'])

            available_bins.remove(sample)
            used_bins[label].append(sample)

            if prefill_bins is not None:
                print(" - required added adddtional sample ", sample)

        available_bins.update(heavy_bins)

        for bin_id in available_bins:
            train.add_tracks(dataset.tracks_by_bin[bin_id])
            train.filter_segments(TRAIN_MIN_MASS,
                                  ignore_labels=['false-positive'])

    print("Segments per class:")
    print("-" * 90)
    print("{:<20} {:<21} {:<21} {:<21}".format("Class", "Train", "Validation",
                                               "Test"))
    print("-" * 90)

    # if we have lots of segments on a single day, reduce the weight so we don't overtrain on this specific
    # example.
    train.balance_bins(max_bin_segments)
    validation.balance_bins(bin_segment_mean +
                            bin_segment_std * CAP_BIN_WEIGHT)

    # balance out the classes
    train.balance_weights(weight_modifiers=LABEL_WEIGHTS)
    validation.balance_weights(weight_modifiers=LABEL_WEIGHTS)
    test.balance_resample(weight_modifiers=LABEL_WEIGHTS,
                          required_samples=TEST_SET_COUNT)

    # display the dataset summary
    for label in dataset.labels:
        print("{:<20} {:<20} {:<20} {:<20}".format(
            label,
            "{}/{}/{}/{:.1f}".format(*train.get_counts(label)),
            "{}/{}/{}/{:.1f}".format(*validation.get_counts(label)),
            "{}/{}/{}/{:.1f}".format(*test.get_counts(label)),
        ))
    print()

    return train, validation, test