Beispiel #1
0
    def predict(self):

        if self.args['tr_parallel']:
            self.model = nn.DataParallel(self.model)

        if self.args['dim'] == True:
            y_val_hat, y_val = SQ.predict_dim(
                self.model,
                self.ds_val,
                self.args['tr_bs_val'],
                self.dev,
                num_workers=self.args['tr_num_workers'])
        else:
            y_val_hat, y_val = SQ.predict_mos(
                self.model,
                self.ds_val,
                self.args['tr_bs_val'],
                self.dev,
                num_workers=self.args['tr_num_workers'])

        self.ds_val.df['model'] = self.args['name']

        if self.args['output_dir']:
            self.ds_val.df.to_csv(os.path.join(self.args['output_dir'],
                                               'NISQA_results.csv'),
                                  index=False)
        print(self.ds_val.df.to_string(index=False))
Beispiel #2
0
    def _loadDatasetsCSVpredict(self):
        data_dir = self.args['input_dir']
        csv_file_path = os.path.join(self.args['input_dir'],
                                     self.args['csv_file'])
        dfile = pd.read_csv(csv_file_path)

        # creating Datasets ---------------------------------------------------
        self.ds_val = SQ.SpeechQualityDataset(
            dfile,
            df_con=None,
            data_dir=data_dir,
            folder_column=self.args['csv_deg_dir'],
            filename_column=self.args['csv_deg'],
            mos_column=None,
            seg_length=self.args['ms_seg_length'],
            max_length=self.args['ms_max_segments'],
            to_memory=False,
            to_memory_workers=None,
            seg_hop_length=self.args['ms_seg_hop_length'],
            transform=None,
            ms_n_fft=self.args['ms_n_fft'],
            ms_hop_length=self.args['ms_hop_length'],
            ms_win_length=self.args['ms_win_length'],
            ms_n_mels=self.args['ms_n_mels'],
            ms_sr=self.args['ms_sr'],
            ms_fmax=self.args['ms_fmax'],
            double_ended=self.args['double_ended'],
            dim=self.args['dim'],
            filename_column_ref=self.args['csv_ref'],
        )
Beispiel #3
0
    def _loadDatasetsFile(self):
        data_dir = os.path.dirname(self.args['deg'])
        file_name = os.path.basename(self.args['deg'])
        df_val = pd.DataFrame([file_name], columns=['deg'])

        # creating Datasets ---------------------------------------------------
        self.ds_val = SQ.SpeechQualityDataset(
            df_val,
            df_con=None,
            data_dir=data_dir,
            folder_column=None,
            filename_column='deg',
            mos_column=None,
            seg_length=self.args['ms_seg_length'],
            max_length=self.args['ms_max_segments'],
            to_memory=self.args['tr_ds_to_memory'],
            to_memory_workers=self.args['tr_ds_to_memory_workers'],
            seg_hop_length=self.args['ms_seg_hop_length'],
            transform=None,
            ms_n_fft=self.args['ms_n_fft'],
            ms_hop_length=self.args['ms_hop_length'],
            ms_win_length=self.args['ms_win_length'],
            ms_n_mels=self.args['ms_n_mels'],
            ms_sr=self.args['ms_sr'],
            ms_fmax=self.args['ms_fmax'],
            double_ended=self.args['double_ended'],
            dim=self.args['dim'],
            filename_column_ref=None,
        )
Beispiel #4
0
    def _loadDatasetsFolder(self):
        data_dir = self.args['deg_dir']
        files = glob(os.path.join(data_dir, '*.wav'))
        files = [os.path.basename(files) for files in files]
        df_val = pd.DataFrame(files, columns=['deg'])

        print('# files: {}'.format(len(df_val)))

        # creating Datasets ---------------------------------------------------
        self.ds_val = SQ.SpeechQualityDataset(
            df_val,
            df_con=None,
            data_dir=data_dir,
            folder_column=None,
            filename_column='deg',
            mos_column=None,
            seg_length=self.args['ms_seg_length'],
            max_length=self.args['ms_max_segments'],
            to_memory=self.args['tr_ds_to_memory'],
            to_memory_workers=self.args['tr_ds_to_memory_workers'],
            seg_hop_length=self.args['ms_seg_hop_length'],
            transform=None,
            ms_n_fft=self.args['ms_n_fft'],
            ms_hop_length=self.args['ms_hop_length'],
            ms_win_length=self.args['ms_win_length'],
            ms_n_mels=self.args['ms_n_mels'],
            ms_sr=self.args['ms_sr'],
            ms_fmax=self.args['ms_fmax'],
            double_ended=self.args['double_ended'],
            dim=self.args['dim'],
            filename_column_ref=None,
        )
Beispiel #5
0
    def _evaluate_mos(self):
        print(self.args['csv_mos_val'])
        self.db_results, self.r = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos=self.args['csv_mos_val'],
            mapping='third_order',
            do_print=True,
            do_plot=True)

        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            self.r['r_p_mean_con'], self.r['rmse_mean_con'],
            self.r['rmse_star_map_mean_con']))
Beispiel #6
0
    def _loadModel(self):
        '''
        Loads the Pytorch models with given input arguments.
        '''

        # if True overwrite input arguments from pretrained model
        if self.args['pretrained_model']:
            if ':' in self.args['pretrained_model']:
                model_path = os.path.join(self.args['pretrained_model'])
            else:
                model_path = os.path.join(os.getcwd(),
                                          self.args['pretrained_model'])
            checkpoint = torch.load(model_path, map_location=self.dev)

            if self.args['mode'] == 'train':
                args_new = self.args
                self.args = checkpoint['args']
                self.args['input_dir'] = args_new['input_dir']
                self.args['output_dir'] = args_new['output_dir']
                self.args['csv_file'] = args_new['csv_file']
                self.args['csv_con'] = args_new['csv_con']
                self.args['csv_deg'] = args_new['csv_deg']
                self.args['csv_ref'] = args_new['csv_ref']
                self.args['csv_deg_dir'] = args_new['csv_deg_dir']
                self.args['csv_db_train'] = args_new['csv_db_train']
                self.args['csv_db_val'] = args_new['csv_db_val']
                self.args['csv_mos_train'] = args_new['csv_mos_train']
                self.args['csv_mos_val'] = args_new['csv_mos_val']
                self.args['pretrained_model'] = args_new['pretrained_model']

                self.args['tr_epochs'] = args_new['tr_epochs']
                self.args['tr_early_stop'] = args_new['tr_early_stop']
                self.args['tr_bs'] = args_new['tr_bs']
                self.args['tr_bs_val'] = args_new['tr_bs_val']
                self.args['tr_lr'] = args_new['tr_lr']
                self.args['tr_lr_patience'] = args_new['tr_lr_patience']
                self.args['tr_num_workers'] = args_new['tr_num_workers']
                self.args['tr_parallel'] = args_new['tr_parallel']
                self.args['tr_bias_anchor_db'] = args_new['tr_bias_anchor_db']
                self.args['tr_bias_mapping'] = args_new['tr_bias_mapping']
                self.args['tr_bias_min_r'] = args_new['tr_bias_min_r']
                self.args['tr_bias_min_r_noi'] = args_new['tr_bias_min_r_noi']
                self.args['tr_bias_min_r_col'] = args_new['tr_bias_min_r_col']
                self.args['tr_bias_min_r_dis'] = args_new['tr_bias_min_r_dis']
                self.args['tr_bias_min_r_loud'] = args_new[
                    'tr_bias_min_r_loud']

                self.args['tr_verbose'] = args_new['tr_verbose']

                self.args['tr_ds_to_memory'] = args_new['tr_ds_to_memory']
                self.args['tr_ds_to_memory_workers'] = args_new[
                    'tr_ds_to_memory_workers']
                self.args['ms_max_segments'] = args_new['ms_max_segments']

            elif self.args['mode'] == 'predict_file':
                args_new = self.args
                self.args = checkpoint['args']
                self.args['deg'] = args_new['deg']
                self.args['mode'] = args_new['mode']
                self.args['output_dir'] = args_new['output_dir']
                self.args['pretrained_model'] = args_new['pretrained_model']

            elif self.args['mode'] == 'predict_dir':
                args_new = self.args
                self.args = checkpoint['args']
                self.args['deg_dir'] = args_new['deg_dir']
                self.args['mode'] = args_new['mode']
                self.args['output_dir'] = args_new['output_dir']
                self.args['pretrained_model'] = args_new['pretrained_model']
                if args_new['bs']:
                    self.args['tr_bs_val'] = args_new['bs']
                if args_new['num_workers']:
                    self.args['tr_num_workers'] = args_new['num_workers']

            elif self.args['mode'] == 'predict_csv':
                args_new = self.args
                self.args = checkpoint['args']
                self.args['csv_file'] = args_new['csv_file']
                self.args['mode'] = args_new['mode']
                self.args['output_dir'] = args_new['output_dir']
                self.args['pretrained_model'] = args_new['pretrained_model']
                self.args['input_dir'] = os.getcwd()
                if args_new['csv_dir'] is None:
                    self.args['csv_deg_dir'] = ''
                else:
                    self.args['csv_deg_dir'] = args_new['csv_dir']
                self.args['csv_deg'] = args_new['csv_deg']
                if args_new['bs']:
                    self.args['tr_bs_val'] = args_new['bs']
                if args_new['num_workers']:
                    self.args['tr_num_workers'] = args_new['num_workers']

            else:
                raise NotImplementedError('Mode not available')

        if self.args['model'] == 'NISQA_DIM':
            self.args['dim'] = True
        else:
            self.args['dim'] = False

        if self.args['model'] == 'NISQA_DE':
            self.args['double_ended'] = True
        else:
            self.args['double_ended'] = False

        # Load Model
        self.model_args = {
            'ms_seg_length': self.args['ms_seg_length'],
            'ms_n_mels': self.args['ms_n_mels'],
            'cnn_model': self.args['cnn_model'],
            'cnn_c_out_1': self.args['cnn_c_out_1'],
            'cnn_c_out_2': self.args['cnn_c_out_2'],
            'cnn_c_out_3': self.args['cnn_c_out_3'],
            'cnn_kernel_size': self.args['cnn_kernel_size'],
            'cnn_dropout': self.args['cnn_dropout'],
            'cnn_pool_1': self.args['cnn_pool_1'],
            'cnn_pool_2': self.args['cnn_pool_2'],
            'cnn_pool_3': self.args['cnn_pool_3'],
            'cnn_fc_out_h': self.args['cnn_fc_out_h'],
            'td': self.args['td'],
            'td_sa_d_model': self.args['td_sa_d_model'],
            'td_sa_nhead': self.args['td_sa_nhead'],
            'td_sa_pool_size': self.args['td_sa_pool_size'],
            'td_sa_pos_enc': self.args['td_sa_pos_enc'],
            'td_sa_num_layers': self.args['td_sa_num_layers'],
            'td_sa_h': self.args['td_sa_h'],
            'td_sa_dropout': self.args['td_sa_dropout'],
            'td_lstm_h': self.args['td_lstm_h'],
            'td_lstm_num_layers': self.args['td_lstm_num_layers'],
            'td_lstm_dropout': self.args['td_lstm_dropout'],
            'td_lstm_bidirectional': self.args['td_lstm_bidirectional'],
            'td_2': self.args['td_2'],
            'td_2_sa_d_model': self.args['td_2_sa_d_model'],
            'td_2_sa_nhead': self.args['td_2_sa_nhead'],
            'td_2_sa_pool_size': self.args['td_2_sa_pool_size'],
            'td_2_sa_pos_enc': self.args['td_2_sa_pos_enc'],
            'td_2_sa_num_layers': self.args['td_2_sa_num_layers'],
            'td_2_sa_h': self.args['td_2_sa_h'],
            'td_2_sa_dropout': self.args['td_2_sa_dropout'],
            'td_2_lstm_h': self.args['td_2_lstm_h'],
            'td_2_lstm_num_layers': self.args['td_2_lstm_num_layers'],
            'td_2_lstm_dropout': self.args['td_2_lstm_dropout'],
            'td_2_lstm_bidirectional': self.args['td_2_lstm_bidirectional'],
            'pool': self.args['pool'],
            'pool_output_size': self.args['pool_output_size'],
            'pool_att_h': self.args['pool_att_h'],
            'pool_att_dropout': self.args['pool_att_dropout'],
        }

        if self.args['double_ended']:
            self.model_args.update({
                'de_align':
                self.args['de_align'],
                'de_align_apply':
                self.args['de_align_apply'],
                'de_align_dim':
                self.args['de_align_dim'],
                'de_fuse_dim':
                self.args['de_fuse_dim'],
                'de_fuse':
                self.args['de_fuse'],
            })

        if self.args['model'] == 'NISQA':
            self.model = SQ.NISQA(**self.model_args)
        elif self.args['model'] == 'NISQA_DIM':
            self.model = SQ.NISQA_DIM(**self.model_args)
        elif self.args['model'] == 'NISQA_DE':
            self.model = SQ.NISQA_DE(**self.model_args)
        else:
            raise NotImplementedError('Model not available')

        # Load weights if pretrained model is used ------------------------------------
        if self.args['pretrained_model']:
            missing_keys, unexpected_keys = self.model.load_state_dict(
                checkpoint['model_state_dict'], strict=True)
            print('Loaded pretrained model from ' +
                  self.args['pretrained_model'])
Beispiel #7
0
    def _loadDatasetsCSV(self):
        data_dir = self.args['input_dir']
        csv_file_path = os.path.join(self.args['input_dir'],
                                     self.args['csv_file'])
        dfile = pd.read_csv(csv_file_path)

        if not set(self.args['csv_db_train'] +
                   self.args['csv_db_val']).issubset(
                       dfile.db.unique().tolist()):
            raise ValueError('Not all dbs found in csv')

        df_train = dfile[dfile.db.isin(
            self.args['csv_db_train'])].reset_index()
        df_val = dfile[dfile.db.isin(self.args['csv_db_val'])].reset_index()

        if self.args['csv_con'] is not None:
            csv_con_path = os.path.join(self.args['input_dir'],
                                        self.args['csv_con'])
            dcon = pd.read_csv(csv_con_path)
            dcon_train = dcon[dcon.db.isin(
                self.args['csv_db_train'])].reset_index()
            dcon_val = dcon[dcon.db.isin(
                self.args['csv_db_val'])].reset_index()
        else:
            dcon = None
            dcon_train = None
            dcon_val = None

        print('training size: {}, validation size: {}'.format(
            len(df_train), len(df_val)))

        # creating Datasets ---------------------------------------------------
        self.ds_train = SQ.SpeechQualityDataset(
            df_train,
            df_con=dcon_train,
            data_dir=data_dir,
            folder_column=self.args['csv_deg_dir'],
            filename_column=self.args['csv_deg'],
            mos_column=self.args['csv_mos_train'],
            seg_length=self.args['ms_seg_length'],
            max_length=self.args['ms_max_segments'],
            to_memory=self.args['tr_ds_to_memory'],
            to_memory_workers=self.args['tr_ds_to_memory_workers'],
            seg_hop_length=self.args['ms_seg_hop_length'],
            transform=None,
            ms_n_fft=self.args['ms_n_fft'],
            ms_hop_length=self.args['ms_hop_length'],
            ms_win_length=self.args['ms_win_length'],
            ms_n_mels=self.args['ms_n_mels'],
            ms_sr=self.args['ms_sr'],
            ms_fmax=self.args['ms_fmax'],
            double_ended=self.args['double_ended'],
            dim=self.args['dim'],
            filename_column_ref=self.args['csv_ref'],
        )

        self.ds_val = SQ.SpeechQualityDataset(
            df_val,
            df_con=dcon_val,
            data_dir=data_dir,
            folder_column=self.args['csv_deg_dir'],
            filename_column=self.args['csv_deg'],
            mos_column=self.args['csv_mos_val'],
            seg_length=self.args['ms_seg_length'],
            max_length=self.args['ms_max_segments'],
            to_memory=self.args['tr_ds_to_memory'],
            to_memory_workers=self.args['tr_ds_to_memory_workers'],
            seg_hop_length=self.args['ms_seg_hop_length'],
            transform=None,
            ms_n_fft=self.args['ms_n_fft'],
            ms_hop_length=self.args['ms_hop_length'],
            ms_win_length=self.args['ms_win_length'],
            ms_n_mels=self.args['ms_n_mels'],
            ms_sr=self.args['ms_sr'],
            ms_fmax=self.args['ms_fmax'],
            double_ended=self.args['double_ended'],
            dim=self.args['dim'],
            filename_column_ref=self.args['csv_ref'],
        )

        self.runinfos['ds_train_len'] = len(self.ds_train)
        self.runinfos['ds_val_len'] = len(self.ds_val)
Beispiel #8
0
    def _train_dim(self):
        '''
        Trains speech quality model.
        '''
        # Initialize  -------------------------------------------------------------
        if self.args['tr_parallel']:
            self.model = nn.DataParallel(self.model)
        self.model.to(self.dev)

        # Runname and savepath  ---------------------------------------------------
        self.runname = self._makeRunname()

        # Optimizer  -------------------------------------------------------------
        opt = optim.Adam(self.model.parameters(), lr=self.args['tr_lr'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            'min',
            verbose=True,
            threshold=0.003,
            patience=self.args['tr_lr_patience'])
        earlyStp = SQ.earlyStopper_dim(self.args['tr_early_stop'])

        biasLoss_1 = SQ.biasLoss(self.ds_train.df.db,
                                 anchor_db=self.args['tr_bias_anchor_db'],
                                 mapping=self.args['tr_bias_mapping'],
                                 min_r=self.args['tr_bias_min_r'])

        biasLoss_2 = SQ.biasLoss(self.ds_train.df.db,
                                 anchor_db=self.args['tr_bias_anchor_db'],
                                 mapping=self.args['tr_bias_mapping'],
                                 min_r=self.args['tr_bias_min_r_noi'])

        biasLoss_3 = SQ.biasLoss(self.ds_train.df.db,
                                 anchor_db=self.args['tr_bias_anchor_db'],
                                 mapping=self.args['tr_bias_mapping'],
                                 min_r=self.args['tr_bias_min_r_dis'])

        biasLoss_4 = SQ.biasLoss(self.ds_train.df.db,
                                 anchor_db=self.args['tr_bias_anchor_db'],
                                 mapping=self.args['tr_bias_mapping'],
                                 min_r=self.args['tr_bias_min_r_col'])

        biasLoss_5 = SQ.biasLoss(self.ds_train.df.db,
                                 anchor_db=self.args['tr_bias_anchor_db'],
                                 mapping=self.args['tr_bias_mapping'],
                                 min_r=self.args['tr_bias_min_r_loud'])

        # Dataloader    -----------------------------------------------------------
        dl_train = DataLoader(self.ds_train,
                              batch_size=self.args['tr_bs'],
                              shuffle=True,
                              drop_last=False,
                              pin_memory=True,
                              num_workers=self.args['tr_num_workers'])

        # Start training loop   ---------------------------------------------------
        print('--> start training')
        for epoch in range(self.args['tr_epochs']):

            tic_epoch = time.time()
            batch_cnt = 0
            loss = 0.0
            y_mos = self.ds_train.df['mos'].to_numpy().reshape(-1, 1)
            y_noi = self.ds_train.df['noi'].to_numpy().reshape(-1, 1)
            y_dis = self.ds_train.df['dis'].to_numpy().reshape(-1, 1)
            y_col = self.ds_train.df['col'].to_numpy().reshape(-1, 1)
            y_loud = self.ds_train.df['loud'].to_numpy().reshape(-1, 1)
            y_train = np.concatenate((y_mos, y_noi, y_dis, y_col, y_loud),
                                     axis=1)
            y_train_hat = np.zeros((len(self.ds_train), 5))

            self.model.train()

            # Progress bar
            if self.args['tr_verbose'] == 2:
                pbar = tqdm(
                    iterable=batch_cnt,
                    total=len(dl_train),
                    ascii=">—",
                    bar_format=
                    '{bar} {percentage:3.0f}%, {n_fmt}/{total_fmt}, {elapsed}<{remaining}{postfix}'
                )

            for xb_spec, yb_mos, (idx, n_wins) in dl_train:

                # Estimate batch ---------------------------------------------------
                xb_spec = xb_spec.to(self.dev)
                yb_mos = yb_mos.to(self.dev)
                n_wins = n_wins.to(self.dev)

                # Forward pass ----------------------------------------------------
                yb_mos_hat = self.model(xb_spec, n_wins)
                y_train_hat[idx, :] = yb_mos_hat.detach().cpu().numpy()

                # Loss ------------------------------------------------------------
                lossb_1 = biasLoss_1.get_loss(yb_mos[:, 0].view(-1, 1),
                                              yb_mos_hat[:, 0].view(-1,
                                                                    1), idx)
                lossb_2 = biasLoss_2.get_loss(yb_mos[:, 1].view(-1, 1),
                                              yb_mos_hat[:, 1].view(-1,
                                                                    1), idx)
                lossb_3 = biasLoss_3.get_loss(yb_mos[:, 2].view(-1, 1),
                                              yb_mos_hat[:, 2].view(-1,
                                                                    1), idx)
                lossb_4 = biasLoss_4.get_loss(yb_mos[:, 3].view(-1, 1),
                                              yb_mos_hat[:, 3].view(-1,
                                                                    1), idx)
                lossb_5 = biasLoss_5.get_loss(yb_mos[:, 4].view(-1, 1),
                                              yb_mos_hat[:, 4].view(-1,
                                                                    1), idx)

                lossb = 1 / (4 + self.args['dim_mos_w']) * (
                    self.args['dim_mos_w'] * lossb_1 +
                    (lossb_2 + lossb_3 + lossb_4 + lossb_5))

                # Backprop  -------------------------------------------------------
                lossb.backward()
                opt.step()
                opt.zero_grad()

                # Update total loss -----------------------------------------------
                loss += lossb.item()
                batch_cnt += 1

                if self.args['tr_verbose'] == 2:
                    pbar.set_postfix(loss=lossb.item())
                    pbar.update()

            if self.args['tr_verbose'] == 2:
                pbar.close()

            loss = loss / batch_cnt

            biasLoss_1.update_bias(y_train[:, 0].reshape(-1, 1),
                                   y_train_hat[:, 0].reshape(-1, 1))
            biasLoss_2.update_bias(y_train[:, 1].reshape(-1, 1),
                                   y_train_hat[:, 1].reshape(-1, 1))
            biasLoss_3.update_bias(y_train[:, 2].reshape(-1, 1),
                                   y_train_hat[:, 2].reshape(-1, 1))
            biasLoss_4.update_bias(y_train[:, 3].reshape(-1, 1),
                                   y_train_hat[:, 3].reshape(-1, 1))
            biasLoss_5.update_bias(y_train[:, 4].reshape(-1, 1),
                                   y_train_hat[:, 4].reshape(-1, 1))

            # Evaluate   -----------------------------------------------------------
            self.ds_train.df['y_hat_mos'] = y_train_hat[:, 0].reshape(-1, 1)
            self.ds_train.df['y_hat_noi'] = y_train_hat[:, 1].reshape(-1, 1)
            self.ds_train.df['y_hat_dis'] = y_train_hat[:, 2].reshape(-1, 1)
            self.ds_train.df['y_hat_col'] = y_train_hat[:, 3].reshape(-1, 1)
            self.ds_train.df['y_hat_loud'] = y_train_hat[:, 4].reshape(-1, 1)

            print('--> MOS:')
            db_results_train_mos, r_train_mos = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos='mos',
                target_ci='mos_ci',
                pred='mos_pred',
                mapping='first_order',
                do_print=True)

            print('--> NOI:')
            db_results_train_noi, r_train_noi = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos='noi',
                target_ci='noi_ci',
                pred='noi_pred',
                mapping='first_order',
                do_print=True)

            print('--> DIS:')
            db_results_train_dis, r_train_dis = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos='dis',
                target_ci='dis_ci',
                pred='dis_pred',
                mapping='first_order',
                do_print=True)

            print('--> COL:')
            db_results_train_col, r_train_col = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos='col',
                target_ci='col_ci',
                pred='col_pred',
                mapping='first_order',
                do_print=True)

            print('--> LOUD:')
            db_results_train_loud, r_train_loud = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos='loud',
                target_ci='loud_ci',
                pred='loud_pred',
                mapping='first_order',
                do_print=True)

            SQ.predict_dim(self.model,
                           self.ds_val,
                           self.args['tr_bs_val'],
                           self.dev,
                           num_workers=self.args['tr_num_workers'])

            print('--> MOS:')
            db_results_val_mos, r_val_mos = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos='mos',
                target_ci='mos_ci',
                pred='mos_pred',
                mapping='first_order',
                do_print=True)
            # r_val_mos = {k+'_mos': v for k, v in r_val_mos.items()}

            print('--> NOI:')
            db_results_val_noi, r_val_noi = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos='noi',
                target_ci='noi_ci',
                pred='noi_pred',
                mapping='first_order',
                do_print=True)
            r_val_noi = {k + '_noi': v for k, v in r_val_noi.items()}

            print('--> DIS:')
            db_results_val_dis, r_val_dis = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos='dis',
                target_ci='dis_ci',
                pred='dis_pred',
                mapping='first_order',
                do_print=True)
            r_val_dis = {k + '_dis': v for k, v in r_val_dis.items()}

            print('--> COL:')
            db_results_val_col, r_val_col = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos='col',
                target_ci='col_ci',
                pred='col_pred',
                mapping='first_order',
                do_print=True)
            r_val_col = {k + '_col': v for k, v in r_val_col.items()}

            print('--> LOUD:')
            db_results_val_loud, r_val_loud = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos='loud',
                target_ci='loud_ci',
                pred='loud_pred',
                mapping='first_order',
                do_print=True)
            r_val_loud = {k + '_loud': v for k, v in r_val_loud.items()}

            r = {
                'train_r_p_mean_con':
                r_train_mos['r_p_mean_con'],
                'train_rmse_mean_con':
                r_train_mos['rmse_mean_con'],
                'train_rmse_star_map_mean_con':
                r_train_mos['rmse_star_map_mean_con'],
                'train_r_p_mean_con_noi':
                r_train_noi['r_p_mean_con'],
                'train_rmse_mean_con_noi':
                r_train_noi['rmse_mean_con'],
                'train_rmse_star_map_mean_con_noi':
                r_train_noi['rmse_star_map_mean_con'],
                'train_r_p_mean_con_dis':
                r_train_dis['r_p_mean_con'],
                'train_rmse_mean_con_dis':
                r_train_dis['rmse_mean_con'],
                'train_rmse_star_map_mean_con_dis':
                r_train_dis['rmse_star_map_mean_con'],
                'train_r_p_mean_con_col':
                r_train_col['r_p_mean_con'],
                'train_rmse_mean_con_col':
                r_train_col['rmse_mean_con'],
                'train_rmse_star_map_mean_con_col':
                r_train_col['rmse_star_map_mean_con'],
                'train_r_p_mean_con_loud':
                r_train_loud['r_p_mean_con'],
                'train_rmse_mean_con_loud':
                r_train_loud['rmse_mean_con'],
                'train_rmse_star_map_mean_con_loud':
                r_train_loud['rmse_star_map_mean_con'],
                **r_val_mos,
                **r_val_noi,
                **r_val_dis,
                **r_val_col,
                **r_val_loud,
            }

            db_results = {
                'db_results_val_mos': db_results_val_mos,
                'db_results_val_noi': db_results_val_noi,
                'db_results_val_dis': db_results_val_dis,
                'db_results_val_col': db_results_val_col,
                'db_results_val_loud': db_results_val_loud
            }

            # Scheduler update    ---------------------------------------------
            scheduler.step(loss)
            earl_stp = earlyStp.step(r)

            # Print    --------------------------------------------------------
            ep_runtime = time.time() - tic_epoch
            if self.args['tr_verbose'] > 0:
                print(
                    'ep {} sec {:0.0f} es {} lr {:0.0e} loss {:0.4f} // '
                    'r_p_tr {:0.2f} rmse_tr {:0.2f} rmse3s_tr {:0.2f} // r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}  // '
                    'best_r_p {:0.2f} best_rmse {:0.2f},'.format(
                        epoch + 1, ep_runtime, earlyStp.cnt, SQ.get_lr(opt),
                        loss, r['train_r_p_mean_con'],
                        r['train_rmse_mean_con'],
                        r['train_rmse_star_map_mean_con'], r['r_p_mean_con'],
                        r['rmse_mean_con'], r['rmse_star_map_mean_con'],
                        earlyStp.best_r_p, earlyStp.best_rmse))

                r_mean = 1 / 5 * (r['r_p_mean_con'] + r['r_p_mean_con_noi'] +
                                  r['r_p_mean_con_col'] + r['r_p_mean_con_dis']
                                  + r['r_p_mean_con_loud'])

                print('\nAverage r_p {:0.3f}'.format(r_mean))

            # Save results and model  -----------------------------------------
            self._saveResults(self.model, self.model_args, opt, epoch, loss,
                              ep_runtime, r, db_results)

            # Early stopping    -----------------------------------------------
            if earl_stp:
                print('--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}'.
                      format(earlyStp.best_r_p, earlyStp.best_rmse))
                return

        # Training done --------------------------------------------------------
        print('--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}'.format(
            earlyStp.best_r_p, earlyStp.best_rmse))
        return
Beispiel #9
0
    def _evaluate_dim(self):
        print('--> MOS:')
        self.db_results_val_mos, r_val_mos = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='mos',
            target_ci='mos_ci',
            pred='mos_pred',
            mapping='first_order',
            do_print=True,
            do_plot=False)
        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            r_val_mos['r_p_mean_con'], r_val_mos['rmse_mean_con'],
            r_val_mos['rmse_star_map_mean_con']))

        # r_val_mos = {k+'_mos': v for k, v in r_val_mos.items()}

        print('--> NOI:')
        self.db_results_val_noi, r_val_noi = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='noi',
            target_ci='noi_ci',
            pred='noi_pred',
            mapping='first_order',
            do_print=True,
            do_plot=False)
        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            r_val_noi['r_p_mean_con'], r_val_noi['rmse_mean_con'],
            r_val_noi['rmse_star_map_mean_con']))
        r_val_noi = {k + '_noi': v for k, v in r_val_noi.items()}

        print('--> DIS:')
        self.db_results_val_dis, r_val_dis = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='dis',
            target_ci='dis_ci',
            pred='dis_pred',
            mapping='first_order',
            do_print=True,
            do_plot=False)
        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            r_val_dis['r_p_mean_con'], r_val_dis['rmse_mean_con'],
            r_val_dis['rmse_star_map_mean_con']))
        r_val_dis = {k + '_dis': v for k, v in r_val_dis.items()}

        print('--> COL:')
        self.db_results_val_col, r_val_col = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='col',
            target_ci='col_ci',
            pred='col_pred',
            mapping='first_order',
            do_print=True,
            do_plot=False)
        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            r_val_col['r_p_mean_con'], r_val_col['rmse_mean_con'],
            r_val_col['rmse_star_map_mean_con']))
        r_val_col = {k + '_col': v for k, v in r_val_col.items()}

        print('--> LOUD:')
        self.db_results_val_loud, r_val_loud = SQ.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='loud',
            target_ci='loud_ci',
            pred='loud_pred',
            mapping='first_order',
            do_print=True,
            do_plot=False)
        print('r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}'.format(
            r_val_loud['r_p_mean_con'], r_val_loud['rmse_mean_con'],
            r_val_loud['rmse_star_map_mean_con']))
        r_val_loud = {k + '_loud': v for k, v in r_val_loud.items()}

        self.r = {
            **r_val_mos,
            **r_val_noi,
            **r_val_dis,
            **r_val_col,
            **r_val_loud,
        }

        r_mean = 1 / 5 * (self.r['r_p_mean_con'] + self.r['r_p_mean_con_noi'] +
                          self.r['r_p_mean_con_col'] +
                          self.r['r_p_mean_con_dis'] +
                          self.r['r_p_mean_con_loud'])

        print('\nAverage r_p {:0.3f}'.format(r_mean))
Beispiel #10
0
    def _train_mos(self):
        '''
        Trains speech quality model.
        '''
        # Initialize  -------------------------------------------------------------
        if self.args['tr_parallel']:
            self.model = nn.DataParallel(self.model)
        self.model.to(self.dev)

        # Runname and savepath  ---------------------------------------------------
        self.runname = self._makeRunname()

        # Optimizer  -------------------------------------------------------------
        opt = optim.Adam(self.model.parameters(), lr=self.args['tr_lr'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            'min',
            verbose=True,
            threshold=0.003,
            patience=self.args['tr_lr_patience'])
        earlyStp = SQ.earlyStopper(self.args['tr_early_stop'])

        biasLoss = SQ.biasLoss(self.ds_train.df.db,
                               anchor_db=self.args['tr_bias_anchor_db'],
                               mapping=self.args['tr_bias_mapping'],
                               min_r=self.args['tr_bias_min_r'])

        # Dataloader    -----------------------------------------------------------
        dl_train = DataLoader(self.ds_train,
                              batch_size=self.args['tr_bs'],
                              shuffle=True,
                              drop_last=False,
                              pin_memory=True,
                              num_workers=self.args['tr_num_workers'])

        # Start training loop   ---------------------------------------------------
        print('--> start training')
        for epoch in range(self.args['tr_epochs']):

            tic_epoch = time.time()
            batch_cnt = 0
            loss = 0.0
            y_train = self.ds_train.df[
                self.args['csv_mos_train']].to_numpy().reshape(-1)
            y_train_hat = np.zeros((len(self.ds_train), 1))
            self.model.train()

            # Progress bar
            if self.args['tr_verbose'] == 2:
                pbar = tqdm(
                    iterable=batch_cnt,
                    total=len(dl_train),
                    ascii=">—",
                    bar_format=
                    '{bar} {percentage:3.0f}%, {n_fmt}/{total_fmt}, {elapsed}<{remaining}{postfix}'
                )

            for xb_spec, yb_mos, (idx, n_wins) in dl_train:

                # Estimate batch ---------------------------------------------------
                xb_spec = xb_spec.to(self.dev)
                yb_mos = yb_mos.to(self.dev)
                n_wins = n_wins.to(self.dev)

                # Forward pass ----------------------------------------------------
                yb_mos_hat = self.model(xb_spec, n_wins)
                y_train_hat[idx] = yb_mos_hat.detach().cpu().numpy()

                # Loss ------------------------------------------------------------
                # lossb = F.mse_loss(yb_mos_hat, yb_mos)
                lossb = biasLoss.get_loss(yb_mos, yb_mos_hat, idx)

                # Backprop  -------------------------------------------------------
                lossb.backward()
                opt.step()
                opt.zero_grad()

                # Update total loss -----------------------------------------------
                loss += lossb.item()
                batch_cnt += 1

                if self.args['tr_verbose'] == 2:
                    pbar.set_postfix(loss=lossb.item())
                    pbar.update()

            if self.args['tr_verbose'] == 2:
                pbar.close()

            loss = loss / batch_cnt

            biasLoss.update_bias(y_train, y_train_hat)

            # Evaluate   -----------------------------------------------------------
            self.ds_train.df['mos_pred'] = y_train_hat
            db_results_train, r_train = SQ.eval_results(
                self.ds_train.df,
                dcon=self.ds_train.df_con,
                target_mos=self.args['csv_mos_train'],
                target_ci=self.args['csv_mos_train'] + '_ci',
                pred='mos_pred',
                mapping='first_order',
                do_print=True)

            SQ.predict_mos(self.model,
                           self.ds_val,
                           self.args['tr_bs_val'],
                           self.dev,
                           num_workers=self.args['tr_num_workers'])
            db_results, r_val = SQ.eval_results(
                self.ds_val.df,
                dcon=self.ds_val.df_con,
                target_mos=self.args['csv_mos_val'],
                target_ci=self.args['csv_mos_val'] + '_ci',
                pred='mos_pred',
                mapping='first_order',
                do_print=True)

            r = {
                'train_r_p_mean_con': r_train['r_p_mean_con'],
                'train_rmse_mean_con': r_train['rmse_mean_con'],
                'train_rmse_star_map_mean_con':
                r_train['rmse_star_map_mean_con'],
                **r_val
            }

            # Scheduler update    ---------------------------------------------
            scheduler.step(loss)
            earl_stp = earlyStp.step(r)

            # Print    --------------------------------------------------------
            ep_runtime = time.time() - tic_epoch
            if self.args['tr_verbose'] > 0:
                print(
                    'ep {} sec {:0.0f} es {} lr {:0.0e} loss {:0.4f} // '
                    'r_p_tr {:0.2f} rmse_tr {:0.2f} rmse3s_tr {:0.2f} // r_p {:0.2f} rmse {:0.2f} rmse3s {:0.2f}  // '
                    'best_r_p {:0.2f} best_rmse {:0.2f},'.format(
                        epoch + 1, ep_runtime, earlyStp.cnt, SQ.get_lr(opt),
                        loss, r['train_r_p_mean_con'],
                        r['train_rmse_mean_con'],
                        r['train_rmse_star_map_mean_con'], r['r_p_mean_con'],
                        r['rmse_mean_con'], r['rmse_star_map_mean_con'],
                        earlyStp.best_r_p, earlyStp.best_rmse))

            # Save results and model  -----------------------------------------
            self._saveResults(self.model, self.model_args, opt, epoch, loss,
                              ep_runtime, r, db_results)

            # Early stopping    -----------------------------------------------
            if earl_stp:
                print('--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}'.
                      format(earlyStp.best_r_p, earlyStp.best_rmse))
                return

        # Training done --------------------------------------------------------
        self._logMetricsFinal(self.results_hist)
        print('--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}'.format(
            earlyStp.best_r_p, earlyStp.best_rmse))
        return