Ejemplo n.º 1
0
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
from torchvision.transforms import ToTensor

from codecarbon import EmissionsTracker

tracker = EmissionsTracker(project_name="pytorch-mnist-multigpu")
tracker.start()
# https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118
# https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html#create-model-and-dataparallel
# https://gist.github.com/MattKleinsmith/5226a94bad5dd12ed0b871aed98cb123

# Parameters and DataLoaders

batch_size = 30
valid_ratio = 0.15

train_data = datasets.MNIST(
    root=".",
    train=True,
    transform=ToTensor(),
    download=True,
)
test_data = datasets.MNIST(root=".", train=False, transform=ToTensor())

n_valid = int(len(train_data) * valid_ratio)
Ejemplo n.º 2
0
    def __init__(self, codecarbon_tracker):
        self.codecarbon_tracker = codecarbon_tracker
        pass

    def on_epoch_end(self, epoch, logs=None):
        self.codecarbon_tracker.flush()


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10),
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

tracker = EmissionsTracker()
tracker.start()
codecarbon_cb = CodeCarbonCallBack(tracker)
model.fit(x_train, y_train, epochs=4, callbacks=[codecarbon_cb])
emissions: float = tracker.stop()
print(f"Emissions: {emissions} kg")
Ejemplo n.º 3
0
        )
        time.sleep(delay)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--google-project",
        help="Specify the name of the Google project (for Cloud Logging)",
    )
    args = parser.parse_args()

    log_name = "code_carbone"
    if args.google_project:
        client = google.cloud.logging.Client(project=args.google_project)
        external_logger = GoogleCloudLoggerOutput(client.logger(log_name))
    else:
        external_logger = logging.getLogger(log_name)
        channel = logging.FileHandler(log_name + ".log")
        external_logger.addHandler(channel)
        external_logger.setLevel(logging.INFO)
        external_logger = LoggerOutput(external_logger, logging.INFO)

    tracker = EmissionsTracker(save_to_logger=True,
                               logging_logger=external_logger)
    tracker.start()
    train_model(epochs=1)  # Each epoch last 30 secondes
    emissions: float = tracker.stop()
    print(f"Emissions: {emissions} kg")
Ejemplo n.º 4
0
def main(
    _run,
    storage_dir,
    hyper_params_dir,
    sed_hyper_params_name,
    crnn_dirs,
    crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    weak_pseudo_labeling,
    boundary_pseudo_labeling,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labeled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundary_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        CRNN.from_storage_dir(storage_dir=crnn_dir,
                              config_name='1/config.json',
                              checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    print('Params',
          sum([p.numel() for crnn in crnns for p in crnn.parameters()]))
    print(
        'CNN2d Params',
        sum([
            p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(weak_pseudo_labeling, list):
        weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling]
    assert len(weak_pseudo_labeling) == len(dataset_name)
    if not isinstance(boundary_pseudo_labeling, list):
        boundary_pseudo_labeling = len(dataset_name) * [
            boundary_pseudo_labeling
        ]
    assert len(boundary_pseudo_labeling) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labeled_dataset_name, list):
        pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name]
    assert len(pseudo_labeled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])

        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(0, audio_durations[audio_id],
                               (max_segment_length - segment_overlap) *
                               frame_shift)
                timestamps[audio_id] = np.concatenate(
                    (ts, [audio_durations[audio_id]]))
        tags, tagging_scores, tagging_results = tagging(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            hyper_params_dir,
            ground_truth_filepath[i],
            audio_durations,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
        )
        if tagging_results:
            dump_json(tagging_results,
                      storage_dir / f'tagging_results_{dataset_name[i]}.json')

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]:
            boundaries, boundaries_detection_results = boundaries_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                ground_truth_filepath[i],
                boundary_collar_based_params,
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
            )
            if boundaries_detection_results:
                dump_json(
                    boundaries_detection_results, storage_dir /
                    f'boundaries_detection_results_{dataset_name[i]}.json')
        else:
            boundaries = {}
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        if (ground_truth_filepath[i] is not None
            ) or strong_pseudo_labeling[i] or save_scores or save_detections:
            events, sed_results = sound_event_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                sed_hyper_params_name,
                ground_truth_filepath[i],
                audio_durations,
                collar_based_params,
                [psds_scenario_1, psds_scenario_2],
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
                score_storage_dir=[
                    score_storage_dir / name for name in sed_hyper_params_name
                ] if save_scores else None,
                detection_storage_dir=[
                    detection_storage_dir / name
                    for name in sed_hyper_params_name
                ] if save_detections else None,
            )
            for j, sed_results_j in enumerate(sed_results):
                if sed_results_j:
                    dump_json(
                        sed_results_j, storage_dir /
                        f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                    )
        else:
            events = [{}]
        database['datasets'][
            pseudo_labeled_dataset_name[i]] = base.pseudo_label(
                database['datasets'][dataset_name[i]],
                event_classes,
                weak_pseudo_labeling[i],
                boundary_pseudo_labeling[i],
                strong_pseudo_labeling[i],
                tags,
                boundaries,
                events[0],
            )

    if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any(
            strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)
Ejemplo n.º 5
0
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, boundaries_filter_lengths,
         tune_detection_scenario_1, detection_window_lengths_scenario_1,
         detection_window_shift_scenario_1,
         detection_medfilt_lengths_scenario_1, tune_detection_scenario_2,
         detection_window_lengths_scenario_2,
         detection_window_shift_scenario_2,
         detection_medfilt_lengths_scenario_2, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundaries_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
        'min_precision': .8,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    metrics = {
        'f':
        partial(base.f_tag,
                ground_truth=validation_ground_truth_filepath,
                num_jobs=8)
    }
    leaderboard = weak_label.crnn.tune_tagging(crnns,
                                               dataset,
                                               device,
                                               timestamps,
                                               event_classes,
                                               metrics,
                                               storage_dir=storage_dir)
    _, hyper_params, tagging_scores = leaderboard['f']
    tagging_thresholds = np.array([
        hyper_params[event_class]['threshold'] for event_class in event_classes
    ])
    tags = {
        audio_id:
        tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds
        for audio_id in tagging_scores
    }

    boundaries_ground_truth = base.boundaries_from_events(
        validation_ground_truth_filepath)
    timestamps = np.arange(0, 10000) * frame_shift
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=boundaries_ground_truth,
            return_onset_offset_bias=True,
            num_jobs=8,
            **boundaries_collar_based_params,
        ),
    }
    weak_label.crnn.tune_boundary_detection(
        crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking=True,
        stepfilt_lengths=boundaries_filter_lengths,
        storage_dir=storage_dir)

    if tune_detection_scenario_1:
        metrics = {
            'f':
            partial(
                base.f_collar,
                ground_truth=validation_ground_truth_filepath,
                return_onset_offset_bias=True,
                num_jobs=8,
                **collar_based_params,
            ),
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_1,
            ),
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking={
                'f': True,
                'auc': '?'
            },
            window_lengths=detection_window_lengths_scenario_1,
            window_shift=detection_window_shift_scenario_1,
            medfilt_lengths=detection_medfilt_lengths_scenario_1,
        )
        dump_json(leaderboard['f'][1],
                  storage_dir / f'sed_hyper_params_f.json')
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds1.json')
    if tune_detection_scenario_2:
        metrics = {
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_2,
            )
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking=False,
            window_lengths=detection_window_lengths_scenario_2,
            window_shift=detection_window_shift_scenario_2,
            medfilt_lengths=detection_medfilt_lengths_scenario_2,
        )
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)

    if eval_set_name:
        if tune_detection_scenario_1:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
            }, )
        if tune_detection_scenario_2:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
                'sed_hyper_params_name': 'psds2',
            }, )
def main():
    """Main function that creates and executes the BIDS App singularity command.

    Returns
    -------
    exit_code : {0, 1}
        An exit code given to `sys.exit()` that can be:

            * '0' in case of successful completion

            * '1' in case of an error
    """
    # Create and parse arguments
    parser = get_singularity_wrapper_parser()
    args = parser.parse_args()

    # Create the singularity run command
    cmd = create_singularity_cmd(args)

    # Create and start the carbon footprint tracker
    if args.track_carbon_footprint:
        logging.getLogger(
            "codecarbon").disabled = True  # Comment this line for debug
        tracker = EmissionsTracker(
            project_name=f"MIALSRTK{__version__}-docker",
            output_dir=str(Path(args.bids_dir) / "code"),
            measure_power_secs=15,
        )
        tracker.start()

    # Execute the singularity run command
    try:
        print(f'... cmd: {cmd}')
        run(cmd)
        exit_code = 0
    except Exception as e:
        print('Failed')
        print(e)
        exit_code = 1

    if args.track_carbon_footprint:
        emissions: float = tracker.stop()
        print("############################################################")
        print(
            f"CARBON FOOTPRINT OF {len(args.participant_label)} SUBJECT(S) PROCESSED"
        )
        print("############################################################")
        print(f" * Estimated Co2 emissions: {emissions} kg")
        car_kms = get_emission_car_miles_equivalent(emissions)
        print(f" * Equivalent in distance travelled by avg car: {car_kms} kms")
        tv_time = get_emission_tv_time_equivalent(emissions)
        print(
            f" * Equivalent in amount of time watching a 32-inch LCD flat screen TV: {tv_time}"
        )
        print("############################################################")
        print(f"PREDICTED CARBON FOOTPRINT OF 100 SUBJECTS PROCESSED")
        print("############################################################")
        pred_emissions = 100 * emissions / len(args.participant_label)
        print(f" * Estimated Co2 emissions: {pred_emissions} kg")
        car_kms = get_emission_car_miles_equivalent(pred_emissions)
        print(f" * Equivalent in distance travelled by avg car: {car_kms} kms")
        tv_time = get_emission_tv_time_equivalent(pred_emissions)
        print(
            f" * Equivalent in amount of time watching a 32-inch LCD flat screen TV: {tv_time}"
        )
        print("############################################################")
        print(
            "Results can be visualized with the codecarbon visualization tool using following command:\n"
        )
        print(
            f'\t$ carbonboard --filepath="{args.bids_dir}/code/emissions.csv" --port=9999\n'
        )

    return exit_code
Ejemplo n.º 7
0
def cli_main():

    # ------------
    # Parse args
    # ------------
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(f"--{EXPERIMENT_NAME}", choices=experiment_names)
    parser.add_argument(f"--{MODEL_NAME}", choices=model_names)
    parser.add_argument(f"--{EXPERIMENTS_CONFIG}", type=str)
    # Project wide parameters
    for parameter in parameters.values():
        parser.add_argument(
            f"--{parameter.name}",
            default=parameter.default,
            type=parameter.type_,
            choices=parameter.choices,
        )
    args, unknown_args = parser.parse_known_args()

    experiments_config = get_experiments_config(args)

    # ------------
    # Generate experiments
    # ------------
    for experiment_idx, experiment_config in enumerate(experiments_config):

        _experiments_config = copy.deepcopy(experiments_config)
        _experiment_config = copy.deepcopy(experiment_config)
        # Add support for experiment specific arguments
        base_model = get_experiment_base_model(_experiment_config)
        parser = base_model.add_model_specific_args(parser)
        # Reparse with experiment specific arguments
        args, unknown_args = parser.parse_known_args()
        _experiments_config = get_experiments_config(args)

        _experiment_config = _experiments_config[experiment_idx]
        # Add support for model specific arguments
        model = get_model(_experiment_config)
        parser = model.add_model_specific_args(parser)
        # Reparse with model specific arguments
        args, unknown_args = parser.parse_known_args()
        _experiments_config = get_experiments_config(args)

        experiment = Experiment(_experiments_config[experiment_idx])
        print(f"--- Starting Experiment {clean_dict(experiment.__dict__)}")

        impact_tracker = EmissionsTracker(
            project_name="sggm",
            output_dir="./impact_logs/",
            co2_signal_api_token="06297ab81ba8d269",
        )
        impact_tracker.start()

        # In the case of a shift split experiment, override number of trials to correspond to the splits.
        if is_shifted_split(experiment.experiment_name):
            if experiment.n_trials is not None:
                warn(
                    f"[WARNING] The number of trials specified, n_trials={experiment.n_trials},"
                    +
                    f" is overriden for experiment {experiment.experiment_name} (shift split)"
                )
            experiment.n_trials = experiment.datamodule.dims

        for n_t in range(experiment.n_trials):

            if isinstance(experiment.seed, int):
                seed = experiment.seed + n_t
                pl.seed_everything(seed)

            # ------------
            # data
            # ------------
            datamodule = experiment.datamodule
            datamodule.setup(
                dim_idx=n_t,
                stage=STAGE_SETUP_SHIFTED_SPLIT) if is_shifted_split(
                    experiment.experiment_name) else datamodule.setup()

            # ------------
            # model
            # ------------
            model = experiment.model
            if isinstance(model, VariationalRegressor):
                model.setup_pig(datamodule)
            if isinstance(model, V3AE):
                model.save_datamodule(datamodule)
                model.set_prior_parameters(
                    datamodule,
                    prior_α=experiment.prior_α,
                    prior_β=experiment.prior_β,
                )

            # ------------
            # training
            # ------------
            if getattr(experiment, SPLIT_TRAINING, None):
                (
                    experiment,
                    model,
                    datamodule,
                    trainer,
                ) = split_mean_uncertainty_training(experiment, model,
                                                    datamodule)
            else:
                trainer = experiment.trainer
                trainer.fit(model, datamodule)

            # ------------
            # testing
            # ------------
            results = trainer.test()

            # ------------
            # saving
            # ------------
            torch.save(results, f"{trainer.logger.log_dir}/results.pkl")
            torch.save(
                get_misc_save_dict(experiment, model),
                f"{trainer.logger.log_dir}/misc.pkl",
            )

        emissions = impact_tracker.stop()
        print(f"Emissions: {emissions} kg")
Ejemplo n.º 8
0
    else:
        args.experiment_name = [args.experiment_name]

    return args


if __name__ == "__main__":
    # baseline model name and dataset
    # if not specified run for all + all

    # Setup dm and extract full datasets to feed in models.
    args = parse_args()

    impact_tracker = EmissionsTracker(
        project_name="sggm_baseline",
        output_dir="./impact_logs/",
        co2_signal_api_token="06297ab81ba8d269",
    )
    impact_tracker.start()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    for model in args.model:
        for experiment_name in args.experiment_name:
            print(
                f"==================== Training model {model} on dataset {experiment_name} ===================="
            )

            # Load data
            dm = datamodules[experiment_name](args.batch_size,
Ejemplo n.º 9
0
def main(
    _run,
    storage_dir,
    strong_label_crnn_hyper_params_dir,
    sed_hyper_params_name,
    strong_label_crnn_dirs,
    strong_label_crnn_checkpoints,
    weak_label_crnn_hyper_params_dir,
    weak_label_crnn_dirs,
    weak_label_crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labelled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    print(
        'Weak Label CRNN Params',
        sum([
            p.numel() for crnn in weak_label_crnns for p in crnn.parameters()
        ]))
    print(
        'Weak Label CNN2d Params',
        sum([
            p.numel() for crnn in weak_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]
    print(
        'Strong Label CRNN Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.parameters()
        ]))
    print(
        'Strong Label CNN2d Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labelled_dataset_name, list):
        pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name]
    assert len(pseudo_labelled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])
        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(
                    (2 + max_segment_length) * frame_shift,
                    audio_durations[audio_id],
                    (max_segment_length - segment_overlap) * frame_shift)
                timestamps[audio_id] = np.concatenate(
                    ([0.], ts - segment_overlap / 2 * frame_shift,
                     [audio_durations[audio_id]]))
        if max_segment_length is not None:
            dataset = dataset.map(
                partial(segment_batch,
                        max_length=max_segment_length,
                        overlap=segment_overlap)).unbatch()
        tags, tagging_scores, _ = tagging(
            weak_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            weak_label_crnn_hyper_params_dir,
            None,
            None,
        )

        def add_tag_condition(example):
            example["tag_condition"] = np.array(
                [tags[example_id] for example_id in example["example_id"]])
            return example

        dataset = dataset.map(add_tag_condition)

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        events, sed_results = sound_event_detection(
            strong_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            strong_label_crnn_hyper_params_dir,
            sed_hyper_params_name,
            ground_truth_filepath[i],
            audio_durations,
            collar_based_params,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
            pseudo_widening=pseudo_widening,
            score_storage_dir=[
                score_storage_dir / name for name in sed_hyper_params_name
            ] if save_scores else None,
            detection_storage_dir=[
                detection_storage_dir / name for name in sed_hyper_params_name
            ] if save_detections else None,
        )
        for j, sed_results_j in enumerate(sed_results):
            if sed_results_j:
                dump_json(
                    sed_results_j, storage_dir /
                    f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                )
        if strong_pseudo_labeling[i]:
            database['datasets'][
                pseudo_labelled_dataset_name[i]] = base.pseudo_label(
                    database['datasets'][dataset_name[i]],
                    event_classes,
                    False,
                    False,
                    strong_pseudo_labeling[i],
                    None,
                    None,
                    events[0],
                )
            with (storage_dir /
                  f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid:
                fid.write('filename\tonset\toffset\tevent_label\n')
                for key, event_list in events[0].items():
                    if len(event_list) == 0:
                        fid.write(f'{key}.wav\t\t\t\n')
                    for t_on, t_off, event_label in event_list:
                        fid.write(
                            f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n')

    if any(strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)