def main(): global dataset global db db = TrackDatabase(os.path.join(DATASET_FOLDER, 'dataset.hdf5')) dataset = Dataset(db, 'dataset') total_tracks = len(db.get_all_track_ids()) tracks_loaded = dataset.load_tracks(track_filter) print("Loaded {}/{} tracks, found {:.1f}k segments".format( tracks_loaded, total_tracks, len(dataset.segments) / 1000)) for key, value in filtered_stats.items(): if value != 0: print(" {} filtered {}".format(key, value)) print() labels = sorted(list(set(dataset.tracks_by_label.keys()))) dataset.labels = labels show_tracks_breakdown() print() show_segments_breakdown() print() show_cameras_breakdown() print() print("Splitting data set into train / validation") if USE_PREVIOUS_SPLIT: split = get_bin_split('template.dat') datasets = split_dataset_days(split) else: datasets = split_dataset_days() pickle.dump(datasets, open(os.path.join(DATASET_FOLDER, 'datasets.dat'), 'wb'))
def split_dataset_days(prefill_bins=None): """ Randomly selects tracks to be used as the train, validation, and test sets :param prefill_bins: if given will use these bins for the test set :return: tuple containing train, validation, and test datasets. This method assigns tracks into 'label-camera-day' bins and splits the bins across datasets. """ # pick out groups to use for the various sets bins_by_label = {} for label in dataset.labels: bins_by_label[label] = [] for bin_id, tracks in dataset.tracks_by_bin.items(): label = dataset.track_by_id[tracks[0].track_id].label bins_by_label[label].append(bin_id) train = Dataset(db, 'train') validation = Dataset(db, 'validation') test = Dataset(db, 'test') train.labels = dataset.labels.copy() validation.labels = dataset.labels.copy() test.labels = dataset.labels.copy() # check bins distribution bin_segments = [] for bin, tracks in dataset.tracks_by_bin.items(): count = sum(len(track.segments) for track in tracks) bin_segments.append((count, bin)) bin_segments.sort() counts = [count for count, bin in bin_segments] bin_segment_mean = np.mean(counts) bin_segment_std = np.std(counts) max_bin_segments = bin_segment_mean + bin_segment_std * CAP_BIN_WEIGHT print() print("Bin segment mean:{:.1f} std:{:.1f} auto max segments:{:.1f}".format( bin_segment_mean, bin_segment_std, max_bin_segments)) print() used_bins = {} for label in dataset.labels: used_bins[label] = [] if prefill_bins is not None: print("Reusing bins from previous split:") for label in dataset.labels: available_bins = set(bins_by_label[label]) if label not in prefill_bins: continue for sample in prefill_bins[label]: # this happens if we have banned/deleted the clip, but it was previously used. if sample not in dataset.tracks_by_bin: continue # this happens if we changed what a 'heavy' bin is. if is_heavy_bin(sample, max_bin_segments): continue validation.add_tracks(dataset.tracks_by_bin[sample]) test.add_tracks(dataset.tracks_by_bin[sample]) validation.filter_segments(TEST_MIN_MASS, ignore_labels=['false-positive']) test.filter_segments(TEST_MIN_MASS, ignore_labels=['false-positive']) available_bins.remove(sample) used_bins[label].append(sample) for bin_id in available_bins: train.add_tracks(dataset.tracks_by_bin[bin_id]) train.filter_segments(TRAIN_MIN_MASS, ignore_labels=['false-positive']) # assign bins to test and validation sets # if we previously added bins from another dataset we are simply filling in the gaps here. for label in dataset.labels: available_bins = set(bins_by_label[label]) # heavy bins are bins with an unsually high number of examples on a day. We exclude these from the test/validation # set as they will be subfiltered down and there is no need to waste that much data. heavy_bins = set() for bin_id in available_bins: if is_heavy_bin(bin_id, max_bin_segments): heavy_bins.add(bin_id) available_bins -= heavy_bins available_bins -= set(used_bins[label]) # print bin statistics print("{}: normal {} heavy {} pre-filled {}".format( label, len(available_bins), len(heavy_bins), len(used_bins[label]))) required_samples = TEST_SET_COUNT * LABEL_WEIGHTS.get(label, 1.0) required_bins = TEST_SET_BINS * LABEL_WEIGHTS.get( label, 1.0) # make sure there is some diversity required_bins = max(4, required_bins) # we assign bins to the test and validation sets randomly until we have enough segments + bins # the remaining bins can be used for training while len(available_bins) > 0 and \ (validation.get_class_segments_count(label) < required_samples or len(used_bins[label]) < required_bins): sample = random.sample(available_bins, 1)[0] validation.add_tracks(dataset.tracks_by_bin[sample]) test.add_tracks(dataset.tracks_by_bin[sample]) validation.filter_segments(TEST_MIN_MASS, ignore_labels=['false-positive']) test.filter_segments(TEST_MIN_MASS, ignore_labels=['false-positive']) available_bins.remove(sample) used_bins[label].append(sample) if prefill_bins is not None: print(" - required added adddtional sample ", sample) available_bins.update(heavy_bins) for bin_id in available_bins: train.add_tracks(dataset.tracks_by_bin[bin_id]) train.filter_segments(TRAIN_MIN_MASS, ignore_labels=['false-positive']) print("Segments per class:") print("-" * 90) print("{:<20} {:<21} {:<21} {:<21}".format("Class", "Train", "Validation", "Test")) print("-" * 90) # if we have lots of segments on a single day, reduce the weight so we don't overtrain on this specific # example. train.balance_bins(max_bin_segments) validation.balance_bins(bin_segment_mean + bin_segment_std * CAP_BIN_WEIGHT) # balance out the classes train.balance_weights(weight_modifiers=LABEL_WEIGHTS) validation.balance_weights(weight_modifiers=LABEL_WEIGHTS) test.balance_resample(weight_modifiers=LABEL_WEIGHTS, required_samples=TEST_SET_COUNT) # display the dataset summary for label in dataset.labels: print("{:<20} {:<20} {:<20} {:<20}".format( label, "{}/{}/{}/{:.1f}".format(*train.get_counts(label)), "{}/{}/{}/{:.1f}".format(*validation.get_counts(label)), "{}/{}/{}/{:.1f}".format(*test.get_counts(label)), )) print() return train, validation, test