Ejemplo n.º 1
0
def do_the_pre_training(source_domain_training_data, settings):
    """Performs the pre-training.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the adaptation process <processes.pre_training>`.

    This function is not used if pre-trained models are\
    used for the method.

    :param source_domain_training_data: The source domain training data.
    :type source_domain_training_data: torch.utils.data.DataLoader
    :param settings: The settings to be used.
    :type settings: dict
    """
    source_model = models.get_asc_model(settings)
    classifier = models.get_label_classifier(settings)

    source_model = source_model.train()
    classifier = classifier.train()

    optimizer_model_source = models.get_optimizer(
        'optimizer_source_asc', [source_model, classifier], settings
    )

    with printing.InformAboutProcess(
            'Creating validation data loader for device: {} '.format(
                ', '.join(settings['data']['source_domain_device'])),
    ):
        source_domain_validation_data = get_data_loader(
            for_devices=settings['data']['source_domain_device'],
            split='validation', shuffle=True, drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'], workers=settings['data']['workers']
        )

    printing.print_msg('Starting pre-training process', start='\n\n-- ')
    model, classifier = pre_training(
        nb_epochs=settings['pre_training']['nb_epochs'],
        training_data=source_domain_training_data,
        validation_data=source_domain_validation_data,
        model=source_model, classifier=classifier,
        optimizer=optimizer_model_source,
        patience=settings['pre_training']['patience'],
        device=settings['general_settings']['device']
    )

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        model
    )

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['label_classifier']['f_name'],
        classifier
    )
Ejemplo n.º 2
0
def _get_source_training_data_loader(settings):
    with printing.InformAboutProcess(
            'Creating training data loader for device: {} '.format(', '.join(
                settings['data']['source_domain_device'])), ):
        return get_data_loader(
            for_devices=settings['data']['source_domain_device'],
            split='training',
            shuffle=True,
            drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'],
            workers=settings['data']['workers'])
Ejemplo n.º 3
0
def _get_models_and_data(settings, is_testing):
    """Retrieves the models and the data to be used.

    This function retrieves the models and the data\
    to be used for the evaluation process, depending\
    on whether there is a validation or a testing case.

    :param settings: The settings to be used.
    :type settings: dict
    :param is_testing: Are we doing testing?
    :type is_testing: bool
    :return: The models and the data.
    :rtype: torch.nn.Module, torch.nn.Module, torch.nn.Module, \
            torch.utils.data.DataLoader, torch.utils.data.DataLoader
    """
    source_model = models.get_asc_model(settings)
    classifier = models.get_label_classifier(settings)
    target_model = models.get_asc_model(settings)

    source_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        source_model).to(settings['general_settings']['device'])

    classifier = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['label_classifier']['f_name'],
        classifier).to(settings['general_settings']['device'])

    target_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['target_model_f_name'],
        target_model).to(settings['general_settings']['device'])

    source_model = source_model.eval()
    target_model = target_model.eval()
    classifier = classifier.eval()

    with printing.InformAboutProcess(
            'Creating training data loader for device: {} '.format(', '.join(
                settings['data']['source_domain_device'])), ):
        s_d_v_data = get_data_loader(
            for_devices=settings['data']['source_domain_device'],
            split='validation' if not is_testing else 'test',
            shuffle=True if not is_testing else False,
            drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'],
            workers=settings['data']['workers'])

    with printing.InformAboutProcess(
            'Creating training data loader for device: {} '.format(', '.join(
                settings['data']['target_domain_device'])), ):
        t_d_v_data = get_data_loader(
            for_devices=settings['data']['target_domain_device'],
            split='validation' if not is_testing else 'test',
            shuffle=True if not is_testing else False,
            drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'],
            workers=settings['data']['workers'])

    return source_model, target_model, classifier, s_d_v_data, t_d_v_data
Ejemplo n.º 4
0
def do_testing(settings):
    """Performs the testing.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the evaluation process <processes.evaluation>`.

    The results are printed on the stdout.

    :param settings: The settings to be used.
    :type settings: dict
    """
    source_m, target_m, classifier, data_loader_s, data_loader_t = _get_models_and_data(
        settings, is_testing=True)

    kwargs_s = {
        'classifier':
        classifier,
        'eval_data':
        data_loader_s,
        'device':
        settings['general_settings']['device'],
        'return_predictions':
        settings['aux_settings']['confusion_matrices']['print_them']
    }

    kwargs_t = {
        'classifier':
        classifier,
        'eval_data':
        data_loader_t,
        'device':
        settings['general_settings']['device'],
        'return_predictions':
        settings['aux_settings']['confusion_matrices']['print_them']
    }

    scene_labels = [
        'airport', 'bus', 'metro', 'metro_station', 'park', 'public_square',
        'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram'
    ]

    printing.print_msg(_info_msg.format('testing', 'source', 'source'),
                       start='\n\n-- ')
    predictions_non_adapted_source = evaluation(model=source_m, **kwargs_s)

    printing.print_msg(_info_msg.format('testing', 'target', 'source'),
                       start='\n\n-- ')
    predictions_adapted_source = evaluation(model=target_m, **kwargs_s)

    printing.print_msg(_info_msg.format('testing', 'source', 'target'),
                       start='\n\n-- ')
    predictions_non_adapted_target = evaluation(model=source_m, **kwargs_t)

    printing.print_msg(_info_msg.format('testing', 'target', 'target'),
                       start='\n\n-- ')
    predictions_adapted_target = evaluation(model=target_m, **kwargs_t)

    if settings['aux_settings']['confusion_matrices']['print_them']:
        with printing.InformAboutProcess('Creating confusion matrices figures',
                                         start='\n\n--'):
            printing.print_confusion_matrices(
                predictions_non_adapted_source, predictions_adapted_source,
                predictions_non_adapted_target, predictions_adapted_target,
                scene_labels, settings['aux_settings']['confusion_matrices'])

    printing.print_msg('', start='\n')
Ejemplo n.º 5
0
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True
    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))
    latent_n =args.layers
    lats = args.layers
    do = args.do
    channels = args.channels

    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do/100) + " dropout.", flush=True)
    print(("Using residual connections" if args.residual else "Not using residual connections"))

    outputPath = str(lats)+str(args.channels)+str(args.do)+("res" if args.residual else "")+ _dataset_parent_dir[7:]
    print("Output path: " + outputPath)
    pt_path = ("./outputs/states/" + str(lats) + str(args.channels)+str(int(args.do))+("res" if args.residual else "")+".pt"+ _dataset_parent_dir[7:])
    pt_path = "./outputs/states/mad.pt" + str(lats) +  "latents" + str(args.channels) +"features"+str(args.do/100) + \
              "dropout" + ("res" if args.residual else "") + str(90) + "epochs"
    # print("USING " + str(latent_n) + " conv layers between enc/dec")
    print("Weights path: " + pt_path, flush=True)
    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):

        mad = MaDConv(
            cnn_channels=args.channels,
            inner_kernel_size=1,
            inner_padding=0,
            cnn_dropout=do/100,
            original_input_dim=hyper_parameters['original_input_dim'],
            context_length=hyper_parameters['context_length'],
            latent_n=latent_n,
            residual=args.residual
        )
    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(pt_path))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1, debug=debug)

    p_testing = partial(
        _testing_process, mad=mad, device=device,
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        window_size=hyper_parameters['window_size'],
        batch_size=4,#training_constants['batch_size'],
        hop_size=hyper_parameters['hop_size'], outputPath=outputPath)

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, sar, total_time = [e for e in zip(*[
        i for index, data in enumerate(testing_it())
        for i in [p_testing(data, index)]])]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        sar=np.median([ii for i in sar for ii in i[0] if not np.isnan(ii)]),
        t=total_time), end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)
        with open(metrics_paths['sar'], 'wb') as f:
            pickle.dump(sar, f, protocol=2)
    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do / 100) + " dropout.", flush=True)


    printing.print_msg('That\'s all folks!')
Ejemplo n.º 6
0
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'

    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))

    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):
        mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'],
                  rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
                  original_input_dim=hyper_parameters['original_input_dim'],
                  context_length=hyper_parameters['context_length'])

    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(output_states_path['mad']))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1,
            debug=debug)

    p_testing = partial(_testing_process,
                        mad=mad,
                        device=device,
                        seq_length=hyper_parameters['seq_length'],
                        context_length=hyper_parameters['context_length'],
                        window_size=hyper_parameters['window_size'],
                        batch_size=training_constants['batch_size'],
                        hop_size=hyper_parameters['hop_size'])

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, total_time = [
        e for e in zip(*[
            i for index, data in enumerate(testing_it())
            for i in [p_testing(data, index)]
        ])
    ]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        t=total_time),
                       end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)

    printing.print_msg('That\'s all folks!')
Ejemplo n.º 7
0
def training_process():
    """The training process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'

    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg(
        'Starting training process. Debug mode: {}'.format(debug))

    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):
        mad_twin_net = MaDTwinNet(
            rnn_enc_input_dim=hyper_parameters['reduced_dim'],
            rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
            original_input_dim=hyper_parameters['original_input_dim'],
            context_length=hyper_parameters['context_length']).to(device)

    # Get the optimizer
    with printing.InformAboutProcess('Setting up optimizer'):
        optimizer = optim.Adam(mad_twin_net.parameters(),
                               lr=hyper_parameters['learning_rate'])

    # Create the data feeder
    with printing.InformAboutProcess('Initializing data feeder'):
        epoch_it = data_feeder.data_feeder_training(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=training_constants['batch_size'],
            files_per_pass=training_constants['files_per_pass'],
            debug=debug)

    # Inform about the future
    printing.print_msg('Training starts', end='\n\n')

    # Auxiliary function for aesthetics
    one_epoch = partial(_one_epoch,
                        module=mad_twin_net,
                        epoch_it=epoch_it,
                        solver=optimizer,
                        separation_loss=kl,
                        twin_reg_loss=l2_loss,
                        reg_fnn_masker=sparsity_penalty,
                        reg_fnn_dec=l2_reg_squared,
                        device=device,
                        lambda_l_twin=hyper_parameters['lambda_l_twin'],
                        lambda_1=hyper_parameters['lambda_1'],
                        lambda_2=hyper_parameters['lambda_2'],
                        max_grad_norm=hyper_parameters['max_grad_norm'])

    # Training
    [one_epoch(epoch_index=e) for e in range(training_constants['epochs'])]

    # Inform about the past
    printing.print_msg('Training done.', start='\n-- ')

    # Save the model
    with printing.InformAboutProcess('Saving model'):
        save(mad_twin_net.mad.state_dict(), output_states_path['mad'])

    # Say goodbye!
    printing.print_msg('That\'s all folks!')
Ejemplo n.º 8
0
def do_adaptation(source_domain_training_data, settings):
    """Performs the adaptation.

    This functions creates/loads the model, creates\
    the data loaders, creates the optimizers, and
    calls :func:`the adaptation process <processes.adaptation>`.

    :param source_domain_training_data: The source domain data.
    :type source_domain_training_data: torch.utils.data.DataLoader
    :param settings: The settings to be used.
    :type settings: dict
    :return: The adapted model.
    :rtype: torch.nn.Module
    """
    with printing.InformAboutProcess(
            'Creating training data loader for device: {} '.format(', '.join(
                settings['data']['target_domain_device'])), ):
        target_domain_training_data = get_data_loader(
            for_devices=settings['data']['target_domain_device'],
            split='training',
            shuffle=True,
            drop_last=True,
            batch_size=settings['data']['batch_size'],
            data_path=settings['data']['data_path'],
            workers=settings['data']['workers'])

    source_model = models.get_asc_model(settings)
    classifier = models.get_label_classifier(settings)

    source_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        source_model).to(settings['general_settings']['device'])

    classifier = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['label_classifier']['f_name'],
        classifier).to(settings['general_settings']['device'])

    target_model = models.get_asc_model(settings)

    target_model = modules_functions.load_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['source_model_f_name'],
        target_model).to(settings['general_settings']['device'])

    discriminator = models.get_domain_classifier(settings)

    source_model = source_model.eval()
    classifier = classifier.eval()

    target_model = target_model.train()
    discriminator = discriminator.train()

    optimizer_model_target = models.get_optimizer('optimizer_target_asc',
                                                  target_model, settings)

    optimizer_discriminator = models.get_optimizer('optimizer_discriminator',
                                                   discriminator, settings)

    printing.print_msg('Starting adaptation process.', start='\n\n-- ')

    target_model = adaptation(
        epochs=settings['adaptation']['nb_epochs'],
        source_model=source_model,
        target_model=target_model,
        classifier=classifier,
        discriminator=discriminator,
        source_data=source_domain_training_data,
        target_data=target_domain_training_data,
        optimizer_target=optimizer_model_target,
        optimizer_discriminator=optimizer_discriminator,
        device=settings['general_settings']['device'],
        labels_loss_w=settings['adaptation']['labels_loss_w'],
        first_iter=settings['adaptation']['first_iter'],
        n_critic=settings['adaptation']['n_critic'])

    modules_functions.save_model_state(
        settings['models']['base_dir_name'],
        settings['models']['asc_model']['target_model_f_name'], target_model)

    target_model = target_model.eval()

    return target_model