Ejemplo n.º 1
0
def calculate_emissions(function, textfile):

    tracker = EmissionsTracker()
    tracker.start()
    print(function, "starting")

    function(textfile)
    emissions: float = tracker.stop()
    print(function, textfile, " : ")
    print(f"Emissions: {emissions} kg \n\n")
Ejemplo n.º 2
0
def carbon(request):
    try:
        from codecarbon import EmissionsTracker
    except ImportError:
        yield True  # we do nothing
        return

    tracker = EmissionsTracker("lkpy-tests", 5)
    tracker.start()
    try:
        yield True
    finally:
        emissions = tracker.stop()
        _log.info('test suite used %.3f kgCO2eq', emissions)
Ejemplo n.º 3
0
def main():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = KerasClassifier(build_fn=build_model, epochs=1)
    param_grid = dict(batch_size=list(range(32, 256 + 32, 32)))
    grid = GridSearchCV(estimator=model, param_grid=param_grid)

    tracker = EmissionsTracker(project_name="mnist_grid_search")
    tracker.start()
    grid_result = grid.fit(x_train, y_train)
    emissions = tracker.stop()

    print(
        f"Best Accuracy : {grid_result.best_score_} using {grid_result.best_params_}"
    )
    print(f"Emissions : {emissions} kg CO₂")
Ejemplo n.º 4
0
def main():
    mnist = tf.keras.datasets.mnist

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    tuner = RandomSearchTuner(
        build_model,
        objective="val_accuracy",
        directory="random_search_results",
        project_name="codecarbon",
        max_trials=3,
    )

    tracker = EmissionsTracker(project_name="mnist_random_search")
    tracker.start()
    tuner.search(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
    emissions = tracker.stop()

    print(f"Emissions : {emissions} kg CO₂")
Ejemplo n.º 5
0
    def __init__(self, codecarbon_tracker):
        self.codecarbon_tracker = codecarbon_tracker
        pass

    def on_epoch_end(self, epoch, logs=None):
        self.codecarbon_tracker.flush()


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10),
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

tracker = EmissionsTracker()
tracker.start()
codecarbon_cb = CodeCarbonCallBack(tracker)
model.fit(x_train, y_train, epochs=4, callbacks=[codecarbon_cb])
emissions: float = tracker.stop()
print(f"Emissions: {emissions} kg")
Ejemplo n.º 6
0
        optimizer.step()

    cnn.eval()

    # Measure validation accuracy at each epoch
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loaders["val"]:
            data = images.to(device)
            test_output = cnn(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze().cpu()
            correct += float((pred_y == labels).sum().item())
            total += float(labels.size(0))
        print(f"\nValidation Accuracy: {correct / total:.3f}")

# Measure final test accuracy
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in loaders["test"]:
        data = images.to(device)
        test_output = cnn(images)
        pred_y = torch.max(test_output, 1)[1].data.squeeze().cpu()
        correct += float((pred_y == labels).sum().item())
        total += float(labels.size(0))
    print(f"\nFinal test Accuracy: {correct / total:.3f}")
    tracker.flush()

tracker.stop()
Ejemplo n.º 7
0
def main(
    _run,
    storage_dir,
    hyper_params_dir,
    sed_hyper_params_name,
    crnn_dirs,
    crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    weak_pseudo_labeling,
    boundary_pseudo_labeling,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labeled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundary_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        CRNN.from_storage_dir(storage_dir=crnn_dir,
                              config_name='1/config.json',
                              checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    print('Params',
          sum([p.numel() for crnn in crnns for p in crnn.parameters()]))
    print(
        'CNN2d Params',
        sum([
            p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(weak_pseudo_labeling, list):
        weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling]
    assert len(weak_pseudo_labeling) == len(dataset_name)
    if not isinstance(boundary_pseudo_labeling, list):
        boundary_pseudo_labeling = len(dataset_name) * [
            boundary_pseudo_labeling
        ]
    assert len(boundary_pseudo_labeling) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labeled_dataset_name, list):
        pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name]
    assert len(pseudo_labeled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])

        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(0, audio_durations[audio_id],
                               (max_segment_length - segment_overlap) *
                               frame_shift)
                timestamps[audio_id] = np.concatenate(
                    (ts, [audio_durations[audio_id]]))
        tags, tagging_scores, tagging_results = tagging(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            hyper_params_dir,
            ground_truth_filepath[i],
            audio_durations,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
        )
        if tagging_results:
            dump_json(tagging_results,
                      storage_dir / f'tagging_results_{dataset_name[i]}.json')

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]:
            boundaries, boundaries_detection_results = boundaries_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                ground_truth_filepath[i],
                boundary_collar_based_params,
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
            )
            if boundaries_detection_results:
                dump_json(
                    boundaries_detection_results, storage_dir /
                    f'boundaries_detection_results_{dataset_name[i]}.json')
        else:
            boundaries = {}
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        if (ground_truth_filepath[i] is not None
            ) or strong_pseudo_labeling[i] or save_scores or save_detections:
            events, sed_results = sound_event_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                sed_hyper_params_name,
                ground_truth_filepath[i],
                audio_durations,
                collar_based_params,
                [psds_scenario_1, psds_scenario_2],
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
                score_storage_dir=[
                    score_storage_dir / name for name in sed_hyper_params_name
                ] if save_scores else None,
                detection_storage_dir=[
                    detection_storage_dir / name
                    for name in sed_hyper_params_name
                ] if save_detections else None,
            )
            for j, sed_results_j in enumerate(sed_results):
                if sed_results_j:
                    dump_json(
                        sed_results_j, storage_dir /
                        f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                    )
        else:
            events = [{}]
        database['datasets'][
            pseudo_labeled_dataset_name[i]] = base.pseudo_label(
                database['datasets'][dataset_name[i]],
                event_classes,
                weak_pseudo_labeling[i],
                boundary_pseudo_labeling[i],
                strong_pseudo_labeling[i],
                tags,
                boundaries,
                events[0],
            )

    if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any(
            strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)
Ejemplo n.º 8
0
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, boundaries_filter_lengths,
         tune_detection_scenario_1, detection_window_lengths_scenario_1,
         detection_window_shift_scenario_1,
         detection_medfilt_lengths_scenario_1, tune_detection_scenario_2,
         detection_window_lengths_scenario_2,
         detection_window_shift_scenario_2,
         detection_medfilt_lengths_scenario_2, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundaries_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
        'min_precision': .8,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    metrics = {
        'f':
        partial(base.f_tag,
                ground_truth=validation_ground_truth_filepath,
                num_jobs=8)
    }
    leaderboard = weak_label.crnn.tune_tagging(crnns,
                                               dataset,
                                               device,
                                               timestamps,
                                               event_classes,
                                               metrics,
                                               storage_dir=storage_dir)
    _, hyper_params, tagging_scores = leaderboard['f']
    tagging_thresholds = np.array([
        hyper_params[event_class]['threshold'] for event_class in event_classes
    ])
    tags = {
        audio_id:
        tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds
        for audio_id in tagging_scores
    }

    boundaries_ground_truth = base.boundaries_from_events(
        validation_ground_truth_filepath)
    timestamps = np.arange(0, 10000) * frame_shift
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=boundaries_ground_truth,
            return_onset_offset_bias=True,
            num_jobs=8,
            **boundaries_collar_based_params,
        ),
    }
    weak_label.crnn.tune_boundary_detection(
        crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking=True,
        stepfilt_lengths=boundaries_filter_lengths,
        storage_dir=storage_dir)

    if tune_detection_scenario_1:
        metrics = {
            'f':
            partial(
                base.f_collar,
                ground_truth=validation_ground_truth_filepath,
                return_onset_offset_bias=True,
                num_jobs=8,
                **collar_based_params,
            ),
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_1,
            ),
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking={
                'f': True,
                'auc': '?'
            },
            window_lengths=detection_window_lengths_scenario_1,
            window_shift=detection_window_shift_scenario_1,
            medfilt_lengths=detection_medfilt_lengths_scenario_1,
        )
        dump_json(leaderboard['f'][1],
                  storage_dir / f'sed_hyper_params_f.json')
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds1.json')
    if tune_detection_scenario_2:
        metrics = {
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_2,
            )
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking=False,
            window_lengths=detection_window_lengths_scenario_2,
            window_shift=detection_window_shift_scenario_2,
            medfilt_lengths=detection_medfilt_lengths_scenario_2,
        )
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)

    if eval_set_name:
        if tune_detection_scenario_1:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
            }, )
        if tune_detection_scenario_2:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
                'sed_hyper_params_name': 'psds2',
            }, )
def main():
    """Main function that creates and executes the BIDS App singularity command.

    Returns
    -------
    exit_code : {0, 1}
        An exit code given to `sys.exit()` that can be:

            * '0' in case of successful completion

            * '1' in case of an error
    """
    # Create and parse arguments
    parser = get_singularity_wrapper_parser()
    args = parser.parse_args()

    # Create the singularity run command
    cmd = create_singularity_cmd(args)

    # Create and start the carbon footprint tracker
    if args.track_carbon_footprint:
        logging.getLogger(
            "codecarbon").disabled = True  # Comment this line for debug
        tracker = EmissionsTracker(
            project_name=f"MIALSRTK{__version__}-docker",
            output_dir=str(Path(args.bids_dir) / "code"),
            measure_power_secs=15,
        )
        tracker.start()

    # Execute the singularity run command
    try:
        print(f'... cmd: {cmd}')
        run(cmd)
        exit_code = 0
    except Exception as e:
        print('Failed')
        print(e)
        exit_code = 1

    if args.track_carbon_footprint:
        emissions: float = tracker.stop()
        print("############################################################")
        print(
            f"CARBON FOOTPRINT OF {len(args.participant_label)} SUBJECT(S) PROCESSED"
        )
        print("############################################################")
        print(f" * Estimated Co2 emissions: {emissions} kg")
        car_kms = get_emission_car_miles_equivalent(emissions)
        print(f" * Equivalent in distance travelled by avg car: {car_kms} kms")
        tv_time = get_emission_tv_time_equivalent(emissions)
        print(
            f" * Equivalent in amount of time watching a 32-inch LCD flat screen TV: {tv_time}"
        )
        print("############################################################")
        print(f"PREDICTED CARBON FOOTPRINT OF 100 SUBJECTS PROCESSED")
        print("############################################################")
        pred_emissions = 100 * emissions / len(args.participant_label)
        print(f" * Estimated Co2 emissions: {pred_emissions} kg")
        car_kms = get_emission_car_miles_equivalent(pred_emissions)
        print(f" * Equivalent in distance travelled by avg car: {car_kms} kms")
        tv_time = get_emission_tv_time_equivalent(pred_emissions)
        print(
            f" * Equivalent in amount of time watching a 32-inch LCD flat screen TV: {tv_time}"
        )
        print("############################################################")
        print(
            "Results can be visualized with the codecarbon visualization tool using following command:\n"
        )
        print(
            f'\t$ carbonboard --filepath="{args.bids_dir}/code/emissions.csv" --port=9999\n'
        )

    return exit_code
Ejemplo n.º 10
0
def cli_main():

    # ------------
    # Parse args
    # ------------
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(f"--{EXPERIMENT_NAME}", choices=experiment_names)
    parser.add_argument(f"--{MODEL_NAME}", choices=model_names)
    parser.add_argument(f"--{EXPERIMENTS_CONFIG}", type=str)
    # Project wide parameters
    for parameter in parameters.values():
        parser.add_argument(
            f"--{parameter.name}",
            default=parameter.default,
            type=parameter.type_,
            choices=parameter.choices,
        )
    args, unknown_args = parser.parse_known_args()

    experiments_config = get_experiments_config(args)

    # ------------
    # Generate experiments
    # ------------
    for experiment_idx, experiment_config in enumerate(experiments_config):

        _experiments_config = copy.deepcopy(experiments_config)
        _experiment_config = copy.deepcopy(experiment_config)
        # Add support for experiment specific arguments
        base_model = get_experiment_base_model(_experiment_config)
        parser = base_model.add_model_specific_args(parser)
        # Reparse with experiment specific arguments
        args, unknown_args = parser.parse_known_args()
        _experiments_config = get_experiments_config(args)

        _experiment_config = _experiments_config[experiment_idx]
        # Add support for model specific arguments
        model = get_model(_experiment_config)
        parser = model.add_model_specific_args(parser)
        # Reparse with model specific arguments
        args, unknown_args = parser.parse_known_args()
        _experiments_config = get_experiments_config(args)

        experiment = Experiment(_experiments_config[experiment_idx])
        print(f"--- Starting Experiment {clean_dict(experiment.__dict__)}")

        impact_tracker = EmissionsTracker(
            project_name="sggm",
            output_dir="./impact_logs/",
            co2_signal_api_token="06297ab81ba8d269",
        )
        impact_tracker.start()

        # In the case of a shift split experiment, override number of trials to correspond to the splits.
        if is_shifted_split(experiment.experiment_name):
            if experiment.n_trials is not None:
                warn(
                    f"[WARNING] The number of trials specified, n_trials={experiment.n_trials},"
                    +
                    f" is overriden for experiment {experiment.experiment_name} (shift split)"
                )
            experiment.n_trials = experiment.datamodule.dims

        for n_t in range(experiment.n_trials):

            if isinstance(experiment.seed, int):
                seed = experiment.seed + n_t
                pl.seed_everything(seed)

            # ------------
            # data
            # ------------
            datamodule = experiment.datamodule
            datamodule.setup(
                dim_idx=n_t,
                stage=STAGE_SETUP_SHIFTED_SPLIT) if is_shifted_split(
                    experiment.experiment_name) else datamodule.setup()

            # ------------
            # model
            # ------------
            model = experiment.model
            if isinstance(model, VariationalRegressor):
                model.setup_pig(datamodule)
            if isinstance(model, V3AE):
                model.save_datamodule(datamodule)
                model.set_prior_parameters(
                    datamodule,
                    prior_α=experiment.prior_α,
                    prior_β=experiment.prior_β,
                )

            # ------------
            # training
            # ------------
            if getattr(experiment, SPLIT_TRAINING, None):
                (
                    experiment,
                    model,
                    datamodule,
                    trainer,
                ) = split_mean_uncertainty_training(experiment, model,
                                                    datamodule)
            else:
                trainer = experiment.trainer
                trainer.fit(model, datamodule)

            # ------------
            # testing
            # ------------
            results = trainer.test()

            # ------------
            # saving
            # ------------
            torch.save(results, f"{trainer.logger.log_dir}/results.pkl")
            torch.save(
                get_misc_save_dict(experiment, model),
                f"{trainer.logger.log_dir}/misc.pkl",
            )

        emissions = impact_tracker.stop()
        print(f"Emissions: {emissions} kg")
Ejemplo n.º 11
0
                    print("encountered error:", e)
                    logpx, rmse = np.nan, np.nan
                T.end()
                log_score.append(logpx)
                rmse_score.append(rmse)

            log_score = np.array(log_score)
            rmse_score = np.array(rmse_score)

            # Save results
            result_folder = f"{pathlib.Path(__file__).parent.absolute()}/results/"
            if result_folder not in os.listdir(
                    f"{pathlib.Path(__file__).parent.absolute()}/"):
                os.makedirs(result_folder, exist_ok=True)
            np.savez(
                result_folder + experiment_name + "_" + model,
                log_score=log_score,
                rmse_score=rmse_score,
                timings=np.array(T.timings),
            )

            # Print the results
            print("log(px): {0:.3f} +- {1:.3f}".format(log_score.mean(),
                                                       log_score.std()))
            print("rmse:    {0:.3f} +- {1:.3f}".format(rmse_score.mean(),
                                                       rmse_score.std()))
            T.res()

    emissions = impact_tracker.stop()
    print(f"Emissions: {emissions} kg")
Ejemplo n.º 12
0
def main(
    _run,
    storage_dir,
    strong_label_crnn_hyper_params_dir,
    sed_hyper_params_name,
    strong_label_crnn_dirs,
    strong_label_crnn_checkpoints,
    weak_label_crnn_hyper_params_dir,
    weak_label_crnn_dirs,
    weak_label_crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labelled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    print(
        'Weak Label CRNN Params',
        sum([
            p.numel() for crnn in weak_label_crnns for p in crnn.parameters()
        ]))
    print(
        'Weak Label CNN2d Params',
        sum([
            p.numel() for crnn in weak_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]
    print(
        'Strong Label CRNN Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.parameters()
        ]))
    print(
        'Strong Label CNN2d Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labelled_dataset_name, list):
        pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name]
    assert len(pseudo_labelled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])
        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(
                    (2 + max_segment_length) * frame_shift,
                    audio_durations[audio_id],
                    (max_segment_length - segment_overlap) * frame_shift)
                timestamps[audio_id] = np.concatenate(
                    ([0.], ts - segment_overlap / 2 * frame_shift,
                     [audio_durations[audio_id]]))
        if max_segment_length is not None:
            dataset = dataset.map(
                partial(segment_batch,
                        max_length=max_segment_length,
                        overlap=segment_overlap)).unbatch()
        tags, tagging_scores, _ = tagging(
            weak_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            weak_label_crnn_hyper_params_dir,
            None,
            None,
        )

        def add_tag_condition(example):
            example["tag_condition"] = np.array(
                [tags[example_id] for example_id in example["example_id"]])
            return example

        dataset = dataset.map(add_tag_condition)

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        events, sed_results = sound_event_detection(
            strong_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            strong_label_crnn_hyper_params_dir,
            sed_hyper_params_name,
            ground_truth_filepath[i],
            audio_durations,
            collar_based_params,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
            pseudo_widening=pseudo_widening,
            score_storage_dir=[
                score_storage_dir / name for name in sed_hyper_params_name
            ] if save_scores else None,
            detection_storage_dir=[
                detection_storage_dir / name for name in sed_hyper_params_name
            ] if save_detections else None,
        )
        for j, sed_results_j in enumerate(sed_results):
            if sed_results_j:
                dump_json(
                    sed_results_j, storage_dir /
                    f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                )
        if strong_pseudo_labeling[i]:
            database['datasets'][
                pseudo_labelled_dataset_name[i]] = base.pseudo_label(
                    database['datasets'][dataset_name[i]],
                    event_classes,
                    False,
                    False,
                    strong_pseudo_labeling[i],
                    None,
                    None,
                    events[0],
                )
            with (storage_dir /
                  f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid:
                fid.write('filename\tonset\toffset\tevent_label\n')
                for key, event_list in events[0].items():
                    if len(event_list) == 0:
                        fid.write(f'{key}.wav\t\t\t\n')
                    for t_on, t_off, event_label in event_list:
                        fid.write(
                            f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n')

    if any(strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)