def main(): init_logging() config = load_config() build_config = config.build db = TrackDatabase(os.path.join(config.tracks_folder, "dataset.hdf5")) dataset = Dataset(db, "dataset", config) tracks_loaded, total_tracks = dataset.load_tracks() print( "Loaded {}/{} tracks, found {:.1f}k segments".format( tracks_loaded, total_tracks, len(dataset.segments) / 1000 ) ) for key, value in dataset.filtered_stats.items(): if value != 0: print(" {} filtered {}".format(key, value)) print() show_tracks_breakdown(dataset) print() show_segments_breakdown(dataset) print() show_cameras_breakdown(dataset) print() print("Splitting data set into train / validation") datasets = split_dataset_by_cameras(db, dataset, build_config) # if build_config.use_previous_split: # split = get_previous_validation_bins(build_config.previous_split) # datasets = split_dataset(db, dataset, build_config, split) # else: # datasets = split_dataset(db, dataset, build_config) pickle.dump(datasets, open(dataset_db_path(config), "wb"))
def train_model(run_name, conf, hyper_params): """Trains a model with the given hyper parameters. """ run_name = os.path.join("train", run_name) # a little bit of a pain, the model needs to know how many classes to classify during initialisation, # but we don't load the dataset till after that, so we load it here just to count the number of labels... datasets_filename = dataset_db_path(conf) with open(datasets_filename, "rb") as f: dsets = pickle.load(f) labels = dsets[0].labels model = ModelCRNN_LQ(labels=len(labels), train_config=conf.train, **hyper_params) model.import_dataset(datasets_filename) # display the data set summary print("Training on labels", labels) print() print( "{:<20} {:<20} {:<20} {:<20} (segments/tracks/bins/weight)".format( "label", "train", "validation", "test" ) ) for label in labels: print( "{:<20} {:<20} {:<20} {:<20}".format( label, "{}/{}/{}/{:.1f}".format(*model.datasets.train.get_counts(label)), "{}/{}/{}/{:.1f}".format(*model.datasets.validation.get_counts(label)), "{}/{}/{}/{:.1f}".format(*model.datasets.test.get_counts(label)), ) ) print() for dataset in dsets: print(dataset.labels) print("Training started") print("---------------------") print("Hyper parameters") print("---------------------") print(model.hyperparams_string) print() print("Found {0:.1f}K training examples".format(model.rows / 1000)) print() model.train_model( epochs=conf.train.epochs, run_name=run_name + " " + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), ) model.save() model.close() # this shouldn't be nessesary, but unfortunately my model.close isn't cleaning up everything. # I think it's because i'm adding everything to the default graph? tf.reset_default_graph() return model
def save_eval_model(args): config = Config.load_from_file() datasets_filename = dataset_db_path(config) with open(datasets_filename, "rb") as f: dsets = pickle.load(f) labels = ["hedgehog", "false-positive", "possum", "rodent", "bird"] # this needs to be the same as the source model class model = ModelCRNN_HQ( labels=len(labels), train_config=config.train, training=True, tflite=True, **config.train.hyper_params, ) model.saver = tf.compat.v1.train.Saver(max_to_keep=1000) model.restore_params(os.path.join(args.model_dir, args.model_name)) model.save(os.path.join(args.model_dir, "eval-model")) model.setup_summary_writers("convert")
def main(): init_logging() args = parse_args() config = load_config(args.config_file) db = TrackDatabase(os.path.join(config.tracks_folder, "dataset.hdf5")) dataset = Dataset( db, "dataset", config, consecutive_segments=args.consecutive_segments ) tracks_loaded, total_tracks = dataset.load_tracks(before_date=args.date) print( "Loaded {}/{} tracks, found {:.1f}k segments".format( tracks_loaded, total_tracks, len(dataset.segments) / 1000 ) ) for key, value in dataset.filtered_stats.items(): if value != 0: print(" {} filtered {}".format(key, value)) print() show_tracks_breakdown(dataset) print() show_segments_breakdown(dataset) print() show_important_frames_breakdown(dataset) print() show_cameras_breakdown(dataset) print() print("Splitting data set into train / validation") datasets = split_dataset_by_cameras(db, dataset, config, args) if args.date is None: args.date = datetime.datetime.now(pytz.utc) - datetime.timedelta(days=7) test = test_dataset(db, config, args.date) datasets = (*datasets, test) print_counts(dataset, *datasets) print_cameras(*datasets) pickle.dump(datasets, open(dataset_db_path(config), "wb"))