def semi_train(task_name,sed_model_name,at_model_name,augmentation): """" Training with semi-supervised learning (Guiding learning) Args: task_name: string the name of the task sed_model_name: string the name of the the PS-model at_model_name: string the name of the the PT-model augmentation: bool whether to add Gaussian noise to the input of the PT-model Return: """ #prepare for training of the PS-model LOG.info('config preparation for %s'%at_model_name) train_sed=trainer.trainer(task_name,sed_model_name,False) #prepare for training of the PT-model LOG.info('config preparation for %s'%sed_model_name) train_at=trainer.trainer(task_name,at_model_name,False) #connect the outputs of the two models to produce a model for end-to-end learning creat_model_at=train_at.model_struct.graph() creat_model_sed=train_sed.model_struct.graph() LEN=train_sed.data_loader.LEN DIM=train_sed.data_loader.DIM inputs=Input((LEN,DIM)) #add Gaussian noise if augmentation: at_inputs=GaussianNoise(0.15)(inputs) else: at_inputs=inputs at_out=creat_model_at(at_inputs,False) sed_out=creat_model_sed(inputs,False) out=concatenate([at_out,sed_out],axis=-1) models=Model(inputs,out) #start training (all intermediate files are saved in the PS-model dir) LOG.info('------------start training------------') train_sed.train(models) #copy the final model to the PT-model dir from the PS-model dir shutil.copyfile(train_sed.best_model_path,train_at.best_model_path) #predict results for validation set and test set (the PT-model) LOG.info('------------result of %s------------'%at_model_name) train_at.save_at_result() #audio tagging result #predict results for validation set and test set (the PS-model) LOG.info('------------result of %s------------'%sed_model_name) train_sed.save_at_result() #audio tagging result train_sed.save_sed_result() #event detection result
def on_train_end(self, logs={}): """" (overwrite) The end of training. """ best_epoch = self.best_epoch best_f1 = self.best_f1 #report the best performance of the PS-model LOG.info('[ best vali f1 : %f at epoch %d ]' % (best_f1, best_epoch))
def on_train_begin(self, logs={}): """" (overwrite) The beginning of training. """ #check extra required attributes self.check_attributes() LOG.info('init training...') LOG.info('metrics : %s %s' % (self.metric, self.ave)) opt = self.get_opt(self.learning_rate) loss = self.get_loss() #compile the model with specific loss function self.model.compile(optimizer=opt, loss=loss)
def test_models(task_name, model_name, model_list_path): """" Test with prepared model dir. The format of the model dir and model weights must be consistent with the required format. Args: task_name: string the name of the task model_name: string the name of the model model_list_path: string the path of file which keeps a list of paths of model weights Return: """ def predict(A): A[A >= 0.5 ] = 1 A[A < 0.5] = 0 return A if model_list_path == None: test(task_name,sed_model_name) else: with open(model_list_path) as f: model_list = f.readlines() model_list = [m.rstrip() for m in model_list] if len(model_list) == 1: LOG.info( 'ensemble results (just a single model)') test(task_name, sed_model_name, model_list[0]) return at_results={} sed_results={} mode = ['vali', 'test'] for model_path in model_list: LOG.info( 'decode for model : {}'.format(model_path)) at_preds, sed_preds = test(task_name, sed_model_name, model_path) for m in mode: if m not in at_results: at_results[m] = predict(at_preds[m]) sed_results[m] = sed_preds[m] else: at_results[m] += predict(at_preds[m]) sed_results[m] += sed_preds[m] for m in mode: at = copy.deepcopy(at_results[m]) #vote for boundary detection mask = np.reshape(at, [at.shape[0],1,at.shape[1]]) mask[mask == 0] = 1 sed_results[m] /= mask #vote for audio tagging at_results[m] = at / len(model_list) sed_results[m] = [at_results[m], sed_results[m]] LOG.info( 'ensemble results') test(task_name, sed_model_name, None, at_results, sed_results)
def supervised_train(task_name,sed_model_name,augmentation): """" Training with only weakly-supervised learning Args: task_name: string the name of the task sed_model_name: string the name of the model augmentation: bool whether to add Gaussian noise Layer Return: """ LOG.info('config preparation for %s'%sed_model_name) #prepare for training train_sed=trainer.trainer(task_name,sed_model_name,False) #creat model using the model structure prepared in [train_sed] creat_model_sed=train_sed.model_struct.graph() LEN=train_sed.data_loader.LEN DIM=train_sed.data_loader.DIM inputs=Input((LEN,DIM)) #add Gaussian noise Layer if augmentation: inputs_t=GaussianNoise(0.15)(inputs) else: inputs_t=inputs outs=creat_model_sed(inputs_t,False) #the model used for training models=Model(inputs,outs) LOG.info('------------start training------------') train_sed.train(extra_model=models,train_mode='supervised') #predict results for validation set and test set train_sed.save_at_result() #audio tagging result train_sed.save_sed_result() #event detection result
#vote for boundary detection mask = np.reshape(at, [at.shape[0],1,at.shape[1]]) mask[mask == 0] = 1 sed_results[m] /= mask #vote for audio tagging at_results[m] = at / len(model_list) sed_results[m] = [at_results[m], sed_results[m]] LOG.info( 'ensemble results') test(task_name, sed_model_name, None, at_results, sed_results) if __name__=='__main__': LOG.info('Disentangled feature') parser = argparse.ArgumentParser(description='') parser.add_argument('-n', '--task_name', dest='task_name', help='task name') parser.add_argument('-s', '--PS_model_name', dest='PS_model_name', help='the name of the PS model') parser.add_argument('-t', '--PT_model_name', dest='PT_model_name', help='the name of the PT model') parser.add_argument('-md', '--mode', dest='mode', help='train or test') parser.add_argument('-g', '--augmentation', dest='augmentation', help='select [true or false] : whether to use augmentation (add Gaussian noise)') parser.add_argument('-u', '--semi_supervised', dest='semi_supervised', help='select [true or false] : whether to use unlabel data')
def save_sed(self, mode='test', sed_preds={}, is_add=False): """" Args: mode: string in ['vali','test'] the dataset to predict at_preds: dict If there is no prediction for the current data set contained in the sed_preds, the prediction will be generated by the model. Otherwise the prediction in the at_preds is considered as the prediction of the model. is_add: bool whether to open the result files by append Return: preds_ori: numpy.array prediction (possibilities) """ model_path = self.best_model_path result_dir = self.result_dir model_name = self.model_name data_loader = self.data_loader f1_utils = self.utils result_path = os.path.join(result_dir, model_name + '_sed.txt') detail_sed_path = os.path.join(result_dir, model_name + '_detail_sed.txt') #path to save prediction (fomatted string) preds_csv_path = os.path.join(result_dir, model_name + '_%s_preds.csv' % mode) #get clip-level prediction and frame-level prediction preds, frame_preds = self.test(mode, 'sed', sed_preds) ori_frame_preds = copy.deepcopy(frame_preds) outs = [] #load the file list and the groundtruths if mode == 'vali': lst, csv = data_loader.get_vali() else: lst, csv = data_loader.get_test() #prepare the file list and the groundtruths for counting scores f1_utils.set_vali_csv(lst, csv) #get F1 performance (segment_based and event_based) segment_based_metrics, event_based_metrics = f1_utils.get_f1( preds, frame_preds, mode='sed') seg_event = [segment_based_metrics, event_based_metrics] seg_event_str = ['segment_based', 'event_based'] for i, u in enumerate(seg_event): re = u.results_class_wise_average_metrics() f1 = re['f_measure']['f_measure'] er = re['error_rate']['error_rate'] pre = re['f_measure']['precision'] recall = re['f_measure']['recall'] dele = re['error_rate']['deletion_rate'] ins = re['error_rate']['insertion_rate'] outs += [ '[ result sed %s %s macro f1 : %f, er : %f, pre : %f, recall : %f, deletion : %f, insertion : %f ]' % (mode, seg_event_str[i], f1, er, pre, recall, dele, ins) ] #show result for o in outs: LOG.info(o) #save result self.save_str(result_path, outs, is_add) #save class-wise performaces into a file for u in seg_event: self.save_str(detail_sed_path, [u.__str__()], is_add) is_add = True #copy prediction csv file from evaluation dir to result dir shutil.copyfile(f1_utils.preds_path, preds_csv_path) preds = np.reshape(preds, [preds.shape[0], 1, preds.shape[1]]) #return frame-level prediction (probilities) return ori_frame_preds * preds
def save_at(self, mode='test', at_preds={}, is_add=False): """" Args: mode: string in ['vali','test'] the dataset to predict at_preds: dict If there is no prediction for the current data set contained in the at_preds, the prediction will be generated by the model. Otherwise the prediction in the at_preds is considered as the prediction of the model. is_add: bool whether to open the result files by append Return: preds_ori: numpy.array prediction (possibilities) """ result_dir = self.result_dir model_name = self.model_name data_loader = self.data_loader f1_utils = self.utils result_path = os.path.join(result_dir, model_name + '_at.txt') detail_at_path = os.path.join(result_dir, model_name + '_detail_at.txt') #load the file list and the groundtruths if mode == 'vali': lst, csv = data_loader.get_vali() elif mode == 'test': lst, csv = data_loader.get_test() #prepare the file list and the groundtruths for counting scores f1_utils.set_vali_csv(lst, csv) #get clip-level prediction and weakly-labeled data preds, labels = self.test(mode, 'at', at_preds) preds_ori = copy.deepcopy(preds) #get F1 performance f1, precision, recall, cf1, cpre, crecall = f1_utils.get_f1(preds, labels, mode='at') outs = [] #result string to show and save outs += [ '[ result audio tagging %s f1 : %f, precision : %f, recall : %f ]' % (mode, f1, precision, recall) ] #show result for o in outs: LOG.info(o) data_loader = self.data_loader label_lst = data_loader.events details = [] for i in range(len(label_lst)): line = '%s\tf1: %f\tpre: %f\trecall: %f' % (label_lst[i], cf1[i], cpre[i], crecall[i]) details += [line] #save result self.save_str(result_path, outs, is_add) self.save_str(detail_at_path, details, is_add) #return clip-level prediction (posibilities) return preds_ori
def on_epoch_end(self, epoch, logs={}): """" (overwrite) The end of a training epoch. """ best_f1 = self.best_f1 f1_utils = self.f1_utils CLASS = self.CLASS train_mode = self.train_mode early_stop = self.early_stop #get the features of the validation data vali_data = self.validation_data #get the labels of the validation data labels = vali_data[1][:, :CLASS] #get audio tagging predictions of the model preds = self.model.predict(vali_data[0], batch_size=self.batch_size) if train_mode == 'semi': #get the predictions of the PT-model preds_PT = preds[:, :CLASS] #get the predictions of the PS-model preds_PS = preds[:, CLASS:] #count F1 score on the validation set for the PT-model pt_f1 = self.get_at(preds_PT, labels) #count F1 score on the validation set for the PS-model ps_f1 = self.get_at(preds_PS, labels) else: #count F1 score on the validation set for the PS-model ps_f1 = self.get_at(preds, labels) #the final performance depends on the PS-model logs['f1_val'] = ps_f1 is_best = 'not_best' #preserve the best model during training if logs['f1_val'] >= self.best_f1: self.best_f1 = logs['f1_val'] self.best_epoch = epoch self.model.save_weights(self.best_model_path) is_best = 'best' self.wait = 0 #the PS-model has not been improved after [wait] epochs self.wait += 1 #training early stops if there is no more improvement if self.wait > early_stop: self.stopped_epoch = epoch self.model.stop_training = True if train_mode == 'semi': LOG.info('[ epoch %d , sed f1 : %f , at f1 : %f ] %s' % (epoch, logs['f1_val'], pt_f1, is_best)) else: LOG.info('[ epoch %d, f1 : %f ] %s' % (epoch, logs['f1_val'], is_best)) #learning rate decays every epoch_of_decay epochs if epoch > 0 and epoch % self.epoch_of_decay == 0: self.learning_rate *= self.decay_rate opt = self.get_opt(self.learning_rate) LOG.info('[ epoch %d , learning rate decay to %f ]' % (epoch, self.learning_rate)) loss = self.get_loss() #recompile the model with decreased learning rate self.model.compile(optimizer=opt, loss=loss)