コード例 #1
0
ファイル: main.py プロジェクト: scatterbrain333/undaw
def main():
    """The main entry point for the code.
    """
    arg_parser = argument_parsing.get_argument_parser()
    args = arg_parser.parse_args()

    with open(path.join('settings', '{}.yaml'.format(args.config_file))) as f:
        settings = yaml.load(f)

    printing.print_date_and_time()
    printing.inform_about_device(settings['general_settings']['device'])
    printing.print_msg('', start='')
    printing.print_yaml_settings(settings)

    if settings['process_flow']['do_pre_training'] or \
            settings['process_flow']['do_adaptation']:
        source_domain_training_data = _get_source_training_data_loader(
            settings)
    else:
        source_domain_training_data = None

    if settings['process_flow']['do_pre_training']:
        do_the_pre_training(source_domain_training_data, settings)

    if settings['process_flow']['do_adaptation']:
        do_adaptation(source_domain_training_data, settings)
        del source_domain_training_data
        printing.print_msg('', start='\n')

    if settings['process_flow']['do_evaluation']:
        do_evaluation(settings)

    if settings['process_flow']['do_testing']:
        do_testing(settings)
コード例 #2
0
def do_the_pre_training(source_domain_training_data, settings):
    """Performs the pre-training.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the adaptation process <processes.pre_training>`.

    This function is not used if pre-trained models are\
    used for the method.

    :param source_domain_training_data: The source domain training data.
    :type source_domain_training_data: torch.utils.data.DataLoader
    :param settings: The settings to be used.
    :type settings: dict
    """
    source_model = models.get_asc_model(settings)
    classifier = models.get_label_classifier(settings)

    source_model = source_model.train()
    classifier = classifier.train()

    optimizer_model_source = models.get_optimizer(
        'optimizer_source_asc', [source_model, classifier], settings
    )

    with printing.InformAboutProcess(
            'Creating validation data loader for device: {} '.format(
                ', '.join(settings['data']['source_domain_device'])),
    ):
        source_domain_validation_data = get_data_loader(
            for_devices=settings['data']['source_domain_device'],
            split='validation', shuffle=True, drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'], workers=settings['data']['workers']
        )

    printing.print_msg('Starting pre-training process', start='\n\n-- ')
    model, classifier = pre_training(
        nb_epochs=settings['pre_training']['nb_epochs'],
        training_data=source_domain_training_data,
        validation_data=source_domain_validation_data,
        model=source_model, classifier=classifier,
        optimizer=optimizer_model_source,
        patience=settings['pre_training']['patience'],
        device=settings['general_settings']['device']
    )

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        model
    )

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['label_classifier']['f_name'],
        classifier
    )
コード例 #3
0
ファイル: _evaluation.py プロジェクト: scatterbrain333/undaw
def do_evaluation(settings):
    """Performs the evaluation.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the evaluation process <processes.evaluation>`.

    The results are printed on the stdout.

    :param settings: The settings to be used.
    :type settings: dict
    """
    source_m, target_m, classifier, data_loader_s, data_loader_t = _get_models_and_data(
        settings, is_testing=False)

    kwargs_s = {
        'classifier': classifier,
        'eval_data': data_loader_s,
        'device': settings['general_settings']['device']
    }

    kwargs_t = {
        'classifier': classifier,
        'eval_data': data_loader_t,
        'device': settings['general_settings']['device']
    }

    printing.print_msg(_info_msg.format('evaluation', 'source', 'source'),
                       start='\n\n-- ')
    evaluation(model=source_m, **kwargs_s)

    printing.print_msg(_info_msg.format('evaluation', 'target', 'source'),
                       start='\n\n-- ')
    evaluation(model=target_m, **kwargs_s)

    printing.print_msg(_info_msg.format('evaluation', 'source', 'target'),
                       start='\n\n-- ')
    evaluation(model=source_m, **kwargs_t)

    printing.print_msg(_info_msg.format('evaluation', 'target', 'target'),
                       start='\n\n-- ')
    evaluation(model=target_m, **kwargs_t)
    printing.print_msg('', start='\n')
コード例 #4
0
def main():
    arg_parser = arg_parsing.get_argument_parser()
    cmd_args = arg_parser.parse_args()
    input_wav = cmd_args.input_wav
    input_list = cmd_args.input_list

    if (input_wav == '') != (len(input_list) != 0):
        printing.print_msg('Please specify **either** a wav file (with -w) '
                           '**or** give a txt file with file names in each '
                           'line (with -l). ')
        printing.print_msg('Exiting.')
        exit(-1)

    input_list = [Path(input_wav)] if len(input_list) == 0 \
        else _get_file_names_from_file(input_list)

    use_me_process(sources_list=input_list,
                   output_file_names=[[
                       '{}_voice.wav'.format(source.stem),
                       '{}_bg_music.wav'.format(source.stem)
                   ] for source in input_list])
コード例 #5
0
def pre_training(nb_epochs, training_data, validation_data, model, classifier,
                 optimizer, patience, device):
    """The pre-training of the model.

    This function pre-trains the acoustic scene classification \
    model, using a typical supervised learning scenario for \
    acoustic scene classification.

    In the case that a pre-trained model is used, this function \
    is not used.

    :param nb_epochs: The amount of max epochs.
    :type nb_epochs: int
    :param training_data: The training data.
    :type training_data: torch.utils.data.DataLoader
    :param validation_data: The validation data.
    :type validation_data: torch.utils.data.DataLoader
    :param model: The model to use.
    :type model: torch.nn.Module
    :param classifier: The classifier to use.
    :type classifier: torch.nn.Module
    :param optimizer: The optimizer for the model
    :type optimizer: torch.optim.Optimizer
    :param patience: The amount of epochs for validation patience.
    :type patience: int
    :param device: The device to use.
    :type device: str
    :return: The optimized model and the optimized classifier.
    :rtype: torch.nn.Module, torch.nn.Module
    """
    best_val_acc = -1
    patience_cntr = 0

    for epoch in range(nb_epochs):
        start_time = time()

        model = model.train()
        classifier = classifier.train()

        model, classifier, tr_loss, tr_acc = _training(
            training_data, model, classifier, optimizer, device)

        va_loss, va_acc = None, None

        if validation_data is not None:
            model = model.eval()
            classifier = classifier.eval()
            va_loss, va_acc = _validation(validation_data, model, classifier, device)

            if best_val_acc < va_acc:
                best_val_acc = va_acc
                patience_cntr = 0
            else:
                patience_cntr += 1

        end_time = time() - start_time

        printing.print_pre_training_results(
            epoch=epoch, training_loss=tr_loss, validation_loss=va_loss,
            training_accuracy=tr_acc, validation_accuracy=va_acc,
            time_elapsed=end_time
        )

        if patience_cntr > patience > 0:
            break

    printing.print_msg('', start='', end='\n\n')

    return model, classifier
コード例 #6
0
ファイル: _evaluation.py プロジェクト: scatterbrain333/undaw
def do_testing(settings):
    """Performs the testing.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the evaluation process <processes.evaluation>`.

    The results are printed on the stdout.

    :param settings: The settings to be used.
    :type settings: dict
    """
    source_m, target_m, classifier, data_loader_s, data_loader_t = _get_models_and_data(
        settings, is_testing=True)

    kwargs_s = {
        'classifier':
        classifier,
        'eval_data':
        data_loader_s,
        'device':
        settings['general_settings']['device'],
        'return_predictions':
        settings['aux_settings']['confusion_matrices']['print_them']
    }

    kwargs_t = {
        'classifier':
        classifier,
        'eval_data':
        data_loader_t,
        'device':
        settings['general_settings']['device'],
        'return_predictions':
        settings['aux_settings']['confusion_matrices']['print_them']
    }

    scene_labels = [
        'airport', 'bus', 'metro', 'metro_station', 'park', 'public_square',
        'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram'
    ]

    printing.print_msg(_info_msg.format('testing', 'source', 'source'),
                       start='\n\n-- ')
    predictions_non_adapted_source = evaluation(model=source_m, **kwargs_s)

    printing.print_msg(_info_msg.format('testing', 'target', 'source'),
                       start='\n\n-- ')
    predictions_adapted_source = evaluation(model=target_m, **kwargs_s)

    printing.print_msg(_info_msg.format('testing', 'source', 'target'),
                       start='\n\n-- ')
    predictions_non_adapted_target = evaluation(model=source_m, **kwargs_t)

    printing.print_msg(_info_msg.format('testing', 'target', 'target'),
                       start='\n\n-- ')
    predictions_adapted_target = evaluation(model=target_m, **kwargs_t)

    if settings['aux_settings']['confusion_matrices']['print_them']:
        with printing.InformAboutProcess('Creating confusion matrices figures',
                                         start='\n\n--'):
            printing.print_confusion_matrices(
                predictions_non_adapted_source, predictions_adapted_source,
                predictions_non_adapted_target, predictions_adapted_target,
                scene_labels, settings['aux_settings']['confusion_matrices'])

    printing.print_msg('', start='\n')
コード例 #7
0
ファイル: testing.py プロジェクト: pppyykknen/mad-twinnet
def _testing_process(data, index, mad, device, seq_length,
                     context_length, window_size, batch_size,
                     hop_size, outputPath):
    """The testing process over testing data.

    :param data: The testing data.
    :type data: numpy.ndarray
    :param index: The index of the testing data (used for\
                  calculating scores).
    :type index: int
    :param mad: The MaD system.
    :type mad: torch.nn.Module
    :param device: The device to be used.
    :type device: str
    :param seq_length: The sequence length used.
    :type seq_length: int
    :param context_length: The context length used.
    :type context_length: int
    :param window_size: The window size used.
    :type window_size: int
    :param batch_size: The batch size used.
    :type batch_size: int
    :param hop_size: The hop size used.
    :type hop_size: int
    :return: The SDR and SIR scores, and the time elapsed for\
             the process.
    :rtype: (numpy.ndarray, numpy.ndarray, float)
    """
    s_time = time.time()

    mix, mix_magnitude, mix_phase, voice_true, bg_true = data

    voice_predicted = np.zeros((
            mix_magnitude.shape[0],
            seq_length - context_length * 2,
            window_size), dtype=np.float32)

    for batch in range(int(mix_magnitude.shape[0] / batch_size)):
        b_start = batch * batch_size
        b_end = (batch + 1) * batch_size

        v_in = from_numpy(
            mix_magnitude[b_start:b_end, :, :]).to(device)

        voice_predicted[b_start:b_end, :, :] = mad(
            v_in.unsqueeze(1)).v_j_filt.cpu().numpy()

    tmp_sdr, tmp_sir, tmp_sar = data_feeder.data_process_results_testing(
        index=index, voice_true=voice_true,
        bg_true=bg_true, voice_predicted=voice_predicted,
        window_size=window_size, mix=mix,
        mix_magnitude=mix_magnitude,
        mix_phase=mix_phase, hop=hop_size, 

        context_length=context_length, outputPath=outputPath)

    time_elapsed = time.time() - s_time

    printing.print_msg(testing_output_string_per_example.format(
        e=index,
        sdr=np.median([i for i in tmp_sdr[0] if not np.isnan(i)]),
        sir=np.median([i for i in tmp_sir[0] if not np.isnan(i)]),
        sar=np.median([i for i in tmp_sar[0] if not np.isnan(i)]),

        t=time_elapsed
    ))

    return tmp_sdr, tmp_sir,tmp_sar, time_elapsed
コード例 #8
0
ファイル: testing.py プロジェクト: pppyykknen/mad-twinnet
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True
    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))
    latent_n =args.layers
    lats = args.layers
    do = args.do
    channels = args.channels

    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do/100) + " dropout.", flush=True)
    print(("Using residual connections" if args.residual else "Not using residual connections"))

    outputPath = str(lats)+str(args.channels)+str(args.do)+("res" if args.residual else "")+ _dataset_parent_dir[7:]
    print("Output path: " + outputPath)
    pt_path = ("./outputs/states/" + str(lats) + str(args.channels)+str(int(args.do))+("res" if args.residual else "")+".pt"+ _dataset_parent_dir[7:])
    pt_path = "./outputs/states/mad.pt" + str(lats) +  "latents" + str(args.channels) +"features"+str(args.do/100) + \
              "dropout" + ("res" if args.residual else "") + str(90) + "epochs"
    # print("USING " + str(latent_n) + " conv layers between enc/dec")
    print("Weights path: " + pt_path, flush=True)
    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):

        mad = MaDConv(
            cnn_channels=args.channels,
            inner_kernel_size=1,
            inner_padding=0,
            cnn_dropout=do/100,
            original_input_dim=hyper_parameters['original_input_dim'],
            context_length=hyper_parameters['context_length'],
            latent_n=latent_n,
            residual=args.residual
        )
    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(pt_path))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1, debug=debug)

    p_testing = partial(
        _testing_process, mad=mad, device=device,
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        window_size=hyper_parameters['window_size'],
        batch_size=4,#training_constants['batch_size'],
        hop_size=hyper_parameters['hop_size'], outputPath=outputPath)

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, sar, total_time = [e for e in zip(*[
        i for index, data in enumerate(testing_it())
        for i in [p_testing(data, index)]])]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        sar=np.median([ii for i in sar for ii in i[0] if not np.isnan(ii)]),
        t=total_time), end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)
        with open(metrics_paths['sar'], 'wb') as f:
            pickle.dump(sar, f, protocol=2)
    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do / 100) + " dropout.", flush=True)


    printing.print_msg('That\'s all folks!')
コード例 #9
0
ファイル: testing.py プロジェクト: xiaozhuo12138/mad-twinnet
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'

    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))

    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):
        mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'],
                  rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
                  original_input_dim=hyper_parameters['original_input_dim'],
                  context_length=hyper_parameters['context_length'])

    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(output_states_path['mad']))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1,
            debug=debug)

    p_testing = partial(_testing_process,
                        mad=mad,
                        device=device,
                        seq_length=hyper_parameters['seq_length'],
                        context_length=hyper_parameters['context_length'],
                        window_size=hyper_parameters['window_size'],
                        batch_size=training_constants['batch_size'],
                        hop_size=hyper_parameters['hop_size'])

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, total_time = [
        e for e in zip(*[
            i for index, data in enumerate(testing_it())
            for i in [p_testing(data, index)]
        ])
    ]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        t=total_time),
                       end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)

    printing.print_msg('That\'s all folks!')
コード例 #10
0
def use_me_process(sources_list, output_file_names):
    """The usage process.

    :param sources_list: The file names to be used.
    :type sources_list: list[pathlib.Path]
    :param output_file_names: The output file names to be used.
    :type output_file_names: list[list[str]]
    """
    printing.print_msg('Welcome to MaD TwinNet.', end='\n\n')
    if debug:
        printing.print_msg('Cannot proceed in debug mode. '
                           'Please set `debug=False` at the settings '
                           'file.')
        printing.print_msg('Exiting.')
        exit(-1)
    printing.print_msg('Now I will extract the voice and the '
                       'background music from the provided files')

    device = 'cuda' if not debug and torch.cuda.is_available() else 'cpu'

    # MaD setting up
    mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'],
              rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
              original_input_dim=hyper_parameters['original_input_dim'],
              context_length=hyper_parameters['context_length'])

    mad.load_state_dict(torch.load(output_states_path['mad']))
    mad = mad.to(device).eval()

    testing_it = data_feeder.data_feeder_testing(
        window_size=hyper_parameters['window_size'],
        fft_size=hyper_parameters['fft_size'],
        hop_size=hyper_parameters['hop_size'],
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        batch_size=1,
        debug=debug,
        sources_list=sources_list)

    printing.print_msg('Let\'s go!', end='\n\n')
    total_time = 0

    for index, data in enumerate(testing_it()):

        s_time = time.time()

        mix, mix_magnitude, mix_phase, voice_true, bg_true = data

        voice_predicted = np.zeros(
            (mix_magnitude.shape[0], hyper_parameters['seq_length'] -
             hyper_parameters['context_length'] * 2,
             hyper_parameters['window_size']),
            dtype=np.float32)

        for batch in range(
                int(mix_magnitude.shape[0] /
                    training_constants['batch_size'])):
            b_start = batch * training_constants['batch_size']
            b_end = (batch + 1) * training_constants['batch_size']

            v_in = torch.from_numpy(
                mix_magnitude[b_start:b_end, :, :]).to(device)

            voice_predicted[b_start:b_end, :, :] = mad(
                v_in).v_j_filt.cpu().numpy()

        data_feeder.data_process_results_testing(
            index=index,
            voice_true=voice_true,
            bg_true=bg_true,
            voice_predicted=voice_predicted,
            window_size=hyper_parameters['window_size'],
            mix=mix,
            mix_magnitude=mix_magnitude,
            mix_phase=mix_phase,
            hop=hyper_parameters['hop_size'],
            context_length=hyper_parameters['context_length'],
            output_file_name=output_file_names[index])

        e_time = time.time()

        printing.print_msg(
            usage_output_string_per_example.format(f=sources_list[index],
                                                   t=e_time - s_time))

        total_time += e_time - s_time

    printing.print_msg('MaDTwinNet finished')
    printing.print_msg(usage_output_string_total.format(t=total_time))
    printing.print_msg('That\'s all folks!')
コード例 #11
0
ファイル: training.py プロジェクト: xiaozhuo12138/mad-twinnet
def _one_epoch(module, epoch_it, solver, separation_loss, twin_reg_loss,
               reg_fnn_masker, reg_fnn_dec, device, epoch_index, lambda_l_twin,
               lambda_1, lambda_2, max_grad_norm):
    """One training epoch for MaD TwinNet.

    :param module: The module of MaD TwinNet.
    :type module: torch.nn.Module
    :param epoch_it: The data iterator for the epoch.
    :type epoch_it: callable
    :param solver: The optimizer to be used.
    :type solver: torch.optim.Optimizer
    :param separation_loss: The loss function used for\
                            the source separation.
    :type separation_loss: callable
    :param twin_reg_loss: The loss function used for the\
                          TwinNet regularization.
    :type twin_reg_loss: callable
    :param reg_fnn_masker: The weight regularization function\
                           for the FNN of the Masker.
    :type reg_fnn_masker: callable
    :param reg_fnn_dec: The weight regularization function\
                        for the FNN of the Denoiser.
    :type reg_fnn_dec: callable
    :param device: The device to be used.
    :type device: str
    :param epoch_index: The current epoch.
    :type epoch_index: int
    :param lambda_l_twin: The weight for the TwinNet loss.
    :type lambda_l_twin: float
    :param lambda_1: The weight for the `reg_fnn_masker`.
    :type lambda_1: float
    :param lambda_2: The weight for the `reg_fnn_dec`.
    :type lambda_2: float
    :param max_grad_norm: The maximum gradient norm for\
                          gradient norm clipping.
    :type max_grad_norm: float
    """
    def _training_iteration(_m, _data, _device, _solver, _sep_l, _reg_twin,
                            _reg_m, _reg_d, _lambda_l_twin, _lambda_1,
                            _lambda_2, _max_grad_norm):
        """One training iteration for the MaD TwinNet.

        :param _m: The module of MaD TwinNet.
        :type _m: torch.nn.Module
        :param _data: The data
        :type _data: numpy.ndarray
        :param _device: The device to be used.
        :type _device: str
        :param _solver: The optimizer to be used.
        :type _solver: torch.optim.Optimizer
        :param _sep_l: The loss function used for the\
                       source separation.
        :type _sep_l: callable
        :param _reg_twin: The loss function used for the\
                          TwinNet regularization.
        :type _reg_twin: callable
        :param _reg_m: The weight regularization function\
                       for the FNN of the Masker.
        :type _reg_m: callable
        :param _reg_d: The weight regularization function\
                       for the FNN of the Denoiser.
        :type _reg_d: callable
        :param _lambda_l_twin: The weight for the TwinNet loss.
        :type _lambda_l_twin: float
        :param _lambda_1: The weight for the `_reg_m`.
        :type _lambda_1: float
        :param _lambda_2: The weight for the `_reg_d`.
        :type _lambda_2: float
        :param _max_grad_norm: The maximum gradient norm for\
                               gradient norm clipping.
        :type _max_grad_norm: float
        :return: The losses for the iteration.
        :rtype: list[float]
        """
        # Get the data to torch and to the device used
        v_in, v_j = [from_numpy(_d).to(_device) for _d in _data]

        # Forward pass of the module
        output = _m(v_in)

        # Calculate losses
        l_m = _sep_l(output.v_j_filt_prime, v_j)
        l_d = _sep_l(output.v_j_filt, v_j)

        l_tw = _sep_l(output.v_j_filt_prime_twin, v_j).mul(_lambda_l_twin)
        l_twin = _reg_twin(output.affine_output, output.h_dec_twin.detach())

        w_reg_masker = _reg_m(
            _m.mad.masker.fnn.linear_layer.weight).mul(_lambda_1)
        w_reg_denoiser = _reg_d(_m.mad.denoiser.fnn_dec.weight).mul(_lambda_2)

        # Make MaD TwinNet objective
        loss = l_m.add(l_d).add(l_tw).add(l_twin).add(w_reg_masker).add(
            w_reg_denoiser)

        # Clear previous gradients
        _solver.zero_grad()

        # Backward pass
        loss.backward()

        # Gradient norm clipping
        nn.utils.clip_grad_norm_(_m.parameters(),
                                 max_norm=_max_grad_norm,
                                 norm_type=2)

        # Optimize
        _solver.step()

        return [l_m.item(), l_d.item(), l_tw.item(), l_twin.item()]

    # Log starting time
    time_start = time.time()

    # Do iteration over all batches
    iter_results = [
        _training_iteration(module, data, device, solver, separation_loss,
                            twin_reg_loss, reg_fnn_masker, reg_fnn_dec,
                            lambda_l_twin, lambda_1, lambda_2, max_grad_norm)
        for data in epoch_it()
    ]

    # Log ending time
    time_end = time.time()

    # Print to stdout
    printing.print_msg(
        training_output_string.format(
            ep=epoch_index,
            t=time_end - time_start,
            **{
                k: v
                for k, v in
                zip(['l_m', 'l_d', 'l_tw', 'l_twin'],
                    [sum(i) / len(iter_results) for i in zip(*iter_results)])
            }))
コード例 #12
0
ファイル: training.py プロジェクト: xiaozhuo12138/mad-twinnet
def training_process():
    """The training process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'

    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg(
        'Starting training process. Debug mode: {}'.format(debug))

    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):
        mad_twin_net = MaDTwinNet(
            rnn_enc_input_dim=hyper_parameters['reduced_dim'],
            rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
            original_input_dim=hyper_parameters['original_input_dim'],
            context_length=hyper_parameters['context_length']).to(device)

    # Get the optimizer
    with printing.InformAboutProcess('Setting up optimizer'):
        optimizer = optim.Adam(mad_twin_net.parameters(),
                               lr=hyper_parameters['learning_rate'])

    # Create the data feeder
    with printing.InformAboutProcess('Initializing data feeder'):
        epoch_it = data_feeder.data_feeder_training(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=training_constants['batch_size'],
            files_per_pass=training_constants['files_per_pass'],
            debug=debug)

    # Inform about the future
    printing.print_msg('Training starts', end='\n\n')

    # Auxiliary function for aesthetics
    one_epoch = partial(_one_epoch,
                        module=mad_twin_net,
                        epoch_it=epoch_it,
                        solver=optimizer,
                        separation_loss=kl,
                        twin_reg_loss=l2_loss,
                        reg_fnn_masker=sparsity_penalty,
                        reg_fnn_dec=l2_reg_squared,
                        device=device,
                        lambda_l_twin=hyper_parameters['lambda_l_twin'],
                        lambda_1=hyper_parameters['lambda_1'],
                        lambda_2=hyper_parameters['lambda_2'],
                        max_grad_norm=hyper_parameters['max_grad_norm'])

    # Training
    [one_epoch(epoch_index=e) for e in range(training_constants['epochs'])]

    # Inform about the past
    printing.print_msg('Training done.', start='\n-- ')

    # Save the model
    with printing.InformAboutProcess('Saving model'):
        save(mad_twin_net.mad.state_dict(), output_states_path['mad'])

    # Say goodbye!
    printing.print_msg('That\'s all folks!')
コード例 #13
0
def do_adaptation(source_domain_training_data, settings):
    """Performs the adaptation.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the adaptation process <processes.adaptation>`.

    :param source_domain_training_data: The source domain data.
    :type source_domain_training_data: torch.utils.data.DataLoader
    :param settings: The settings to be used.
    :type settings: dict
    :return: The adapted model.
    :rtype: torch.nn.Module
    """
    with printing.InformAboutProcess(
            'Creating training data loader for device: {} '.format(', '.join(
                settings['data']['target_domain_device'])), ):
        target_domain_training_data = get_data_loader(
            for_devices=settings['data']['target_domain_device'],
            split='training',
            shuffle=True,
            drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'],
            workers=settings['data']['workers'])

    source_model = models.get_asc_model(settings)
    classifier = models.get_label_classifier(settings)

    source_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        source_model).to(settings['general_settings']['device'])

    classifier = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['label_classifier']['f_name'],
        classifier).to(settings['general_settings']['device'])

    target_model = models.get_asc_model(settings)

    target_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        target_model).to(settings['general_settings']['device'])

    discriminator = models.get_domain_classifier(settings)

    source_model = source_model.eval()
    classifier = classifier.eval()

    target_model = target_model.train()
    discriminator = discriminator.train()

    optimizer_model_target = models.get_optimizer('optimizer_target_asc',
                                                  target_model, settings)

    optimizer_discriminator = models.get_optimizer('optimizer_discriminator',
                                                   discriminator, settings)

    printing.print_msg('Starting adaptation process.', start='\n\n-- ')

    target_model = adaptation(
        epochs=settings['adaptation']['nb_epochs'],
        source_model=source_model,
        target_model=target_model,
        classifier=classifier,
        discriminator=discriminator,
        source_data=source_domain_training_data,
        target_data=target_domain_training_data,
        optimizer_target=optimizer_model_target,
        optimizer_discriminator=optimizer_discriminator,
        device=settings['general_settings']['device'],
        labels_loss_w=settings['adaptation']['labels_loss_w'],
        first_iter=settings['adaptation']['first_iter'],
        n_critic=settings['adaptation']['n_critic'])

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['target_model_f_name'], target_model)

    target_model = target_model.eval()

    return target_model