Пример #1
0
def main():
    #load test arguments
    opt = TestOptions().parse()
    opt.device = torch.device("cuda")

    # Network Builders
    builder = ModelBuilder()
    net_lipreading = builder.build_lipreadingnet(
        config_path=opt.lipreading_config_path,
        weights=opt.weights_lipreadingnet,
        extract_feats=opt.lipreading_extract_feature)
    #if identity feature dim is not 512, for resnet reduce dimension to this feature dim
    if opt.identity_feature_dim != 512:
        opt.with_fc = True
    else:
        opt.with_fc = False
    net_facial_attributes = builder.build_facial(
        pool_type=opt.visual_pool,
        fc_out=opt.identity_feature_dim,
        with_fc=opt.with_fc,
        weights=opt.weights_facial)
    net_unet = builder.build_unet(
        ngf=opt.unet_ngf,
        input_nc=opt.unet_input_nc,
        output_nc=opt.unet_output_nc,
        audioVisual_feature_dim=opt.audioVisual_feature_dim,
        identity_feature_dim=opt.identity_feature_dim,
        weights=opt.weights_unet)
    net_vocal_attributes = builder.build_vocal(pool_type=opt.audio_pool,
                                               input_channel=2,
                                               with_fc=opt.with_fc,
                                               fc_out=opt.identity_feature_dim,
                                               weights=opt.weights_vocal)

    nets = (net_lipreading, net_facial_attributes, net_unet,
            net_vocal_attributes)
    print(nets)

    # construct our audio-visual model
    model = AudioVisualModel(nets, opt)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    model.to(opt.device)
    model.eval()
    mtcnn = MTCNN(keep_all=True, device=opt.device)

    lipreading_preprocessing_func = get_preprocessing_pipelines()['test']
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    vision_transform_list = [transforms.ToTensor()]
    if opt.normalization:
        vision_transform_list.append(normalize)
    vision_transform = transforms.Compose(vision_transform_list)

    # load data
    mouthroi_1 = load_mouthroi(opt.mouthroi1_path)
    mouthroi_2 = load_mouthroi(opt.mouthroi2_path)

    _, audio1 = wavfile.read(opt.audio1_path)
    _, audio2 = wavfile.read(opt.audio2_path)
    _, audio_offscreen = wavfile.read(opt.offscreen_audio_path)
    audio1 = audio1 / 32768
    audio2 = audio2 / 32768
    audio_offscreen = audio_offscreen / 32768

    #make sure the two audios are of the same length and then mix them
    audio_length = min(min(len(audio1), len(audio2)), len(audio_offscreen))
    audio1 = clip_audio(audio1[:audio_length])
    audio2 = clip_audio(audio2[:audio_length])
    audio_offscreen = clip_audio(audio_offscreen[:audio_length])
    audio_mix = (audio1 + audio2 + audio_offscreen * opt.noise_weight) / 3

    if opt.reliable_face:
        best_score_1 = 0
        best_score_2 = 0
        for i in range(10):
            frame_1 = load_frame(opt.video1_path)
            frame_2 = load_frame(opt.video2_path)
            boxes, scores = mtcnn.detect(frame_1)
            if scores[0] > best_score_1:
                best_frame_1 = frame_1
            boxes, scores = mtcnn.detect(frame_2)
            if scores[0] > best_score_2:
                best_frame_2 = frame_2
        frames_1 = vision_transform(best_frame_1).squeeze().unsqueeze(0)
        frames_2 = vision_transform(best_frame_2).squeeze().unsqueeze(0)
    else:
        frame_1_list = []
        frame_2_list = []
        for i in range(opt.number_of_identity_frames):
            frame_1 = load_frame(opt.video1_path)
            frame_2 = load_frame(opt.video2_path)
            frame_1 = vision_transform(frame_1)
            frame_2 = vision_transform(frame_2)
            frame_1_list.append(frame_1)
            frame_2_list.append(frame_2)
        frames_1 = torch.stack(frame_1_list).squeeze().unsqueeze(0)
        frames_2 = torch.stack(frame_2_list).squeeze().unsqueeze(0)

    #perform separation over the whole audio using a sliding window approach
    overlap_count = np.zeros((audio_length))
    sep_audio1 = np.zeros((audio_length))
    sep_audio2 = np.zeros((audio_length))
    sliding_window_start = 0
    data = {}
    avged_sep_audio1 = np.zeros((audio_length))
    avged_sep_audio2 = np.zeros((audio_length))

    samples_per_window = int(opt.audio_length * opt.audio_sampling_rate)
    while sliding_window_start + samples_per_window < audio_length:
        sliding_window_end = sliding_window_start + samples_per_window

        #get audio spectrogram
        segment1_audio = audio1[sliding_window_start:sliding_window_end]
        segment2_audio = audio2[sliding_window_start:sliding_window_end]
        segment_offscreen = audio_offscreen[
            sliding_window_start:sliding_window_end]

        if opt.audio_normalization:
            normalizer1, segment1_audio = audio_normalize(segment1_audio)
            normalizer2, segment2_audio = audio_normalize(segment2_audio)
            _, segment_offscreen = audio_normalize(segment_offscreen)
        else:
            normalizer1 = 1
            normalizer2 = 1

        audio_segment = (segment1_audio + segment2_audio +
                         segment_offscreen * opt.noise_weight) / 3
        audio_mix_spec = generate_spectrogram_complex(audio_segment,
                                                      opt.window_size,
                                                      opt.hop_size, opt.n_fft)
        audio_spec_1 = generate_spectrogram_complex(segment1_audio,
                                                    opt.window_size,
                                                    opt.hop_size, opt.n_fft)
        audio_spec_2 = generate_spectrogram_complex(segment2_audio,
                                                    opt.window_size,
                                                    opt.hop_size, opt.n_fft)

        #get mouthroi
        frame_index_start = int(
            round(sliding_window_start / opt.audio_sampling_rate * 25))
        segment1_mouthroi = mouthroi_1[frame_index_start:(
            frame_index_start + opt.num_frames), :, :]
        segment2_mouthroi = mouthroi_2[frame_index_start:(
            frame_index_start + opt.num_frames), :, :]

        #transform mouthrois
        segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi)
        segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi)

        data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(
            0)
        data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
        data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0)
        data['frame_A'] = frames_1
        data['frame_B'] = frames_2
        #don't care below
        data['frame_A'] = frames_1
        data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
        data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(
            0)

        outputs = model.forward(data)
        reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio(
            outputs, data, opt)
        reconstructed_signal_1 = reconstructed_signal_1 * normalizer1
        reconstructed_signal_2 = reconstructed_signal_2 * normalizer2
        sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[
            sliding_window_start:sliding_window_end] + reconstructed_signal_1
        sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[
            sliding_window_start:sliding_window_end] + reconstructed_signal_2
        #update overlap count
        overlap_count[sliding_window_start:sliding_window_end] = overlap_count[
            sliding_window_start:sliding_window_end] + 1
        sliding_window_start = sliding_window_start + int(
            opt.hop_length * opt.audio_sampling_rate)

    #deal with the last segment
    #get audio spectrogram
    segment1_audio = audio1[-samples_per_window:]
    segment2_audio = audio2[-samples_per_window:]
    segment_offscreen = audio_offscreen[-samples_per_window:]
    if opt.audio_normalization:
        normalizer1, segment1_audio = audio_normalize(segment1_audio)
        normalizer2, segment2_audio = audio_normalize(segment2_audio)
    else:
        normalizer1 = 1
        normalizer2 = 1

    audio_segment = (segment1_audio + segment2_audio +
                     segment_offscreen * opt.noise_weight) / 3
    audio_mix_spec = generate_spectrogram_complex(audio_segment,
                                                  opt.window_size,
                                                  opt.hop_size, opt.n_fft)
    #get mouthroi
    frame_index_start = int(
        round((len(audio1) - samples_per_window) / opt.audio_sampling_rate *
              25)) - 1
    segment1_mouthroi = mouthroi_1[frame_index_start:(frame_index_start +
                                                      opt.num_frames), :, :]
    segment2_mouthroi = mouthroi_2[frame_index_start:(frame_index_start +
                                                      opt.num_frames), :, :]
    #transform mouthrois
    segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi)
    segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi)
    audio_spec_1 = generate_spectrogram_complex(segment1_audio,
                                                opt.window_size, opt.hop_size,
                                                opt.n_fft)
    audio_spec_2 = generate_spectrogram_complex(segment2_audio,
                                                opt.window_size, opt.hop_size,
                                                opt.n_fft)
    data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0)
    data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
    data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0)
    data['frame_A'] = frames_1
    data['frame_B'] = frames_2
    #don't care below
    data['frame_A'] = frames_1
    data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
    data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0)

    outputs = model.forward(data)
    reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio(
        outputs, data, opt)
    reconstructed_signal_1 = reconstructed_signal_1 * normalizer1
    reconstructed_signal_2 = reconstructed_signal_2 * normalizer2
    sep_audio1[-samples_per_window:] = sep_audio1[
        -samples_per_window:] + reconstructed_signal_1
    sep_audio2[-samples_per_window:] = sep_audio2[
        -samples_per_window:] + reconstructed_signal_2
    #update overlap count
    overlap_count[
        -samples_per_window:] = overlap_count[-samples_per_window:] + 1

    #divide the aggregated predicted audio by the overlap count
    avged_sep_audio1 = avged_sep_audio1 + clip_audio(
        np.divide(sep_audio1, overlap_count))
    avged_sep_audio2 = avged_sep_audio2 + clip_audio(
        np.divide(sep_audio2, overlap_count))

    #output original and separated audios
    parts1 = opt.video1_path.split('/')
    parts2 = opt.video2_path.split('/')
    video1_name = parts1[-3] + '_' + parts1[-2] + '_' + parts1[-1][:-4]
    video2_name = parts2[-3] + '_' + parts2[-2] + '_' + parts2[-1][:-4]

    output_dir = os.path.join(opt.output_dir_root,
                              video1_name + 'VS' + video2_name)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio_offscreen.wav'),
                             audio_offscreen, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'),
                             audio_mix, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'),
                             avged_sep_audio1, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'),
                             avged_sep_audio2, opt.audio_sampling_rate)
Пример #2
0
def main():
	#load test arguments
	opt = TestRealOptions().parse()
	opt.device = torch.device("cuda")

	# Network Builders
	builder = ModelBuilder()
	net_lipreading = builder.build_lipreadingnet(
		config_path=opt.lipreading_config_path,
		weights=opt.weights_lipreadingnet,
		extract_feats=opt.lipreading_extract_feature)
	#if identity feature dim is not 512, for resnet reduce dimension to this feature dim
	if opt.identity_feature_dim != 512:
		opt.with_fc = True
	else:
		opt.with_fc = False
	net_facial_attributes = builder.build_facial(
			pool_type=opt.visual_pool,
			fc_out = opt.identity_feature_dim,
			with_fc=opt.with_fc,
			weights=opt.weights_facial)  
	net_unet = builder.build_unet(
			ngf=opt.unet_ngf,
			input_nc=opt.unet_input_nc,
			output_nc=opt.unet_output_nc,
			audioVisual_feature_dim=opt.audioVisual_feature_dim,
			identity_feature_dim=opt.identity_feature_dim,
			weights=opt.weights_unet)
	net_vocal_attributes = builder.build_vocal(
			pool_type=opt.audio_pool,
			input_channel=2,
			with_fc=opt.with_fc,
			fc_out = opt.identity_feature_dim,
			weights=opt.weights_vocal)

	nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes)
	print(nets)

	# construct our audio-visual model
	model = AudioVisualModel(nets, opt)
	model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
	model.to(opt.device)
	model.eval()
	mtcnn = MTCNN(keep_all=True, device=opt.device)

	lipreading_preprocessing_func = get_preprocessing_pipelines()['test']
	normalize = transforms.Normalize(
		mean=[0.485, 0.456, 0.406],
		std=[0.229, 0.224, 0.225]
	)
	vision_transform_list = [transforms.ToTensor()]
	if opt.normalization:
		vision_transform_list.append(normalize)
	vision_transform = transforms.Compose(vision_transform_list)

	for speaker_index in range(opt.number_of_speakers):
		mouthroi_path = os.path.join(opt.mouthroi_root, 'speaker' + str(speaker_index+1) + '.npz')
		facetrack_path = os.path.join(opt.facetrack_root, 'speaker' + str(speaker_index+1) + '.mp4')

		#load data
		mouthroi = load_mouthroi(mouthroi_path)
		sr, audio = wavfile.read(opt.audio_path)
		print("sampling rate of audio: ", sr)
		if len((audio.shape)) == 2:
			audio = np.mean(audio, axis=1) #convert to mono if stereo
		audio = audio / 32768
		audio = audio / 2.0
		audio = clip_audio(audio)
		audio_length = len(audio)

		if opt.reliable_face:
			best_score = 0
			for i in range(10):
				frame = load_frame(facetrack_path)
				boxes, scores = mtcnn.detect(frame)
				if scores[0] > best_score:
					best_frame = frame			
			frames = vision_transform(best_frame).squeeze().unsqueeze(0).cuda()
		else:
			frame_list = []
			for i in range(opt.number_of_identity_frames):
				frame = load_frame(facetrack_path)
				frame = vision_transform(frame)
				frame_list.append(frame)
			frame = torch.stack(frame_list).squeeze().unsqueeze(0).cuda()
		
		sep_audio = np.zeros((audio_length))

		#perform separation over the whole audio using a sliding window approach
		sliding_window_start = 0
		overlap_count = np.zeros((audio_length))
		sep_audio = np.zeros((audio_length))
		avged_sep_audio = np.zeros((audio_length))
		samples_per_window = int(opt.audio_length * opt.audio_sampling_rate)
		while sliding_window_start + samples_per_window < audio_length:
			sliding_window_end = sliding_window_start + samples_per_window

			#get audio spectrogram
			segment_audio = audio[sliding_window_start:sliding_window_end]
			if opt.audio_normalization:
				normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07)
			else:
				normalizer = 1

			audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft)            
			audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda()

			#get mouthroi
			frame_index_start = int(round(sliding_window_start / opt.audio_sampling_rate * 25))
			segment_mouthroi = mouthroi[frame_index_start:(frame_index_start + opt.num_frames), :, :]
			segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi)
			segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda()
				
			reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt)
			reconstructed_signal = reconstructed_signal * normalizer
			sep_audio[sliding_window_start:sliding_window_end] = sep_audio[sliding_window_start:sliding_window_end] + reconstructed_signal

			#update overlap count
			overlap_count[sliding_window_start:sliding_window_end] = overlap_count[sliding_window_start:sliding_window_end] + 1
			sliding_window_start = sliding_window_start + int(opt.hop_length * opt.audio_sampling_rate)
				
		#deal with the last segment
		segment_audio = audio[-samples_per_window:]
		if opt.audio_normalization:
			normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07)
		else:
			normalizer = 1

		audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft)
		audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda()
		#get mouthroi
		frame_index_start = int(round((len(audio) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1
		segment_mouthroi = mouthroi[-opt.num_frames:, :, :]
		segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi)
		segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda()
		reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt)
		reconstructed_signal = reconstructed_signal * normalizer
		sep_audio[-samples_per_window:] = sep_audio[-samples_per_window:] + reconstructed_signal

		#update overlap count
		overlap_count[-samples_per_window:] = overlap_count[-samples_per_window:] + 1
		
		#divide the aggregated predicted audio by the overlap count
		avged_sep_audio =  clip_audio(np.divide(sep_audio, overlap_count))

		#output separated audios
		if not os.path.isdir(opt.output_dir_root):
			os.mkdir(opt.output_dir_root)
		librosa.output.write_wav(os.path.join(opt.output_dir_root, 'speaker' + str(speaker_index+1) + '.wav'), avged_sep_audio, opt.audio_sampling_rate)