Пример #1
0
    def __call__(self):
        # load all train data
        provider = DataProvider()
        test_bass_list, test_drums_list, test_other_list, test_vocals_list = provider.load_all_test_data(
        )
        # define model
        tf_mix = tf.placeholder(tf.float32,
                                (1, self.sample_len))  #Batch, Sample
        tf_est_source = self.__model(tf_mix)

        # GPU config
        config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            visible_device_list='0',  # specify GPU number
            allow_growth=True))

        saver = tf.train.import_meta_graph(
            './../results/model/FFN_ver2/ffn_ver2_3999.ckpt.meta')
        with tf.Session(config=config) as sess:
            saver.restore(sess,
                          './../results/model/FFN_ver2/ffn_ver2_3999.ckpt')

            total_parameters = 0
            parameters_string = ""

            for variable in tf.trainable_variables():
                shape = variable.get_shape()
                variable_parameters = 1
                for dim in shape:
                    variable_parameters *= dim.value
                total_parameters += variable_parameters
                if len(shape) == 1:
                    parameters_string += ("%s %d, " %
                                          (variable.name, variable_parameters))
                else:
                    parameters_string += (
                        "%s %s=%d, " %
                        (variable.name, str(shape), variable_parameters))

            print(parameters_string)
            print("Total %d variables, %s params" % (len(
                tf.trainable_variables()), "{:,}".format(total_parameters)))
Пример #2
0
        def __call__(self):                                              
                # load all train data
                provider = DataProvider()
                test_bass_list, test_drums_list, test_other_list, test_vocals_list = provider.load_all_test_data()
                # define model
                tf_mix = tf.placeholder(tf.float32, (None, self.sample_len)) #Batch, Sample
                tf_est_source = self.__model(tf_mix)
                
                # GPU config
                config = tf.ConfigProto(
                        gpu_options=tf.GPUOptions(
                                visible_device_list='0', # specify GPU number
                                allow_growth = True
                        )
                )
                
                saver = tf.train.import_meta_graph('./../results/model/mini-U-Net/mini-U-Net_1999.ckpt.meta')
                with tf.Session(config = config) as sess:
                        saver.restore(sess, './../results/model/mini-U-Net/mini-U-Net_1999.ckpt')
                        
                        test_mixed_list = []
                        for bass, drums, other, vocals in zip(test_bass_list, test_drums_list, test_other_list, test_vocals_list):
                            test_mixed_list.append(AudioModule.mixing(
                                                                                bass,
                                                                                drums,
                                                                                other,
                                                                                vocals
                                                               ))
                        test_target_list = test_vocals_list
                        tf.keras.backend.set_learning_phase(0)
                        # make mix audio
                        est_start = time.time()
                        for mix in test_mixed_list:
                            cutted_mix_array = provider.test_data_split_and_pad(mix, self.sample_len)                            
                            tmp_est_data_array = np.zeros((len(cutted_mix_array), self.sample_len))
                            for index, mix_packet in enumerate(cutted_mix_array):
                                mix_packet = mix_packet.reshape(1,-1)
                                est_source = sess.run(tf_est_source, feed_dict = {
                                       tf_mix: mix_packet[:,:]
                                    }
                                 )
                                tmp_est_data_array[index,:] = est_source
                                
                            self.est_audio_list.append(tmp_est_data_array.reshape(1,-1))
                        est_end = time.time()
                        print("excuted time", est_end - est_start)
                        
                        evaluate_start = time.time()
                        for est, target, mix in zip(self.est_audio_list, test_target_list, test_mixed_list):
                                target = target.reshape(1,-1)
                                mix = mix.reshape(1,-1)
                                
                                est_array = np.zeros((2, target.shape[1]))
                                est_array[0,:] = est[:, :target.shape[1]]
                                est_array[1,:] = mix[:, :target.shape[1]] - est[:, :target.shape[1]]

                                target_array = np.zeros((2, target.shape[1]))
                                target_array[0,:] = target
                                target_array[1,:] = mix[:, :target.shape[1]] - target
                                
                                sdr, sir, sar, perm =  mir_eval.separation.bss_eval_sources(target_array, est_array)
                                self.sdr_list.append(sdr[0])
                                self.sir_list.append(sir[0])
                                self.sar_list.append(sar[0])
                        print('sdr mean',np.mean(self.sdr_list))
                        print('sir mean',np.mean(self.sir_list))
                        print('sar mean',np.mean(self.sar_list))
                        
                        print('sdr median', np.median(self.sdr_list))
                        print('sir median', np.median(self.sir_list))
                        print('sar median', np.median(self.sar_list))
                        
                        evaluate_end = time.time()
                        print('evaluate time', evaluate_end - evaluate_start)
                return self.est_audio_list,  test_target_list, test_mixed_list
        def __call__(self):                                              
                # load all train data
                provider = DataProvider()
                bass_list, drums_list, other_list, vocals_list = provider.load_all_train_data()
                # split train valid
                train_bass_list,    valid_bass_list = provider.split_to_train_valid(bass_list)
                train_drums_list, valid_drums_list = provider.split_to_train_valid(drums_list)
                train_other_list,   valid_other_list = provider.split_to_train_valid(other_list)
                train_vocals_list,  valid_vocals_list = provider.split_to_train_valid(vocals_list)
                # define model
                tf_lr = tf.placeholder(tf.float32) # learning rate
                tf_mix = tf.placeholder(tf.float32, (None, self.sample_len)) #Batch, Sample
                tf_target = tf.placeholder(tf.float32, (None, self.sample_len)) #Batch,Sample
                
                tf_train_step, tf_loss , tf_target_spec, tf_mag_mix_spec, tf_ori_mix_spec, tf_est_masks, tf_est_spec = self.__model(tf_mix, tf_target, tf_lr)
                
                # GPU config
                config = tf.ConfigProto(
                        gpu_options=tf.GPUOptions(
                                visible_device_list='0', # specify GPU number
                                allow_growth = True
                        )
                )
                with tf.Session(config = config) as sess:
                        init = tf.global_variables_initializer()  
                        sess.run(init)
                        print("Start Training")
                        net_saver = NetSaver(saver_folder_name='UNet_other_sources_bass',  saver_file_name='u_net_bass')
                        early_stopping = EarlyStopping()
                        for epoch in range(self.epoch_num):
                                sys.stdout.flush()
                                print('epoch:' + str(epoch))
                                start = time.time()

                                train_data_argument = DataArgument(self.fs, self.sec, self.train_data_num)
                                train_arg_bass_array = train_data_argument(train_bass_list)
                                train_arg_drums_array = train_data_argument(train_drums_list)                                
                                train_arg_other_array = train_data_argument(train_other_list)
                                train_arg_vocals_array = train_data_argument(train_vocals_list)
                                
                                valid_data_argument = DataArgument(self.fs, self.sec, self.valid_data_num)
                                valid_arg_bass_array = valid_data_argument(valid_bass_list)
                                valid_arg_drums_array = valid_data_argument(valid_drums_list)
                                valid_arg_other_array = valid_data_argument(valid_other_list)
                                valid_arg_vocals_array = valid_data_argument(valid_vocals_list)  
                                
                                self.train_iter = int(len(train_arg_bass_array) / self.batch_size)
                                self.valid_iter = int(len(valid_arg_bass_array) / self.batch_size)
                                # mixing
                                train_mixed_array = AudioModule.mixing(
                                                                                    train_arg_bass_array,
                                                                                    train_arg_drums_array,
                                                                                    train_arg_other_array,
                                                                                    train_arg_vocals_array
                                                                            )
                                train_target_array = train_arg_bass_array
                                
                                valid_mixed_array = AudioModule.mixing(
                                                                                    valid_arg_bass_array,
                                                                                    valid_arg_drums_array,
                                                                                    valid_arg_other_array,
                                                                                    valid_arg_vocals_array
                                                                            )
                                valid_target_array = valid_arg_bass_array
#                                
                                # training
                                
                                tf.keras.backend.set_learning_phase(1)
                                for train_time in range(self.train_iter):
                                    sess.run(tf_train_step, feed_dict = {
                                           tf_mix: train_mixed_array[train_time*self.batch_size:(train_time+1)*self.batch_size, :self.sample_len],
                                           tf_target: train_target_array[train_time*self.batch_size:(train_time+1)*self.batch_size, :self.sample_len],
                                           tf_lr: self.lr_init
                                        }
                                     )
                            
                                tmp_valid_loss_list = [] 
                                tf.keras.backend.set_learning_phase(0) 
                                for valid_time in range(self.valid_iter):                
                                    valid_loss = sess.run(tf_loss, feed_dict = {
                                               tf_mix: valid_mixed_array[valid_time*self.batch_size:(valid_time+1)*self.batch_size, :self.sample_len],
                                               tf_target: valid_target_array[valid_time*self.batch_size:(valid_time+1)*self.batch_size, :self.sample_len],
                                               tf_lr:  0.
                                            }
                                         )
                                    tmp_valid_loss_list.append(valid_loss)

                                self.valid_loss_list.append(np.mean(tmp_valid_loss_list))
                            
                                vmin = -70
                                vmax = 0
                                target_spec, mag_mix_spec, ori_spec_mix, est_mask, est_spec = sess.run([tf_target_spec, tf_mag_mix_spec , tf_ori_mix_spec, tf_est_masks, tf_est_spec], feed_dict ={
                                    tf_mix: train_mixed_array[0:1, :self.sample_len],
                                    tf_target: train_target_array[0:1, :self.sample_len],
                                    tf_lr: 0.
                                })
                    
                                est_mask = np.squeeze(est_mask, axis=-1)
                                target_spec = np.squeeze(target_spec, axis=-1)
                                mag_mix_spec = np.squeeze(mag_mix_spec, axis=-1)
                                est_spec = np.squeeze(est_spec, axis=-1)
                                print("original spec mix")
                                visualize_spec.plot_spec(ori_spec_mix[0], self.fs, self.sec, vmax, vmin)
                                print("magnitude spec mix")
                                visualize_spec.plot_log_spec(mag_mix_spec[0], self.fs, self.sec, 10, -10)
                                print("target spec")
                                visualize_spec.plot_spec(target_spec[0], self.fs, self.sec, vmax, vmin)
                                print("est mask")
                                visualize_spec.plot_log_spec(est_mask[0], self.fs, self.sec,  1, 0)
                                print("est spec")
                                visualize_spec.plot_spec(est_spec[0], self.fs, self.sec,  vmax, vmin)
                
                                visualize_loss.plot_loss(self.valid_loss_list)
                                end = time.time()
                                print(' excute time', end - start)
                                if epoch%9 ==  0:
                                    net_saver(sess, step=epoch)