def main(): #load test arguments opt = TestRealOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out = opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal( pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out = opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) for speaker_index in range(opt.number_of_speakers): mouthroi_path = os.path.join(opt.mouthroi_root, 'speaker' + str(speaker_index+1) + '.npz') facetrack_path = os.path.join(opt.facetrack_root, 'speaker' + str(speaker_index+1) + '.mp4') #load data mouthroi = load_mouthroi(mouthroi_path) sr, audio = wavfile.read(opt.audio_path) print("sampling rate of audio: ", sr) if len((audio.shape)) == 2: audio = np.mean(audio, axis=1) #convert to mono if stereo audio = audio / 32768 audio = audio / 2.0 audio = clip_audio(audio) audio_length = len(audio) if opt.reliable_face: best_score = 0 for i in range(10): frame = load_frame(facetrack_path) boxes, scores = mtcnn.detect(frame) if scores[0] > best_score: best_frame = frame frames = vision_transform(best_frame).squeeze().unsqueeze(0).cuda() else: frame_list = [] for i in range(opt.number_of_identity_frames): frame = load_frame(facetrack_path) frame = vision_transform(frame) frame_list.append(frame) frame = torch.stack(frame_list).squeeze().unsqueeze(0).cuda() sep_audio = np.zeros((audio_length)) #perform separation over the whole audio using a sliding window approach sliding_window_start = 0 overlap_count = np.zeros((audio_length)) sep_audio = np.zeros((audio_length)) avged_sep_audio = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment_audio = audio[sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round(sliding_window_start / opt.audio_sampling_rate * 25)) segment_mouthroi = mouthroi[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[sliding_window_start:sliding_window_end] = sep_audio[sliding_window_start:sliding_window_end] + reconstructed_signal #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int(opt.hop_length * opt.audio_sampling_rate) #deal with the last segment segment_audio = audio[-samples_per_window:] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round((len(audio) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment_mouthroi = mouthroi[-opt.num_frames:, :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[-samples_per_window:] = sep_audio[-samples_per_window:] + reconstructed_signal #update overlap count overlap_count[-samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio = clip_audio(np.divide(sep_audio, overlap_count)) #output separated audios if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) librosa.output.write_wav(os.path.join(opt.output_dir_root, 'speaker' + str(speaker_index+1) + '.wav'), avged_sep_audio, opt.audio_sampling_rate)
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out=opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal(pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out=opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) # load data mouthroi_1 = load_mouthroi(opt.mouthroi1_path) mouthroi_2 = load_mouthroi(opt.mouthroi2_path) _, audio1 = wavfile.read(opt.audio1_path) _, audio2 = wavfile.read(opt.audio2_path) _, audio_offscreen = wavfile.read(opt.offscreen_audio_path) audio1 = audio1 / 32768 audio2 = audio2 / 32768 audio_offscreen = audio_offscreen / 32768 #make sure the two audios are of the same length and then mix them audio_length = min(min(len(audio1), len(audio2)), len(audio_offscreen)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio_offscreen = clip_audio(audio_offscreen[:audio_length]) audio_mix = (audio1 + audio2 + audio_offscreen * opt.noise_weight) / 3 if opt.reliable_face: best_score_1 = 0 best_score_2 = 0 for i in range(10): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) boxes, scores = mtcnn.detect(frame_1) if scores[0] > best_score_1: best_frame_1 = frame_1 boxes, scores = mtcnn.detect(frame_2) if scores[0] > best_score_2: best_frame_2 = frame_2 frames_1 = vision_transform(best_frame_1).squeeze().unsqueeze(0) frames_2 = vision_transform(best_frame_2).squeeze().unsqueeze(0) else: frame_1_list = [] frame_2_list = [] for i in range(opt.number_of_identity_frames): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) frame_1 = vision_transform(frame_1) frame_2 = vision_transform(frame_2) frame_1_list.append(frame_1) frame_2_list.append(frame_2) frames_1 = torch.stack(frame_1_list).squeeze().unsqueeze(0) frames_2 = torch.stack(frame_2_list).squeeze().unsqueeze(0) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sliding_window_start = 0 data = {} avged_sep_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment1_audio = audio1[sliding_window_start:sliding_window_end] segment2_audio = audio2[sliding_window_start:sliding_window_end] segment_offscreen = audio_offscreen[ sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) _, segment_offscreen = audio_normalize(segment_offscreen) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round(sliding_window_start / opt.audio_sampling_rate * 25)) segment1_mouthroi = mouthroi_1[frame_index_start:( frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:( frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start:sliding_window_end] + reconstructed_signal_1 sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start:sliding_window_end] + reconstructed_signal_2 #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_length * opt.audio_sampling_rate) #deal with the last segment #get audio spectrogram segment1_audio = audio1[-samples_per_window:] segment2_audio = audio2[-samples_per_window:] segment_offscreen = audio_offscreen[-samples_per_window:] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round((len(audio1) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment1_mouthroi = mouthroi_1[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:(frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal_1 sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal_2 #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count)) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count)) #output original and separated audios parts1 = opt.video1_path.split('/') parts2 = opt.video2_path.split('/') video1_name = parts1[-3] + '_' + parts1[-2] + '_' + parts1[-1][:-4] video2_name = parts2[-3] + '_' + parts2[-2] + '_' + parts2[-1][:-4] output_dir = os.path.join(opt.output_dir_root, video1_name + 'VS' + video2_name) if not os.path.isdir(output_dir): os.mkdir(output_dir) librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_offscreen.wav'), audio_offscreen, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'), avged_sep_audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'), avged_sep_audio2, opt.audio_sampling_rate)
data_loader_val = CreateDataLoader(opt) dataset_val = data_loader_val.load_data() dataset_size_val = len(data_loader_val) print('#validation images = %d' % dataset_size_val) opt.mode = 'train' #set it back if opt.tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter(comment=opt.name) else: writer = None # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attribtes = builder.build_facial( pool_type=opt.visual_pool, fc_out = opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc,