Пример #1
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    if os.path.exists('./temp'):
        shutil.rmtree('./temp')

    if os.path.exists('./lmark'):
        shutil.rmtree('./lmark')
        os.mkdir('./lmark')
    if not os.path.exists('./lmark/real'):
        os.mkdir('./lmark/real')
    if not os.path.exists('./lmark/fake'):
        os.mkdir('./lmark/fake')
    if not os.path.exists('./lmark/fake_rt'):
        os.mkdir('./lmark/fake_rt')
        os.mkdir('./lmark/fake_rt_3d')
    os.mkdir('./temp')
    os.mkdir('./temp/img')
    pca = torch.FloatTensor(np.load('./basics/U_lrw1.npy')[:, :6]).cuda()
    mean = torch.FloatTensor(np.load('./basics/mean_lrw1.npy')).cuda()

    grid_mean = torch.FloatTensor(np.load('./basics/mean_grid.npy')).cuda()

    grid_std = torch.FloatTensor(np.load('./basics/std_grid.npy')).cuda()
    grid_pca = torch.FloatTensor(np.load('./basics/U_grid.npy')).cuda()

    encoder = AT_net()
    if config.cuda:
        encoder = encoder.cuda()

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()

    test_file = config.in_file

    example_image, example_landmark = generator_demo_example_lips(
        config.person)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
    example_image = transform(example_image)

    example_landmark = example_landmark.reshape(
        (1, example_landmark.shape[0] * example_landmark.shape[1]))

    if config.cuda:
        example_image = Variable(example_image.view(1, 3, 128, 128)).cuda()
        example_landmark = Variable(
            torch.FloatTensor(example_landmark.astype(float))).cuda()
    else:
        example_image = Variable(example_image.view(1, 3, 128, 128))
        example_landmark = Variable(
            torch.FloatTensor(example_landmark.astype(float)))
    # Load speech and extract features
    example_landmark = example_landmark * 5.0
    example_landmark = example_landmark - mean.expand_as(example_landmark)
    example_landmark = torch.mm(example_landmark, pca)
    speech, sr = librosa.load(test_file, sr=16000)
    mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)
    speech = np.insert(speech, 0, np.zeros(1920))
    speech = np.append(speech, np.zeros(1920))
    mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)

    sound, _ = librosa.load(test_file, sr=44100)

    print('=======================================')
    print('Start to generate images')
    t = time.time()
    ind = 3
    RTs = np.load(config.rts)
    gt_front = np.load(config.rts.replace('sRT', 'front'))
    gt = np.load(config.rts.replace('_sRT', ''))
    with torch.no_grad():
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0] / 4) - 4:
            t_mfcc = mfcc[(ind - 3) * 4:(ind + 4) * 4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc).cuda()
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc, dim=0)
        input_mfcc = input_mfcc.unsqueeze(0)
        fake_lmark = encoder(example_landmark, input_mfcc)
        fake_lmark = fake_lmark.view(
            fake_lmark.size(0) * fake_lmark.size(1), 6)
        example_landmark = torch.mm(example_landmark, pca.t())
        example_landmark = example_landmark + mean.expand_as(example_landmark)
        fake_lmark[:, 1:6] *= 2 * torch.FloatTensor(
            np.array([1.1, 1.2, 1.3, 1.4, 1.5])).cuda()
        fake_lmark = torch.mm(fake_lmark, pca.t())
        fake_lmark = fake_lmark + mean.expand_as(fake_lmark)

        print(
            'In total, generate {:d} images, cost time: {:03f} seconds'.format(
                fake_lmark.size(0),
                time.time() - t))

        fake_lmark = fake_lmark.data.cpu().numpy()
        np.save(os.path.join(config.sample_dir, 'obama_fake.npy'), fake_lmark)
        fake_lmark3d = np.zeros((fake_lmark.shape[0], 68, 3))
        fake_lmark = np.reshape(fake_lmark, (fake_lmark.shape[0], 68, 2))
        fake_lmark3d[:, :, :-1] = fake_lmark * 200
        meanshape = np.load('./basics/mean_grid.npy').reshape(68, 3)
        fake_lmark3d[:, :, -1] += meanshape[:, -1]
        fake_lmark = fake_lmark3d

        gt_front = torch.FloatTensor(gt_front.reshape(gt_front.shape[0],
                                                      -1)).cuda()
        gt_front = gt_front - grid_mean.expand_as(gt_front)
        gt_front = gt_front / grid_std
        gt_front = torch.mm(gt_front, grid_pca)

        gt_front = torch.mm(gt_front, grid_pca.t()) * grid_std

        gt_front = gt_front + grid_mean.expand_as(gt_front)

        gt_front = gt_front.cpu().numpy()

        for gg in range(fake_lmark.shape[0]):
            backgroung = cv2.imread('./image/background.png')
            backgroung = cv2.cvtColor(backgroung, cv2.COLOR_BGR2RGB)
            lmark_name = "./lmark/fake/%05d.png" % gg
            plt = utils.lmark2img(fake_lmark[gg].reshape((68, 3)),
                                  backgroung,
                                  c='b')
            plt.savefig(lmark_name)
            A3 = utils.reverse_rt(fake_lmark[gg].reshape((68, 3)), RTs[gg])
            lmark_name = "./lmark/fake_rt/%05d.png" % gg
            plt = utils.lmark2img(A3.reshape((68, 3)), backgroung, c='b')
            plt.savefig(lmark_name)

            lmark_name = "./lmark/fake_rt_3d/%05d.png" % gg
            plt = vis(A3.reshape(68, 3))
            plt.savefig(lmark_name)
        video_name = os.path.join(config.sample_dir, 'fake.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/fake'), video_name)
        utils.add_audio(video_name, config.in_file)
        print('The generated video is: {}'.format(
            os.path.join(config.sample_dir, 'fake.mov')))

        #
        video_name = os.path.join(config.sample_dir, 'fake_rt.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/fake_rt'), video_name)
        utils.add_audio(video_name, config.in_file)
        print('The generated video is: {}'.format(
            os.path.join(config.sample_dir, 'fake_rt.mov')))

        video_name = os.path.join(config.sample_dir, 'fake_rt_3d.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/fake_rt_3d'),
                             video_name)
        utils.add_audio(video_name, config.in_file)
        print('The generated video is: {}'.format(
            os.path.join(config.sample_dir, 'fake_rt_3d.mov')))

        for gg in range(gt_front.shape[0]):
            backgroung = cv2.imread('./image/background.png')
            #             backgroung = cv2.imread('./tmp/%04d.png'%(gg+ 1))
            backgroung = cv2.cvtColor(backgroung, cv2.COLOR_BGR2RGB)
            lmark_name = "./lmark/real/%05d.png" % gg
            plt = utils.lmark2img(gt_front[gg].reshape((68, 3)),
                                  backgroung,
                                  c='b')
            plt.savefig(lmark_name)

            if not os.path.exists('./lmark/original'):
                os.mkdir('./lmark/original')
            lmark_name = "./lmark/original/%05d.png" % gg
            plt = utils.lmark2img(gt[gg].reshape((68, 3)), backgroung, c='b')
            plt.savefig(lmark_name)

        video_name = os.path.join(config.sample_dir, 'real.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/real'), video_name)
        utils.add_audio(video_name, config.in_file)
        print('The generated video is: {}'.format(
            os.path.join(config.sample_dir, 'real.mov')))

        video_name = os.path.join(config.sample_dir, 'orig.mp4')
        utils.image_to_video(os.path.join('./', 'lmark/original'), video_name)
        utils.add_audio(video_name, config.in_file)
        print('The generated video is: {}'.format(
            os.path.join(config.sample_dir, 'orig.mov')))
Пример #2
0
def test( config):            
        generator = atnet()
        l1_loss_fn =  nn.L1Loss()
        mse_loss_fn = nn.MSELoss()
        config = config

        if config.cuda:
            generator     = generator.cuda()
            mse_loss_fn   = mse_loss_fn.cuda()

        initialize_weights(generator)
        if config.dataset_name == 'vox':
            dataset =  audio_lmark('/home/cxu-serve/p1/lchen63/voxceleb/txt/', config.is_train)
        else:
            dataset =  audio_lmark('/home/cxu-serve/p1/lchen63/voxceleb/txt/', config.is_train)
        data_loader = DataLoader(dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=False, drop_last=True)
        if config.dataset_name == 'vox':
            pca = torch.FloatTensor( np.load('./basics/U_front_smooth_vox.npy')[:,:6]).cuda()
            mean =torch.FloatTensor( np.load('./basics/mean_front_smooth_vox.npy')).cuda()    
        elif config.dataset_name == 'grid':
            face_pca = torch.FloatTensor( np.load('./basics/U_grid_roni.npy')[:,:6]).cuda()
            face_mean =torch.FloatTensor( np.load('./basics/mean_grid_roni.npy')).cuda()
            lip_pca = torch.FloatTensor( np.load('./basics/U_grid_lip.npy')[:,:6]).cuda()
            lip_mean =torch.FloatTensor( np.load('./basics/mean_grid_lip.npy')).cuda()
            lip_std = torch.FloatTensor( np.load('./basics/std_grid_lip.npy')).cuda()
            face_std = torch.FloatTensor( np.load('./basics/std_grid_roni.npy')).cuda()

        num_steps_per_epoch = len(data_loader)
        cc = 0
        t0 = time.time()
        
        generator.load_state_dict( torch.load(config.model_dir))
        generator.eval()

        logger = Logger(config.log_dir)
        total_mse  = []
        total_openrate_mse = []
        
        std_total_mse  = []
        std_total_openrate_mse = []
        for step, data in enumerate(data_loader):
            print (step)
            if step == 100 :
                break
            t1 = time.time()
            sample_audio =data['audio']
            sample_lmark =  data['sample_lmark'] 
            ex_lmark = data['ex_lmark']  
            
            if config.dataset_name == 'vox':
                sample_rt = data['sample_rt'].numpy()



            if config.cuda:
                sample_audio    = Variable(sample_audio.float()).cuda()
                sample_lmark = sample_lmark.float().cuda()
                ex_lmark = ex_lmark.float().cuda()
            

            sample_lmark_pca = (sample_lmark- mean.expand_as(sample_lmark))
            
            if config.pca:
                sample_lmark_pca = torch.mm(sample_lmark_pca,  pca)

            ex_lmark_pca = (ex_lmark- mean.expand_as(ex_lmark))
            if config.pca:
                ex_lmark_pca = torch.mm(ex_lmark_pca,  pca)

            sample_lmark_pca = Variable(sample_lmark_pca)
            ex_lmark_pca = Variable(ex_lmark_pca)


            fake_lmark = generator(sample_audio,ex_lmark_pca)              
            
            
            loss  = mse_loss_fn(fake_lmark , sample_lmark_pca)
           

            logger.scalar_summary('loss', loss,   step+1)
                
            if config.pca:        
                fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 6)
                fake_lmark = torch.mm( fake_lmark, pca.t() ) 
            else:
                fake_lmark =  fake_lmark.view(fake_lmark.size(0)  , 204)
            fake_lmark +=  mean.expand_as(fake_lmark)
           

            fake_lmark = fake_lmark.data.cpu().numpy()
            
            if config.pca:        
                real_lmark =  sample_lmark_pca.view(sample_lmark_pca.size(0)  , 6)
                real_lmark = torch.mm( real_lmark, pca.t() ) 
            else:
                real_lmark =  real_lmark.view(real_lmark.size(0)  , 204)
            real_lmark +=  mean.expand_as(real_lmark)
                       
            

            real_lmark = real_lmark.data.cpu().numpy()
            
            if not os.path.exists( os.path.join(config.sample_dir, str(step)) ):
                os.mkdir(os.path.join(config.sample_dir, str(step)))

            for gg in range(int(config.batch_size)):
                
                if config.dataset_name == 'vox':
                #convert lmark to rted landmark
                    fake_A3 = utils.reverse_rt(fake_lmark[gg].reshape((68,3)), sample_rt[gg])
                    #convert lmark to rted landmark
                    A3 = utils.reverse_rt(real_lmark[gg].reshape((68,3)), sample_rt[gg])
                if config.visualize:
                    if config.dataset_name == 'vox':
                        v_path = os.path.join('/home/cxu-serve/p1/lchen63/voxceleb/unzip',data['img_path'][gg]  + '.mp4') 
                        cap = cv2.VideoCapture(v_path)
                        for t in range(real_lmark.shape[0]):
                            ret, frame = cap.read()
                            if ret :
                                if t == int(data['sample_id'][gg]) :
                                    gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                    break
                    else:
                         gt_img = cv2.cvtColor(cv2.imread(data['img_path'][gg]), cv2.COLOR_BGR2RGB)   
                    lmark_name  = "{}/fake_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                    plt = utils.lmark2img(fake_lmark[gg].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)
                    
                    if config.dataset_name == 'vox':
                        lmark_name  = "{}/fake_rt_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                        plt = utils.lmark2img(fake_A3.reshape((68,3)), gt_img)
                        plt.savefig(lmark_name)
                    
                    lmark_name  = "{}/real_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                    plt = utils.lmark2img(real_lmark[gg].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)
                
                    if config.dataset_name == 'vox':
                        lmark_name  = "{}/real_rt_{}.png".format(os.path.join(config.sample_dir, str(step)), gg)
                        plt = utils.lmark2img(A3.reshape((68,3)), gt_img)
                        plt.savefig(lmark_name)
                if config.dataset_name == 'vox':
                    mse = utils.mse_metrix(A3, fake_A3)
                    openrate = utils.openrate_metrix(A3, fake_A3)
                    if mse > 50 or openrate > 50:
                        continue
                    total_openrate_mse.append(openrate)
                    total_mse.append(mse)
                
                std_mse = utils.mse_metrix(real_lmark[gg].reshape((68,3)), fake_lmark[gg].reshape((68,3)))
                std_openrate = utils.openrate_metrix(real_lmark[gg].reshape((68,3)), fake_lmark[gg].reshape((68,3)))
                std_total_openrate_mse.append(std_openrate)
                std_total_mse.append(std_mse)
                
                
                #compute different evaluation matrix  input: (68,3) real A3, fake A3 
        if config.dataset_name == 'vox':
            total_mse = np.asarray(total_mse)
            total_openrate_mse = np.asarray(total_openrate_mse)
        
        std_total_mse = np.asarray(std_total_mse)
        std_total_openrate_mse = np.asarray(std_total_openrate_mse)
        
        
        if config.dataset_name == 'vox':
            print (total_mse.mean())
            print (total_openrate_mse.mean())
        print (std_total_mse.mean())
        print (std_total_openrate_mse.mean())
Пример #3
0
    def fit(self):
        config = self.config
        if config.dataset_name == 'vox':
            face_mean = torch.FloatTensor(
                np.load('./basics/mean_front_roni.npy')).cuda()
            lip_mean = torch.FloatTensor(
                np.load('./basics/mean_front.npy')).cuda()
            face_pca = torch.FloatTensor(
                np.load('./basics/U_front_roni.npy')[:, :6]).cuda()
            lip_pca = torch.FloatTensor(
                np.load('./basics/U_front.npy')[:, :6]).cuda()
            lip_std = torch.FloatTensor(
                np.load('./basics/std_front.npy')).cuda()
            face_std = torch.FloatTensor(
                np.load('./basics/std_front_roni.npy')).cuda()
        elif config.dataset_name == 'grid':
            face_mean = torch.FloatTensor(
                np.load('./basics/mean_grid_roni_norm.npy')).cuda()
            lip_mean = torch.FloatTensor(
                np.load('./basics/mean_grid_lip_norm.npy')).cuda()
            lip_pca = torch.FloatTensor(
                np.load('./basics/U_grid_lip_norm.npy')[:, :6]).cuda()
            face_pca = torch.FloatTensor(
                np.load('./basics/U_grid_roni_norm.npy')[:, :6]).cuda()
            lip_std = torch.FloatTensor(
                np.load('./basics/std_grid_lip_norm.npy')).cuda()
            face_std = torch.FloatTensor(
                np.load('./basics/std_grid_roni_norm.npy')).cuda()

            pca_mean = torch.FloatTensor(
                np.load('./basics/mean_grid_norm.npy')).cuda()
            pca_std = torch.FloatTensor(
                np.load('./basics/std_grid_norm.npy')).cuda()
            pca_pca = torch.FloatTensor(
                np.load('./basics/U_grid_norm.npy')[:, :6]).cuda()

        num_steps_per_epoch = len(self.data_loader)
        cc = 0
        t0 = time.time()
        logger = Logger(config.log_dir)

        if config.openrate_loss:
            upper_pair = []
            lower_pair = []
            open_pair = []
            for i in range(5):
                upper_pair.append([i + 1, i + 12])
            for i in range(4):
                lower_pair.append([i + 7, i + 16])
            lower_pair.append([11, 12])

            for i in range(3):
                open_pair.append([i + 13, 19 - i])
        for epoch in range(self.start_epoch, config.max_epochs):
            for step, data in enumerate(self.data_loader):
                t1 = time.time()

                sample_audio = data['audio']
                if not config.seperate:
                    target_lmark = data['target_lmark']
                    ex_lmark = data['ex_lmark']
                else:
                    lip_region = data['lip_region']
                    other_region = data['other_region']
                    ex_other_region = data['ex_other_region']
                    ex_lip_region = data['ex_lip_region']

                if config.cuda:
                    sample_audio = Variable(sample_audio.float()).cuda()
                    if not config.seperate:
                        target_lmark = target_lmark.float().cuda()
                        ex_lmark = ex_lmark.float().cuda()
                    else:
                        lip_region = lip_region.float().cuda()
                        other_region = other_region.float().cuda()
                        ex_other_region = ex_other_region.float().cuda()
                        ex_lip_region = ex_lip_region.float().cuda()
                if config.seperate:
                    lip_region_pca = (lip_region -
                                      lip_mean.expand_as(lip_region)) / lip_std

                    if config.pca:
                        lip_region_pca = torch.mm(lip_region_pca, lip_pca)

                    ex_lip_region_pca = (
                        ex_lip_region -
                        lip_mean.expand_as(ex_lip_region)) / lip_std

                    if config.pca:
                        ex_lip_region_pca = torch.mm(ex_lip_region_pca,
                                                     lip_pca)

                    other_region_pca = (
                        other_region -
                        face_mean.expand_as(other_region)) / face_std

                    if config.pca:
                        other_region_pca = torch.mm(other_region_pca, face_pca)

                    ex_other_region_pca = (
                        ex_other_region -
                        face_mean.expand_as(ex_other_region)) / face_std

                    if config.pca:
                        ex_other_region_pca = torch.mm(ex_other_region_pca,
                                                       face_pca)

                    ex_other_region_pca = Variable(ex_other_region_pca)
                    lip_region_pca = Variable(lip_region_pca)
                    other_region_pca = Variable(other_region_pca)

                    fake_lip, fake_face = self.generator(
                        sample_audio, ex_other_region_pca, ex_lip_region_pca)

                    loss_lip = self.mse_loss_fn(fake_lip, lip_region_pca)
                    loss_face = self.mse_loss_fn(fake_face, other_region_pca)
                    loss = loss_lip + loss_face

                    real_upper_mouth = []
                    fake_upper_mouth = []
                    real_lower_mouth = []
                    fake_lower_mouth = []
                    real_open = []
                    fake_open = []

                    if config.openrate_loss:
                        for k in range(3):
                            real_open.append(
                                lip_region_pca[:, open_pair[k][0] *
                                               3:open_pair[k][0] * 3 + 2] -
                                lip_region_pca[:, open_pair[k][1] *
                                               3:open_pair[k][1] * 3 + 2])
                            fake_open.append(
                                fake_lip[:, open_pair[k][0] *
                                         3:open_pair[k][0] * 3 + 2] -
                                fake_lip[:, open_pair[k][1] *
                                         3:open_pair[k][1] * 3 + 2])

    #                     real_upper_mouth = torch.stack(real_upper_mouth, 1)
    #                     fake_upper_mouth = torch.stack(fake_upper_mouth, 1)
    #                     real_lower_mouth = torch.stack(real_lower_mouth, 1)
    #                     fake_lower_mouth = torch.stack(fake_lower_mouth, 1)
                        real_open = torch.stack(real_open, 1)
                        fake_open = torch.stack(fake_open, 1)

                        #                     loss_upper= self.mse_loss_fn(fake_upper_mouth , real_upper_mouth)
                        #                     loss_lower= self.mse_loss_fn(fake_lower_mouth , real_lower_mouth)
                        loss_open = self.mse_loss_fn(fake_open, real_open)

                        loss = loss + loss_open  #+ loss_upper + loss_lower + loss_open

                    loss.backward()
                    self.opt_g.step()
                    self._reset_gradients()

                    logger.scalar_summary(
                        'loss_lip', loss_lip,
                        epoch * num_steps_per_epoch + step + 1)
                    logger.scalar_summary(
                        'loss_face', loss_face,
                        epoch * num_steps_per_epoch + step + 1)
                    #                 logger.scalar_summary('loss_upper', loss_upper, epoch * num_steps_per_epoch +  step+1)
                    #                 logger.scalar_summary('loss_lower', loss_lower,epoch * num_steps_per_epoch + step+1)
                    if config.openrate_loss:
                        logger.scalar_summary(
                            'loss_open', loss_open,
                            epoch * num_steps_per_epoch + step + 1)

                    if step % 100 == 0:
                        print(
                            "[{}/{}:{}/{}]]   loss_face: {:.8f},loss_lip: {:.8f},data time: {:.4f},  model time: {} second"
                            .format(epoch + 1, config.max_epochs, step,
                                    num_steps_per_epoch, loss_face, loss_lip,
                                    t1 - t0,
                                    time.time() - t1))
                    t0 = time.time()


#                 if config.pca:
#                     fake_lip =  fake_lip.view(fake_lip.size(0)  , 6)
#                     fake_lip = torch.mm( fake_lip, lip_pca.t() )
#                 else:
#                     fake_lip =  fake_lip.view(fake_lip.size(0)  , 60)
#                 fake_lip = fake_lip  * lip_std + lip_mean.expand_as(fake_lip)

#                 if config.pca:
#                     fake_face =  fake_face.view(fake_face.size(0) , 6)
#                     fake_face = torch.mm( fake_face, face_pca.t() )
#                 else:
#                     fake_face =  fake_face.view(fake_face.size(0) , 144)
#                 fake_face = fake_face * face_std + face_mean.expand_as(fake_face)

#                 fake_lmark = torch.cat([fake_face, fake_lip], 1)

#                 fake_lmark = fake_lmark.data.cpu().numpy()

#                 if config.pca:
#                     real_lip =  lip_region_pca.view(lip_region_pca.size(0)  , 6)
#                     real_lip = torch.mm( real_lip, lip_pca.t() )
#                 else:
#                     real_lip =  lip_region_pca.view(lip_region_pca.size(0)  , 60)
#                 real_lip = real_lip  * lip_std + lip_mean.expand_as(real_lip)

#                 if config.pca:
#                     real_face =  other_region_pca.view(other_region_pca.size(0) , 6)
#                     real_face = torch.mm( real_face, face_pca.t() )
#                 else:
#                     real_face =  other_region_pca.view(other_region_pca.size(0) , 144)
#                 real_face = real_face  * face_std + face_mean.expand_as(real_face)

#                 real_lmark = torch.cat([real_face, real_lip], 1)

#                 real_lmark = real_lmark.data.cpu().numpy()

#                 for gg in range(int(config.batch_size/50)):

#                     lmark_name  = "{}/fake_{}_{}.png".format(config.sample_dir,epoch, gg)
#                     if config.dataset_name == 'vox':
#                         v_path = os.path.join('/data2/lchen63/voxceleb/unzip',data['img_path'][gg]  + '.mp4')
#                         cap = cv2.VideoCapture(v_path)
#                         for t in range(real_lmark.shape[0]):
#                             ret, frame = cap.read()
#                             if ret :
#                                 if t == int(data['sample_id'][gg]) :
#                                     gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#                                     break
#                     else:
#                         gt_img = cv2.imread(data['img_path'][gg])
#                         gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)

#                     plt = utils.lmark2img(fake_lmark[gg].reshape((68,3)), gt_img)
#                     plt.savefig(lmark_name)

#                     lmark_name  = "{}/real_{}_{}.png".format(config.sample_dir,epoch, gg)
#                     plt = utils.lmark2img(real_lmark[gg].reshape((68,3)), gt_img)
#                     plt.savefig(lmark_name)

#                 torch.save(self.generator.state_dict(),
#                            "{}/anet2_single.pth"
#                            .format(config.model_dir))

                else:
                    target_pca = (target_lmark -
                                  pca_mean.expand_as(target_lmark)) / pca_std

                    if config.pca:
                        target_pca = torch.mm(target_pca, pca_pca)

                    ex_pca = (ex_lmark -
                              pca_mean.expand_as(ex_lmark)) / pca_std

                    if config.pca:
                        ex_pca = torch.mm(ex_pca, pca_pca)

                    ex_pca = Variable(ex_pca)
                    target_pca = Variable(target_pca)

                    fake_lmark = self.generator(sample_audio, ex_pca)

                    loss = self.mse_loss_fn(fake_lmark, target_pca)

                    loss.backward()
                    self.opt_g.step()
                    self._reset_gradients()

                    logger.scalar_summary(
                        'loss', loss, epoch * num_steps_per_epoch + step + 1)
                    if config.openrate_loss:
                        logger.scalar_summary(
                            'loss_open', loss_open,
                            epoch * num_steps_per_epoch + step + 1)

                    if step % 100 == 0:
                        print(
                            "[{}/{}:{}/{}]]   loss: {:.8f}, data time: {:.4f},  model time: {} second"
                            .format(epoch + 1, config.max_epochs, step,
                                    num_steps_per_epoch, loss, t1 - t0,
                                    time.time() - t1))
                    t0 = time.time()

                if config.pca:
                    fake_lmark = fake_lmark.view(fake_lmark.size(0), 6)
                    fake_lmark = torch.mm(fake_lmark, pca_pca.t())
                else:
                    fake_lmark = fake_lmark.view(fake_lmark.size(0), 204)
                fake_lmark = fake_lmark * pca_std + pca_mean.expand_as(
                    fake_lmark)

                fake_lmark = fake_lmark.data.cpu().numpy()

                if config.pca:
                    real_lmark = target_pca.view(target_pca.size(0), 6)
                    real_lmark = torch.mm(real_lmark, pca_pca.t())
                else:
                    real_lmark = target_pca.view(target_pca.size(0), 60)
                real_lmark = real_lmark * pca_std + pca_mean.expand_as(
                    real_lmark)

                real_lmark = real_lmark.data.cpu().numpy()

                for gg in range(int(config.batch_size / 50)):

                    lmark_name = "{}/fake_{}_{}.png".format(
                        config.sample_dir, epoch, gg)
                    if config.dataset_name == 'vox':
                        v_path = os.path.join('/data2/lchen63/voxceleb/unzip',
                                              data['img_path'][gg] + '.mp4')
                        cap = cv2.VideoCapture(v_path)
                        for t in range(real_lmark.shape[0]):
                            ret, frame = cap.read()
                            if ret:
                                if t == int(data['sample_id'][gg]):
                                    gt_img = cv2.cvtColor(
                                        frame, cv2.COLOR_BGR2RGB)
                                    break
                    else:
                        gt_img = cv2.imread(data['img_path'][gg])
                        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)

                    plt = utils.lmark2img(fake_lmark[gg].reshape((68, 3)),
                                          gt_img)
                    plt.savefig(lmark_name)

                    lmark_name = "{}/real_{}_{}.png".format(
                        config.sample_dir, epoch, gg)
                    plt = utils.lmark2img(real_lmark[gg].reshape((68, 3)),
                                          gt_img)
                    plt.savefig(lmark_name)

                torch.save(self.generator.state_dict(),
                           "{}/anet2_single.pth".format(config.model_dir))
Пример #4
0
    def fit(self):
        config = self.config
        pca = torch.FloatTensor(
            np.load('./basics/U_front_smooth_vox.npy')[:, :6]).cuda()
        mean = torch.FloatTensor(
            np.load('./basics/mean_front_smooth_vox.npy')).cuda()
        num_steps_per_epoch = len(self.data_loader)
        cc = 0
        t0 = time.time()
        logger = Logger(config.log_dir)

        for epoch in range(self.start_epoch, config.max_epochs):
            for step, data in enumerate(self.data_loader):
                t1 = time.time()
                sample_audio = data['audio']
                sample_lmark = data['sample_lmark']
                ex_lmark = data['ex_lmark']

                if config.cuda:
                    sample_audio = Variable(sample_audio.float()).cuda()
                    sample_lmark = sample_lmark.float().cuda()
                    ex_lmark = ex_lmark.float().cuda()
                else:
                    sample_audio = Variable(sample_audio.float())
                    sample_lmark = Variable(sample_lmark.float())
                    ex_lmark = Variable(ex_lmark.float())

                sample_lmark_pca = (sample_lmark -
                                    mean.expand_as(sample_lmark))
                sample_lmark_pca = torch.mm(sample_lmark_pca, pca)

                ex_lmark_pca = (ex_lmark - mean.expand_as(ex_lmark))
                ex_lmark_pca = torch.mm(ex_lmark_pca, pca)

                sample_lmark_pca = Variable(sample_lmark_pca)
                ex_lmark_pca = Variable(ex_lmark_pca)

                fake_lmark = self.generator(sample_audio, ex_lmark_pca)

                loss = self.mse_loss_fn(fake_lmark, sample_lmark_pca)
                loss.backward()
                self.opt_g.step()
                self._reset_gradients()

                logger.scalar_summary('loss', loss,
                                      epoch * num_steps_per_epoch + step + 1)

                if step % 100 == 0:
                    # print ('++++++++++++++++++++')
                    # print (fake_lmark)
                    # print ('===================')
                    # print (sample_lmark_pca)
                    print(
                        "[{}/{}:{}/{}]]   loss: {:.8f},data time: {:.4f},  model time: {} second"
                        .format(epoch + 1, config.max_epochs, step,
                                num_steps_per_epoch, loss, t1 - t0,
                                time.time() - t1))
                t0 = time.time()

            fake_lmark = fake_lmark.view(fake_lmark.size(0), 6)
            fake_lmark = torch.mm(fake_lmark, pca.t())
            fake_lmark = fake_lmark + mean.expand_as(fake_lmark)

            fake_lmark = fake_lmark.data.cpu().numpy()

            real_lmark = sample_lmark_pca.view(sample_lmark_pca.size(0), 6)
            real_lmark = torch.mm(real_lmark, pca.t())
            real_lmark = real_lmark + mean.expand_as(real_lmark)

            real_lmark = real_lmark.data.cpu().numpy()

            for gg in range(int(config.batch_size) / 20):
                lmark_name = "{}/fake_{}_{}.png".format(
                    config.sample_dir, epoch, gg)
                v_path = os.path.join(
                    '/home/cxu-serve/p1/lchen63/voxceleb/unzip',
                    data['img_path'][gg] + '.mp4')
                cap = cv2.VideoCapture(v_path)
                for t in range(real_lmark.shape[0]):
                    ret, frame = cap.read()
                    if ret:
                        if t == int(data['sample_id'][gg]):
                            gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            break

                plt = utils.lmark2img(fake_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/real_{}_{}.png".format(
                    config.sample_dir, epoch, gg)
                plt = utils.lmark2img(real_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

            torch.save(self.generator.state_dict(),
                       "{}/anet_single.pth".format(config.model_dir))
Пример #5
0
    def fit(self):
        time_length = 32
        config = self.config
        face_pca = torch.FloatTensor( np.load('./basics/U_front_roni.npy')[:,:6]).cuda()
        face_mean =torch.FloatTensor( np.load('./basics/mean_front_roni.npy')).cuda()
        lip_pca = torch.FloatTensor( np.load('./basics/U_front.npy')[:,:6]).cuda()
        lip_mean =torch.FloatTensor( np.load('./basics/mean_front.npy')).cuda()
        
        
        lip_std = torch.FloatTensor( np.load('./basics/std_front.npy')).cuda()
        face_std = torch.FloatTensor( np.load('./basics/std_front_roni.npy')).cuda()

        num_steps_per_epoch = len(self.data_loader)
        cc = 0
        t0 = time.time()
        logger = Logger(config.log_dir)

        

        for epoch in range(self.start_epoch, config.max_epochs):
            for step, data in enumerate(self.data_loader):
                t1 = time.time()
                sample_audio =data['audio']
                lip_region =  data['lip_region'] 
                other_region = data['other_region']  
                ex_other_region = data['ex_other_region'] 
                ex_lip_region = data['ex_lip_region'] 


                if config.cuda:
                    sample_audio    = Variable(sample_audio.float()).cuda()
                    lip_region = lip_region.float().cuda()
                    other_region = other_region.float().cuda()
                    ex_other_region = ex_other_region.float().cuda()
                    ex_lip_region = ex_lip_region.float().cuda()
                else:
                    sample_audio    = Variable(sample_audio.float())
                    lip_region = Variable(lip_region.float())
                    other_region = Variable(other_region.float())
                    ex_other_region = Variable(ex_other_region.float())
                    ex_lip_region = ex_lip_region.float()
                
                
                lip_region_pca = (lip_region- lip_mean.expand_as(lip_region))/lip_std
                lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)
                lip_region_pca = torch.mm(lip_region_pca,  lip_pca)
                
                lip_region_pca = lip_region_pca.view(config.batch_size ,time_length,6)
                
                
#                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)
                ex_lip_region_pca = (ex_lip_region- lip_mean.expand_as(ex_lip_region))/lip_std
                
#                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)
                
                ex_lip_region_pca = torch.mm(ex_lip_region_pca,  lip_pca)
                
                ex_other_region_pca = (ex_other_region - face_mean.expand_as(ex_other_region))/face_std
                ex_other_region_pca = torch.mm(ex_other_region_pca,  face_pca)    
                
                other_region_pca = (other_region - face_mean.expand_as(other_region))/face_std
                other_region_pca = other_region_pca.view(other_region_pca.shape[0] *other_region_pca.shape[1], -1)
                other_region_pca = torch.mm(other_region_pca,  face_pca)
                
                other_region_pca = other_region_pca.view(config.batch_size ,time_length,6)
                
                ex_other_region_pca = Variable(ex_other_region_pca)
                lip_region_pca = Variable(lip_region_pca)
                other_region_pca = Variable(other_region_pca)
                
                
                
           
                fake_lip, fake_face = self.generator(sample_audio,ex_other_region_pca, ex_lip_region_pca)
#                 print (fake_lip.shape, fake_face.shape, lip_region_pca.shape, other_region_pca.shape)
                                                     
                loss_lip =  self.mse_loss_fn(fake_lip, lip_region_pca)
                loss_face =  self.mse_loss_fn(fake_face , other_region_pca)
                loss= loss_lip + 0.001* loss_face
                loss.backward() 
                self.opt_g.step()
                self._reset_gradients()
                
                logger.scalar_summary('loss_lip', loss_lip, epoch * num_steps_per_epoch +  step+1)
                logger.scalar_summary('loss_face', loss_face,epoch * num_steps_per_epoch + step+1)

                    
                
                if step % 100 == 0:
                    print("[{}/{}:{}/{}]]   loss_face: {:.8f},loss_lip: {:.8f},data time: {:.4f},  model time: {} second"
                      .format(epoch+1, config.max_epochs, step, num_steps_per_epoch,
                              loss_face, loss_lip,  t1-t0,  time.time() - t1))
                t0 = time.time()    
            fake_lip =  fake_lip.view(config.batch_size * time_length, 6)
            fake_lip = torch.mm( fake_lip, lip_pca.t() ) 
            fake_lip =  fake_lip.view(config.batch_size, time_length, -1)
            fake_lip = fake_lip  * lip_std+ lip_mean.expand_as(fake_lip)

            fake_face =  fake_face.view(config.batch_size * time_length, 6) 
            fake_face = torch.mm( fake_face, face_pca.t() ) 
            fake_face =  fake_face.view(config.batch_size, time_length, -1)
            fake_face = fake_face * face_std+ face_mean.expand_as(fake_face)
            fake_lmark = torch.cat([fake_face, fake_lip], 2)
            fake_lmark = fake_lmark.data.cpu().numpy()


            real_lip =  lip_region_pca.view(config.batch_size * time_length  , 6) 
            real_lip = torch.mm( real_lip, lip_pca.t() ) 
            real_lip =  real_lip.view(config.batch_size, time_length, -1)
            real_lip = real_lip  * lip_std + lip_mean.expand_as(real_lip)

            real_face =  other_region_pca.view(config.batch_size * time_length  , 6) 
            real_face = torch.mm( real_face, face_pca.t() ) 
            real_face =  real_face.view(config.batch_size, time_length, -1)
            real_face = real_face  * face_std+ face_mean.expand_as(real_face)

            real_lmark = torch.cat([real_face, real_lip], 2)

            real_lmark = real_lmark.data.cpu().numpy()

            for gg in range(int(config.batch_size/40)):
                v_path = os.path.join('/data2/lchen63/voxceleb/unzip',data['img_path'][gg]  + '.mp4') 
                cap = cv2.VideoCapture(v_path)
                ret = True
                count = 0
                while ret:
                    ret, frame = cap.read()
                    if ret and count == int(data['sample_id'][gg]) :
                        break
                    if not ret:
                        break
                    count += 1
                for g in range(time_length):
                    lmark_name  = "{}/fake_{}_{}_{}.png".format(config.sample_dir,epoch, gg, g)

                    gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    plt = utils.lmark2img(fake_lmark[gg,g].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)


                    lmark_name  = "{}/real_{}_{}_{}.png".format(config.sample_dir,epoch, gg, g)
                    plt = utils.lmark2img(real_lmark[gg,g].reshape((68,3)), gt_img)
                    plt.savefig(lmark_name)
                    ret, frame = cap.read()
                cap.release()
                cv2.destroyAllWindows()
                    
            torch.save(self.generator.state_dict(),
                       "{}/anet_lstm.pth"
                       .format(config.model_dir))
Пример #6
0
def test(config):
    generator = atnet()
    l1_loss_fn = nn.L1Loss()
    mse_loss_fn = nn.MSELoss()
    config = config
    time_length = 32

    if config.cuda:
        generator = generator.cuda()
        mse_loss_fn = mse_loss_fn.cuda()

    initialize_weights(generator)

    dataset = Voxceleb_audio_lmark('/data2/lchen63/voxceleb/txt/',
                                   config.is_train)

    data_loader = DataLoader(dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=False,
                             drop_last=True)

    face_pca = torch.FloatTensor(
        np.load('./basics/U_front_roni.npy')[:, :6]).cuda()
    face_mean = torch.FloatTensor(
        np.load('./basics/mean_front_roni.npy')).cuda()
    lip_pca = torch.FloatTensor(np.load('./basics/U_front.npy')[:, :6]).cuda()
    lip_mean = torch.FloatTensor(np.load('./basics/mean_front.npy')).cuda()
    lip_std = torch.FloatTensor(np.load('./basics/std_front.npy')).cuda()
    face_std = torch.FloatTensor(np.load('./basics/std_front_roni.npy')).cuda()

    num_steps_per_epoch = len(data_loader)
    cc = 0
    t0 = time.time()

    generator.load_state_dict(torch.load(config.model_dir))
    generator.eval()

    logger = Logger(config.log_dir)
    total_mse = []
    total_openrate_mse = []

    std_total_mse = []
    std_total_openrate_mse = []
    for step, data in enumerate(data_loader):
        print(step)
        if step == 100:
            break
        t1 = time.time()
        sample_audio = data['audio']
        lip_region = data['lip_region']
        other_region = data['other_region']
        ex_other_region = data['ex_other_region']
        ex_lip_region = data['ex_lip_region']

        sample_rt = data['sample_rt'].numpy()

        if config.cuda:
            sample_audio = Variable(sample_audio.float()).cuda()
            lip_region = lip_region.float().cuda()
            other_region = other_region.float().cuda()
            ex_other_region = ex_other_region.float().cuda()
            ex_lip_region = ex_lip_region.float().cuda()
        else:
            sample_audio = Variable(sample_audio.float())
            lip_region = Variable(lip_region.float())
            other_region = Variable(other_region.float())
            ex_other_region = Variable(ex_other_region.float())
            ex_lip_region = ex_lip_region.float()

        lip_region_pca = (lip_region -
                          lip_mean.expand_as(lip_region)) / lip_std
        lip_region_pca = lip_region_pca.view(
            lip_region_pca.shape[0] * lip_region_pca.shape[1], -1)
        lip_region_pca = torch.mm(lip_region_pca, lip_pca)

        lip_region_pca = lip_region_pca.view(config.batch_size, time_length, 6)

        #                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)
        ex_lip_region_pca = (ex_lip_region -
                             lip_mean.expand_as(ex_lip_region)) / lip_std

        #                 lip_region_pca = lip_region_pca.view(lip_region_pca.shape[0] *lip_region_pca.shape[1], -1)

        ex_lip_region_pca = torch.mm(ex_lip_region_pca, lip_pca)

        ex_other_region_pca = (ex_other_region -
                               face_mean.expand_as(ex_other_region)) / face_std
        ex_other_region_pca = torch.mm(ex_other_region_pca, face_pca)

        other_region_pca = (other_region -
                            face_mean.expand_as(other_region)) / face_std
        other_region_pca = other_region_pca.view(
            other_region_pca.shape[0] * other_region_pca.shape[1], -1)
        other_region_pca = torch.mm(other_region_pca, face_pca)

        other_region_pca = other_region_pca.view(config.batch_size,
                                                 time_length, 6)

        ex_other_region_pca = Variable(ex_other_region_pca)
        lip_region_pca = Variable(lip_region_pca)
        other_region_pca = Variable(other_region_pca)

        fake_lip, fake_face = generator(sample_audio, ex_other_region_pca,
                                        ex_lip_region_pca)

        loss_lip = mse_loss_fn(fake_lip, lip_region_pca)
        loss_face = mse_loss_fn(fake_face, other_region_pca)

        logger.scalar_summary('loss_lip', loss_lip, step + 1)
        logger.scalar_summary('loss_face', loss_face, step + 1)

        fake_lip = fake_lip.view(fake_lip.size(0) * time_length, 6)

        fake_face = fake_face.view(fake_face.size(0) * time_length, 6)

        lip_region_pca = lip_region_pca.view(
            lip_region_pca.size(0) * time_length, 6)

        other_region_pca = other_region_pca.view(
            other_region_pca.size(0) * time_length, 6)

        if config.pca:
            fake_lip = fake_lip.view(fake_lip.size(0), 6)
            fake_lip = torch.mm(fake_lip, lip_pca.t())
        else:
            fake_lip = fake_lip.view(fake_lip.size(0), 60)
        fake_lip = fake_lip * lip_std + lip_mean.expand_as(fake_lip)
        if config.pca:
            fake_face = fake_face.view(fake_face.size(0), 6)
            fake_face = torch.mm(fake_face, face_pca.t())
        else:
            fake_face = fake_face.view(fake_face.size(0), 144)
        fake_face = fake_face * face_std + face_mean.expand_as(fake_face)

        fake_lmark = torch.cat([fake_face, fake_lip], 1)

        fake_lmark = fake_lmark.data.cpu().numpy()

        if config.pca:
            real_lip = lip_region_pca.view(lip_region_pca.size(0), 6)
            real_lip = torch.mm(real_lip, lip_pca.t())
        else:
            real_lip = lip_region_pca.view(lip_region_pca.size(0), 60)
        real_lip = real_lip * lip_std + lip_mean.expand_as(real_lip)
        if config.pca:
            real_face = other_region_pca.view(other_region_pca.size(0), 6)
            real_face = torch.mm(real_face, face_pca.t())
        else:
            real_face = other_region_pca.view(other_region_pca.size(0), 144)
        real_face = real_face * face_std + face_mean.expand_as(real_face)

        real_lmark = torch.cat([real_face, real_lip], 1)

        real_lmark = real_lmark.data.cpu().numpy()

        if not os.path.exists(os.path.join(config.sample_dir, str(step))):
            os.mkdir(os.path.join(config.sample_dir, str(step)))

        for gg in range(int(config.batch_size)):

            #convert lmark to rted landmark
            fake_A3 = utils.reverse_rt(fake_lmark[gg].reshape((68, 3)),
                                       sample_rt[gg])
            #convert lmark to rted landmark
            A3 = utils.reverse_rt(real_lmark[gg].reshape((68, 3)),
                                  sample_rt[gg])
            if config.visualize:
                lmark_name = "{}/fake_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                v_path = os.path.join('/data2/lchen63/voxceleb/unzip',
                                      data['img_path'][gg] + '.mp4')
                cap = cv2.VideoCapture(v_path)
                for t in range(real_lmark.shape[0]):
                    ret, frame = cap.read()
                    if ret:
                        if t == int(data['sample_id'][gg]):
                            gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            break

                plt = utils.lmark2img(fake_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/fake_rt_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(fake_A3.reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/real_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(real_lmark[gg].reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

                lmark_name = "{}/real_rt_{}.png".format(
                    os.path.join(config.sample_dir, str(step)), gg)
                plt = utils.lmark2img(A3.reshape((68, 3)), gt_img)
                plt.savefig(lmark_name)

            mse = utils.mse_metrix(A3, fake_A3)
            openrate = utils.openrate_metrix(A3, fake_A3)
            if mse > 50 or openrate > 50:
                continue
            total_openrate_mse.append(openrate)
            total_mse.append(mse)

            std_mse = utils.mse_metrix(real_lmark[gg].reshape((68, 3)),
                                       fake_lmark[gg].reshape((68, 3)))
            std_openrate = utils.openrate_metrix(
                real_lmark[gg].reshape((68, 3)), fake_lmark[gg].reshape(
                    (68, 3)))
            std_total_openrate_mse.append(std_openrate)
            std_total_mse.append(std_mse)

            print('{}/{}: mse: {}:openrate: {}'.format(step, gg, mse,
                                                       openrate))

            #compute different evaluation matrix  input: (68,3) real A3, fake A3
    total_mse = np.asarray(total_mse)
    total_openrate_mse = np.asarray(total_openrate_mse)

    std_total_mse = np.asarray(std_total_mse)
    std_total_openrate_mse = np.asarray(std_total_openrate_mse)

    print(total_mse.mean())
    print(total_openrate_mse.mean())
    print(std_total_mse.mean())
    print(std_total_openrate_mse.mean())