Ejemplo n.º 1
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    generator = VG_net()
    if config.cuda:
        generator = generator.cuda()

    state_dict = multi2single(config.model_name, 1)
    generator.load_state_dict(state_dict)
    print('load pretrained [{}]'.format(config.model_name))
    generator.eval()
    dataset = LRWdataset(config.dataset_dir, train='test')
    data_loader = DataLoader(dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=True,
                             drop_last=True)
    data_iter = iter(data_loader)
    data_iter.next()

    if not os.path.exists(config.sample_dir):
        os.mkdir(config.sample_dir)
    if not os.path.exists(os.path.join(config.sample_dir, 'fake')):
        os.mkdir(os.path.join(config.sample_dir, 'fake'))
    if not os.path.exists(os.path.join(config.sample_dir, 'real')):
        os.mkdir(os.path.join(config.sample_dir, 'real'))
    if not os.path.exists(os.path.join(config.sample_dir, 'color')):
        os.mkdir(os.path.join(config.sample_dir, 'color'))
    if not os.path.exists(os.path.join(config.sample_dir, 'single')):
        os.mkdir(os.path.join(config.sample_dir, 'single'))
    generator.eval()
    for step, (example_img, example_landmark, real_im,
               landmarks) in enumerate(data_loader):
        with torch.no_grad():
            if step == 43:
                break
            if config.cuda:
                real_im = Variable(real_im.float()).cuda()
                example_img = Variable(example_img.float()).cuda()
                landmarks = Variable(landmarks.float()).cuda()
                example_landmark = Variable(example_landmark.float()).cuda()
            fake_im, atts, colors, lmark_att = generator(
                example_img, landmarks, example_landmark)

            fake_store = fake_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            real_store = real_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            atts_store = atts.data.contiguous().view(config.batch_size * 16, 1,
                                                     128, 128)
            colors_store = colors.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            lmark_att = bilinear_interpolate_torch_gridsample(lmark_att)
            lmark_att_store = lmark_att.data.contiguous().view(
                config.batch_size * 16, 1, 128, 128)
            torchvision.utils.save_image(atts_store,
                                         "{}color/att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(lmark_att_store,
                                         "{}color/lmark_att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(colors_store,
                                         "{}color/color_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)
            torchvision.utils.save_image(fake_store,
                                         "{}fake/fake_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            torchvision.utils.save_image(real_store,
                                         "{}real/real_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            fake_store = fake_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()
            real_store = real_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()

            print(step)
            print(fake_store.shape)
            for inx in range(config.batch_size * 16):
                scipy.misc.imsave(
                    "{}single/fake_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), fake_store[inx])
                scipy.misc.imsave(
                    "{}single/real_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), real_store[inx])
Ejemplo n.º 2
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    if os.path.exists('../temp'):
        shutil.rmtree('../temp')
    os.mkdir('../temp')
    os.mkdir('../temp/img')
    os.mkdir('../temp/motion')
    os.mkdir('../temp/attention')
    pca = torch.FloatTensor( np.load('../basics/U_lrw1.npy')[:,:6]).cuda()
    mean =torch.FloatTensor( np.load('../basics/mean_lrw1.npy')).cuda()
    decoder = VG_net()
    encoder = AT_net()
    if config.cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    state_dict2 = multi2single(config.vg_model, 1)

    # state_dict2 = torch.load(config.video_model, map_location=lambda storage, loc: storage)
    decoder.load_state_dict(state_dict2)

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()
    test_file = config.in_file

    example_image, example_landmark = generator_demo_example_lips( config.person)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
     ])        
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
    example_image = transform(example_image)

    example_landmark =  example_landmark.reshape((1,example_landmark.shape[0]* example_landmark.shape[1]))

    if config.cuda:
        example_image = Variable(example_image.view(1,3,128,128)).cuda()
        example_landmark = Variable(torch.FloatTensor(example_landmark.astype(float)) ).cuda()
    else:
        example_image = Variable(example_image.view(1,3,128,128))
        example_landmark = Variable(torch.FloatTensor(example_landmark.astype(float)))
    # Load speech and extract features
    example_landmark = example_landmark * 5.0
    example_landmark  = example_landmark - mean.expand_as(example_landmark)
    example_landmark = torch.mm(example_landmark,  pca)
    speech, sr = librosa.load(test_file, sr=16000)
    mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)
    speech = np.insert(speech, 0, np.zeros(1920))
    speech = np.append(speech, np.zeros(1920))
    mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)

    sound, _ = librosa.load(test_file, sr=44100)

    print ('=======================================')
    print ('Start to generate images')
    t =time.time()
    ind = 3
    with torch.no_grad(): 
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0]/4) - 4:
            t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc).cuda()
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc,dim = 0)
        input_mfcc = input_mfcc.unsqueeze(0)
        fake_lmark = encoder(example_landmark, input_mfcc)
        fake_lmark = fake_lmark.view(fake_lmark.size(0) *fake_lmark.size(1) , 6)
        example_landmark  = torch.mm( example_landmark, pca.t() ) 
        example_landmark = example_landmark + mean.expand_as(example_landmark)
        fake_lmark[:, 1:6] *= 2*torch.FloatTensor(np.array([1.1, 1.2, 1.3, 1.4, 1.5])).cuda() 
        fake_lmark = torch.mm( fake_lmark, pca.t() )
        fake_lmark = fake_lmark + mean.expand_as(fake_lmark)
    

        fake_lmark = fake_lmark.unsqueeze(0) 

        fake_ims, atts ,ms ,_ = decoder(example_image, fake_lmark, example_landmark )

        for indx in range(fake_ims.size(1)):
            fake_im = fake_ims[:,indx]
            fake_store = fake_im.permute(0,2,3,1).data.cpu().numpy()[0]
            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'img') ,indx ), fake_store)
            m = ms[:,indx]
            att = atts[:,indx]
            m = m.permute(0,2,3,1).data.cpu().numpy()[0]
            att = att.data.cpu().numpy()[0,0]

            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'motion' ) ,indx ), m)
            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'attention') ,indx ), att)

        print ( 'In total, generate {:d} images, cost time: {:03f} seconds'.format(fake_ims.size(1), time.time() - t) )
            
        fake_lmark = fake_lmark.data.cpu().numpy()
        np.save( os.path.join( config.sample_dir,  'obama_fake.npy'), fake_lmark)
        fake_lmark = np.reshape(fake_lmark, (fake_lmark.shape[1], 68, 2))
        utils.write_video_wpts_wsound(fake_lmark, sound, 44100, config.sample_dir, 'fake', [-1.0, 1.0], [-1.0, 1.0])
        video_name = os.path.join(config.sample_dir , 'results.mp4')
        utils.image_to_video(os.path.join('../', 'temp', 'img'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'results.mov')))
Ejemplo n.º 3
0
def test():
    data_root = '/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/'
    #data_root = '/home/cxu-serve/p1/common/lrs3/lrs3_v0.4/'
    audios = []
    videos = []
    start_ids = []
    end_ids = []
    with open(
            '/home/cxu-serve/p1/common/degree/degree_store/vox/new_extra_data.csv',
            'r') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            #print (row)
            audios.append(row[1])
            videos.append(row[0])
            start_ids.append(int(row[2]))
            end_ids.append(int(row[3]))
            #audios.append(os.path.join(data_root, 'test',tmp[1], tmp[2] + '.wav'))
            #videos.append(os.path.join(data_root, 'test', tmp[1], tmp[2] + '_crop.mp4'))
    #print (gg)
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    if os.path.exists('../temp'):
        shutil.rmtree('../temp')
    os.mkdir('../temp')
    pca = torch.FloatTensor(np.load('../basics/U_lrw1.npy')[:, :6]).cuda()
    mean = torch.FloatTensor(np.load('../basics/mean_lrw1.npy')).cuda()
    decoder = VG_net()
    encoder = AT_net()
    if config.cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    state_dict2 = multi2single(config.vg_model, 1)

    # state_dict2 = torch.load(config.video_model, map_location=lambda storage, loc: storage)
    decoder.load_state_dict(state_dict2)

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()

    # vox
    # # get file paths
    # path = '/home/cxu-serve/p1/common/experiment/vox_good'
    # files = os.listdir(path)
    # data_root = '/home/cxu-serve/p1/common/voxceleb2/unzip/'
    # audios = []
    # videos = []
    # for f in files:
    #     if f[:7] == 'id00817' :
    #         audios.append(os.path.join(data_root, 'test_audio', f[:7], f[8:-14], '{}.wav'.format(f.split('_')[-2])))
    #         videos.append(os.path.join(data_root, 'test_video', f[:7], f[8:-14], '{}_aligned.mp4'.format(f.split('_')[-2])))

    # for i in range(len(audios)):
    #     audio_file = audios[i]
    #     video_file = videos[i]

    #     test_file = audio_file
    #     image_path = video_file
    #     video_name = image_path.split('/')[-3] + '__'  + image_path.split('/')[-2] +'__' + image_path.split('/')[-1][:-4]

    # get file paths
    #path = '/home/cxu-serve/p1/common/other/lrs_good2'
    #files = os.listdir(path)

    #for f in files:
    #    print (f)
    #    if f[:4] !='test':
    #        continue
    #    tmp = f.split('_')#

    # if f[:7] == 'id00817' :

    for i in range(len(audios)):
        try:
            audio_file = audios[i]
            video_file = videos[i]

            test_file = audio_file
            image_path = video_file
            video_name = video_file.split('/')[-1][:-4] + '_' + str(
                start_ids[i])
            #image_path = os.path.join('../image', video_name + '.jpg')
            #print (video_name, image_path)
            #cap = cv2.VideoCapture(video_file)
            #imgs = []
            #count = 0
            #while(cap.isOpened()):
            #    count += 1
            #    ret, frame = cap.read()
            #    if count != 33:
            #        continue
            #   else:
            #        cv2.imwrite(image_path, frame)
            #    try:
            #        example_image, example_landmark = generator_demo_example_lips(image_path)
            #    except:
            #        continue
            #    break

            example_image, example_landmark = generator_demo_example_lips(
                image_path)

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ])
            example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
            example_image = transform(example_image)

            example_landmark = example_landmark.reshape(
                (1, example_landmark.shape[0] * example_landmark.shape[1]))

            if config.cuda:
                example_image = Variable(example_image.view(1, 3, 128,
                                                            128)).cuda()
                example_landmark = Variable(
                    torch.FloatTensor(example_landmark.astype(float))).cuda()
            else:
                example_image = Variable(example_image.view(1, 3, 128, 128))
                example_landmark = Variable(
                    torch.FloatTensor(example_landmark.astype(float)))
        # Load speech and extract features
            example_landmark = example_landmark * 5.0
            example_landmark = example_landmark - mean.expand_as(
                example_landmark)
            example_landmark = torch.mm(example_landmark, pca)
            speech, sr = librosa.load(audio_file, sr=16000)
            mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)
            speech = np.insert(speech, 0, np.zeros(1920))
            speech = np.append(speech, np.zeros(1920))
            mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)

            # print (mfcc.shape)

            #sound, _ = librosa.load(test_file, sr=44100)

            print('=======================================')
            print('Start to generate images')
            t = time.time()
            ind = 3
            with torch.no_grad():
                fake_lmark = []
                input_mfcc = []

                while ind <= int(mfcc.shape[0] / 4) - 4:
                    t_mfcc = mfcc[(ind - 3) * 4:(ind + 4) * 4, 1:]
                    t_mfcc = torch.FloatTensor(t_mfcc).cuda()
                    if ind >= start_ids[i] and ind < end_ids[i]:
                        input_mfcc.append(t_mfcc)
                    ind += 1

                input_mfcc = torch.stack(input_mfcc, dim=0)
                input_mfcc = input_mfcc.unsqueeze(0)
                print(input_mfcc.shape)
                fake_lmark = encoder(example_landmark, input_mfcc)
                fake_lmark = fake_lmark.view(
                    fake_lmark.size(0) * fake_lmark.size(1), 6)
                example_landmark = torch.mm(example_landmark, pca.t())
                example_landmark = example_landmark + mean.expand_as(
                    example_landmark)
                fake_lmark[:, 1:6] *= 2 * torch.FloatTensor(
                    np.array([1.1, 1.2, 1.3, 1.4, 1.5])).cuda()
                fake_lmark = torch.mm(fake_lmark, pca.t())
                fake_lmark = fake_lmark + mean.expand_as(fake_lmark)

                fake_lmark = fake_lmark.unsqueeze(0)

                fake_ims, _, _, _ = decoder(example_image, fake_lmark,
                                            example_landmark)
                os.system('rm ../temp/*')
                for indx in range(fake_ims.size(1)):
                    fake_im = fake_ims[:, indx]
                    fake_store = fake_im.permute(0, 2, 3,
                                                 1).data.cpu().numpy()[0]
                    scipy.misc.imsave(
                        "{}/{:05d}.png".format(os.path.join('../', 'temp'),
                                               indx), fake_store)
                print(time.time() - t)
                fake_lmark = fake_lmark.data.cpu().numpy()
                # os.system('rm ../results/*')
                # np.save( os.path.join( config.sample_dir,  'obama_fake.npy'), fake_lmark)
                # fake_lmark = np.reshape(fake_lmark, (fake_lmark.shape[1], 68, 2))
                # utils.write_video_wpts_wsound(fake_lmark, sound, 44100, config.sample_dir, 'fake', [-1.0, 1.0], [-1.0, 1.0])
                video_name = os.path.join(config.sample_dir, video_name)
                # ffmpeg.input('../temp/*.png', pattern_type='glob', framerate=25).output(video_name).run()

                utils.image_to_video(os.path.join('../', 'temp'),
                                     video_name + '.mp4')
                utils.add_audio(video_name + '.mp4', audio_file)
                print('The generated video is: {}'.format(
                    os.path.join(config.sample_dir, video_name + '.mov')))
        except:
            continue