def train_places_low_shot( k_values: List[int], sample_inds: List[int], output_dir: str, layername: str, cfg: AttrDict, ): low_shot_trainer = SVMLowShotTrainer(cfg["SVM"], layer=layername, output_dir=output_dir) # we have extracted the features in the running_tasks = [ mp.Process( target=train_sample_places_low_shot, args=( low_shot_trainer, k_values, sample_inds, sample_num, output_dir, layername, cfg, ), ) for sample_num in sample_inds ] for running_task in running_tasks: running_task.start() for running_task in running_tasks: running_task.join() results = low_shot_trainer.aggregate_stats(k_values, sample_inds) logging.info(f"All Done for layer: {layername}") return results
def train_voc07_low_shot( k_values: List[int], sample_inds: List[int], output_dir: str, layername: str, cfg: AttrDict, ): dataset_name = cfg["SVM"]["low_shot"]["dataset_name"] low_shot_trainer = SVMLowShotTrainer(cfg["SVM"], layer=layername, output_dir=output_dir) train_data = merge_features(output_dir, "train", layername) train_features, train_targets = train_data["features"], train_data[ "targets"] test_data = merge_features(output_dir, "test", layername) test_features, test_targets = test_data["features"], test_data["targets"] # now we want to create the low-shot samples based on the kind of dataset. # We only create low-shot samples for training. We test on the full dataset. generate_low_shot_samples(dataset_name, train_targets, k_values, sample_inds, output_dir, layername) # Now, we train and test the low-shot SVM for every sample and k-value. for sample_num in sample_inds: for low_shot_kvalue in k_values: train_targets = load_file( f"{output_dir}/{layername}_sample{sample_num}_k{low_shot_kvalue}.npy" ) low_shot_trainer.train(train_features, train_targets, sample_num, low_shot_kvalue) low_shot_trainer.test(test_features, test_targets, sample_num, low_shot_kvalue) # now we aggregate the stats across all independent samples and for each # k-value and report mean/min/max/std stats results = low_shot_trainer.aggregate_stats(k_values, sample_inds) logging.info("All Done!") return results