Ejemplo n.º 1
0
def main():
    #load test arguments
    opt = TestOptions().parse()
    opt.device = torch.device("cuda")

    # Network Builders
    builder = ModelBuilder()
    net_visual = builder.build_visual(pool_type=opt.visual_pool,
                                      weights=opt.weights_visual)
    net_unet = builder.build_unet(unet_num_layers=opt.unet_num_layers,
                                  ngf=opt.unet_ngf,
                                  input_nc=opt.unet_input_nc,
                                  output_nc=opt.unet_output_nc,
                                  weights=opt.weights_unet)
    if opt.with_additional_scene_image:
        opt.number_of_classes = opt.number_of_classes + 1
    net_classifier = builder.build_classifier(
        pool_type=opt.classifier_pool,
        num_of_classes=opt.number_of_classes,
        input_channel=opt.unet_output_nc,
        weights=opt.weights_classifier)
    nets = (net_visual, net_unet, net_classifier)

    # construct our audio-visual model
    model = AudioVisualModel(nets, opt)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    model.to(opt.device)
    model.eval()

    #load the two audios
    audio1_path = os.path.join(opt.data_path, 'audio_11025',
                               opt.video1_name + '.wav')
    audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate)
    audio2_path = os.path.join(opt.data_path, 'audio_11025',
                               opt.video2_name + '.wav')
    audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate)

    #make sure the two audios are of the same length and then mix them
    audio_length = min(len(audio1), len(audio2))
    audio1 = clip_audio(audio1[:audio_length])
    audio2 = clip_audio(audio2[:audio_length])
    audio_mix = (audio1 + audio2) / 2.0

    #define the transformation to perform on visual frames
    vision_transform_list = [
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]
    if opt.subtract_mean:
        vision_transform_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))
    vision_transform = transforms.Compose(vision_transform_list)

    #load the object regions of the highest confidence score for both videos
    detectionResult1 = np.load(
        os.path.join(opt.data_path, 'detection_results',
                     opt.video1_name + '.npy'))
    detectionResult2 = np.load(
        os.path.join(opt.data_path, 'detection_results',
                     opt.video2_name + '.npy'))

    avged_sep_audio1 = np.zeros((audio_length))
    avged_sep_audio2 = np.zeros((audio_length))

    for i in range(opt.num_of_object_detections_to_use):
        det_box1 = detectionResult1[np.argmax(
            detectionResult1[:, 2]
        ), :]  #get the box of the highest confidence score
        det_box2 = detectionResult2[np.argmax(
            detectionResult2[:, 2]
        ), :]  #get the box of the highest confidence score
        detectionResult1[np.argmax(detectionResult1[:, 2]),
                         2] = 0  # set to 0 after using it
        detectionResult2[np.argmax(detectionResult2[:, 2]),
                         2] = 0  # set to 0 after using it
        frame_path1 = os.path.join(opt.data_path, 'frame', opt.video1_name,
                                   "%06d.png" % det_box1[0])
        frame_path2 = os.path.join(opt.data_path, 'frame', opt.video2_name,
                                   "%06d.png" % det_box2[0])
        detection1 = Image.open(frame_path1).convert('RGB').crop(
            (det_box1[-4], det_box1[-3], det_box1[-2], det_box1[-1]))
        detection2 = Image.open(frame_path2).convert('RGB').crop(
            (det_box2[-4], det_box2[-3], det_box2[-2], det_box2[-1]))

        #perform separation over the whole audio using a sliding window approach
        overlap_count = np.zeros((audio_length))
        sep_audio1 = np.zeros((audio_length))
        sep_audio2 = np.zeros((audio_length))
        sliding_window_start = 0
        data = {}
        samples_per_window = opt.audio_window
        while sliding_window_start + samples_per_window < audio_length:
            sliding_window_end = sliding_window_start + samples_per_window
            audio_segment = audio_mix[sliding_window_start:sliding_window_end]
            audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
                audio_segment, opt.stft_frame, opt.stft_hop)
            data['audio_mix_mags'] = torch.FloatTensor(
                audio_mix_mags).unsqueeze(0)
            data['audio_mix_phases'] = torch.FloatTensor(
                audio_mix_phases).unsqueeze(0)
            data['real_audio_mags'] = data[
                'audio_mix_mags']  #dont' care for testing
            data['audio_mags'] = data[
                'audio_mix_mags']  #dont' care for testing
            #separate for video 1
            data['visuals'] = vision_transform(detection1).unsqueeze(0)
            data['labels'] = torch.FloatTensor(np.ones(
                (1, 1)))  #don't care for testing
            data['vids'] = torch.FloatTensor(np.ones(
                (1, 1)))  #don't care for testing
            outputs = model.forward(data)
            reconstructed_signal = get_separated_audio(outputs, data, opt)
            sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[
                sliding_window_start:sliding_window_end] + reconstructed_signal
            #separate for video 2
            data['visuals'] = vision_transform(detection2).unsqueeze(0)
            #data['label'] = torch.LongTensor([0]) #don't care for testing
            outputs = model.forward(data)
            reconstructed_signal = get_separated_audio(outputs, data, opt)
            sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[
                sliding_window_start:sliding_window_end] + reconstructed_signal
            #update overlap count
            overlap_count[
                sliding_window_start:sliding_window_end] = overlap_count[
                    sliding_window_start:sliding_window_end] + 1
            sliding_window_start = sliding_window_start + int(
                opt.hop_size * opt.audio_sampling_rate)

        #deal with the last segment
        audio_segment = audio_mix[-samples_per_window:]
        audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
            audio_segment, opt.stft_frame, opt.stft_hop)
        data['audio_mix_mags'] = torch.FloatTensor(audio_mix_mags).unsqueeze(0)
        data['audio_mix_phases'] = torch.FloatTensor(
            audio_mix_phases).unsqueeze(0)
        data['real_audio_mags'] = data[
            'audio_mix_mags']  #dont' care for testing
        data['audio_mags'] = data['audio_mix_mags']  #dont' care for testing
        #separate for video 1
        data['visuals'] = vision_transform(detection1).unsqueeze(0)
        data['labels'] = torch.FloatTensor(np.ones(
            (1, 1)))  #don't care for testing
        data['vids'] = torch.FloatTensor(np.ones(
            (1, 1)))  #don't care for testing
        outputs = model.forward(data)
        reconstructed_signal = get_separated_audio(outputs, data, opt)
        sep_audio1[-samples_per_window:] = sep_audio1[
            -samples_per_window:] + reconstructed_signal
        #separate for video 2
        data['visuals'] = vision_transform(detection2).unsqueeze(0)
        outputs = model.forward(data)
        reconstructed_signal = get_separated_audio(outputs, data, opt)
        sep_audio2[-samples_per_window:] = sep_audio2[
            -samples_per_window:] + reconstructed_signal
        #update overlap count
        overlap_count[
            -samples_per_window:] = overlap_count[-samples_per_window:] + 1

        #divide the aggregated predicted audio by the overlap count
        avged_sep_audio1 = avged_sep_audio1 + clip_audio(
            np.divide(sep_audio1, overlap_count) * 2)
        avged_sep_audio2 = avged_sep_audio2 + clip_audio(
            np.divide(sep_audio2, overlap_count) * 2)

    separation1 = avged_sep_audio1 / opt.num_of_object_detections_to_use
    separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use

    #output original and separated audios
    output_dir = os.path.join(opt.output_dir_root,
                              opt.video1_name + 'VS' + opt.video2_name)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'),
                             audio_mix, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'),
                             separation1, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'),
                             separation2, opt.audio_sampling_rate)
    #save the two detections
    detection1.save(os.path.join(output_dir, 'audio1.png'))
    detection2.save(os.path.join(output_dir, 'audio2.png'))
    #save the spectrograms & masks
    if opt.visualize_spectrogram:
        import matplotlib.pyplot as plt
        plt.switch_backend('agg')
        plt.ioff()
        audio1_mag = generate_spectrogram_magphase(audio1,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)
        audio2_mag = generate_spectrogram_magphase(audio2,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)
        audio_mix_mag = generate_spectrogram_magphase(audio_mix,
                                                      opt.stft_frame,
                                                      opt.stft_hop,
                                                      with_phase=False)
        separation1_mag = generate_spectrogram_magphase(separation1,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        separation2_mag = generate_spectrogram_magphase(separation2,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        utils.visualizeSpectrogram(audio1_mag[0, :, :],
                                   os.path.join(output_dir, 'audio1_spec.png'))
        utils.visualizeSpectrogram(audio2_mag[0, :, :],
                                   os.path.join(output_dir, 'audio2_spec.png'))
        utils.visualizeSpectrogram(
            audio_mix_mag[0, :, :],
            os.path.join(output_dir, 'audio_mixed_spec.png'))
        utils.visualizeSpectrogram(
            separation1_mag[0, :, :],
            os.path.join(output_dir, 'separation1_spec.png'))
        utils.visualizeSpectrogram(
            separation2_mag[0, :, :],
            os.path.join(output_dir, 'separation2_spec.png'))
Ejemplo n.º 2
0
def main():
    #load test arguments
    opt = TestOptions().parse()
    opt.device = torch.device("cuda")

    # Network Builders
    builder = ModelBuilder()
    net_lipreading = builder.build_lipreadingnet(
        config_path=opt.lipreading_config_path,
        weights=opt.weights_lipreadingnet,
        extract_feats=opt.lipreading_extract_feature)
    #if identity feature dim is not 512, for resnet reduce dimension to this feature dim
    if opt.identity_feature_dim != 512:
        opt.with_fc = True
    else:
        opt.with_fc = False
    net_facial_attributes = builder.build_facial(
        pool_type=opt.visual_pool,
        fc_out=opt.identity_feature_dim,
        with_fc=opt.with_fc,
        weights=opt.weights_facial)
    net_unet = builder.build_unet(
        ngf=opt.unet_ngf,
        input_nc=opt.unet_input_nc,
        output_nc=opt.unet_output_nc,
        audioVisual_feature_dim=opt.audioVisual_feature_dim,
        identity_feature_dim=opt.identity_feature_dim,
        weights=opt.weights_unet)
    net_vocal_attributes = builder.build_vocal(pool_type=opt.audio_pool,
                                               input_channel=2,
                                               with_fc=opt.with_fc,
                                               fc_out=opt.identity_feature_dim,
                                               weights=opt.weights_vocal)

    nets = (net_lipreading, net_facial_attributes, net_unet,
            net_vocal_attributes)
    print(nets)

    # construct our audio-visual model
    model = AudioVisualModel(nets, opt)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    model.to(opt.device)
    model.eval()
    mtcnn = MTCNN(keep_all=True, device=opt.device)

    lipreading_preprocessing_func = get_preprocessing_pipelines()['test']
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    vision_transform_list = [transforms.ToTensor()]
    if opt.normalization:
        vision_transform_list.append(normalize)
    vision_transform = transforms.Compose(vision_transform_list)

    # load data
    mouthroi_1 = load_mouthroi(opt.mouthroi1_path)
    mouthroi_2 = load_mouthroi(opt.mouthroi2_path)

    _, audio1 = wavfile.read(opt.audio1_path)
    _, audio2 = wavfile.read(opt.audio2_path)
    _, audio_offscreen = wavfile.read(opt.offscreen_audio_path)
    audio1 = audio1 / 32768
    audio2 = audio2 / 32768
    audio_offscreen = audio_offscreen / 32768

    #make sure the two audios are of the same length and then mix them
    audio_length = min(min(len(audio1), len(audio2)), len(audio_offscreen))
    audio1 = clip_audio(audio1[:audio_length])
    audio2 = clip_audio(audio2[:audio_length])
    audio_offscreen = clip_audio(audio_offscreen[:audio_length])
    audio_mix = (audio1 + audio2 + audio_offscreen * opt.noise_weight) / 3

    if opt.reliable_face:
        best_score_1 = 0
        best_score_2 = 0
        for i in range(10):
            frame_1 = load_frame(opt.video1_path)
            frame_2 = load_frame(opt.video2_path)
            boxes, scores = mtcnn.detect(frame_1)
            if scores[0] > best_score_1:
                best_frame_1 = frame_1
            boxes, scores = mtcnn.detect(frame_2)
            if scores[0] > best_score_2:
                best_frame_2 = frame_2
        frames_1 = vision_transform(best_frame_1).squeeze().unsqueeze(0)
        frames_2 = vision_transform(best_frame_2).squeeze().unsqueeze(0)
    else:
        frame_1_list = []
        frame_2_list = []
        for i in range(opt.number_of_identity_frames):
            frame_1 = load_frame(opt.video1_path)
            frame_2 = load_frame(opt.video2_path)
            frame_1 = vision_transform(frame_1)
            frame_2 = vision_transform(frame_2)
            frame_1_list.append(frame_1)
            frame_2_list.append(frame_2)
        frames_1 = torch.stack(frame_1_list).squeeze().unsqueeze(0)
        frames_2 = torch.stack(frame_2_list).squeeze().unsqueeze(0)

    #perform separation over the whole audio using a sliding window approach
    overlap_count = np.zeros((audio_length))
    sep_audio1 = np.zeros((audio_length))
    sep_audio2 = np.zeros((audio_length))
    sliding_window_start = 0
    data = {}
    avged_sep_audio1 = np.zeros((audio_length))
    avged_sep_audio2 = np.zeros((audio_length))

    samples_per_window = int(opt.audio_length * opt.audio_sampling_rate)
    while sliding_window_start + samples_per_window < audio_length:
        sliding_window_end = sliding_window_start + samples_per_window

        #get audio spectrogram
        segment1_audio = audio1[sliding_window_start:sliding_window_end]
        segment2_audio = audio2[sliding_window_start:sliding_window_end]
        segment_offscreen = audio_offscreen[
            sliding_window_start:sliding_window_end]

        if opt.audio_normalization:
            normalizer1, segment1_audio = audio_normalize(segment1_audio)
            normalizer2, segment2_audio = audio_normalize(segment2_audio)
            _, segment_offscreen = audio_normalize(segment_offscreen)
        else:
            normalizer1 = 1
            normalizer2 = 1

        audio_segment = (segment1_audio + segment2_audio +
                         segment_offscreen * opt.noise_weight) / 3
        audio_mix_spec = generate_spectrogram_complex(audio_segment,
                                                      opt.window_size,
                                                      opt.hop_size, opt.n_fft)
        audio_spec_1 = generate_spectrogram_complex(segment1_audio,
                                                    opt.window_size,
                                                    opt.hop_size, opt.n_fft)
        audio_spec_2 = generate_spectrogram_complex(segment2_audio,
                                                    opt.window_size,
                                                    opt.hop_size, opt.n_fft)

        #get mouthroi
        frame_index_start = int(
            round(sliding_window_start / opt.audio_sampling_rate * 25))
        segment1_mouthroi = mouthroi_1[frame_index_start:(
            frame_index_start + opt.num_frames), :, :]
        segment2_mouthroi = mouthroi_2[frame_index_start:(
            frame_index_start + opt.num_frames), :, :]

        #transform mouthrois
        segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi)
        segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi)

        data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(
            0)
        data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
        data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0)
        data['frame_A'] = frames_1
        data['frame_B'] = frames_2
        #don't care below
        data['frame_A'] = frames_1
        data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
            0).unsqueeze(0)
        data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
        data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(
            0)

        outputs = model.forward(data)
        reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio(
            outputs, data, opt)
        reconstructed_signal_1 = reconstructed_signal_1 * normalizer1
        reconstructed_signal_2 = reconstructed_signal_2 * normalizer2
        sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[
            sliding_window_start:sliding_window_end] + reconstructed_signal_1
        sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[
            sliding_window_start:sliding_window_end] + reconstructed_signal_2
        #update overlap count
        overlap_count[sliding_window_start:sliding_window_end] = overlap_count[
            sliding_window_start:sliding_window_end] + 1
        sliding_window_start = sliding_window_start + int(
            opt.hop_length * opt.audio_sampling_rate)

    #deal with the last segment
    #get audio spectrogram
    segment1_audio = audio1[-samples_per_window:]
    segment2_audio = audio2[-samples_per_window:]
    segment_offscreen = audio_offscreen[-samples_per_window:]
    if opt.audio_normalization:
        normalizer1, segment1_audio = audio_normalize(segment1_audio)
        normalizer2, segment2_audio = audio_normalize(segment2_audio)
    else:
        normalizer1 = 1
        normalizer2 = 1

    audio_segment = (segment1_audio + segment2_audio +
                     segment_offscreen * opt.noise_weight) / 3
    audio_mix_spec = generate_spectrogram_complex(audio_segment,
                                                  opt.window_size,
                                                  opt.hop_size, opt.n_fft)
    #get mouthroi
    frame_index_start = int(
        round((len(audio1) - samples_per_window) / opt.audio_sampling_rate *
              25)) - 1
    segment1_mouthroi = mouthroi_1[frame_index_start:(frame_index_start +
                                                      opt.num_frames), :, :]
    segment2_mouthroi = mouthroi_2[frame_index_start:(frame_index_start +
                                                      opt.num_frames), :, :]
    #transform mouthrois
    segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi)
    segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi)
    audio_spec_1 = generate_spectrogram_complex(segment1_audio,
                                                opt.window_size, opt.hop_size,
                                                opt.n_fft)
    audio_spec_2 = generate_spectrogram_complex(segment2_audio,
                                                opt.window_size, opt.hop_size,
                                                opt.n_fft)
    data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0)
    data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
    data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0)
    data['frame_A'] = frames_1
    data['frame_B'] = frames_2
    #don't care below
    data['frame_A'] = frames_1
    data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze(
        0).unsqueeze(0)
    data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0)
    data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0)

    outputs = model.forward(data)
    reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio(
        outputs, data, opt)
    reconstructed_signal_1 = reconstructed_signal_1 * normalizer1
    reconstructed_signal_2 = reconstructed_signal_2 * normalizer2
    sep_audio1[-samples_per_window:] = sep_audio1[
        -samples_per_window:] + reconstructed_signal_1
    sep_audio2[-samples_per_window:] = sep_audio2[
        -samples_per_window:] + reconstructed_signal_2
    #update overlap count
    overlap_count[
        -samples_per_window:] = overlap_count[-samples_per_window:] + 1

    #divide the aggregated predicted audio by the overlap count
    avged_sep_audio1 = avged_sep_audio1 + clip_audio(
        np.divide(sep_audio1, overlap_count))
    avged_sep_audio2 = avged_sep_audio2 + clip_audio(
        np.divide(sep_audio2, overlap_count))

    #output original and separated audios
    parts1 = opt.video1_path.split('/')
    parts2 = opt.video2_path.split('/')
    video1_name = parts1[-3] + '_' + parts1[-2] + '_' + parts1[-1][:-4]
    video2_name = parts2[-3] + '_' + parts2[-2] + '_' + parts2[-1][:-4]

    output_dir = os.path.join(opt.output_dir_root,
                              video1_name + 'VS' + video2_name)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2,
                             opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio_offscreen.wav'),
                             audio_offscreen, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'),
                             audio_mix, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'),
                             avged_sep_audio1, opt.audio_sampling_rate)
    librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'),
                             avged_sep_audio2, opt.audio_sampling_rate)
Ejemplo n.º 3
0
                  (epoch, total_batches))
            torch.save(
                net_visual.state_dict(),
                os.path.join('.', opt.checkpoints_dir, opt.name,
                             'visual_latest.pth'))
            torch.save(
                net_unet.state_dict(),
                os.path.join('.', opt.checkpoints_dir, opt.name,
                             'unet_latest.pth'))
            torch.save(
                net_classifier.state_dict(),
                os.path.join('.', opt.checkpoints_dir, opt.name,
                             'classifier_latest.pth'))

        if (total_batches % opt.validation_freq == 0 and opt.validation_on):
            model.eval()
            opt.mode = 'val'
            print(
                'Display validation results at (epoch %d, total_batches %d)' %
                (epoch, total_batches))
            val_err = display_val(model, crit, writer, total_batches,
                                  dataset_val, opt)
            print('end of display \n')
            model.train()
            opt.mode = 'main'
            #save the model that achieves the smallest validation error
            if val_err < best_err:
                best_err = val_err
                print(
                    'saving the best model (epoch %d, total_batches %d) with validation error %.3f\n'
                    % (epoch, total_batches, val_err))
Ejemplo n.º 4
0
def test_sepration(opt, nets, output_dir, save_files=False):
    #load test arguments

    # altered_visual1 = opt.video1_name               #  '1oz3h9doX_g_5'
    # altered_visual2 = opt.video2_name               # '2R12lQszz90_4'
    opt.visualize_spectrogram = True

    model = AudioVisualModel(nets, opt)

    #model = torch.nn.DataParallel(model, device_ids=[0])
    model.to('cuda')
    model.eval()

    #load the two audios
    audio1_path = os.path.join(opt.data_path, 'solo_audio_resample',
                               opt.video1_ins, opt.video1_name + '.wav')
    audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate)
    audio2_path = os.path.join(opt.data_path, 'solo_audio_resample',
                               opt.video2_ins, opt.video2_name + '.wav')
    audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate)
    audio3_path = os.path.join(opt.data_path, 'solo_audio_resample',
                               opt.video3_ins, opt.video3_name + '.wav')
    audio3, _ = librosa.load(audio3_path, sr=opt.audio_sampling_rate)

    #make sure the two audios are of the same length and then mix them
    audio_length = min(len(audio1), len(audio2), len(audio3))
    audio1 = clip_audio(audio1[:audio_length])
    audio2 = clip_audio(audio2[:audio_length])
    audio3 = clip_audio(audio3[:audio_length])
    audio_mix = (audio1 + audio2 + audio3) / 3.0

    #define the transformation to perform on visual frames
    vision_transform_list = [
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]
    vision_transform_list_for_unet = [
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ]
    if opt.subtract_mean:
        vision_transform_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))
        vision_transform_list_for_unet.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

    vision_transform = transforms.Compose(vision_transform_list)
    vision_transform_for_unet = transforms.Compose(
        vision_transform_list_for_unet)

    #load the object regions of the highest confidence score for both videos
    detectionResult1 = np.load(
        os.path.join(opt.data_path, 'solo_detect', opt.video1_ins,
                     opt.video1_name + '.npy'))
    detectionResult2 = np.load(
        os.path.join(opt.data_path, 'solo_detect', opt.video2_ins,
                     opt.video2_name + '.npy'))
    detectionResult3 = np.load(
        os.path.join(opt.data_path, 'solo_detect', opt.video3_ins,
                     opt.video3_name + '.npy'))
    clip_det_bbs1 = sample_object_detections(detectionResult1)
    clip_det_bbs2 = sample_object_detections(detectionResult2)
    clip_det_bbs3 = sample_object_detections(detectionResult3)

    avged_sep_audio1 = np.zeros((audio_length))
    avged_refine_audio1 = np.zeros((audio_length))
    avged_sep_audio2 = np.zeros((audio_length))
    avged_refine_audio2 = np.zeros((audio_length))
    avged_sep_audio3 = np.zeros((audio_length))
    avged_refine_audio3 = np.zeros((audio_length))

    for i in range(opt.num_of_object_detections_to_use):
        # 第一个的筛选
        if clip_det_bbs1.shape[0] == 1:
            frame_path1 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video1_ins, opt.video1_name,
                                       "%06d.png" % clip_det_bbs1[0, 0])
            detection_bbs_filter_1 = clip_det_bbs1[0]
        elif clip_det_bbs1.shape[0] >= 2:
            hand_npy = os.path.join(opt.data_path, 'solo_detect_hand',
                                    opt.video1_ins,
                                    opt.video1_name + '_hand.npy')
            if os.path.exists(hand_npy):
                hand_bbs = np.load(hand_npy)
            else:
                hand_bbs = np.array([])

            if hand_bbs.shape[0] == 0:
                hand_bb = np.array([])
                sign = False
                print("this npy file {} donot have detected hands".format(
                    os.path.basename(hand_npy)))
            elif hand_bbs.shape[0] == 1:
                hand_bb = hand_bbs
                sign = True
            elif hand_bbs.shape[
                    0] >= 2:  # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个
                the_max = np.argmax(hand_bbs[:, 1])
                hand_bb1 = hand_bbs[the_max, :]  # 取一个概率最大的
                hand_bb1 = hand_bb1[np.newaxis, :]
                hand_bbs[the_max, 1] = 0  # 取出后置为0
                the_second_max = np.argmax(hand_bbs[:, 1])  # 取一个次大的。
                hand_bb2 = hand_bbs[the_second_max, :]
                hand_bb2 = hand_bb2[np.newaxis, :]
                hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0)
                sign = True
            detection_bbs_filter_1 = filter_det_bbs(hand_bb, sign,
                                                    clip_det_bbs1)
            frame_path1 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video1_ins, opt.video1_name,
                                       "%06d.png" % detection_bbs_filter_1[0])

        detection1 = Image.open(frame_path1).convert('RGB').crop(
            (detection_bbs_filter_1[-4], detection_bbs_filter_1[-3],
             detection_bbs_filter_1[-2], detection_bbs_filter_1[-1]))
        # 第二个的筛选
        if clip_det_bbs2.shape[0] == 1:
            frame_path2 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video2_ins, opt.video2_name,
                                       "%06d.png" % clip_det_bbs2[0, 0])
            detection_bbs_filter_2 = clip_det_bbs2[0]
        elif clip_det_bbs2.shape[0] >= 2:
            hand_npy = os.path.join(opt.data_path, 'solo_detect_hand',
                                    opt.video2_ins,
                                    opt.video2_name + '_hand.npy')
            if os.path.exists(hand_npy):
                hand_bbs = np.load(hand_npy)
            else:
                hand_bbs = np.array([])

            if hand_bbs.shape[0] == 0:
                hand_bb = np.array([])
                sign = False
                print("this npy file {} donot have detected hands".format(
                    os.path.basename(hand_npy)))
            elif hand_bbs.shape[0] == 1:
                hand_bb = hand_bbs
                sign = True
            elif hand_bbs.shape[
                    0] >= 2:  # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个
                the_max = np.argmax(hand_bbs[:, 1])
                hand_bb1 = hand_bbs[the_max, :]  # 取一个概率最大的
                hand_bb1 = hand_bb1[np.newaxis, :]
                hand_bbs[the_max, 1] = 0  # 取出后置为0
                the_second_max = np.argmax(hand_bbs[:, 1])  # 取一个次大的。
                hand_bb2 = hand_bbs[the_second_max, :]
                hand_bb2 = hand_bb2[np.newaxis, :]
                hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0)
                sign = True
            detection_bbs_filter_2 = filter_det_bbs(hand_bb, sign,
                                                    clip_det_bbs2)
            frame_path2 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video2_ins, opt.video2_name,
                                       "%06d.png" % clip_det_bbs2[0, 0])

        detection2 = Image.open(frame_path2).convert('RGB').crop(
            (detection_bbs_filter_2[-4], detection_bbs_filter_2[-3],
             detection_bbs_filter_2[-2], detection_bbs_filter_2[-1]))

        # 第三个的筛选
        if clip_det_bbs3.shape[0] == 1:
            frame_path3 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video3_ins, opt.video3_name,
                                       "%06d.png" % clip_det_bbs3[0, 0])
            detection_bbs_filter_3 = clip_det_bbs3[0]
        elif clip_det_bbs3.shape[0] >= 2:
            hand_npy = os.path.join(opt.data_path, 'solo_detect_hand',
                                    opt.video3_ins,
                                    opt.video3_name + '_hand.npy')
            if os.path.exists(hand_npy):
                hand_bbs = np.load(hand_npy)
            else:
                hand_bbs = np.array([])

            if hand_bbs.shape[0] == 0:
                hand_bb = np.array([])
                sign = False
                print("this npy file {} donot have detected hands".format(
                    os.path.basename(hand_npy)))
            elif hand_bbs.shape[0] == 1:
                hand_bb = hand_bbs
                sign = True
            elif hand_bbs.shape[
                    0] >= 2:  # 在检测到的乐器数不止一个的情况下,如果检测到两只手以上,则取计算结果中概率最大的前两个
                the_max = np.argmax(hand_bbs[:, 1])
                hand_bb1 = hand_bbs[the_max, :]  # 取一个概率最大的
                hand_bb1 = hand_bb1[np.newaxis, :]
                hand_bbs[the_max, 1] = 0  # 取出后置为0
                the_second_max = np.argmax(hand_bbs[:, 1])  # 取一个次大的。
                hand_bb2 = hand_bbs[the_second_max, :]
                hand_bb2 = hand_bb2[np.newaxis, :]
                hand_bb = np.concatenate((hand_bb1, hand_bb2), axis=0)
                sign = True
            detection_bbs_filter_3 = filter_det_bbs(hand_bb, sign,
                                                    clip_det_bbs3)
            frame_path3 = os.path.join(opt.data_path, 'solo_extract',
                                       opt.video3_ins, opt.video3_name,
                                       "%06d.png" % clip_det_bbs3[0, 0])

        detection3 = Image.open(frame_path3).convert('RGB').crop(
            (detection_bbs_filter_3[-4], detection_bbs_filter_3[-3],
             detection_bbs_filter_3[-2], detection_bbs_filter_3[-1]))

        #perform separation over the whole audio using a sliding window approach
        overlap_count = np.zeros((audio_length))
        sep_audio1 = np.zeros((audio_length))
        sep_audio2 = np.zeros((audio_length))
        sep_audio3 = np.zeros((audio_length))
        refine_sep1 = np.zeros((audio_length))
        refine_sep2 = np.zeros((audio_length))
        refine_sep3 = np.zeros((audio_length))

        sliding_window_start = 0
        data = {}
        samples_per_window = opt.audio_window
        while sliding_window_start + samples_per_window < audio_length:
            objects_visuals = []
            objects_labels = []
            objects_audio_mag = []
            objects_audio_phase = []
            objects_vids = []
            objects_real_audio_mag = []
            objects_audio_mix_mag = []
            objects_audio_mix_phase = []
            objects_visuals_256 = []

            sliding_window_end = sliding_window_start + samples_per_window
            audio_segment = audio_mix[sliding_window_start:sliding_window_end]
            audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
                audio_segment, opt.stft_frame, opt.stft_hop)
            ''' 第一份音乐的信息'''
            objects_audio_mix_mag.append(
                torch.FloatTensor(audio_mix_mags).unsqueeze(0))
            objects_audio_mix_phase.append(
                torch.FloatTensor(audio_mix_phases).unsqueeze(0))
            objects_visuals.append(vision_transform(detection1).unsqueeze(0))
            objects_visuals_256.append(
                vision_transform_for_unet(detection1).unsqueeze(0))
            objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
            objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
            ''' 第二份音乐的信息'''
            objects_audio_mix_mag.append(
                torch.FloatTensor(audio_mix_mags).unsqueeze(0))
            objects_audio_mix_phase.append(
                torch.FloatTensor(audio_mix_phases).unsqueeze(0))
            objects_visuals.append(vision_transform(detection2).unsqueeze(0))
            objects_visuals_256.append(
                vision_transform_for_unet(detection2).unsqueeze(0))
            objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
            objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
            ''' 第3份音乐的信息'''
            objects_audio_mix_mag.append(
                torch.FloatTensor(audio_mix_mags).unsqueeze(0))
            objects_audio_mix_phase.append(
                torch.FloatTensor(audio_mix_phases).unsqueeze(0))
            objects_visuals.append(vision_transform(detection3).unsqueeze(0))
            objects_visuals_256.append(
                vision_transform_for_unet(detection3).unsqueeze(0))
            objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
            objects_vids.append(torch.FloatTensor(np.ones((1, 1))))

            data['audio_mix_mags'] = torch.FloatTensor(
                np.vstack(objects_audio_mix_mag)).cuda()
            data['audio_mags'] = data['audio_mix_mags']
            data['audio_mix_phases'] = torch.FloatTensor(
                np.vstack(objects_audio_mix_phase)).cuda()
            data['visuals'] = torch.FloatTensor(
                np.vstack(objects_visuals)).cuda()
            data['visuals_256'] = torch.FloatTensor(
                np.vstack(objects_visuals_256)).cuda()
            data['labels'] = torch.FloatTensor(
                np.vstack(objects_labels)).cuda()
            data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda()

            outputs = model.forward(data)

            reconstructed_signal, refine_signal = get_separated_audio(
                outputs, data, opt)

            sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[
                sliding_window_start:
                sliding_window_end] + reconstructed_signal[0]
            refine_sep1[sliding_window_start:sliding_window_end] = refine_sep1[
                sliding_window_start:sliding_window_end] + refine_signal[0]

            sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[
                sliding_window_start:
                sliding_window_end] + reconstructed_signal[1]
            refine_sep2[sliding_window_start:sliding_window_end] = refine_sep2[
                sliding_window_start:sliding_window_end] + refine_signal[1]

            sep_audio3[sliding_window_start:sliding_window_end] = sep_audio3[
                sliding_window_start:
                sliding_window_end] + reconstructed_signal[2]
            refine_sep3[sliding_window_start:sliding_window_end] = refine_sep3[
                sliding_window_start:sliding_window_end] + refine_signal[2]
            #update overlap count
            overlap_count[
                sliding_window_start:sliding_window_end] = overlap_count[
                    sliding_window_start:sliding_window_end] + 1
            sliding_window_start = sliding_window_start + int(
                opt.hop_size * opt.audio_sampling_rate)

        # deal with the last segment
        audio_segment = audio_mix[-samples_per_window:]
        audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
            audio_segment, opt.stft_frame, opt.stft_hop)

        objects_visuals = []
        objects_labels = []
        objects_audio_mag = []
        objects_audio_phase = []
        objects_vids = []
        objects_real_audio_mag = []
        objects_audio_mix_mag = []
        objects_audio_mix_phase = []
        objects_visuals_256 = []
        ''' 第一份音乐的信息,应该有两份'''
        objects_audio_mix_mag.append(
            torch.FloatTensor(audio_mix_mags).unsqueeze(0))
        objects_audio_mix_phase.append(
            torch.FloatTensor(audio_mix_phases).unsqueeze(0))
        objects_visuals.append(vision_transform(detection1).unsqueeze(0))
        objects_visuals_256.append(
            vision_transform_for_unet(detection1).unsqueeze(0))
        objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
        objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
        ''' 第二份音乐的信息'''
        objects_audio_mix_mag.append(
            torch.FloatTensor(audio_mix_mags).unsqueeze(0))
        objects_audio_mix_phase.append(
            torch.FloatTensor(audio_mix_phases).unsqueeze(0))
        objects_visuals_256.append(
            vision_transform_for_unet(detection2).unsqueeze(0))
        objects_visuals.append(vision_transform(detection2).unsqueeze(0))
        objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
        objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
        ''' 第3份音乐的信息'''
        objects_audio_mix_mag.append(
            torch.FloatTensor(audio_mix_mags).unsqueeze(0))
        objects_audio_mix_phase.append(
            torch.FloatTensor(audio_mix_phases).unsqueeze(0))
        objects_visuals.append(vision_transform(detection3).unsqueeze(0))
        objects_visuals_256.append(
            vision_transform_for_unet(detection3).unsqueeze(0))
        objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
        objects_vids.append(torch.FloatTensor(np.ones((1, 1))))

        data['audio_mix_mags'] = torch.FloatTensor(
            np.vstack(objects_audio_mix_mag)).cuda()
        data['audio_mags'] = data['audio_mix_mags']
        data['audio_mix_phases'] = torch.FloatTensor(
            np.vstack(objects_audio_mix_phase)).cuda()
        data['visuals'] = torch.FloatTensor(np.vstack(objects_visuals)).cuda()
        data['labels'] = torch.FloatTensor(np.vstack(objects_labels)).cuda()
        data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda()
        data['visuals_256'] = torch.FloatTensor(
            np.vstack(objects_visuals_256)).cuda()

        outputs = model.forward(data)

        reconstructed_signal, refine_signal = get_separated_audio(
            outputs, data, opt)
        sep_audio1[-samples_per_window:] = sep_audio1[
            -samples_per_window:] + reconstructed_signal[0]
        refine_sep1[-samples_per_window:] = refine_sep1[
            -samples_per_window:] + refine_signal[0]

        sep_audio2[-samples_per_window:] = sep_audio2[
            -samples_per_window:] + reconstructed_signal[1]
        refine_sep2[-samples_per_window:] = refine_sep2[
            -samples_per_window:] + refine_signal[1]

        sep_audio3[-samples_per_window:] = sep_audio3[
            -samples_per_window:] + reconstructed_signal[2]
        refine_sep3[-samples_per_window:] = refine_sep3[
            -samples_per_window:] + refine_signal[2]

        #update overlap count
        overlap_count[
            -samples_per_window:] = overlap_count[-samples_per_window:] + 1

        #divide the aggregated predicted audio by the overlap count
        avged_sep_audio1 = avged_sep_audio1 + clip_audio(
            np.divide(sep_audio1, overlap_count) * 2)
        avged_refine_audio1 = avged_refine_audio1 + clip_audio(
            np.divide(refine_sep1, overlap_count) * 2)
        avged_sep_audio2 = avged_sep_audio2 + clip_audio(
            np.divide(sep_audio2, overlap_count) * 2)
        avged_refine_audio2 = avged_refine_audio2 + clip_audio(
            np.divide(refine_sep2, overlap_count) * 2)
        avged_sep_audio3 = avged_sep_audio3 + clip_audio(
            np.divide(sep_audio3, overlap_count) * 2)
        avged_refine_audio3 = avged_refine_audio3 + clip_audio(
            np.divide(refine_sep3, overlap_count) * 2)

    separation1 = avged_sep_audio1 / opt.num_of_object_detections_to_use
    separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use
    separation3 = avged_sep_audio3 / opt.num_of_object_detections_to_use
    refine_spearation1 = avged_refine_audio1 / opt.num_of_object_detections_to_use
    refine_spearation2 = avged_refine_audio2 / opt.num_of_object_detections_to_use
    refine_spearation3 = avged_refine_audio3 / opt.num_of_object_detections_to_use

    #output original and separated audios
    output_dir = os.path.join(
        output_dir, opt.video1_name + '$_VS_$' + opt.video2_name + '$_VS_$' +
        opt.video3_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if save_files:
        librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'),
                                 audio1, opt.audio_sampling_rate)
        librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'),
                                 audio2, opt.audio_sampling_rate)
        librosa.output.write_wav(os.path.join(output_dir, 'audio3.wav'),
                                 audio3, opt.audio_sampling_rate)
        librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'),
                                 audio_mix, opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio1_separated.wav'), separation1,
            opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio2_separated.wav'), separation2,
            opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio3_separated.wav'), separation3,
            opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio1_refine_separated.wav'),
            refine_spearation1, opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio2_refine_separated.wav'),
            refine_spearation2, opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio3_refine_separated.wav'),
            refine_spearation3, opt.audio_sampling_rate)

    c_reference_sources = np.concatenate(
        (np.expand_dims(audio1, axis=0), np.expand_dims(
            audio2, axis=0), np.expand_dims(audio3, axis=0)),
        axis=0)
    c_estimated_sources = np.concatenate((np.expand_dims(
        separation1, axis=0), np.expand_dims(
            separation2, axis=0), np.expand_dims(separation3, axis=0)),
                                         axis=0)

    c_sdr, c_sir, c_sar = getSeparationMetrics(c_reference_sources,
                                               c_estimated_sources)

    r_reference_sources = np.concatenate(
        (np.expand_dims(audio1, axis=0), np.expand_dims(
            audio2, axis=0), np.expand_dims(audio3, axis=0)),
        axis=0)
    r_estimated_sources = np.concatenate(
        (np.expand_dims(refine_spearation1,
                        axis=0), np.expand_dims(refine_spearation2, axis=0),
         np.expand_dims(refine_spearation3, axis=0)),
        axis=0)

    r_sdr, r_sir, r_sar = getSeparationMetrics(r_reference_sources,
                                               r_estimated_sources)

    #save the two detections
    if save_files:
        detection1.save(os.path.join(output_dir, 'audio1.png'))
        detection2.save(os.path.join(output_dir, 'audio2.png'))
        detection3.save(os.path.join(output_dir, 'audio3.png'))
    #save the spectrograms & masks
    if opt.visualize_spectrogram:
        import matplotlib.pyplot as plt
        plt.switch_backend('agg')
        plt.ioff()
        audio1_mag = generate_spectrogram_magphase(audio1,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)
        audio2_mag = generate_spectrogram_magphase(audio2,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)
        audio3_mag = generate_spectrogram_magphase(audio3,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)
        audio_mix_mag = generate_spectrogram_magphase(audio_mix,
                                                      opt.stft_frame,
                                                      opt.stft_hop,
                                                      with_phase=False)
        separation1_mag = generate_spectrogram_magphase(separation1,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        separation2_mag = generate_spectrogram_magphase(separation2,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        separation3_mag = generate_spectrogram_magphase(separation3,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        refine_sep1_mag = generate_spectrogram_magphase(refine_spearation1,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        refine_sep2_mag = generate_spectrogram_magphase(refine_spearation2,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        refine_sep3_mag = generate_spectrogram_magphase(refine_spearation3,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        utils.visualizeSpectrogram(audio1_mag[0, :, :],
                                   os.path.join(output_dir, 'audio1_spec.png'))
        utils.visualizeSpectrogram(audio2_mag[0, :, :],
                                   os.path.join(output_dir, 'audio2_spec.png'))
        utils.visualizeSpectrogram(audio3_mag[0, :, :],
                                   os.path.join(output_dir, 'audio3_spec.png'))
        utils.visualizeSpectrogram(
            audio_mix_mag[0, :, :],
            os.path.join(output_dir, 'audio_mixed_spec.png'))
        utils.visualizeSpectrogram(
            separation1_mag[0, :, :],
            os.path.join(output_dir, 'separation1_spec.png'))
        utils.visualizeSpectrogram(
            separation2_mag[0, :, :],
            os.path.join(output_dir, 'separation2_spec.png'))
        utils.visualizeSpectrogram(
            separation3_mag[0, :, :],
            os.path.join(output_dir, 'separation3_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep1_mag[0, :, :],
            os.path.join(output_dir, 'refine1_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep2_mag[0, :, :],
            os.path.join(output_dir, 'refine2_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep3_mag[0, :, :],
            os.path.join(output_dir, 'refine3_spec.png'))
    return c_sdr, c_sir, c_sar, r_sdr, r_sir, r_sar
Ejemplo n.º 5
0
def main():
	#load test arguments
	opt = TestRealOptions().parse()
	opt.device = torch.device("cuda")

	# Network Builders
	builder = ModelBuilder()
	net_lipreading = builder.build_lipreadingnet(
		config_path=opt.lipreading_config_path,
		weights=opt.weights_lipreadingnet,
		extract_feats=opt.lipreading_extract_feature)
	#if identity feature dim is not 512, for resnet reduce dimension to this feature dim
	if opt.identity_feature_dim != 512:
		opt.with_fc = True
	else:
		opt.with_fc = False
	net_facial_attributes = builder.build_facial(
			pool_type=opt.visual_pool,
			fc_out = opt.identity_feature_dim,
			with_fc=opt.with_fc,
			weights=opt.weights_facial)  
	net_unet = builder.build_unet(
			ngf=opt.unet_ngf,
			input_nc=opt.unet_input_nc,
			output_nc=opt.unet_output_nc,
			audioVisual_feature_dim=opt.audioVisual_feature_dim,
			identity_feature_dim=opt.identity_feature_dim,
			weights=opt.weights_unet)
	net_vocal_attributes = builder.build_vocal(
			pool_type=opt.audio_pool,
			input_channel=2,
			with_fc=opt.with_fc,
			fc_out = opt.identity_feature_dim,
			weights=opt.weights_vocal)

	nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes)
	print(nets)

	# construct our audio-visual model
	model = AudioVisualModel(nets, opt)
	model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
	model.to(opt.device)
	model.eval()
	mtcnn = MTCNN(keep_all=True, device=opt.device)

	lipreading_preprocessing_func = get_preprocessing_pipelines()['test']
	normalize = transforms.Normalize(
		mean=[0.485, 0.456, 0.406],
		std=[0.229, 0.224, 0.225]
	)
	vision_transform_list = [transforms.ToTensor()]
	if opt.normalization:
		vision_transform_list.append(normalize)
	vision_transform = transforms.Compose(vision_transform_list)

	for speaker_index in range(opt.number_of_speakers):
		mouthroi_path = os.path.join(opt.mouthroi_root, 'speaker' + str(speaker_index+1) + '.npz')
		facetrack_path = os.path.join(opt.facetrack_root, 'speaker' + str(speaker_index+1) + '.mp4')

		#load data
		mouthroi = load_mouthroi(mouthroi_path)
		sr, audio = wavfile.read(opt.audio_path)
		print("sampling rate of audio: ", sr)
		if len((audio.shape)) == 2:
			audio = np.mean(audio, axis=1) #convert to mono if stereo
		audio = audio / 32768
		audio = audio / 2.0
		audio = clip_audio(audio)
		audio_length = len(audio)

		if opt.reliable_face:
			best_score = 0
			for i in range(10):
				frame = load_frame(facetrack_path)
				boxes, scores = mtcnn.detect(frame)
				if scores[0] > best_score:
					best_frame = frame			
			frames = vision_transform(best_frame).squeeze().unsqueeze(0).cuda()
		else:
			frame_list = []
			for i in range(opt.number_of_identity_frames):
				frame = load_frame(facetrack_path)
				frame = vision_transform(frame)
				frame_list.append(frame)
			frame = torch.stack(frame_list).squeeze().unsqueeze(0).cuda()
		
		sep_audio = np.zeros((audio_length))

		#perform separation over the whole audio using a sliding window approach
		sliding_window_start = 0
		overlap_count = np.zeros((audio_length))
		sep_audio = np.zeros((audio_length))
		avged_sep_audio = np.zeros((audio_length))
		samples_per_window = int(opt.audio_length * opt.audio_sampling_rate)
		while sliding_window_start + samples_per_window < audio_length:
			sliding_window_end = sliding_window_start + samples_per_window

			#get audio spectrogram
			segment_audio = audio[sliding_window_start:sliding_window_end]
			if opt.audio_normalization:
				normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07)
			else:
				normalizer = 1

			audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft)            
			audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda()

			#get mouthroi
			frame_index_start = int(round(sliding_window_start / opt.audio_sampling_rate * 25))
			segment_mouthroi = mouthroi[frame_index_start:(frame_index_start + opt.num_frames), :, :]
			segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi)
			segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda()
				
			reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt)
			reconstructed_signal = reconstructed_signal * normalizer
			sep_audio[sliding_window_start:sliding_window_end] = sep_audio[sliding_window_start:sliding_window_end] + reconstructed_signal

			#update overlap count
			overlap_count[sliding_window_start:sliding_window_end] = overlap_count[sliding_window_start:sliding_window_end] + 1
			sliding_window_start = sliding_window_start + int(opt.hop_length * opt.audio_sampling_rate)
				
		#deal with the last segment
		segment_audio = audio[-samples_per_window:]
		if opt.audio_normalization:
			normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07)
		else:
			normalizer = 1

		audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft)
		audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda()
		#get mouthroi
		frame_index_start = int(round((len(audio) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1
		segment_mouthroi = mouthroi[-opt.num_frames:, :, :]
		segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi)
		segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda()
		reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt)
		reconstructed_signal = reconstructed_signal * normalizer
		sep_audio[-samples_per_window:] = sep_audio[-samples_per_window:] + reconstructed_signal

		#update overlap count
		overlap_count[-samples_per_window:] = overlap_count[-samples_per_window:] + 1
		
		#divide the aggregated predicted audio by the overlap count
		avged_sep_audio =  clip_audio(np.divide(sep_audio, overlap_count))

		#output separated audios
		if not os.path.isdir(opt.output_dir_root):
			os.mkdir(opt.output_dir_root)
		librosa.output.write_wav(os.path.join(opt.output_dir_root, 'speaker' + str(speaker_index+1) + '.wav'), avged_sep_audio, opt.audio_sampling_rate)
Ejemplo n.º 6
0
def main():
    #load test arguments
    opt = TestOptions().parse()
    opt.device = torch.device("cuda")

    # network builders
    builder = ModelBuilder()
    net_visual = builder.build_visual(weights=opt.weights_visual)
    net_audio = builder.build_audio(ngf=opt.unet_ngf,
                                    input_nc=opt.unet_input_nc,
                                    output_nc=opt.unet_output_nc,
                                    weights=opt.weights_audio)
    nets = (net_visual, net_audio)

    # construct our audio-visual model
    model = AudioVisualModel(nets, opt)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    model.to(opt.device)
    model.eval()

    #load the audio to perform separation
    audio, audio_rate = librosa.load(opt.input_audio_path,
                                     sr=opt.audio_sampling_rate,
                                     mono=False)
    audio_channel1 = audio[0, :]
    audio_channel2 = audio[1, :]

    #define the transformation to perform on visual frames
    vision_transform_list = [
        transforms.Resize((224, 448)),
        transforms.ToTensor()
    ]
    vision_transform_list.append(
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]))
    vision_transform = transforms.Compose(vision_transform_list)

    #perform spatialization over the whole audio using a sliding window approach
    overlap_count = np.zeros(
        (audio.shape))  #count the number of times a data point is calculated
    binaural_audio = np.zeros((audio.shape))

    #perform spatialization over the whole spectrogram in a siliding-window fashion
    sliding_window_start = 0
    data = {}
    samples_per_window = int(opt.audio_length * opt.audio_sampling_rate)
    while sliding_window_start + samples_per_window < audio.shape[-1]:
        sliding_window_end = sliding_window_start + samples_per_window
        normalizer, audio_segment = audio_normalize(
            audio[:, sliding_window_start:sliding_window_end])
        audio_segment_channel1 = audio_segment[0, :]
        audio_segment_channel2 = audio_segment[1, :]
        audio_segment_mix = audio_segment_channel1 + audio_segment_channel2

        data['audio_diff_spec'] = torch.FloatTensor(
            generate_spectrogram(audio_segment_channel1 -
                                 audio_segment_channel2)).unsqueeze(
                                     0)  #unsqueeze to add a batch dimension
        data['audio_mix_spec'] = torch.FloatTensor(
            generate_spectrogram(audio_segment_channel1 +
                                 audio_segment_channel2)).unsqueeze(
                                     0)  #unsqueeze to add a batch dimension
        #get the frame index for current window
        frame_index = int(
            round((((sliding_window_start + samples_per_window / 2.0) /
                    audio.shape[-1]) * opt.input_audio_length + 0.05) * 10))
        image = Image.open(
            os.path.join(opt.video_frame_path,
                         str(frame_index).zfill(6) + '.jpg')).convert('RGB')
        #image = image.transpose(Image.FLIP_LEFT_RIGHT)
        frame = vision_transform(image).unsqueeze(
            0)  #unsqueeze to add a batch dimension
        data['frame'] = frame

        output = model.forward(data)
        predicted_spectrogram = output['binaural_spectrogram'][
            0, :, :, :].data[:].cpu().numpy()

        #ISTFT to convert back to audio
        reconstructed_stft_diff = predicted_spectrogram[0, :, :] + (
            1j * predicted_spectrogram[1, :, :])
        reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff,
                                                  hop_length=160,
                                                  win_length=400,
                                                  center=True,
                                                  length=samples_per_window)
        reconstructed_signal_left = (audio_segment_mix +
                                     reconstructed_signal_diff) / 2
        reconstructed_signal_right = (audio_segment_mix -
                                      reconstructed_signal_diff) / 2
        reconstructed_binaural = np.concatenate(
            (np.expand_dims(reconstructed_signal_left, axis=0),
             np.expand_dims(reconstructed_signal_right, axis=0)),
            axis=0) * normalizer

        binaural_audio[:, sliding_window_start:
                       sliding_window_end] = binaural_audio[:,
                                                            sliding_window_start:
                                                            sliding_window_end] + reconstructed_binaural
        overlap_count[:, sliding_window_start:
                      sliding_window_end] = overlap_count[:,
                                                          sliding_window_start:
                                                          sliding_window_end] + 1
        sliding_window_start = sliding_window_start + int(
            opt.hop_size * opt.audio_sampling_rate)

    #deal with the last segment
    normalizer, audio_segment = audio_normalize(audio[:, -samples_per_window:])
    audio_segment_channel1 = audio_segment[0, :]
    audio_segment_channel2 = audio_segment[1, :]
    data['audio_diff_spec'] = torch.FloatTensor(
        generate_spectrogram(audio_segment_channel1 -
                             audio_segment_channel2)).unsqueeze(
                                 0)  #unsqueeze to add a batch dimension
    data['audio_mix_spec'] = torch.FloatTensor(
        generate_spectrogram(audio_segment_channel1 +
                             audio_segment_channel2)).unsqueeze(
                                 0)  #unsqueeze to add a batch dimension
    #get the frame index for last window
    frame_index = int(
        round(((opt.input_audio_length - opt.audio_length / 2.0) + 0.05) * 10))
    image = Image.open(
        os.path.join(opt.video_frame_path,
                     str(frame_index).zfill(6) + '.jpg')).convert('RGB')
    #image = image.transpose(Image.FLIP_LEFT_RIGHT)
    frame = vision_transform(image).unsqueeze(
        0)  #unsqueeze to add a batch dimension
    data['frame'] = frame
    output = model.forward(data)
    predicted_spectrogram = output['binaural_spectrogram'][
        0, :, :, :].data[:].cpu().numpy()
    #ISTFT to convert back to audio
    reconstructed_stft_diff = predicted_spectrogram[0, :, :] + (
        1j * predicted_spectrogram[1, :, :])
    reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff,
                                              hop_length=160,
                                              win_length=400,
                                              center=True,
                                              length=samples_per_window)
    reconstructed_signal_left = (audio_segment_mix +
                                 reconstructed_signal_diff) / 2
    reconstructed_signal_right = (audio_segment_mix -
                                  reconstructed_signal_diff) / 2
    reconstructed_binaural = np.concatenate(
        (np.expand_dims(reconstructed_signal_left, axis=0),
         np.expand_dims(reconstructed_signal_right, axis=0)),
        axis=0) * normalizer

    #add the spatialized audio to reconstructed_binaural
    binaural_audio[:,
                   -samples_per_window:] = binaural_audio[:,
                                                          -samples_per_window:] + reconstructed_binaural
    overlap_count[:,
                  -samples_per_window:] = overlap_count[:,
                                                        -samples_per_window:] + 1

    #divide aggregated predicted audio by their corresponding counts
    predicted_binaural_audio = np.divide(binaural_audio, overlap_count)

    #check output directory
    if not os.path.isdir(opt.output_dir_root):
        os.mkdir(opt.output_dir_root)

    mixed_mono = (audio_channel1 + audio_channel2) / 2
    librosa.output.write_wav(
        os.path.join(opt.output_dir_root, 'predicted_binaural.wav'),
        predicted_binaural_audio, opt.audio_sampling_rate)
    librosa.output.write_wav(
        os.path.join(opt.output_dir_root, 'mixed_mono.wav'), mixed_mono,
        opt.audio_sampling_rate)
    librosa.output.write_wav(
        os.path.join(opt.output_dir_root, 'input_binaural.wav'), audio,
        opt.audio_sampling_rate)
Ejemplo n.º 7
0
def test_sepration(opt, nets, output_dir, save_files=False):
    #load test arguments

    # altered_visual1 = opt.video1_name               #  '1oz3h9doX_g_5'
    # altered_visual2 = opt.video2_name               # '2R12lQszz90_4'
    opt.visualize_spectrogram = True

    model = AudioVisualModel(nets, opt)

    #model = torch.nn.DataParallel(model, device_ids=[0])
    model.to('cuda')
    model.eval()

    #load the two audios
    # audio1_path = os.path.join(opt.data_path, 'solo_audio_resample', opt.video1_ins, opt.video1_name + '.wav')
    # audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate)
    audio2_path = os.path.join(opt.data_path_duet, 'duet_audio_resample',
                               opt.video2_ins, opt.video2_name + '.wav')
    audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate)

    #make sure the two audios are of the same length and then mix them
    # audio_length = min(len(audio1), len(audio2))
    # audio1 = clip_audio(audio1[:audio_length])
    audio_length = len(audio2)
    audio2 = clip_audio(audio2[:audio_length])
    audio_mix = audio2

    #define the transformation to perform on visual frames
    vision_transform_list = [
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]
    vision_transform_list_for_unet = [
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ]
    if opt.subtract_mean:
        vision_transform_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))
        vision_transform_list_for_unet.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

    vision_transform = transforms.Compose(vision_transform_list)
    vision_transform_for_unet = transforms.Compose(
        vision_transform_list_for_unet)

    #load the object regions of the highest confidence score for both videos
    # detectionResult1 = np.load(os.path.join(opt.data_path, 'solo_detect', opt.video1_ins, opt.video1_name + '.npy'))
    detectionR_npy = os.path.join(opt.data_path_duet, 'duet_detect',
                                  opt.video2_ins, opt.video2_name + '.npy')
    detectionResult2 = np.load(detectionR_npy)
    # detectionResult3 = np.load(os.path.join(opt.data_path, 'solo_detect', opt.video3_ins, opt.video3_name + '.npy'))
    # clip_det_bbs1 = sample_object_detections(detectionResult1)
    clip_2_path = os.path.join(
        '/data/mashuo/work/study/refine-separation/dataset/music_test',
        opt.video2_ins, opt.video2_name)
    frame_name = os.listdir(clip_2_path)[0][:6]
    frame_name = float(frame_name)
    clip_det_bbs2 = None
    sign = False
    for n in range(detectionResult2.shape[0]):
        index = detectionResult2[n][0]
        index = float(index)
        if index == frame_name:
            if not sign:
                clip_det_bbs2 = np.expand_dims(detectionResult2[n, :], axis=0)
                sign = True
            else:
                clip_det_bbs2 = np.concatenate(
                    (clip_det_bbs2,
                     np.expand_dims(detectionResult2[n, :], axis=0)),
                    axis=0)
    # clip_det_bbs2 = sample_object_detections(detectionResult2)
    # clip_det_bbs3 = sample_object_detections(detectionResult3)

    avged_sep_audio1 = np.zeros((audio_length))
    avged_refine_audio1 = np.zeros((audio_length))
    avged_sep_audio2 = np.zeros((audio_length))
    avged_refine_audio2 = np.zeros((audio_length))
    avged_sep_audio3 = np.zeros((audio_length))
    avged_refine_audio3 = np.zeros((audio_length))

    for i in range(opt.num_of_object_detections_to_use):
        # 第二个的筛选

        # det_box2 = clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]), :]
        # clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]),2] = 0
        #
        # det_box3 = clip_det_bbs2[np.argmax(clip_det_bbs2[:, 2]), :]
        det_box2 = clip_det_bbs2[0, :]
        det_box3 = clip_det_bbs2[1, :]

        frame_path2 = os.path.join(opt.data_path_duet, 'duet_extract',
                                   opt.video2_ins, opt.video2_name,
                                   "%06d.png" % det_box2[0])
        frame_2 = Image.open(frame_path2).convert('RGB')
        # frame = Image.open('/data/mashuo/work/study/refine-separation/dataset/music_test/xylophone-acoustic_guitar/0EMNATwzLA4_25/human.png').convert('RGB')
        detection2 = frame_2.crop(
            (det_box2[-4], det_box2[-3], det_box2[-2], det_box2[-1]))
        # detection2 = frame

        detection3 = frame_2.crop(
            (det_box3[-4], det_box3[-3], det_box3[-2], det_box3[-1]))

        #perform separation over the whole audio using a sliding window approach
        overlap_count = np.zeros((audio_length))
        sep_audio1 = np.zeros((audio_length))
        sep_audio2 = np.zeros((audio_length))
        sep_audio3 = np.zeros((audio_length))
        refine_sep1 = np.zeros((audio_length))
        refine_sep2 = np.zeros((audio_length))
        refine_sep3 = np.zeros((audio_length))

        sliding_window_start = 0
        data = {}
        samples_per_window = opt.audio_window
        while sliding_window_start + samples_per_window < audio_length:
            objects_visuals = []
            objects_labels = []
            objects_audio_mag = []
            objects_audio_phase = []
            objects_vids = []
            objects_real_audio_mag = []
            objects_audio_mix_mag = []
            objects_audio_mix_phase = []
            objects_visuals_256 = []

            sliding_window_end = sliding_window_start + samples_per_window
            audio_segment = audio_mix[sliding_window_start:sliding_window_end]
            audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
                audio_segment, opt.stft_frame, opt.stft_hop)
            ''' 第二份音乐的信息'''
            objects_audio_mix_mag.append(
                torch.FloatTensor(audio_mix_mags).unsqueeze(0))
            objects_audio_mix_phase.append(
                torch.FloatTensor(audio_mix_phases).unsqueeze(0))
            objects_visuals.append(vision_transform(detection2).unsqueeze(0))
            objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
            objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
            ''' 第3份音乐的信息'''
            objects_audio_mix_mag.append(
                torch.FloatTensor(audio_mix_mags).unsqueeze(0))
            objects_audio_mix_phase.append(
                torch.FloatTensor(audio_mix_phases).unsqueeze(0))
            objects_visuals.append(vision_transform(detection3).unsqueeze(0))
            objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
            objects_vids.append(torch.FloatTensor(np.ones((1, 1))))

            data['audio_mix_mags'] = torch.FloatTensor(
                np.vstack(objects_audio_mix_mag)).cuda()
            data['audio_mags'] = data['audio_mix_mags']
            data['audio_mix_phases'] = torch.FloatTensor(
                np.vstack(objects_audio_mix_phase)).cuda()
            data['visuals'] = torch.FloatTensor(
                np.vstack(objects_visuals)).cuda()
            data['labels'] = torch.FloatTensor(
                np.vstack(objects_labels)).cuda()
            data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda()

            outputs = model.forward(data)

            reconstructed_signal, refine_signal = get_separated_audio(
                outputs, data, opt)

            sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[
                sliding_window_start:
                sliding_window_end] + reconstructed_signal[0]
            refine_sep2[sliding_window_start:sliding_window_end] = refine_sep2[
                sliding_window_start:sliding_window_end] + refine_signal[0]

            sep_audio3[sliding_window_start:sliding_window_end] = sep_audio3[
                sliding_window_start:
                sliding_window_end] + reconstructed_signal[1]
            refine_sep3[sliding_window_start:sliding_window_end] = refine_sep3[
                sliding_window_start:sliding_window_end] + refine_signal[1]
            #update overlap count
            overlap_count[
                sliding_window_start:sliding_window_end] = overlap_count[
                    sliding_window_start:sliding_window_end] + 1
            sliding_window_start = sliding_window_start + int(
                opt.hop_size * opt.audio_sampling_rate)

        # deal with the last segment
        audio_segment = audio_mix[-samples_per_window:]
        audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase(
            audio_segment, opt.stft_frame, opt.stft_hop)

        objects_visuals = []
        objects_labels = []
        objects_audio_mag = []
        objects_audio_phase = []
        objects_vids = []
        objects_real_audio_mag = []
        objects_audio_mix_mag = []
        objects_audio_mix_phase = []
        objects_visuals_256 = []
        ''' 第二份音乐的信息'''
        objects_audio_mix_mag.append(
            torch.FloatTensor(audio_mix_mags).unsqueeze(0))
        objects_audio_mix_phase.append(
            torch.FloatTensor(audio_mix_phases).unsqueeze(0))
        objects_visuals.append(vision_transform(detection2).unsqueeze(0))
        objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
        objects_vids.append(torch.FloatTensor(np.ones((1, 1))))
        ''' 第3份音乐的信息'''
        objects_audio_mix_mag.append(
            torch.FloatTensor(audio_mix_mags).unsqueeze(0))
        objects_audio_mix_phase.append(
            torch.FloatTensor(audio_mix_phases).unsqueeze(0))
        objects_visuals.append(vision_transform(detection3).unsqueeze(0))
        objects_labels.append(torch.FloatTensor(np.ones((1, 1))))
        objects_vids.append(torch.FloatTensor(np.ones((1, 1))))

        data['audio_mix_mags'] = torch.FloatTensor(
            np.vstack(objects_audio_mix_mag)).cuda()
        data['audio_mags'] = data['audio_mix_mags']
        data['audio_mix_phases'] = torch.FloatTensor(
            np.vstack(objects_audio_mix_phase)).cuda()
        data['visuals'] = torch.FloatTensor(np.vstack(objects_visuals)).cuda()
        data['labels'] = torch.FloatTensor(np.vstack(objects_labels)).cuda()
        data['vids'] = torch.FloatTensor(np.vstack(objects_vids)).cuda()

        outputs = model.forward(data)

        reconstructed_signal, refine_signal = get_separated_audio(
            outputs, data, opt)

        sep_audio2[-samples_per_window:] = sep_audio2[
            -samples_per_window:] + reconstructed_signal[0]
        refine_sep2[-samples_per_window:] = refine_sep2[
            -samples_per_window:] + refine_signal[0]

        sep_audio3[-samples_per_window:] = sep_audio3[
            -samples_per_window:] + reconstructed_signal[1]
        refine_sep3[-samples_per_window:] = refine_sep3[
            -samples_per_window:] + refine_signal[1]

        #update overlap count
        overlap_count[
            -samples_per_window:] = overlap_count[-samples_per_window:] + 1

        #divide the aggregated predicted audio by the overlap count
        avged_sep_audio2 = avged_sep_audio2 + clip_audio(
            np.divide(sep_audio2, overlap_count) * 2)
        avged_refine_audio2 = avged_refine_audio2 + clip_audio(
            np.divide(refine_sep2, overlap_count) * 2)
        avged_sep_audio3 = avged_sep_audio3 + clip_audio(
            np.divide(sep_audio3, overlap_count) * 2)
        avged_refine_audio3 = avged_refine_audio3 + clip_audio(
            np.divide(refine_sep3, overlap_count) * 2)

    separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use
    separation3 = avged_sep_audio3 / opt.num_of_object_detections_to_use
    refine_spearation2 = avged_refine_audio2 / opt.num_of_object_detections_to_use
    refine_spearation3 = avged_refine_audio3 / opt.num_of_object_detections_to_use

    #output original and separated audios
    output_dir = os.path.join(output_dir, opt.video2_name + '**')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if save_files:
        librosa.output.write_wav(os.path.join(output_dir, 'audio_duet.wav'),
                                 audio2, opt.audio_sampling_rate)
        librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'),
                                 audio_mix, opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio2_separated.wav'), separation2,
            opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio3_separated.wav'), separation3,
            opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio2_refine_separated.wav'),
            refine_spearation2, opt.audio_sampling_rate)
        librosa.output.write_wav(
            os.path.join(output_dir, 'audio3_refine_separated.wav'),
            refine_spearation3, opt.audio_sampling_rate)

    c_reference_sources = np.expand_dims(audio2, axis=0)
    c_estimated_sources = np.expand_dims((separation2 + separation3), axis=0)

    c_sdr, c_sir, c_sar = getSeparationMetrics(c_reference_sources,
                                               c_estimated_sources)

    r_reference_sources = np.expand_dims(audio2, axis=0)
    r_estimated_sources = np.expand_dims(
        (refine_spearation2 + refine_spearation3), axis=0)

    r_sdr, r_sir, r_sar = getSeparationMetrics(r_reference_sources,
                                               r_estimated_sources)

    #save the two detections
    if save_files:
        frame_2.save(os.path.join(output_dir, 'frame_2.png'))
        detection2.save(os.path.join(output_dir, 'det2.png'))
        detection3.save(os.path.join(output_dir, 'det3.png'))
    #save the spectrograms & masks
    if opt.visualize_spectrogram:
        import matplotlib.pyplot as plt
        plt.switch_backend('agg')
        plt.ioff()
        audio2_mag = generate_spectrogram_magphase(audio2,
                                                   opt.stft_frame,
                                                   opt.stft_hop,
                                                   with_phase=False)

        audio_mix_mag = generate_spectrogram_magphase(audio_mix,
                                                      opt.stft_frame,
                                                      opt.stft_hop,
                                                      with_phase=False)
        separation2_mag = generate_spectrogram_magphase(separation2,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        separation3_mag = generate_spectrogram_magphase(separation3,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        refine_sep2_mag = generate_spectrogram_magphase(refine_spearation2,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        refine_sep3_mag = generate_spectrogram_magphase(refine_spearation3,
                                                        opt.stft_frame,
                                                        opt.stft_hop,
                                                        with_phase=False)
        # ref_2_3_mag = generate_spectrogram_magphase(refine_spearation1+refine_spearation2,  opt.stft_frame, opt.stft_hop, with_phase=False)
        utils.visualizeSpectrogram(audio2_mag[0, :, :],
                                   os.path.join(output_dir, 'audio2_spec.png'))
        utils.visualizeSpectrogram(
            audio_mix_mag[0, :, :],
            os.path.join(output_dir, 'audio_mixed_spec.png'))
        utils.visualizeSpectrogram(
            separation2_mag[0, :, :],
            os.path.join(output_dir, 'separation2_spec.png'))
        utils.visualizeSpectrogram(
            separation3_mag[0, :, :],
            os.path.join(output_dir, 'separation3_spec.png'))
        utils.visualizeSpectrogram(
            separation2_mag[0, :, :] + separation3_mag[0, :, :],
            os.path.join(output_dir, 'separation2+3_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep2_mag[0, :, :],
            os.path.join(output_dir, 'refine2_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep3_mag[0, :, :],
            os.path.join(output_dir, 'refine3_spec.png'))
        utils.visualizeSpectrogram(
            refine_sep2_mag[0, :, :] + refine_sep3_mag[0, :, :],
            os.path.join(output_dir, 'ref_2+3_spec.png'))
    return c_sdr, c_sir, c_sar, r_sdr, r_sir, r_sar