コード例 #1
0
ファイル: pipeline_factory.py プロジェクト: zge/VQ-VAE-Speech
    def build(configuration, device_configuration, experiments_path,
              experiment_name, results_path):
        data_stream = VCTKFeaturesStream('../data/vctk', configuration,
                                         device_configuration.gpu_ids,
                                         device_configuration.use_cuda)

        if configuration['decoder_type'] == 'deconvolutional':
            vqvae_model = ConvolutionalVQVAE(configuration,
                                             device_configuration.device).to(
                                                 device_configuration.device)
            evaluator = Evaluator(device_configuration.device, vqvae_model,
                                  data_stream, configuration, results_path,
                                  experiment_name)
        else:
            raise NotImplementedError(
                "Decoder type '{}' isn't implemented for now".format(
                    configuration['decoder_type']))

        if configuration['trainer_type'] == 'convolutional':
            trainer = ConvolutionalTrainer(device_configuration.device,
                                           data_stream, configuration,
                                           experiments_path, experiment_name,
                                           **{'model': vqvae_model})
        else:
            raise NotImplementedError(
                "Trainer type '{}' isn't implemented for now".format(
                    configuration['trainer_type']))

        vqvae_model = nn.DataParallel(
            vqvae_model, device_ids=device_configuration.gpu_ids
        ) if device_configuration.use_data_parallel else vqvae_model

        return trainer, evaluator
コード例 #2
0
    def load(experiments_path,
             experiment_name,
             results_path,
             data_path='../data'):
        error_caught = False

        try:
            configuration_file, checkpoint_files = PipelineFactory.load_configuration_and_checkpoints(
                experiments_path, experiment_name)
        except:
            ConsoleLogger.error(
                'Failed to load existing configuration. Building a new model...'
            )
            error_caught = True

        # Load the configuration file
        ConsoleLogger.status('Loading the configuration file')
        configuration = None
        with open(experiments_path + os.sep + configuration_file, 'r') as file:
            configuration = yaml.load(file, Loader=yaml.FullLoader)
        device_configuration = DeviceConfiguration.load_from_configuration(
            configuration)

        if error_caught or len(checkpoint_files) == 0:
            trainer, evaluator = PipelineFactory.build(configuration,
                                                       device_configuration,
                                                       experiments_path,
                                                       experiment_name,
                                                       results_path)
        else:
            latest_checkpoint_file, latest_epoch = CheckpointUtils.search_latest_checkpoint_file(
                checkpoint_files)
            # Update the epoch number to begin with for the future training
            configuration['start_epoch'] = latest_epoch
            configuration['num_epochs'] = 60
            #latest_checkpoint_file = 'baseline_15_checkpoint.pth'
            #print(latest_checkpoint_file)
            # Load the checkpoint file
            checkpoint_path = experiments_path + os.sep + latest_checkpoint_file
            ConsoleLogger.status(
                "Loading the checkpoint file '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path,
                                    map_location=device_configuration.device)

            # Load the data stream
            ConsoleLogger.status('Loading the data stream')
            data_stream = VCTKFeaturesStream('/atlas/u/xuyilun/vctk',
                                             configuration,
                                             device_configuration.gpu_ids,
                                             device_configuration.use_cuda)

            def load_state_dicts(model, checkpoint, model_name,
                                 optimizer_name):
                # Load the state dict from the checkpoint to the model
                model.load_state_dict(checkpoint[model_name])
                # Create an Adam optimizer using the model parameters
                optimizer = optim.Adam(model.parameters())
                # Load the state dict from the checkpoint to the optimizer
                optimizer.load_state_dict(checkpoint[optimizer_name])
                # Map the optimizer memory into the specified device
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device_configuration.device)
                return model, optimizer

            # If the decoder type is a deconvolutional
            if configuration['decoder_type'] == 'deconvolutional':
                # Create the model and map it to the specified device
                vqvae_model = ConvolutionalVQVAE(
                    configuration, device_configuration.device).to(
                        device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)

                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            elif configuration['decoder_type'] == 'wavenet':
                vqvae_model = WaveNetVQVAE(configuration,
                                           data_stream.speaker_dic,
                                           device_configuration.device).to(
                                               device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)
                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            else:
                raise NotImplementedError(
                    "Decoder type '{}' isn't implemented for now".format(
                        configuration['decoder_type']))

            # Temporary backward compatibility
            if 'trainer_type' not in configuration:
                ConsoleLogger.error(
                    "trainer_type was not found in configuration file. Use 'convolutional' by default."
                )
                configuration['trainer_type'] = 'convolutional'

            if configuration['trainer_type'] == 'convolutional':
                trainer = ConvolutionalTrainer(
                    device_configuration.device, data_stream, configuration,
                    experiments_path, experiment_name, **{
                        'model': vqvae_model,
                        'optimizer': vqvae_optimizer
                    })
            else:
                raise NotImplementedError(
                    "Trainer type '{}' isn't implemented for now".format(
                        configuration['trainer_type']))

            # Use data parallelization if needed and available
            vqvae_model = vqvae_model

        return trainer, evaluator, configuration, device_configuration
コード例 #3
0
        data_stream = VCTKSpeechStream(configuration,
                                       device_configuration.gpu_ids,
                                       device_configuration.use_cuda)
        data_stream.export_to_features(default_dataset_path, configuration)
        ConsoleLogger.success(
            "VCTK exported to a new features dataset at: '{}'".format(
                default_dataset_path + os.sep +
                configuration['features_path']))
        sys.exit(0)

    if args.evaluate:
        Experiments.load(
            args.experiments_configuration_path).evaluate(evaluation_options)
        ConsoleLogger.success('All evaluating experiments done')
        sys.exit(0)

    if args.compute_dataset_stats:
        configuration = load_configuration(default_configuration_path)
        configuration = update_configuration_from_experiments(
            args.experiments_configuration_path, configuration)
        device_configuration = DeviceConfiguration.load_from_configuration(
            configuration)
        data_stream = VCTKFeaturesStream(default_dataset_path, configuration,
                                         device_configuration.gpu_ids,
                                         device_configuration.use_cuda)
        data_stream.compute_dataset_stats()
        sys.exit(0)

    Experiments.load(args.experiments_configuration_path).train()
    ConsoleLogger.success('All training experiments done')
コード例 #4
0
    return raw_audio, trimmed_audio, trimming_indices


if __name__ == "__main__":
    configuration_file_path = '../configurations/vctk_features.yaml'
    set_deterministic_on(1234)

    ConsoleLogger.status(
        'Loading the configuration file {}...'.format(configuration_file_path))
    configuration = None
    with open(configuration_file_path, 'r') as configuration_file:
        configuration = yaml.load(configuration_file, Loader=yaml.FullLoader)
    device_configuration = DeviceConfiguration.load_from_configuration(
        configuration)
    data_stream = VCTKFeaturesStream('../data/vctk', configuration,
                                     device_configuration.gpu_ids,
                                     device_configuration.use_cuda)

    res_type = 'kaiser_fast'
    top_db = 20
    N = 0
    audio_filenames = list()
    original_shifting_times = list()
    sil_duration_gaps = list()
    beginning_trimmed_times = list()
    detected_sil_durations = list()

    with tqdm(data_stream.validation_loader) as bar:
        for features in bar:
            audio_filename = features['wav_filename'][0][0]
            shifting_time = features['shifting_time'].item()