def main(): """The main entry point for the code. """ arg_parser = argument_parsing.get_argument_parser() args = arg_parser.parse_args() with open(path.join('settings', '{}.yaml'.format(args.config_file))) as f: settings = yaml.load(f) printing.print_date_and_time() printing.inform_about_device(settings['general_settings']['device']) printing.print_msg('', start='') printing.print_yaml_settings(settings) if settings['process_flow']['do_pre_training'] or \ settings['process_flow']['do_adaptation']: source_domain_training_data = _get_source_training_data_loader( settings) else: source_domain_training_data = None if settings['process_flow']['do_pre_training']: do_the_pre_training(source_domain_training_data, settings) if settings['process_flow']['do_adaptation']: do_adaptation(source_domain_training_data, settings) del source_domain_training_data printing.print_msg('', start='\n') if settings['process_flow']['do_evaluation']: do_evaluation(settings) if settings['process_flow']['do_testing']: do_testing(settings)
def do_the_pre_training(source_domain_training_data, settings): """Performs the pre-training. This functions creates/loads the model, creates\ the data loaders, creates the optimizers, and calls :func:`the adaptation process <processes.pre_training>`. This function is not used if pre-trained models are\ used for the method. :param source_domain_training_data: The source domain training data. :type source_domain_training_data: torch.utils.data.DataLoader :param settings: The settings to be used. :type settings: dict """ source_model = models.get_asc_model(settings) classifier = models.get_label_classifier(settings) source_model = source_model.train() classifier = classifier.train() optimizer_model_source = models.get_optimizer( 'optimizer_source_asc', [source_model, classifier], settings ) with printing.InformAboutProcess( 'Creating validation data loader for device: {} '.format( ', '.join(settings['data']['source_domain_device'])), ): source_domain_validation_data = get_data_loader( for_devices=settings['data']['source_domain_device'], split='validation', shuffle=True, drop_last=True, batch_size=settings['data']['batch_size'], data_path=settings['data']['data_path'], workers=settings['data']['workers'] ) printing.print_msg('Starting pre-training process', start='\n\n-- ') model, classifier = pre_training( nb_epochs=settings['pre_training']['nb_epochs'], training_data=source_domain_training_data, validation_data=source_domain_validation_data, model=source_model, classifier=classifier, optimizer=optimizer_model_source, patience=settings['pre_training']['patience'], device=settings['general_settings']['device'] ) modules_functions.save_model_state( settings['models']['base_dir_name'], settings['models']['asc_model']['source_model_f_name'], model ) modules_functions.save_model_state( settings['models']['base_dir_name'], settings['models']['label_classifier']['f_name'], classifier )
def do_evaluation(settings): """Performs the evaluation. This functions creates/loads the model, creates\ the data loaders, creates the optimizers, and calls :func:`the evaluation process <processes.evaluation>`. The results are printed on the stdout. :param settings: The settings to be used. :type settings: dict """ source_m, target_m, classifier, data_loader_s, data_loader_t = _get_models_and_data( settings, is_testing=False) kwargs_s = { 'classifier': classifier, 'eval_data': data_loader_s, 'device': settings['general_settings']['device'] } kwargs_t = { 'classifier': classifier, 'eval_data': data_loader_t, 'device': settings['general_settings']['device'] } printing.print_msg(_info_msg.format('evaluation', 'source', 'source'), start='\n\n-- ') evaluation(model=source_m, **kwargs_s) printing.print_msg(_info_msg.format('evaluation', 'target', 'source'), start='\n\n-- ') evaluation(model=target_m, **kwargs_s) printing.print_msg(_info_msg.format('evaluation', 'source', 'target'), start='\n\n-- ') evaluation(model=source_m, **kwargs_t) printing.print_msg(_info_msg.format('evaluation', 'target', 'target'), start='\n\n-- ') evaluation(model=target_m, **kwargs_t) printing.print_msg('', start='\n')
def main(): arg_parser = arg_parsing.get_argument_parser() cmd_args = arg_parser.parse_args() input_wav = cmd_args.input_wav input_list = cmd_args.input_list if (input_wav == '') != (len(input_list) != 0): printing.print_msg('Please specify **either** a wav file (with -w) ' '**or** give a txt file with file names in each ' 'line (with -l). ') printing.print_msg('Exiting.') exit(-1) input_list = [Path(input_wav)] if len(input_list) == 0 \ else _get_file_names_from_file(input_list) use_me_process(sources_list=input_list, output_file_names=[[ '{}_voice.wav'.format(source.stem), '{}_bg_music.wav'.format(source.stem) ] for source in input_list])
def pre_training(nb_epochs, training_data, validation_data, model, classifier, optimizer, patience, device): """The pre-training of the model. This function pre-trains the acoustic scene classification \ model, using a typical supervised learning scenario for \ acoustic scene classification. In the case that a pre-trained model is used, this function \ is not used. :param nb_epochs: The amount of max epochs. :type nb_epochs: int :param training_data: The training data. :type training_data: torch.utils.data.DataLoader :param validation_data: The validation data. :type validation_data: torch.utils.data.DataLoader :param model: The model to use. :type model: torch.nn.Module :param classifier: The classifier to use. :type classifier: torch.nn.Module :param optimizer: The optimizer for the model :type optimizer: torch.optim.Optimizer :param patience: The amount of epochs for validation patience. :type patience: int :param device: The device to use. :type device: str :return: The optimized model and the optimized classifier. :rtype: torch.nn.Module, torch.nn.Module """ best_val_acc = -1 patience_cntr = 0 for epoch in range(nb_epochs): start_time = time() model = model.train() classifier = classifier.train() model, classifier, tr_loss, tr_acc = _training( training_data, model, classifier, optimizer, device) va_loss, va_acc = None, None if validation_data is not None: model = model.eval() classifier = classifier.eval() va_loss, va_acc = _validation(validation_data, model, classifier, device) if best_val_acc < va_acc: best_val_acc = va_acc patience_cntr = 0 else: patience_cntr += 1 end_time = time() - start_time printing.print_pre_training_results( epoch=epoch, training_loss=tr_loss, validation_loss=va_loss, training_accuracy=tr_acc, validation_accuracy=va_acc, time_elapsed=end_time ) if patience_cntr > patience > 0: break printing.print_msg('', start='', end='\n\n') return model, classifier
def do_testing(settings): """Performs the testing. This functions creates/loads the model, creates\ the data loaders, creates the optimizers, and calls :func:`the evaluation process <processes.evaluation>`. The results are printed on the stdout. :param settings: The settings to be used. :type settings: dict """ source_m, target_m, classifier, data_loader_s, data_loader_t = _get_models_and_data( settings, is_testing=True) kwargs_s = { 'classifier': classifier, 'eval_data': data_loader_s, 'device': settings['general_settings']['device'], 'return_predictions': settings['aux_settings']['confusion_matrices']['print_them'] } kwargs_t = { 'classifier': classifier, 'eval_data': data_loader_t, 'device': settings['general_settings']['device'], 'return_predictions': settings['aux_settings']['confusion_matrices']['print_them'] } scene_labels = [ 'airport', 'bus', 'metro', 'metro_station', 'park', 'public_square', 'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram' ] printing.print_msg(_info_msg.format('testing', 'source', 'source'), start='\n\n-- ') predictions_non_adapted_source = evaluation(model=source_m, **kwargs_s) printing.print_msg(_info_msg.format('testing', 'target', 'source'), start='\n\n-- ') predictions_adapted_source = evaluation(model=target_m, **kwargs_s) printing.print_msg(_info_msg.format('testing', 'source', 'target'), start='\n\n-- ') predictions_non_adapted_target = evaluation(model=source_m, **kwargs_t) printing.print_msg(_info_msg.format('testing', 'target', 'target'), start='\n\n-- ') predictions_adapted_target = evaluation(model=target_m, **kwargs_t) if settings['aux_settings']['confusion_matrices']['print_them']: with printing.InformAboutProcess('Creating confusion matrices figures', start='\n\n--'): printing.print_confusion_matrices( predictions_non_adapted_source, predictions_adapted_source, predictions_non_adapted_target, predictions_adapted_target, scene_labels, settings['aux_settings']['confusion_matrices']) printing.print_msg('', start='\n')
def _testing_process(data, index, mad, device, seq_length, context_length, window_size, batch_size, hop_size, outputPath): """The testing process over testing data. :param data: The testing data. :type data: numpy.ndarray :param index: The index of the testing data (used for\ calculating scores). :type index: int :param mad: The MaD system. :type mad: torch.nn.Module :param device: The device to be used. :type device: str :param seq_length: The sequence length used. :type seq_length: int :param context_length: The context length used. :type context_length: int :param window_size: The window size used. :type window_size: int :param batch_size: The batch size used. :type batch_size: int :param hop_size: The hop size used. :type hop_size: int :return: The SDR and SIR scores, and the time elapsed for\ the process. :rtype: (numpy.ndarray, numpy.ndarray, float) """ s_time = time.time() mix, mix_magnitude, mix_phase, voice_true, bg_true = data voice_predicted = np.zeros(( mix_magnitude.shape[0], seq_length - context_length * 2, window_size), dtype=np.float32) for batch in range(int(mix_magnitude.shape[0] / batch_size)): b_start = batch * batch_size b_end = (batch + 1) * batch_size v_in = from_numpy( mix_magnitude[b_start:b_end, :, :]).to(device) voice_predicted[b_start:b_end, :, :] = mad( v_in.unsqueeze(1)).v_j_filt.cpu().numpy() tmp_sdr, tmp_sir, tmp_sar = data_feeder.data_process_results_testing( index=index, voice_true=voice_true, bg_true=bg_true, voice_predicted=voice_predicted, window_size=window_size, mix=mix, mix_magnitude=mix_magnitude, mix_phase=mix_phase, hop=hop_size, context_length=context_length, outputPath=outputPath) time_elapsed = time.time() - s_time printing.print_msg(testing_output_string_per_example.format( e=index, sdr=np.median([i for i in tmp_sdr[0] if not np.isnan(i)]), sir=np.median([i for i in tmp_sir[0] if not np.isnan(i)]), sar=np.median([i for i in tmp_sar[0] if not np.isnan(i)]), t=time_elapsed )) return tmp_sdr, tmp_sir,tmp_sar, time_elapsed
def testing_process(): """The testing process. """ # Check what device we'll be using device = 'cuda' if not debug and cuda.is_available() else 'cpu' torch.backends.cudnn.benchmark = True # Inform about the device and time and date printing.print_intro_messages(device) printing.print_msg('Starting training process. ' 'Debug mode: {}'.format(debug)) latent_n =args.layers lats = args.layers do = args.do channels = args.channels print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True) print("Using " + str(args.channels) + " cnn channels.", flush=True) print("Using " + str(args.do/100) + " dropout.", flush=True) print(("Using residual connections" if args.residual else "Not using residual connections")) outputPath = str(lats)+str(args.channels)+str(args.do)+("res" if args.residual else "")+ _dataset_parent_dir[7:] print("Output path: " + outputPath) pt_path = ("./outputs/states/" + str(lats) + str(args.channels)+str(int(args.do))+("res" if args.residual else "")+".pt"+ _dataset_parent_dir[7:]) pt_path = "./outputs/states/mad.pt" + str(lats) + "latents" + str(args.channels) +"features"+str(args.do/100) + \ "dropout" + ("res" if args.residual else "") + str(90) + "epochs" # print("USING " + str(latent_n) + " conv layers between enc/dec") print("Weights path: " + pt_path, flush=True) # Set up MaD TwinNet with printing.InformAboutProcess('Setting up MaD TwinNet'): mad = MaDConv( cnn_channels=args.channels, inner_kernel_size=1, inner_padding=0, cnn_dropout=do/100, original_input_dim=hyper_parameters['original_input_dim'], context_length=hyper_parameters['context_length'], latent_n=latent_n, residual=args.residual ) with printing.InformAboutProcess('Loading states'): mad.load_state_dict(load(pt_path)) mad = mad.to(device).eval() with printing.InformAboutProcess('Initializing data feeder'): testing_it = data_feeder.data_feeder_testing( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=1, debug=debug) p_testing = partial( _testing_process, mad=mad, device=device, seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], window_size=hyper_parameters['window_size'], batch_size=4,#training_constants['batch_size'], hop_size=hyper_parameters['hop_size'], outputPath=outputPath) printing.print_msg('Testing starts', end='\n\n') sdr, sir, sar, total_time = [e for e in zip(*[ i for index, data in enumerate(testing_it()) for i in [p_testing(data, index)]])] total_time = sum(total_time) printing.print_msg('Testing finished', start='\n-- ', end='\n\n') printing.print_msg(testing_output_string_all.format( sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]), sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]), sar=np.median([ii for i in sar for ii in i[0] if not np.isnan(ii)]), t=total_time), end='\n\n') with printing.InformAboutProcess('Saving results... '): with open(metrics_paths['sdr'], 'wb') as f: pickle.dump(sdr, f, protocol=2) with open(metrics_paths['sir'], 'wb') as f: pickle.dump(sir, f, protocol=2) with open(metrics_paths['sar'], 'wb') as f: pickle.dump(sar, f, protocol=2) print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True) print("Using " + str(args.channels) + " cnn channels.", flush=True) print("Using " + str(args.do / 100) + " dropout.", flush=True) printing.print_msg('That\'s all folks!')
def testing_process(): """The testing process. """ # Check what device we'll be using device = 'cuda' if not debug and cuda.is_available() else 'cpu' # Inform about the device and time and date printing.print_intro_messages(device) printing.print_msg('Starting training process. ' 'Debug mode: {}'.format(debug)) # Set up MaD TwinNet with printing.InformAboutProcess('Setting up MaD TwinNet'): mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'], rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'], original_input_dim=hyper_parameters['original_input_dim'], context_length=hyper_parameters['context_length']) with printing.InformAboutProcess('Loading states'): mad.load_state_dict(load(output_states_path['mad'])) mad = mad.to(device).eval() with printing.InformAboutProcess('Initializing data feeder'): testing_it = data_feeder.data_feeder_testing( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=1, debug=debug) p_testing = partial(_testing_process, mad=mad, device=device, seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], window_size=hyper_parameters['window_size'], batch_size=training_constants['batch_size'], hop_size=hyper_parameters['hop_size']) printing.print_msg('Testing starts', end='\n\n') sdr, sir, total_time = [ e for e in zip(*[ i for index, data in enumerate(testing_it()) for i in [p_testing(data, index)] ]) ] total_time = sum(total_time) printing.print_msg('Testing finished', start='\n-- ', end='\n\n') printing.print_msg(testing_output_string_all.format( sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]), sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]), t=total_time), end='\n\n') with printing.InformAboutProcess('Saving results... '): with open(metrics_paths['sdr'], 'wb') as f: pickle.dump(sdr, f, protocol=2) with open(metrics_paths['sir'], 'wb') as f: pickle.dump(sir, f, protocol=2) printing.print_msg('That\'s all folks!')
def use_me_process(sources_list, output_file_names): """The usage process. :param sources_list: The file names to be used. :type sources_list: list[pathlib.Path] :param output_file_names: The output file names to be used. :type output_file_names: list[list[str]] """ printing.print_msg('Welcome to MaD TwinNet.', end='\n\n') if debug: printing.print_msg('Cannot proceed in debug mode. ' 'Please set `debug=False` at the settings ' 'file.') printing.print_msg('Exiting.') exit(-1) printing.print_msg('Now I will extract the voice and the ' 'background music from the provided files') device = 'cuda' if not debug and torch.cuda.is_available() else 'cpu' # MaD setting up mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'], rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'], original_input_dim=hyper_parameters['original_input_dim'], context_length=hyper_parameters['context_length']) mad.load_state_dict(torch.load(output_states_path['mad'])) mad = mad.to(device).eval() testing_it = data_feeder.data_feeder_testing( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=1, debug=debug, sources_list=sources_list) printing.print_msg('Let\'s go!', end='\n\n') total_time = 0 for index, data in enumerate(testing_it()): s_time = time.time() mix, mix_magnitude, mix_phase, voice_true, bg_true = data voice_predicted = np.zeros( (mix_magnitude.shape[0], hyper_parameters['seq_length'] - hyper_parameters['context_length'] * 2, hyper_parameters['window_size']), dtype=np.float32) for batch in range( int(mix_magnitude.shape[0] / training_constants['batch_size'])): b_start = batch * training_constants['batch_size'] b_end = (batch + 1) * training_constants['batch_size'] v_in = torch.from_numpy( mix_magnitude[b_start:b_end, :, :]).to(device) voice_predicted[b_start:b_end, :, :] = mad( v_in).v_j_filt.cpu().numpy() data_feeder.data_process_results_testing( index=index, voice_true=voice_true, bg_true=bg_true, voice_predicted=voice_predicted, window_size=hyper_parameters['window_size'], mix=mix, mix_magnitude=mix_magnitude, mix_phase=mix_phase, hop=hyper_parameters['hop_size'], context_length=hyper_parameters['context_length'], output_file_name=output_file_names[index]) e_time = time.time() printing.print_msg( usage_output_string_per_example.format(f=sources_list[index], t=e_time - s_time)) total_time += e_time - s_time printing.print_msg('MaDTwinNet finished') printing.print_msg(usage_output_string_total.format(t=total_time)) printing.print_msg('That\'s all folks!')
def _one_epoch(module, epoch_it, solver, separation_loss, twin_reg_loss, reg_fnn_masker, reg_fnn_dec, device, epoch_index, lambda_l_twin, lambda_1, lambda_2, max_grad_norm): """One training epoch for MaD TwinNet. :param module: The module of MaD TwinNet. :type module: torch.nn.Module :param epoch_it: The data iterator for the epoch. :type epoch_it: callable :param solver: The optimizer to be used. :type solver: torch.optim.Optimizer :param separation_loss: The loss function used for\ the source separation. :type separation_loss: callable :param twin_reg_loss: The loss function used for the\ TwinNet regularization. :type twin_reg_loss: callable :param reg_fnn_masker: The weight regularization function\ for the FNN of the Masker. :type reg_fnn_masker: callable :param reg_fnn_dec: The weight regularization function\ for the FNN of the Denoiser. :type reg_fnn_dec: callable :param device: The device to be used. :type device: str :param epoch_index: The current epoch. :type epoch_index: int :param lambda_l_twin: The weight for the TwinNet loss. :type lambda_l_twin: float :param lambda_1: The weight for the `reg_fnn_masker`. :type lambda_1: float :param lambda_2: The weight for the `reg_fnn_dec`. :type lambda_2: float :param max_grad_norm: The maximum gradient norm for\ gradient norm clipping. :type max_grad_norm: float """ def _training_iteration(_m, _data, _device, _solver, _sep_l, _reg_twin, _reg_m, _reg_d, _lambda_l_twin, _lambda_1, _lambda_2, _max_grad_norm): """One training iteration for the MaD TwinNet. :param _m: The module of MaD TwinNet. :type _m: torch.nn.Module :param _data: The data :type _data: numpy.ndarray :param _device: The device to be used. :type _device: str :param _solver: The optimizer to be used. :type _solver: torch.optim.Optimizer :param _sep_l: The loss function used for the\ source separation. :type _sep_l: callable :param _reg_twin: The loss function used for the\ TwinNet regularization. :type _reg_twin: callable :param _reg_m: The weight regularization function\ for the FNN of the Masker. :type _reg_m: callable :param _reg_d: The weight regularization function\ for the FNN of the Denoiser. :type _reg_d: callable :param _lambda_l_twin: The weight for the TwinNet loss. :type _lambda_l_twin: float :param _lambda_1: The weight for the `_reg_m`. :type _lambda_1: float :param _lambda_2: The weight for the `_reg_d`. :type _lambda_2: float :param _max_grad_norm: The maximum gradient norm for\ gradient norm clipping. :type _max_grad_norm: float :return: The losses for the iteration. :rtype: list[float] """ # Get the data to torch and to the device used v_in, v_j = [from_numpy(_d).to(_device) for _d in _data] # Forward pass of the module output = _m(v_in) # Calculate losses l_m = _sep_l(output.v_j_filt_prime, v_j) l_d = _sep_l(output.v_j_filt, v_j) l_tw = _sep_l(output.v_j_filt_prime_twin, v_j).mul(_lambda_l_twin) l_twin = _reg_twin(output.affine_output, output.h_dec_twin.detach()) w_reg_masker = _reg_m( _m.mad.masker.fnn.linear_layer.weight).mul(_lambda_1) w_reg_denoiser = _reg_d(_m.mad.denoiser.fnn_dec.weight).mul(_lambda_2) # Make MaD TwinNet objective loss = l_m.add(l_d).add(l_tw).add(l_twin).add(w_reg_masker).add( w_reg_denoiser) # Clear previous gradients _solver.zero_grad() # Backward pass loss.backward() # Gradient norm clipping nn.utils.clip_grad_norm_(_m.parameters(), max_norm=_max_grad_norm, norm_type=2) # Optimize _solver.step() return [l_m.item(), l_d.item(), l_tw.item(), l_twin.item()] # Log starting time time_start = time.time() # Do iteration over all batches iter_results = [ _training_iteration(module, data, device, solver, separation_loss, twin_reg_loss, reg_fnn_masker, reg_fnn_dec, lambda_l_twin, lambda_1, lambda_2, max_grad_norm) for data in epoch_it() ] # Log ending time time_end = time.time() # Print to stdout printing.print_msg( training_output_string.format( ep=epoch_index, t=time_end - time_start, **{ k: v for k, v in zip(['l_m', 'l_d', 'l_tw', 'l_twin'], [sum(i) / len(iter_results) for i in zip(*iter_results)]) }))
def training_process(): """The training process. """ # Check what device we'll be using device = 'cuda' if not debug and cuda.is_available() else 'cpu' # Inform about the device and time and date printing.print_intro_messages(device) printing.print_msg( 'Starting training process. Debug mode: {}'.format(debug)) # Set up MaD TwinNet with printing.InformAboutProcess('Setting up MaD TwinNet'): mad_twin_net = MaDTwinNet( rnn_enc_input_dim=hyper_parameters['reduced_dim'], rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'], original_input_dim=hyper_parameters['original_input_dim'], context_length=hyper_parameters['context_length']).to(device) # Get the optimizer with printing.InformAboutProcess('Setting up optimizer'): optimizer = optim.Adam(mad_twin_net.parameters(), lr=hyper_parameters['learning_rate']) # Create the data feeder with printing.InformAboutProcess('Initializing data feeder'): epoch_it = data_feeder.data_feeder_training( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=training_constants['batch_size'], files_per_pass=training_constants['files_per_pass'], debug=debug) # Inform about the future printing.print_msg('Training starts', end='\n\n') # Auxiliary function for aesthetics one_epoch = partial(_one_epoch, module=mad_twin_net, epoch_it=epoch_it, solver=optimizer, separation_loss=kl, twin_reg_loss=l2_loss, reg_fnn_masker=sparsity_penalty, reg_fnn_dec=l2_reg_squared, device=device, lambda_l_twin=hyper_parameters['lambda_l_twin'], lambda_1=hyper_parameters['lambda_1'], lambda_2=hyper_parameters['lambda_2'], max_grad_norm=hyper_parameters['max_grad_norm']) # Training [one_epoch(epoch_index=e) for e in range(training_constants['epochs'])] # Inform about the past printing.print_msg('Training done.', start='\n-- ') # Save the model with printing.InformAboutProcess('Saving model'): save(mad_twin_net.mad.state_dict(), output_states_path['mad']) # Say goodbye! printing.print_msg('That\'s all folks!')
def do_adaptation(source_domain_training_data, settings): """Performs the adaptation. This functions creates/loads the model, creates\ the data loaders, creates the optimizers, and calls :func:`the adaptation process <processes.adaptation>`. :param source_domain_training_data: The source domain data. :type source_domain_training_data: torch.utils.data.DataLoader :param settings: The settings to be used. :type settings: dict :return: The adapted model. :rtype: torch.nn.Module """ with printing.InformAboutProcess( 'Creating training data loader for device: {} '.format(', '.join( settings['data']['target_domain_device'])), ): target_domain_training_data = get_data_loader( for_devices=settings['data']['target_domain_device'], split='training', shuffle=True, drop_last=True, batch_size=settings['data']['batch_size'], data_path=settings['data']['data_path'], workers=settings['data']['workers']) source_model = models.get_asc_model(settings) classifier = models.get_label_classifier(settings) source_model = modules_functions.load_model_state( settings['models']['base_dir_name'], settings['models']['asc_model']['source_model_f_name'], source_model).to(settings['general_settings']['device']) classifier = modules_functions.load_model_state( settings['models']['base_dir_name'], settings['models']['label_classifier']['f_name'], classifier).to(settings['general_settings']['device']) target_model = models.get_asc_model(settings) target_model = modules_functions.load_model_state( settings['models']['base_dir_name'], settings['models']['asc_model']['source_model_f_name'], target_model).to(settings['general_settings']['device']) discriminator = models.get_domain_classifier(settings) source_model = source_model.eval() classifier = classifier.eval() target_model = target_model.train() discriminator = discriminator.train() optimizer_model_target = models.get_optimizer('optimizer_target_asc', target_model, settings) optimizer_discriminator = models.get_optimizer('optimizer_discriminator', discriminator, settings) printing.print_msg('Starting adaptation process.', start='\n\n-- ') target_model = adaptation( epochs=settings['adaptation']['nb_epochs'], source_model=source_model, target_model=target_model, classifier=classifier, discriminator=discriminator, source_data=source_domain_training_data, target_data=target_domain_training_data, optimizer_target=optimizer_model_target, optimizer_discriminator=optimizer_discriminator, device=settings['general_settings']['device'], labels_loss_w=settings['adaptation']['labels_loss_w'], first_iter=settings['adaptation']['first_iter'], n_critic=settings['adaptation']['n_critic']) modules_functions.save_model_state( settings['models']['base_dir_name'], settings['models']['asc_model']['target_model_f_name'], target_model) target_model = target_model.eval() return target_model