def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_visual = builder.build_visual(pool_type=opt.visual_pool, weights=opt.weights_visual) net_unet = builder.build_unet(unet_num_layers=opt.unet_num_layers, ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_unet) if opt.with_additional_scene_image: opt.number_of_classes = opt.number_of_classes + 1 net_classifier = builder.build_classifier( pool_type=opt.classifier_pool, num_of_classes=opt.number_of_classes, input_channel=opt.unet_output_nc, weights=opt.weights_classifier) nets = (net_visual, net_unet, net_classifier) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() #load the two audios audio1_path = os.path.join(opt.data_path, 'audio_11025', opt.video1_name + '.wav') audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate) audio2_path = os.path.join(opt.data_path, 'audio_11025', opt.video2_name + '.wav') audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate) #make sure the two audios are of the same length and then mix them audio_length = min(len(audio1), len(audio2)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio_mix = (audio1 + audio2) / 2.0 #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 224)), transforms.ToTensor() ] if opt.subtract_mean: vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) #load the object regions of the highest confidence score for both videos detectionResult1 = np.load( os.path.join(opt.data_path, 'detection_results', opt.video1_name + '.npy')) detectionResult2 = np.load( os.path.join(opt.data_path, 'detection_results', opt.video2_name + '.npy')) avged_sep_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) for i in range(opt.num_of_object_detections_to_use): det_box1 = detectionResult1[np.argmax( detectionResult1[:, 2] ), :] #get the box of the highest confidence score det_box2 = detectionResult2[np.argmax( detectionResult2[:, 2] ), :] #get the box of the highest confidence score detectionResult1[np.argmax(detectionResult1[:, 2]), 2] = 0 # set to 0 after using it detectionResult2[np.argmax(detectionResult2[:, 2]), 2] = 0 # set to 0 after using it frame_path1 = os.path.join(opt.data_path, 'frame', opt.video1_name, "%06d.png" % det_box1[0]) frame_path2 = os.path.join(opt.data_path, 'frame', opt.video2_name, "%06d.png" % det_box2[0]) detection1 = Image.open(frame_path1).convert('RGB').crop( (det_box1[-4], det_box1[-3], det_box1[-2], det_box1[-1])) detection2 = Image.open(frame_path2).convert('RGB').crop( (det_box2[-4], det_box2[-3], det_box2[-2], det_box2[-1])) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sliding_window_start = 0 data = {} samples_per_window = opt.audio_window while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window audio_segment = audio_mix[sliding_window_start:sliding_window_end] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) data['audio_mix_mags'] = torch.FloatTensor( audio_mix_mags).unsqueeze(0) data['audio_mix_phases'] = torch.FloatTensor( audio_mix_phases).unsqueeze(0) data['real_audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing data['audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing #separate for video 1 data['visuals'] = vision_transform(detection1).unsqueeze(0) data['labels'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing data['vids'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start:sliding_window_end] + reconstructed_signal #separate for video 2 data['visuals'] = vision_transform(detection2).unsqueeze(0) #data['label'] = torch.LongTensor([0]) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start:sliding_window_end] + reconstructed_signal #update overlap count overlap_count[ sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) #deal with the last segment audio_segment = audio_mix[-samples_per_window:] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) data['audio_mix_mags'] = torch.FloatTensor(audio_mix_mags).unsqueeze(0) data['audio_mix_phases'] = torch.FloatTensor( audio_mix_phases).unsqueeze(0) data['real_audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing data['audio_mags'] = data['audio_mix_mags'] #dont' care for testing #separate for video 1 data['visuals'] = vision_transform(detection1).unsqueeze(0) data['labels'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing data['vids'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal #separate for video 2 data['visuals'] = vision_transform(detection2).unsqueeze(0) outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count) * 2) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count) * 2) separation1 = avged_sep_audio1 / opt.num_of_object_detections_to_use separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use #output original and separated audios output_dir = os.path.join(opt.output_dir_root, opt.video1_name + 'VS' + opt.video2_name) if not os.path.isdir(output_dir): os.mkdir(output_dir) librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'), separation1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'), separation2, opt.audio_sampling_rate) #save the two detections detection1.save(os.path.join(output_dir, 'audio1.png')) detection2.save(os.path.join(output_dir, 'audio2.png')) #save the spectrograms & masks if opt.visualize_spectrogram: import matplotlib.pyplot as plt plt.switch_backend('agg') plt.ioff() audio1_mag = generate_spectrogram_magphase(audio1, opt.stft_frame, opt.stft_hop, with_phase=False) audio2_mag = generate_spectrogram_magphase(audio2, opt.stft_frame, opt.stft_hop, with_phase=False) audio_mix_mag = generate_spectrogram_magphase(audio_mix, opt.stft_frame, opt.stft_hop, with_phase=False) separation1_mag = generate_spectrogram_magphase(separation1, opt.stft_frame, opt.stft_hop, with_phase=False) separation2_mag = generate_spectrogram_magphase(separation2, opt.stft_frame, opt.stft_hop, with_phase=False) utils.visualizeSpectrogram(audio1_mag[0, :, :], os.path.join(output_dir, 'audio1_spec.png')) utils.visualizeSpectrogram(audio2_mag[0, :, :], os.path.join(output_dir, 'audio2_spec.png')) utils.visualizeSpectrogram( audio_mix_mag[0, :, :], os.path.join(output_dir, 'audio_mixed_spec.png')) utils.visualizeSpectrogram( separation1_mag[0, :, :], os.path.join(output_dir, 'separation1_spec.png')) utils.visualizeSpectrogram( separation2_mag[0, :, :], os.path.join(output_dir, 'separation2_spec.png'))
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out=opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal(pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out=opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) # load data mouthroi_1 = load_mouthroi(opt.mouthroi1_path) mouthroi_2 = load_mouthroi(opt.mouthroi2_path) _, audio1 = wavfile.read(opt.audio1_path) _, audio2 = wavfile.read(opt.audio2_path) _, audio_offscreen = wavfile.read(opt.offscreen_audio_path) audio1 = audio1 / 32768 audio2 = audio2 / 32768 audio_offscreen = audio_offscreen / 32768 #make sure the two audios are of the same length and then mix them audio_length = min(min(len(audio1), len(audio2)), len(audio_offscreen)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio_offscreen = clip_audio(audio_offscreen[:audio_length]) audio_mix = (audio1 + audio2 + audio_offscreen * opt.noise_weight) / 3 if opt.reliable_face: best_score_1 = 0 best_score_2 = 0 for i in range(10): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) boxes, scores = mtcnn.detect(frame_1) if scores[0] > best_score_1: best_frame_1 = frame_1 boxes, scores = mtcnn.detect(frame_2) if scores[0] > best_score_2: best_frame_2 = frame_2 frames_1 = vision_transform(best_frame_1).squeeze().unsqueeze(0) frames_2 = vision_transform(best_frame_2).squeeze().unsqueeze(0) else: frame_1_list = [] frame_2_list = [] for i in range(opt.number_of_identity_frames): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) frame_1 = vision_transform(frame_1) frame_2 = vision_transform(frame_2) frame_1_list.append(frame_1) frame_2_list.append(frame_2) frames_1 = torch.stack(frame_1_list).squeeze().unsqueeze(0) frames_2 = torch.stack(frame_2_list).squeeze().unsqueeze(0) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sliding_window_start = 0 data = {} avged_sep_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment1_audio = audio1[sliding_window_start:sliding_window_end] segment2_audio = audio2[sliding_window_start:sliding_window_end] segment_offscreen = audio_offscreen[ sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) _, segment_offscreen = audio_normalize(segment_offscreen) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round(sliding_window_start / opt.audio_sampling_rate * 25)) segment1_mouthroi = mouthroi_1[frame_index_start:( frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:( frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start:sliding_window_end] + reconstructed_signal_1 sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start:sliding_window_end] + reconstructed_signal_2 #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_length * opt.audio_sampling_rate) #deal with the last segment #get audio spectrogram segment1_audio = audio1[-samples_per_window:] segment2_audio = audio2[-samples_per_window:] segment_offscreen = audio_offscreen[-samples_per_window:] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round((len(audio1) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment1_mouthroi = mouthroi_1[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:(frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal_1 sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal_2 #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count)) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count)) #output original and separated audios parts1 = opt.video1_path.split('/') parts2 = opt.video2_path.split('/') video1_name = parts1[-3] + '_' + parts1[-2] + '_' + parts1[-1][:-4] video2_name = parts2[-3] + '_' + parts2[-2] + '_' + parts2[-1][:-4] output_dir = os.path.join(opt.output_dir_root, video1_name + 'VS' + video2_name) if not os.path.isdir(output_dir): os.mkdir(output_dir) librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_offscreen.wav'), audio_offscreen, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'), avged_sep_audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'), avged_sep_audio2, opt.audio_sampling_rate)
(epoch, total_batches)) torch.save( net_visual.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'visual_latest.pth')) torch.save( net_unet.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'unet_latest.pth')) torch.save( net_classifier.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'classifier_latest.pth')) if (total_batches % opt.validation_freq == 0 and opt.validation_on): model.eval() opt.mode = 'val' print( 'Display validation results at (epoch %d, total_batches %d)' % (epoch, total_batches)) val_err = display_val(model, crit, writer, total_batches, dataset_val, opt) print('end of display \n') model.train() opt.mode = 'main' #save the model that achieves the smallest validation error if val_err < best_err: best_err = val_err print( 'saving the best model (epoch %d, total_batches %d) with validation error %.3f\n' % (epoch, total_batches, val_err))
def test_sepration(opt, nets, output_dir, save_files=False): #load test arguments # altered_visual1 = opt.video1_name # '1oz3h9doX_g_5' # altered_visual2 = opt.video2_name # '2R12lQszz90_4' opt.visualize_spectrogram = True model = AudioVisualModel(nets, opt) #model = torch.nn.DataParallel(model, device_ids=[0]) model.to('cuda') model.eval() #load the two audios audio1_path = os.path.join(opt.data_path, 'solo_audio_resample', opt.video1_ins, opt.video1_name + '.wav') audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate) audio2_path = os.path.join(opt.data_path, 'solo_audio_resample', opt.video2_ins, opt.video2_name + '.wav') audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate) audio3_path = os.path.join(opt.data_path, 'solo_audio_resample', opt.video3_ins, opt.video3_name + '.wav') audio3, _ = librosa.load(audio3_path, sr=opt.audio_sampling_rate) #make sure the two audios are of the same length and then mix them audio_length = min(len(audio1), len(audio2), len(audio3)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio3 = clip_audio(audio3[:audio_length]) audio_mix = (audio1 + audio2 + audio3) / 3.0 #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 224)), transforms.ToTensor() ] vision_transform_list_for_unet = [ transforms.Resize((256, 256)), transforms.ToTensor() ] if opt.subtract_mean: vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform_list_for_unet.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) vision_transform_for_unet = transforms.Compose( vision_transform_list_for_unet) #load the object regions of the highest confidence score for both videos detectionResult1 = np.load( os.path.join(opt.data_path, 'solo_detect', opt.video1_ins, opt.video1_name + '.npy')) detectionResult2 = np.load( os.path.join(opt.data_path, 'solo_detect', opt.video2_ins, opt.video2_name + '.npy')) detectionResult3 = np.load( os.path.join(opt.data_path, 'solo_detect', opt.video3_ins, opt.video3_name + '.npy')) clip_det_bbs1 = sample_object_detections(detectionResult1) clip_det_bbs2 = sample_object_detections(detectionResult2) clip_det_bbs3 = sample_object_detections(detectionResult3) avged_sep_audio1 = np.zeros((audio_length)) avged_refine_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) avged_refine_audio2 = np.zeros((audio_length)) avged_sep_audio3 = np.zeros((audio_length)) avged_refine_audio3 = np.zeros((audio_length)) for i in range(opt.num_of_object_detections_to_use): # 第一个的筛选 if clip_det_bbs1.shape[0] == 1: frame_path1 = os.path.join(opt.data_path, 'solo_extract', opt.video1_ins, opt.video1_name, "%06d.png" % clip_det_bbs1[0, 0]) detection_bbs_filter_1 = clip_det_bbs1[0] elif clip_det_bbs1.shape[0] >= 2: hand_npy = os.path.join(opt.data_path, 'solo_detect_hand', opt.video1_ins, opt.video1_name + '_hand.npy') if os.path.exists(hand_npy): hand_bbs = np.load(hand_npy) else: hand_bbs = np.array([]) if hand_bbs.shape[0] == 0: hand_bb = np.array([]) sign = False print("this npy file {} donot have detected hands".format( os.path.basename(hand_npy))) elif hand_bbs.shape[0] == 1: hand_bb = hand_bbs sign = True elif hand_bbs.shape[ 0] >= 2: # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个 the_max = np.argmax(hand_bbs[:, 1]) hand_bb1 = hand_bbs[the_max, :] # 取一个概率最大的 hand_bb1 = hand_bb1[np.newaxis, :] hand_bbs[the_max, 1] = 0 # 取出后置为0 the_second_max = np.argmax(hand_bbs[:, 1]) # 取一个次大的。 hand_bb2 = hand_bbs[the_second_max, :] hand_bb2 = hand_bb2[np.newaxis, :] hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0) sign = True detection_bbs_filter_1 = filter_det_bbs(hand_bb, sign, clip_det_bbs1) frame_path1 = os.path.join(opt.data_path, 'solo_extract', opt.video1_ins, opt.video1_name, "%06d.png" % detection_bbs_filter_1[0]) detection1 = Image.open(frame_path1).convert('RGB').crop( (detection_bbs_filter_1[-4], detection_bbs_filter_1[-3], detection_bbs_filter_1[-2], detection_bbs_filter_1[-1])) # 第二个的筛选 if clip_det_bbs2.shape[0] == 1: frame_path2 = os.path.join(opt.data_path, 'solo_extract', opt.video2_ins, opt.video2_name, "%06d.png" % clip_det_bbs2[0, 0]) detection_bbs_filter_2 = clip_det_bbs2[0] elif clip_det_bbs2.shape[0] >= 2: hand_npy = os.path.join(opt.data_path, 'solo_detect_hand', opt.video2_ins, opt.video2_name + '_hand.npy') if os.path.exists(hand_npy): hand_bbs = np.load(hand_npy) else: hand_bbs = np.array([]) if hand_bbs.shape[0] == 0: hand_bb = np.array([]) sign = False print("this npy file {} donot have detected hands".format( os.path.basename(hand_npy))) elif hand_bbs.shape[0] == 1: hand_bb = hand_bbs sign = True elif hand_bbs.shape[ 0] >= 2: # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个 the_max = np.argmax(hand_bbs[:, 1]) hand_bb1 = hand_bbs[the_max, :] # 取一个概率最大的 hand_bb1 = hand_bb1[np.newaxis, :] hand_bbs[the_max, 1] = 0 # 取出后置为0 the_second_max = np.argmax(hand_bbs[:, 1]) # 取一个次大的。 hand_bb2 = hand_bbs[the_second_max, :] hand_bb2 = hand_bb2[np.newaxis, :] hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0) sign = True detection_bbs_filter_2 = filter_det_bbs(hand_bb, sign, clip_det_bbs2) frame_path2 = os.path.join(opt.data_path, 'solo_extract', opt.video2_ins, opt.video2_name, "%06d.png" % clip_det_bbs2[0, 0]) detection2 = Image.open(frame_path2).convert('RGB').crop( (detection_bbs_filter_2[-4], detection_bbs_filter_2[-3], detection_bbs_filter_2[-2], detection_bbs_filter_2[-1])) # 第三个的筛选 if clip_det_bbs3.shape[0] == 1: frame_path3 = os.path.join(opt.data_path, 'solo_extract', opt.video3_ins, opt.video3_name, "%06d.png" % clip_det_bbs3[0, 0]) detection_bbs_filter_3 = clip_det_bbs3[0] elif clip_det_bbs3.shape[0] >= 2: hand_npy = os.path.join(opt.data_path, 'solo_detect_hand', opt.video3_ins, opt.video3_name + '_hand.npy') if os.path.exists(hand_npy): hand_bbs = np.load(hand_npy) else: hand_bbs = np.array([]) if hand_bbs.shape[0] == 0: hand_bb = np.array([]) sign = False print("this npy file {} donot have detected hands".format( os.path.basename(hand_npy))) elif hand_bbs.shape[0] == 1: hand_bb = hand_bbs sign = True elif hand_bbs.shape[ 0] >= 2: # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个 the_max = np.argmax(hand_bbs[:, 1]) hand_bb1 = hand_bbs[the_max, :] # 取一个概率最大的 hand_bb1 = hand_bb1[np.newaxis, :] hand_bbs[the_max, 1] = 0 # 取出后置为0 the_second_max = np.argmax(hand_bbs[:, 1]) # 取一个次大的。 hand_bb2 = hand_bbs[the_second_max, :] hand_bb2 = hand_bb2[np.newaxis, :] hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0) sign = True detection_bbs_filter_3 = filter_det_bbs(hand_bb, sign, clip_det_bbs3) frame_path3 = os.path.join(opt.data_path, 'solo_extract', opt.video3_ins, opt.video3_name, "%06d.png" % clip_det_bbs3[0, 0]) detection3 = Image.open(frame_path3).convert('RGB').crop( (detection_bbs_filter_3[-4], detection_bbs_filter_3[-3], detection_bbs_filter_3[-2], detection_bbs_filter_3[-1])) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sep_audio3 = np.zeros((audio_length)) refine_sep1 = np.zeros((audio_length)) refine_sep2 = np.zeros((audio_length)) refine_sep3 = np.zeros((audio_length)) sliding_window_start = 0 data = {} samples_per_window = opt.audio_window while sliding_window_start + samples_per_window < audio_length: objects_visuals = [] objects_labels = [] objects_audio_mag = [] objects_audio_phase = [] objects_vids = [] objects_real_audio_mag = [] objects_audio_mix_mag = [] objects_audio_mix_phase = [] objects_visuals_256 = [] sliding_window_end = sliding_window_start + samples_per_window audio_segment = audio_mix[sliding_window_start:sliding_window_end] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) ''' 第一份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection1).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection1).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第二份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection2).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection2).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第3份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection3).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection3).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) data['audio_mix_mags'] = torch.FloatTensor( np.vstack(objects_audio_mix_mag)).cuda() data['audio_mags'] = data['audio_mix_mags'] data['audio_mix_phases'] = torch.FloatTensor( np.vstack(objects_audio_mix_phase)).cuda() data['visuals'] = torch.FloatTensor( np.vstack(objects_visuals)).cuda() data['visuals_256'] = torch.FloatTensor( np.vstack(objects_visuals_256)).cuda() data['labels'] = torch.FloatTensor( np.vstack(objects_labels)).cuda() data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda() outputs = model.forward(data) reconstructed_signal, refine_signal = get_separated_audio( outputs, data, opt) sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start: sliding_window_end] + reconstructed_signal[0] refine_sep1[sliding_window_start:sliding_window_end] = refine_sep1[ sliding_window_start:sliding_window_end] + refine_signal[0] sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start: sliding_window_end] + reconstructed_signal[1] refine_sep2[sliding_window_start:sliding_window_end] = refine_sep2[ sliding_window_start:sliding_window_end] + refine_signal[1] sep_audio3[sliding_window_start:sliding_window_end] = sep_audio3[ sliding_window_start: sliding_window_end] + reconstructed_signal[2] refine_sep3[sliding_window_start:sliding_window_end] = refine_sep3[ sliding_window_start:sliding_window_end] + refine_signal[2] #update overlap count overlap_count[ sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) # deal with the last segment audio_segment = audio_mix[-samples_per_window:] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) objects_visuals = [] objects_labels = [] objects_audio_mag = [] objects_audio_phase = [] objects_vids = [] objects_real_audio_mag = [] objects_audio_mix_mag = [] objects_audio_mix_phase = [] objects_visuals_256 = [] ''' 第一份音乐的信息,应该有两份''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection1).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection1).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第二份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection2).unsqueeze(0)) objects_visuals.append(vision_transform(detection2).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第3份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection3).unsqueeze(0)) objects_visuals_256.append( vision_transform_for_unet(detection3).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) data['audio_mix_mags'] = torch.FloatTensor( np.vstack(objects_audio_mix_mag)).cuda() data['audio_mags'] = data['audio_mix_mags'] data['audio_mix_phases'] = torch.FloatTensor( np.vstack(objects_audio_mix_phase)).cuda() data['visuals'] = torch.FloatTensor(np.vstack(objects_visuals)).cuda() data['labels'] = torch.FloatTensor(np.vstack(objects_labels)).cuda() data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda() data['visuals_256'] = torch.FloatTensor( np.vstack(objects_visuals_256)).cuda() outputs = model.forward(data) reconstructed_signal, refine_signal = get_separated_audio( outputs, data, opt) sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal[0] refine_sep1[-samples_per_window:] = refine_sep1[ -samples_per_window:] + refine_signal[0] sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal[1] refine_sep2[-samples_per_window:] = refine_sep2[ -samples_per_window:] + refine_signal[1] sep_audio3[-samples_per_window:] = sep_audio3[ -samples_per_window:] + reconstructed_signal[2] refine_sep3[-samples_per_window:] = refine_sep3[ -samples_per_window:] + refine_signal[2] #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count) * 2) avged_refine_audio1 = avged_refine_audio1 + clip_audio( np.divide(refine_sep1, overlap_count) * 2) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count) * 2) avged_refine_audio2 = avged_refine_audio2 + clip_audio( np.divide(refine_sep2, overlap_count) * 2) avged_sep_audio3 = avged_sep_audio3 + clip_audio( np.divide(sep_audio3, overlap_count) * 2) avged_refine_audio3 = avged_refine_audio3 + clip_audio( np.divide(refine_sep3, overlap_count) * 2) separation1 = avged_sep_audio1 / opt.num_of_object_detections_to_use separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use separation3 = avged_sep_audio3 / opt.num_of_object_detections_to_use refine_spearation1 = avged_refine_audio1 / opt.num_of_object_detections_to_use refine_spearation2 = avged_refine_audio2 / opt.num_of_object_detections_to_use refine_spearation3 = avged_refine_audio3 / opt.num_of_object_detections_to_use #output original and separated audios output_dir = os.path.join( output_dir, opt.video1_name + '$_VS_$' + opt.video2_name + '$_VS_$' + opt.video3_name) if not os.path.exists(output_dir): os.makedirs(output_dir) if save_files: librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio3.wav'), audio3, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio1_separated.wav'), separation1, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio2_separated.wav'), separation2, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio3_separated.wav'), separation3, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio1_refine_separated.wav'), refine_spearation1, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio2_refine_separated.wav'), refine_spearation2, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio3_refine_separated.wav'), refine_spearation3, opt.audio_sampling_rate) c_reference_sources = np.concatenate( (np.expand_dims(audio1, axis=0), np.expand_dims( audio2, axis=0), np.expand_dims(audio3, axis=0)), axis=0) c_estimated_sources = np.concatenate((np.expand_dims( separation1, axis=0), np.expand_dims( separation2, axis=0), np.expand_dims(separation3, axis=0)), axis=0) c_sdr, c_sir, c_sar = getSeparationMetrics(c_reference_sources, c_estimated_sources) r_reference_sources = np.concatenate( (np.expand_dims(audio1, axis=0), np.expand_dims( audio2, axis=0), np.expand_dims(audio3, axis=0)), axis=0) r_estimated_sources = np.concatenate( (np.expand_dims(refine_spearation1, axis=0), np.expand_dims(refine_spearation2, axis=0), np.expand_dims(refine_spearation3, axis=0)), axis=0) r_sdr, r_sir, r_sar = getSeparationMetrics(r_reference_sources, r_estimated_sources) #save the two detections if save_files: detection1.save(os.path.join(output_dir, 'audio1.png')) detection2.save(os.path.join(output_dir, 'audio2.png')) detection3.save(os.path.join(output_dir, 'audio3.png')) #save the spectrograms & masks if opt.visualize_spectrogram: import matplotlib.pyplot as plt plt.switch_backend('agg') plt.ioff() audio1_mag = generate_spectrogram_magphase(audio1, opt.stft_frame, opt.stft_hop, with_phase=False) audio2_mag = generate_spectrogram_magphase(audio2, opt.stft_frame, opt.stft_hop, with_phase=False) audio3_mag = generate_spectrogram_magphase(audio3, opt.stft_frame, opt.stft_hop, with_phase=False) audio_mix_mag = generate_spectrogram_magphase(audio_mix, opt.stft_frame, opt.stft_hop, with_phase=False) separation1_mag = generate_spectrogram_magphase(separation1, opt.stft_frame, opt.stft_hop, with_phase=False) separation2_mag = generate_spectrogram_magphase(separation2, opt.stft_frame, opt.stft_hop, with_phase=False) separation3_mag = generate_spectrogram_magphase(separation3, opt.stft_frame, opt.stft_hop, with_phase=False) refine_sep1_mag = generate_spectrogram_magphase(refine_spearation1, opt.stft_frame, opt.stft_hop, with_phase=False) refine_sep2_mag = generate_spectrogram_magphase(refine_spearation2, opt.stft_frame, opt.stft_hop, with_phase=False) refine_sep3_mag = generate_spectrogram_magphase(refine_spearation3, opt.stft_frame, opt.stft_hop, with_phase=False) utils.visualizeSpectrogram(audio1_mag[0, :, :], os.path.join(output_dir, 'audio1_spec.png')) utils.visualizeSpectrogram(audio2_mag[0, :, :], os.path.join(output_dir, 'audio2_spec.png')) utils.visualizeSpectrogram(audio3_mag[0, :, :], os.path.join(output_dir, 'audio3_spec.png')) utils.visualizeSpectrogram( audio_mix_mag[0, :, :], os.path.join(output_dir, 'audio_mixed_spec.png')) utils.visualizeSpectrogram( separation1_mag[0, :, :], os.path.join(output_dir, 'separation1_spec.png')) utils.visualizeSpectrogram( separation2_mag[0, :, :], os.path.join(output_dir, 'separation2_spec.png')) utils.visualizeSpectrogram( separation3_mag[0, :, :], os.path.join(output_dir, 'separation3_spec.png')) utils.visualizeSpectrogram( refine_sep1_mag[0, :, :], os.path.join(output_dir, 'refine1_spec.png')) utils.visualizeSpectrogram( refine_sep2_mag[0, :, :], os.path.join(output_dir, 'refine2_spec.png')) utils.visualizeSpectrogram( refine_sep3_mag[0, :, :], os.path.join(output_dir, 'refine3_spec.png')) return c_sdr, c_sir, c_sar, r_sdr, r_sir, r_sar
def main(): #load test arguments opt = TestRealOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out = opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal( pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out = opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) for speaker_index in range(opt.number_of_speakers): mouthroi_path = os.path.join(opt.mouthroi_root, 'speaker' + str(speaker_index+1) + '.npz') facetrack_path = os.path.join(opt.facetrack_root, 'speaker' + str(speaker_index+1) + '.mp4') #load data mouthroi = load_mouthroi(mouthroi_path) sr, audio = wavfile.read(opt.audio_path) print("sampling rate of audio: ", sr) if len((audio.shape)) == 2: audio = np.mean(audio, axis=1) #convert to mono if stereo audio = audio / 32768 audio = audio / 2.0 audio = clip_audio(audio) audio_length = len(audio) if opt.reliable_face: best_score = 0 for i in range(10): frame = load_frame(facetrack_path) boxes, scores = mtcnn.detect(frame) if scores[0] > best_score: best_frame = frame frames = vision_transform(best_frame).squeeze().unsqueeze(0).cuda() else: frame_list = [] for i in range(opt.number_of_identity_frames): frame = load_frame(facetrack_path) frame = vision_transform(frame) frame_list.append(frame) frame = torch.stack(frame_list).squeeze().unsqueeze(0).cuda() sep_audio = np.zeros((audio_length)) #perform separation over the whole audio using a sliding window approach sliding_window_start = 0 overlap_count = np.zeros((audio_length)) sep_audio = np.zeros((audio_length)) avged_sep_audio = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment_audio = audio[sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round(sliding_window_start / opt.audio_sampling_rate * 25)) segment_mouthroi = mouthroi[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[sliding_window_start:sliding_window_end] = sep_audio[sliding_window_start:sliding_window_end] + reconstructed_signal #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int(opt.hop_length * opt.audio_sampling_rate) #deal with the last segment segment_audio = audio[-samples_per_window:] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round((len(audio) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment_mouthroi = mouthroi[-opt.num_frames:, :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[-samples_per_window:] = sep_audio[-samples_per_window:] + reconstructed_signal #update overlap count overlap_count[-samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio = clip_audio(np.divide(sep_audio, overlap_count)) #output separated audios if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) librosa.output.write_wav(os.path.join(opt.output_dir_root, 'speaker' + str(speaker_index+1) + '.wav'), avged_sep_audio, opt.audio_sampling_rate)
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # network builders builder = ModelBuilder() net_visual = builder.build_visual(weights=opt.weights_visual) net_audio = builder.build_audio(ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_audio) nets = (net_visual, net_audio) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() #load the audio to perform separation audio, audio_rate = librosa.load(opt.input_audio_path, sr=opt.audio_sampling_rate, mono=False) audio_channel1 = audio[0, :] audio_channel2 = audio[1, :] #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 448)), transforms.ToTensor() ] vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) #perform spatialization over the whole audio using a sliding window approach overlap_count = np.zeros( (audio.shape)) #count the number of times a data point is calculated binaural_audio = np.zeros((audio.shape)) #perform spatialization over the whole spectrogram in a siliding-window fashion sliding_window_start = 0 data = {} samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio.shape[-1]: sliding_window_end = sliding_window_start + samples_per_window normalizer, audio_segment = audio_normalize( audio[:, sliding_window_start:sliding_window_end]) audio_segment_channel1 = audio_segment[0, :] audio_segment_channel2 = audio_segment[1, :] audio_segment_mix = audio_segment_channel1 + audio_segment_channel2 data['audio_diff_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension data['audio_mix_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension #get the frame index for current window frame_index = int( round((((sliding_window_start + samples_per_window / 2.0) / audio.shape[-1]) * opt.input_audio_length + 0.05) * 10)) image = Image.open( os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.jpg')).convert('RGB') #image = image.transpose(Image.FLIP_LEFT_RIGHT) frame = vision_transform(image).unsqueeze( 0) #unsqueeze to add a batch dimension data['frame'] = frame output = model.forward(data) predicted_spectrogram = output['binaural_spectrogram'][ 0, :, :, :].data[:].cpu().numpy() #ISTFT to convert back to audio reconstructed_stft_diff = predicted_spectrogram[0, :, :] + ( 1j * predicted_spectrogram[1, :, :]) reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 reconstructed_binaural = np.concatenate( (np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer binaural_audio[:, sliding_window_start: sliding_window_end] = binaural_audio[:, sliding_window_start: sliding_window_end] + reconstructed_binaural overlap_count[:, sliding_window_start: sliding_window_end] = overlap_count[:, sliding_window_start: sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) #deal with the last segment normalizer, audio_segment = audio_normalize(audio[:, -samples_per_window:]) audio_segment_channel1 = audio_segment[0, :] audio_segment_channel2 = audio_segment[1, :] data['audio_diff_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension data['audio_mix_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension #get the frame index for last window frame_index = int( round(((opt.input_audio_length - opt.audio_length / 2.0) + 0.05) * 10)) image = Image.open( os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.jpg')).convert('RGB') #image = image.transpose(Image.FLIP_LEFT_RIGHT) frame = vision_transform(image).unsqueeze( 0) #unsqueeze to add a batch dimension data['frame'] = frame output = model.forward(data) predicted_spectrogram = output['binaural_spectrogram'][ 0, :, :, :].data[:].cpu().numpy() #ISTFT to convert back to audio reconstructed_stft_diff = predicted_spectrogram[0, :, :] + ( 1j * predicted_spectrogram[1, :, :]) reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 reconstructed_binaural = np.concatenate( (np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer #add the spatialized audio to reconstructed_binaural binaural_audio[:, -samples_per_window:] = binaural_audio[:, -samples_per_window:] + reconstructed_binaural overlap_count[:, -samples_per_window:] = overlap_count[:, -samples_per_window:] + 1 #divide aggregated predicted audio by their corresponding counts predicted_binaural_audio = np.divide(binaural_audio, overlap_count) #check output directory if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) mixed_mono = (audio_channel1 + audio_channel2) / 2 librosa.output.write_wav( os.path.join(opt.output_dir_root, 'predicted_binaural.wav'), predicted_binaural_audio, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(opt.output_dir_root, 'mixed_mono.wav'), mixed_mono, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(opt.output_dir_root, 'input_binaural.wav'), audio, opt.audio_sampling_rate)
def test_sepration(opt, nets, output_dir, save_files=False): #load test arguments # altered_visual1 = opt.video1_name # '1oz3h9doX_g_5' # altered_visual2 = opt.video2_name # '2R12lQszz90_4' opt.visualize_spectrogram = True model = AudioVisualModel(nets, opt) #model = torch.nn.DataParallel(model, device_ids=[0]) model.to('cuda') model.eval() #load the two audios # audio1_path = os.path.join(opt.data_path, 'solo_audio_resample', opt.video1_ins, opt.video1_name + '.wav') # audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate) audio2_path = os.path.join(opt.data_path_duet, 'duet_audio_resample', opt.video2_ins, opt.video2_name + '.wav') audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate) #make sure the two audios are of the same length and then mix them # audio_length = min(len(audio1), len(audio2)) # audio1 = clip_audio(audio1[:audio_length]) audio_length = len(audio2) audio2 = clip_audio(audio2[:audio_length]) audio_mix = audio2 #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 224)), transforms.ToTensor() ] vision_transform_list_for_unet = [ transforms.Resize((256, 256)), transforms.ToTensor() ] if opt.subtract_mean: vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform_list_for_unet.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) vision_transform_for_unet = transforms.Compose( vision_transform_list_for_unet) #load the object regions of the highest confidence score for both videos # detectionResult1 = np.load(os.path.join(opt.data_path, 'solo_detect', opt.video1_ins, opt.video1_name + '.npy')) detectionR_npy = os.path.join(opt.data_path_duet, 'duet_detect', opt.video2_ins, opt.video2_name + '.npy') detectionResult2 = np.load(detectionR_npy) # detectionResult3 = np.load(os.path.join(opt.data_path, 'solo_detect', opt.video3_ins, opt.video3_name + '.npy')) # clip_det_bbs1 = sample_object_detections(detectionResult1) clip_2_path = os.path.join( '/data/mashuo/work/study/refine-separation/dataset/music_test', opt.video2_ins, opt.video2_name) frame_name = os.listdir(clip_2_path)[0][:6] frame_name = float(frame_name) clip_det_bbs2 = None sign = False for n in range(detectionResult2.shape[0]): index = detectionResult2[n][0] index = float(index) if index == frame_name: if not sign: clip_det_bbs2 = np.expand_dims(detectionResult2[n, :], axis=0) sign = True else: clip_det_bbs2 = np.concatenate( (clip_det_bbs2, np.expand_dims(detectionResult2[n, :], axis=0)), axis=0) # clip_det_bbs2 = sample_object_detections(detectionResult2) # clip_det_bbs3 = sample_object_detections(detectionResult3) avged_sep_audio1 = np.zeros((audio_length)) avged_refine_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) avged_refine_audio2 = np.zeros((audio_length)) avged_sep_audio3 = np.zeros((audio_length)) avged_refine_audio3 = np.zeros((audio_length)) for i in range(opt.num_of_object_detections_to_use): # 第二个的筛选 # det_box2 = clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]), :] # clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]),2] = 0 # # det_box3 = clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]), :] det_box2 = clip_det_bbs2[0, :] det_box3 = clip_det_bbs2[1, :] frame_path2 = os.path.join(opt.data_path_duet, 'duet_extract', opt.video2_ins, opt.video2_name, "%06d.png" % det_box2[0]) frame_2 = Image.open(frame_path2).convert('RGB') # frame = Image.open('/data/mashuo/work/study/refine-separation/dataset/music_test/xylophone-acoustic_guitar/0EMNATwzLA4_25/human.png').convert('RGB') detection2 = frame_2.crop( (det_box2[-4], det_box2[-3], det_box2[-2], det_box2[-1])) # detection2 = frame detection3 = frame_2.crop( (det_box3[-4], det_box3[-3], det_box3[-2], det_box3[-1])) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sep_audio3 = np.zeros((audio_length)) refine_sep1 = np.zeros((audio_length)) refine_sep2 = np.zeros((audio_length)) refine_sep3 = np.zeros((audio_length)) sliding_window_start = 0 data = {} samples_per_window = opt.audio_window while sliding_window_start + samples_per_window < audio_length: objects_visuals = [] objects_labels = [] objects_audio_mag = [] objects_audio_phase = [] objects_vids = [] objects_real_audio_mag = [] objects_audio_mix_mag = [] objects_audio_mix_phase = [] objects_visuals_256 = [] sliding_window_end = sliding_window_start + samples_per_window audio_segment = audio_mix[sliding_window_start:sliding_window_end] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) ''' 第二份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection2).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第3份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection3).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) data['audio_mix_mags'] = torch.FloatTensor( np.vstack(objects_audio_mix_mag)).cuda() data['audio_mags'] = data['audio_mix_mags'] data['audio_mix_phases'] = torch.FloatTensor( np.vstack(objects_audio_mix_phase)).cuda() data['visuals'] = torch.FloatTensor( np.vstack(objects_visuals)).cuda() data['labels'] = torch.FloatTensor( np.vstack(objects_labels)).cuda() data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda() outputs = model.forward(data) reconstructed_signal, refine_signal = get_separated_audio( outputs, data, opt) sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start: sliding_window_end] + reconstructed_signal[0] refine_sep2[sliding_window_start:sliding_window_end] = refine_sep2[ sliding_window_start:sliding_window_end] + refine_signal[0] sep_audio3[sliding_window_start:sliding_window_end] = sep_audio3[ sliding_window_start: sliding_window_end] + reconstructed_signal[1] refine_sep3[sliding_window_start:sliding_window_end] = refine_sep3[ sliding_window_start:sliding_window_end] + refine_signal[1] #update overlap count overlap_count[ sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) # deal with the last segment audio_segment = audio_mix[-samples_per_window:] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) objects_visuals = [] objects_labels = [] objects_audio_mag = [] objects_audio_phase = [] objects_vids = [] objects_real_audio_mag = [] objects_audio_mix_mag = [] objects_audio_mix_phase = [] objects_visuals_256 = [] ''' 第二份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection2).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) ''' 第3份音乐的信息''' objects_audio_mix_mag.append( torch.FloatTensor(audio_mix_mags).unsqueeze(0)) objects_audio_mix_phase.append( torch.FloatTensor(audio_mix_phases).unsqueeze(0)) objects_visuals.append(vision_transform(detection3).unsqueeze(0)) objects_labels.append(torch.FloatTensor(np.ones((1, 1)))) objects_vids.append(torch.FloatTensor(np.ones((1, 1)))) data['audio_mix_mags'] = torch.FloatTensor( np.vstack(objects_audio_mix_mag)).cuda() data['audio_mags'] = data['audio_mix_mags'] data['audio_mix_phases'] = torch.FloatTensor( np.vstack(objects_audio_mix_phase)).cuda() data['visuals'] = torch.FloatTensor(np.vstack(objects_visuals)).cuda() data['labels'] = torch.FloatTensor(np.vstack(objects_labels)).cuda() data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda() outputs = model.forward(data) reconstructed_signal, refine_signal = get_separated_audio( outputs, data, opt) sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal[0] refine_sep2[-samples_per_window:] = refine_sep2[ -samples_per_window:] + refine_signal[0] sep_audio3[-samples_per_window:] = sep_audio3[ -samples_per_window:] + reconstructed_signal[1] refine_sep3[-samples_per_window:] = refine_sep3[ -samples_per_window:] + refine_signal[1] #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count) * 2) avged_refine_audio2 = avged_refine_audio2 + clip_audio( np.divide(refine_sep2, overlap_count) * 2) avged_sep_audio3 = avged_sep_audio3 + clip_audio( np.divide(sep_audio3, overlap_count) * 2) avged_refine_audio3 = avged_refine_audio3 + clip_audio( np.divide(refine_sep3, overlap_count) * 2) separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use separation3 = avged_sep_audio3 / opt.num_of_object_detections_to_use refine_spearation2 = avged_refine_audio2 / opt.num_of_object_detections_to_use refine_spearation3 = avged_refine_audio3 / opt.num_of_object_detections_to_use #output original and separated audios output_dir = os.path.join(output_dir, opt.video2_name + '**') if not os.path.exists(output_dir): os.makedirs(output_dir) if save_files: librosa.output.write_wav(os.path.join(output_dir, 'audio_duet.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio2_separated.wav'), separation2, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio3_separated.wav'), separation3, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio2_refine_separated.wav'), refine_spearation2, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(output_dir, 'audio3_refine_separated.wav'), refine_spearation3, opt.audio_sampling_rate) c_reference_sources = np.expand_dims(audio2, axis=0) c_estimated_sources = np.expand_dims((separation2 + separation3), axis=0) c_sdr, c_sir, c_sar = getSeparationMetrics(c_reference_sources, c_estimated_sources) r_reference_sources = np.expand_dims(audio2, axis=0) r_estimated_sources = np.expand_dims( (refine_spearation2 + refine_spearation3), axis=0) r_sdr, r_sir, r_sar = getSeparationMetrics(r_reference_sources, r_estimated_sources) #save the two detections if save_files: frame_2.save(os.path.join(output_dir, 'frame_2.png')) detection2.save(os.path.join(output_dir, 'det2.png')) detection3.save(os.path.join(output_dir, 'det3.png')) #save the spectrograms & masks if opt.visualize_spectrogram: import matplotlib.pyplot as plt plt.switch_backend('agg') plt.ioff() audio2_mag = generate_spectrogram_magphase(audio2, opt.stft_frame, opt.stft_hop, with_phase=False) audio_mix_mag = generate_spectrogram_magphase(audio_mix, opt.stft_frame, opt.stft_hop, with_phase=False) separation2_mag = generate_spectrogram_magphase(separation2, opt.stft_frame, opt.stft_hop, with_phase=False) separation3_mag = generate_spectrogram_magphase(separation3, opt.stft_frame, opt.stft_hop, with_phase=False) refine_sep2_mag = generate_spectrogram_magphase(refine_spearation2, opt.stft_frame, opt.stft_hop, with_phase=False) refine_sep3_mag = generate_spectrogram_magphase(refine_spearation3, opt.stft_frame, opt.stft_hop, with_phase=False) # ref_2_3_mag = generate_spectrogram_magphase(refine_spearation1+refine_spearation2, opt.stft_frame, opt.stft_hop, with_phase=False) utils.visualizeSpectrogram(audio2_mag[0, :, :], os.path.join(output_dir, 'audio2_spec.png')) utils.visualizeSpectrogram( audio_mix_mag[0, :, :], os.path.join(output_dir, 'audio_mixed_spec.png')) utils.visualizeSpectrogram( separation2_mag[0, :, :], os.path.join(output_dir, 'separation2_spec.png')) utils.visualizeSpectrogram( separation3_mag[0, :, :], os.path.join(output_dir, 'separation3_spec.png')) utils.visualizeSpectrogram( separation2_mag[0, :, :] + separation3_mag[0, :, :], os.path.join(output_dir, 'separation2+3_spec.png')) utils.visualizeSpectrogram( refine_sep2_mag[0, :, :], os.path.join(output_dir, 'refine2_spec.png')) utils.visualizeSpectrogram( refine_sep3_mag[0, :, :], os.path.join(output_dir, 'refine3_spec.png')) utils.visualizeSpectrogram( refine_sep2_mag[0, :, :] + refine_sep3_mag[0, :, :], os.path.join(output_dir, 'ref_2+3_spec.png')) return c_sdr, c_sir, c_sar, r_sdr, r_sir, r_sar