Esempio n. 1
0
    def __init__(self, config):            
        self.generator = atnet()
        self.l1_loss_fn =  nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()
        self.config = config

        if config.cuda:
            self.generator     = self.generator.cuda()
            self.mse_loss_fn   = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
# #########single GPU#######################

#         if config.cuda:
#             device_ids = [int(i) for i in config.device_ids.split(',')]
#             self.generator     = self.generator.cuda()
#             self.encoder = self.encoder.cuda()
#             self.mse_loss_fn   = self.mse_loss_fn.cuda()
#             self.l1_loss_fn =  nn.L1Loss().cuda()
        initialize_weights(self.generator)
        self.start_epoch = 0
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
        self.opt_g = torch.optim.Adam( self.generator.parameters(),
            lr=config.lr, betas=(config.beta1, config.beta2))
#         self.opt_g = torch.optim.SGD(self.generator.parameters(), lr=config.lr)  
        self.dataset = Voxceleb_audio_lmark_lstm('/data2/lchen63/voxceleb/txt/', 'train')


        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True, drop_last=True)
Esempio n. 2
0
    def __init__(self, config):
        if config.pca:
            if not config.seperate:
                from ATVG import AT_single as atnet
            else:
                from ATVG import AT_single2 as atnet
        else:
            from ATVG import AT_single2_no_pca as atnet
        self.generator = atnet()
        self.l1_loss_fn = nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()
        self.config = config

        if config.cuda:
            self.generator = self.generator.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
# #########single GPU#######################

#         if config.cuda:
#             device_ids = [int(i) for i in config.device_ids.split(',')]
#             self.generator     = self.generator.cuda()
#             self.encoder = self.encoder.cuda()
#             self.mse_loss_fn   = self.mse_loss_fn.cuda()
#             self.l1_loss_fn =  nn.L1Loss().cuda()
        initialize_weights(self.generator)
        self.start_epoch = 0
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
        self.opt_g = torch.optim.Adam(self.generator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        #         self.opt_g = torch.optim.SGD(self.generator.parameters(), lr=config.lr)
        if config.dataset_name == 'vox':

            self.dataset = Voxceleb_audio_lmark('/data2/lchen63/voxceleb/txt/',
                                                'train')
        elif config.dataset_name == 'grid':
            if not config.seperate:
                from dataset import Grid_audio_lmark_single_whole as Grid_audio_lmark
                self.dataset = Grid_audio_lmark('/data/lchen63/grid/zip/txt/',
                                                'train')
            else:
                self.dataset = Grid_audio_lmark('/data/lchen63/grid/zip/txt/',
                                                'train')
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
Esempio n. 3
0
def test( config):            
        generator = atnet()
        l1_loss_fn =  nn.L1Loss()
        mse_loss_fn = nn.MSELoss()
        config = config

        if config.cuda:
            generator     = generator.cuda()
            mse_loss_fn   = mse_loss_fn.cuda()

        initialize_weights(generator)
        if config.dataset_name == 'vox':
            dataset =  audio_lmark('/home/cxu-serve/p1/lchen63/voxceleb/txt/', config.is_train)
        else:
            dataset =  audio_lmark('/home/cxu-serve/p1/lchen63/voxceleb/txt/', config.is_train)
        data_loader = DataLoader(dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=False, drop_last=True)
        if config.dataset_name == 'vox':
            pca = torch.FloatTensor( np.load('./basics/U_front_smooth_vox.npy')[:,:6]).cuda()
            mean =torch.FloatTensor( np.load('./basics/mean_front_smooth_vox.npy')).cuda()    
        elif config.dataset_name == 'grid':
            face_pca = torch.FloatTensor( np.load('./basics/U_grid_roni.npy')[:,:6]).cuda()
            face_mean =torch.FloatTensor( np.load('./basics/mean_grid_roni.npy')).cuda()
            lip_pca = torch.FloatTensor( np.load('./basics/U_grid_lip.npy')[:,:6]).cuda()
            lip_mean =torch.FloatTensor( np.load('./basics/mean_grid_lip.npy')).cuda()
            lip_std = torch.FloatTensor( np.load('./basics/std_grid_lip.npy')).cuda()
            face_std = torch.FloatTensor( np.load('./basics/std_grid_roni.npy')).cuda()

        num_steps_per_epoch = len(data_loader)
        cc = 0
        t0 = time.time()
        
        generator.load_state_dict( torch.load(config.model_dir))
        generator.eval()

        logger = Logger(config.log_dir)
        total_mse  = []
        total_openrate_mse = []
        
        std_total_mse  = []
        std_total_openrate_mse = []
        for step, data in enumerate(data_loader):
            print (step)
            if step == 100 :
                break
            t1 = time.time()
            sample_audio =data['audio']
            sample_lmark =  data['sample_lmark'] 
            ex_lmark = data['ex_lmark']  
            
            if config.dataset_name == 'vox':
                sample_rt = data['sample_rt'].numpy()



            if config.cuda:
                sample_audio    = Variable(sample_audio.float()).cuda()
                sample_lmark = sample_lmark.float().cuda()
                ex_lmark = ex_lmark.float().cuda()
            

            sample_lmark_pca = (sample_lmark- mean.expand_as(sample_lmark))
            
            if config.pca:
                sample_lmark_pca = torch.mm(sample_lmark_pca,  pca)

            ex_lmark_pca = (ex_lmark- mean.expand_as(ex_lmark))
            if config.pca:
                ex_lmark_pca = torch.mm(ex_lmark_pca,  pca)

            sample_lmark_pca = Variable(sample_lmark_pca)
            ex_lmark_pca = Variable(ex_lmark_pca)


            fake_lmark = generator(sample_audio,ex_lmark_pca)              
            
            
            loss  = mse_loss_fn(fake_lmark , sample_lmark_pca)
           

            logger.scalar_summary('loss', loss,   step+1)
                
            if config.pca:        
                fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 6)
                fake_lmark = torch.mm( fake_lmark, pca.t() ) 
            else:
                fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 204)
            fake_lmark +=  mean.expand_as(fake_lmark)
           

            fake_lmark = fake_lmark.data.cpu().numpy()
            
            if config.pca:        
                real_lmark =  sample_lmark_pca.view(sample_lmark_pca.size(0)  , 6)
                real_lmark = torch.mm( real_lmark, pca.t() ) 
            else:
                real_lmark =  real_lmark.view(real_lmark.size(0)  , 204)
            real_lmark +=  mean.expand_as(real_lmark)
                       
            

            real_lmark = real_lmark.data.cpu().numpy()
            
            if not os.path.exists( os.path.join(config.sample_dir, str(step)) ):
                os.mkdir(os.path.join(config.sample_dir, str(step)))

            for gg in range(int(config.batch_size)):
                
                if config.dataset_name == 'vox':
                #convert lmark to rted landmark
                    fake_A3 = utils.reverse_rt(fake_lmark[gg].reshape((68,3)), sample_rt[gg])
                    #convert lmark to rted landmark
                    A3 = utils.reverse_rt(real_lmark[gg].reshape((68,3)), sample_rt[gg])
                if config.visualize:
                    if config.dataset_name == 'vox':
                        v_path = os.path.join('/home/cxu-serve/p1/lchen63/voxceleb/unzip',data['img_path'][gg]  + '.mp4') 
                        cap = cv2.VideoCapture(v_path)
                        for t in range(real_lmark.shape[0]):
                            ret, frame = cap.read()
                            if ret :
                                if t == int(data['sample_id'][gg]) :
                                    gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                    break
                    else:
                         gt_img = cv2.cvtColor(cv2.imread(data['img_path'][gg]), cv2.COLOR_BGR2RGB)   
                    lmark_name  = "{}/fake_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                    plt = utils.lmark2img(fake_lmark[gg].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)
                    
                    if config.dataset_name == 'vox':
                        lmark_name  = "{}/fake_rt_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                        plt = utils.lmark2img(fake_A3.reshape((68,3)), gt_img)
                        plt.savefig(lmark_name)
                    
                    lmark_name  = "{}/real_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                    plt = utils.lmark2img(real_lmark[gg].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)
                
                    if config.dataset_name == 'vox':
                        lmark_name  = "{}/real_rt_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                        plt = utils.lmark2img(A3.reshape((68,3)), gt_img)
                        plt.savefig(lmark_name)
                if config.dataset_name == 'vox':
                    mse = utils.mse_metrix(A3, fake_A3)
                    openrate = utils.openrate_metrix(A3, fake_A3)
                    if mse > 50 or openrate > 50:
                        continue
                    total_openrate_mse.append(openrate)
                    total_mse.append(mse)
                
                std_mse = utils.mse_metrix(real_lmark[gg].reshape((68,3)), fake_lmark[gg].reshape((68,3)))
                std_openrate = utils.openrate_metrix(real_lmark[gg].reshape((68,3)), fake_lmark[gg].reshape((68,3)))
                std_total_openrate_mse.append(std_openrate)
                std_total_mse.append(std_mse)
                
                
                #compute different evaluation matrix  input: (68,3) real A3, fake A3 
        if config.dataset_name == 'vox':
            total_mse = np.asarray(total_mse)
            total_openrate_mse = np.asarray(total_openrate_mse)
        
        std_total_mse = np.asarray(std_total_mse)
        std_total_openrate_mse = np.asarray(std_total_openrate_mse)
        
        
        if config.dataset_name == 'vox':
            print (total_mse.mean())
            print (total_openrate_mse.mean())
        print (std_total_mse.mean())
        print (std_total_openrate_mse.mean())
Esempio n. 4
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    
    if os.path.exists('./tmp'):
        shutil.rmtree('./tmp')
    os.mkdir('./tmp')  
    
    if config.dataset_name == 'vox':
        pca = torch.FloatTensor( np.load('./basics/U_front_smooth_vox.npy')[:,:6]).cuda()
        mean =torch.FloatTensor( np.load('./basics/mean_front_smooth_vox.npy')).cuda() 


       
    elif config.dataset_name == 'grid':
        face_pca = torch.FloatTensor( np.load('./basics/U_grid_roni.npy')[:,:6]).cuda()
        face_mean =torch.FloatTensor( np.load('./basics/mean_grid_roni.npy')).cuda()
        lip_pca = torch.FloatTensor( np.load('./basics/U_grid_lip.npy')[:,:6]).cuda()
        lip_mean =torch.FloatTensor( np.load('./basics/mean_grid_lip.npy')).cuda()

        lip_std = torch.FloatTensor( np.load('./basics/std_grid_lip.npy')).cuda()
        face_std = torch.FloatTensor( np.load('./basics/std_grid_roni.npy')).cuda()

    #change frame rate to 25FPS
    command = 'ffmpeg -i ' +  config.video +   ' -r 25 -y  ' + config.video
    
#     os.system(command)

    #extract audio from video
    _extract_audio(config.video)

#     # crop video into correct ratio
    _crop_video(config.video)
    config.video = config.video[:-4] + '_crop.mp4'
    
# #     genrate ground truth
    _video2img2lmark(config.video)
    
# #     #compute_RT and get front view ground truth
    compute_RT(config.video)
    
#     command = 'cp ./tmp/0141.png ./image/test.png'
#     os.system(command)
    
#     #extract exmaple landmark from single image
    
    
    RTs = np.load(config.video[:-4] +  '_sRT.npy' )
    gt_front = np.load(config.video[:-4] +  '_front.npy')
    
    
    get3DLmarks_single_image(config.person)
    example_lmark = compute_RT_single(config.person)
    
    example_lmark =torch.FloatTensor(example_lmark).view(1,-1).cuda()

    example_lmark_pca = (example_lmark - mean.expand_as(example_lmark))

    example_lmark_pca = torch.mm(example_lmark_pca, pca)

    
    example_lmark_pca = Variable(example_lmark_pca)
    
      
    if os.path.exists('./lmark'):
        shutil.rmtree('./lmark')
    os.mkdir('./lmark')
    if not os.path.exists('./lmark/real'):
        os.mkdir('./lmark/real')
    if not os.path.exists('./lmark/fake'):
        os.mkdir('./lmark/fake')
    if not os.path.exists('./lmark/fake_rt'):
        os.mkdir('./lmark/fake_rt')
    
    encoder = atnet()
    if config.cuda:
        encoder = encoder.cuda()

#     state_dict = multi2single(config.at_model, 0)
    encoder.load_state_dict( torch.load(config.at_model))
    encoder.eval()
    test_file = config.in_file
 
    # Load speech and extract features
    speech, sr = librosa.load(test_file, sr=16000)
    mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)
#     speech = np.insert(speech, 0, np.zeros(1920))
#     speech = np.append(speech, np.zeros(1920))
    sound, _ = librosa.load(test_file, sr=44100)

    print ('=======================================')
    print ('Start to generate images')
    t =time.time()
    ind = 3
    with torch.no_grad(): 
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0]/4) - 4:
            t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc).cuda()
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc,dim = 0)
      
        fake_lmark = encoder(input_mfcc, example_lmark_pca.repeat(input_mfcc.shape[0] ,1) )
        if config.pca:
            fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 6)
            fake_lmark = torch.mm( fake_lmark, pca.t() ) 
        else:
            fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 204)
        fake_lmark += mean.expand_as(fake_lmark)
        
        fake_lmark = fake_lmark.data.cpu().numpy()
        
        
        

        for gg in range(fake_lmark.shape[0]):
            backgroung = cv2.imread('./tmp/%04d.png'%(gg+ 1))
            backgroung= cv2.cvtColor(backgroung, cv2.COLOR_BGR2RGB) 
            lmark_name  = "./lmark/fake/%05d.png"%gg
            plt = utils.lmark2img(fake_lmark[gg].reshape((68,3)), backgroung, c = 'b')
            plt.savefig(lmark_name)
            A3 = utils.reverse_rt(fake_lmark[gg].reshape((68,3)), RTs[gg])
            lmark_name  = "./lmark/fake_rt/%05d.png"%gg
            plt = utils.lmark2img(A3.reshape((68,3)), backgroung, c = 'b')
            plt.savefig(lmark_name)
            
        video_name = os.path.join(config.sample_dir , 'fake.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/fake'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'fake.mov')))
#         
        video_name = os.path.join(config.sample_dir , 'fake_rt.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/fake_rt'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'fake_rt.mov')))
        

        for gg in range(gt_front.shape[0]):
            backgroung = cv2.imread('./tmp/%04d.png'%(gg+ 1))
            backgroung= cv2.cvtColor(backgroung, cv2.COLOR_BGR2RGB) 
            lmark_name  = "./lmark/real/%05d.png"%gg
            plt = utils.lmark2img(gt_front[gg].reshape((68,3)), backgroung, c = 'b')
            plt.savefig(lmark_name)
            
            reversed_lmark = utils.reverse_rt(gt_front[gg], RTs[gg])
            if not os.path.exists('./lmark/original'):
                os.mkdir ('./lmark/original')
            lmark_name  = "./lmark/original/%05d.png"%gg
            plt = utils.lmark2img(reversed_lmark.reshape((68,3)), backgroung, c = 'b')
            plt.savefig(lmark_name)
            
        video_name = os.path.join(config.sample_dir , 'real.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/real'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'real.mov')))
        
        video_name = os.path.join(config.sample_dir , 'orig.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/original'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'orig.mov')))
Esempio n. 5
0
def test(config):
    generator = atnet()
    l1_loss_fn = nn.L1Loss()
    mse_loss_fn = nn.MSELoss()
    config = config
    time_length = 32

    if config.cuda:
        generator = generator.cuda()
        mse_loss_fn = mse_loss_fn.cuda()

    initialize_weights(generator)

    dataset = Voxceleb_audio_lmark('/data2/lchen63/voxceleb/txt/',
                                   config.is_train)

    data_loader = DataLoader(dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=False,
                             drop_last=True)

    face_pca = torch.FloatTensor(
        np.load('./basics/U_front_roni.npy')[:, :6]).cuda()
    face_mean = torch.FloatTensor(
        np.load('./basics/mean_front_roni.npy')).cuda()
    lip_pca = torch.FloatTensor(np.load('./basics/U_front.npy')[:, :6]).cuda()
    lip_mean = torch.FloatTensor(np.load('./basics/mean_front.npy')).cuda()
    lip_std = torch.FloatTensor(np.load('./basics/std_front.npy')).cuda()
    face_std = torch.FloatTensor(np.load('./basics/std_front_roni.npy')).cuda()

    num_steps_per_epoch = len(data_loader)
    cc = 0
    t0 = time.time()

    generator.load_state_dict(torch.load(config.model_dir))
    generator.eval()

    logger = Logger(config.log_dir)
    total_mse = []
    total_openrate_mse = []

    std_total_mse = []
    std_total_openrate_mse = []
    for step, data in enumerate(data_loader):
        print(step)
        if step == 100:
            break
        t1 = time.time()
        sample_audio = data['audio']
        lip_region = data['lip_region']
        other_region = data['other_region']
        ex_other_region = data['ex_other_region']
        ex_lip_region = data['ex_lip_region']

        sample_rt = data['sample_rt'].numpy()

        if config.cuda:
            sample_audio = Variable(sample_audio.float()).cuda()
            lip_region = lip_region.float().cuda()
            other_region = other_region.float().cuda()
            ex_other_region = ex_other_region.float().cuda()
            ex_lip_region = ex_lip_region.float().cuda()
        else:
            sample_audio = Variable(sample_audio.float())
            lip_region = Variable(lip_region.float())
            other_region = Variable(other_region.float())
            ex_other_region = Variable(ex_other_region.float())
            ex_lip_region = ex_lip_region.float()

        lip_region_pca = (lip_region -
                          lip_mean.expand_as(lip_region)) / lip_std
        lip_region_pca = lip_region_pca.view(
            lip_region_pca.shape[0] * lip_region_pca.shape[1], -1)
        lip_region_pca = torch.mm(lip_region_pca, lip_pca)

        lip_region_pca = lip_region_pca.view(config.batch_size, time_length, 6)

        #                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)
        ex_lip_region_pca = (ex_lip_region -
                             lip_mean.expand_as(ex_lip_region)) / lip_std

        #                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)

        ex_lip_region_pca = torch.mm(ex_lip_region_pca, lip_pca)

        ex_other_region_pca = (ex_other_region -
                               face_mean.expand_as(ex_other_region)) / face_std
        ex_other_region_pca = torch.mm(ex_other_region_pca, face_pca)

        other_region_pca = (other_region -
                            face_mean.expand_as(other_region)) / face_std
        other_region_pca = other_region_pca.view(
            other_region_pca.shape[0] * other_region_pca.shape[1], -1)
        other_region_pca = torch.mm(other_region_pca, face_pca)

        other_region_pca = other_region_pca.view(config.batch_size,
                                                 time_length, 6)

        ex_other_region_pca = Variable(ex_other_region_pca)
        lip_region_pca = Variable(lip_region_pca)
        other_region_pca = Variable(other_region_pca)

        fake_lip, fake_face = generator(sample_audio, ex_other_region_pca,
                                        ex_lip_region_pca)

        loss_lip = mse_loss_fn(fake_lip, lip_region_pca)
        loss_face = mse_loss_fn(fake_face, other_region_pca)

        logger.scalar_summary('loss_lip', loss_lip, step + 1)
        logger.scalar_summary('loss_face', loss_face, step + 1)

        fake_lip = fake_lip.view(fake_lip.size(0) * time_length, 6)

        fake_face = fake_face.view(fake_face.size(0) * time_length, 6)

        lip_region_pca = lip_region_pca.view(
            lip_region_pca.size(0) * time_length, 6)

        other_region_pca = other_region_pca.view(
            other_region_pca.size(0) * time_length, 6)

        if config.pca:
            fake_lip = fake_lip.view(fake_lip.size(0), 6)
            fake_lip = torch.mm(fake_lip, lip_pca.t())
        else:
            fake_lip = fake_lip.view(fake_lip.size(0), 60)
        fake_lip = fake_lip * lip_std + lip_mean.expand_as(fake_lip)
        if config.pca:
            fake_face = fake_face.view(fake_face.size(0), 6)
            fake_face = torch.mm(fake_face, face_pca.t())
        else:
            fake_face = fake_face.view(fake_face.size(0), 144)
        fake_face = fake_face * face_std + face_mean.expand_as(fake_face)

        fake_lmark = torch.cat([fake_face, fake_lip], 1)

        fake_lmark = fake_lmark.data.cpu().numpy()

        if config.pca:
            real_lip = lip_region_pca.view(lip_region_pca.size(0), 6)
            real_lip = torch.mm(real_lip, lip_pca.t())
        else:
            real_lip = lip_region_pca.view(lip_region_pca.size(0), 60)
        real_lip = real_lip * lip_std + lip_mean.expand_as(real_lip)
        if config.pca:
            real_face = other_region_pca.view(other_region_pca.size(0), 6)
            real_face = torch.mm(real_face, face_pca.t())
        else:
            real_face = other_region_pca.view(other_region_pca.size(0), 144)
        real_face = real_face * face_std + face_mean.expand_as(real_face)

        real_lmark = torch.cat([real_face, real_lip], 1)

        real_lmark = real_lmark.data.cpu().numpy()

        if not os.path.exists(os.path.join(config.sample_dir, str(step))):
            os.mkdir(os.path.join(config.sample_dir, str(step)))

        for gg in range(int(config.batch_size)):

            #convert lmark to rted landmark
            fake_A3 = utils.reverse_rt(fake_lmark[gg].reshape((68, 3)),
                                       sample_rt[gg])
            #convert lmark to rted landmark
            A3 = utils.reverse_rt(real_lmark[gg].reshape((68, 3)),
                                  sample_rt[gg])
            if config.visualize:
                lmark_name = "{}/fake_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                v_path = os.path.join('/data2/lchen63/voxceleb/unzip',
                                      data['img_path'][gg] + '.mp4')
                cap = cv2.VideoCapture(v_path)
                for t in range(real_lmark.shape[0]):
                    ret, frame = cap.read()
                    if ret:
                        if t == int(data['sample_id'][gg]):
                            gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            break

                plt = utils.lmark2img(fake_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/fake_rt_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(fake_A3.reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/real_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(real_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/real_rt_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(A3.reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

            mse = utils.mse_metrix(A3, fake_A3)
            openrate = utils.openrate_metrix(A3, fake_A3)
            if mse > 50 or openrate > 50:
                continue
            total_openrate_mse.append(openrate)
            total_mse.append(mse)

            std_mse = utils.mse_metrix(real_lmark[gg].reshape((68, 3)),
                                       fake_lmark[gg].reshape((68, 3)))
            std_openrate = utils.openrate_metrix(
                real_lmark[gg].reshape((68, 3)), fake_lmark[gg].reshape(
                    (68, 3)))
            std_total_openrate_mse.append(std_openrate)
            std_total_mse.append(std_mse)

            print('{}/{}: mse: {}:openrate: {}'.format(step, gg, mse,
                                                       openrate))

            #compute different evaluation matrix  input: (68,3) real A3, fake A3
    total_mse = np.asarray(total_mse)
    total_openrate_mse = np.asarray(total_openrate_mse)

    std_total_mse = np.asarray(std_total_mse)
    std_total_openrate_mse = np.asarray(std_total_openrate_mse)

    print(total_mse.mean())
    print(total_openrate_mse.mean())
    print(std_total_mse.mean())
    print(std_total_openrate_mse.mean())