Пример #1
0
    def __init__(self, config):

        self.generator = VG_net()
        self.discriminator = GL_Discriminator()
        self.bce_loss_fn = nn.BCELoss()
        self.l1_loss_fn = nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()
        self.config = config
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            # self.encoder = nn.DataParallel(self.encoder.cuda(device=config.cuda1), device_ids=device_ids)
            self.generator = nn.DataParallel(self.generator,
                                             device_ids=device_ids).cuda()
            self.discriminator = nn.DataParallel(self.discriminator,
                                                 device_ids=device_ids).cuda()

            self.bce_loss_fn = self.bce_loss_fn.cuda(device=config.cuda1)
            self.mse_loss_fn = self.mse_loss_fn.cuda(device=config.cuda1)
            self.l1_loss_fn = self.l1_loss_fn.cuda(device=config.cuda1)
            self.ones = self.ones.cuda(device=config.cuda1)
            self.zeros = self.zeros.cuda(device=config.cuda1)


# #########single GPU#######################

#         if config.cuda:
#             device_ids = [int(i) for i in config.device_ids.split(',')]
#             self.generator     = self.generator.cuda(device=config.cuda1)
#             self.encoder = self.encoder.cuda(device=config.cuda1)
#             self.mse_loss_fn   = self.mse_loss_fn.cuda(device=config.cuda1)
#             self.l1_loss_fn =  nn.L1Loss().cuda(device=config.cuda1)
        initialize_weights(self.generator)
        self.start_epoch = 0
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
        print('-----------')

        self.opt_g = torch.optim.Adam(self.generator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.dataset = LRWdataset1D_lstm_gt(config.dataset_dir,
                                            train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
Пример #2
0
class Trainer():
    def __init__(self, config):

        self.generator = VG_net()
        self.discriminator = GL_Discriminator()
        self.bce_loss_fn = nn.BCELoss()
        self.l1_loss_fn =  nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()
        self.config = config
        self.ones = Variable(torch.ones(config.batch_size), requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator     = nn.DataParallel(self.generator, device_ids=device_ids).cuda()
            self.discriminator     = nn.DataParallel(self.discriminator, device_ids=device_ids).cuda()

            self.bce_loss_fn   = self.bce_loss_fn.cuda(device=config.cuda1)
            self.mse_loss_fn   = self.mse_loss_fn.cuda(device=config.cuda1)
            self.l1_loss_fn = self.l1_loss_fn.cuda(device=config.cuda1)
            self.ones          = self.ones.cuda(device=config.cuda1)
            self.zeros          = self.zeros.cuda(device=config.cuda1)
        initialize_weights(self.generator)
        self.start_epoch = 0
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
        print ('-----------')
       

        self.opt_g = torch.optim.Adam( self.generator.parameters(),
            lr=config.lr, betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam( self.discriminator.parameters(),
            lr=config.lr, betas=(config.beta1, config.beta2))
        self.dataset = LRWdataset1D_lstm_gt(config.dataset_dir, train=config.is_train)
        
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True, drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()

    def fit(self):

        config = self.config

        musk = torch.FloatTensor(1,136)
        musk[:,:] = 1.0
        for jj in range(0,40,2):
            musk[:,(97+jj) ] = 100.0
        musk = musk.unsqueeze(0)

        num_steps_per_epoch = len(self.data_loader)
        cc = 0
        t0 = time.time()
        xLim=(-1.0, 1.0)
        yLim=(-1.0,1.0)
        xLab = 'x'
        yLab = 'y'

        for epoch in range(self.start_epoch, config.max_epochs):
            for step, (example_img, example_landmark, right_img,right_landmark) in enumerate(self.data_loader):
                t1 = time.time()

                if config.cuda:
                    example_img = Variable(example_img.float()).cuda(device=config.cuda1)
                    example_landmark = Variable(example_landmark.float()).cuda(device=config.cuda1)
                    right_img    = Variable(right_img.float()).cuda(device=config.cuda1)
                    right_landmark = Variable(right_landmark.float()).cuda(device=config.cuda1)
                    musk = Variable(musk.float()).cuda(device=config.cuda1)
                else:
                    example_img = Variable(example_img.float())

                    right_img = Variable(right_img.float())
                    mfccs = Variable(mfccs.float())
                for p in self.discriminator.parameters():
                    p.requires_grad =  True

                fake_im, _ ,_,_ = self.generator(example_img, right_landmark, example_landmark)
                D_real, d_r_lmark = self.discriminator(right_img, example_landmark)
                loss_real = self.bce_loss_fn(D_real, self.ones)

                loss_lmark_r =  self.mse_loss_fn(d_r_lmark * musk, right_landmark* musk)
                D_fake, d_f_lmark = self.discriminator(fake_im.detach(), example_landmark)
                loss_fake = self.bce_loss_fn(D_fake, self.zeros)
                loss_lmark_f =  self.mse_loss_fn(d_f_lmark * musk, right_landmark* musk)

                loss_disc = loss_real  + loss_fake + loss_lmark_f + loss_lmark_r

                loss_disc.backward()
                self.opt_d.step()
                self._reset_gradients()

                for p in self.discriminator.parameters():
                    p.requires_grad = False

                fake_im, att ,colors, _ = self.generator(example_img,  right_landmark, example_landmark)
                D_fake, d_f_lmark = self.discriminator(fake_im.detach(), example_landmark)

                loss_lmark =  self.mse_loss_fn(d_f_lmark * musk, right_landmark* musk)
                loss_gen = self.bce_loss_fn(D_fake, self.ones)

                loss_musk = att.detach() + 0.5
                diff = torch.abs(fake_im - right_img) * loss_musk
                loss_pix = torch.mean(diff)
                loss =10 * loss_pix + loss_gen + loss_lmark
                loss.backward() 
                self.opt_g.step()
                self._reset_gradients()
                t2 = time.time()

                if (step+1) % 10 == 0 or (step+1) == num_steps_per_epoch:
                    steps_remain = num_steps_per_epoch-step+1 + \
                        (config.max_epochs-epoch+1)*num_steps_per_epoch

                    print("[{}/{}][{}/{}]  ,  loss_disc: {:.8f},   loss_lmark_f: {:.8f},  loss_lmark_r: {:.8f} ,  loss_gen: {:.8f}  ,  loss_pix: {:.8f} ,  loss_lmark: {:.8f} , data time: {:.4f},  model time: {} second"
                          .format(epoch+1, config.max_epochs,
                                  step+1, num_steps_per_epoch, loss_disc.data[0],loss_lmark_f.data[0],loss_lmark_r.data[0],loss_gen.data[0],loss_pix.data[0],loss_lmark.data[0], t1-t0,  t2 - t1))

                if (step) % (int(num_steps_per_epoch  / 20 )) == 0 :
                    atts_store = att.data.contiguous().view(config.batch_size*16,1,128,128)
                    colors_store = colors.data.contiguous().view(config.batch_size * 16,3,128,128)
                    fake_store = fake_im.data.contiguous().view(config.batch_size*16,3,128,128)
                    torchvision.utils.save_image(atts_store, 
                        "{}att_{}.png".format(config.sample_dir,cc),normalize=True)
                    torchvision.utils.save_image(colors_store, 
                        "{}color_{}.png".format(config.sample_dir,cc),normalize=True)
                    torchvision.utils.save_image(fake_store,
                        "{}imf_fake_{}.png".format(config.sample_dir,cc),normalize=True)
                    real_store = right_img.data.contiguous().view(config.batch_size * 16,3,128,128)
                    torchvision.utils.save_image(real_store,
                        "{}img_real_{}.png".format(config.sample_dir,cc),normalize=True)
                    cc += 1
                    torch.save(self.generator.state_dict(),
                               "{}/vg_net_{}.pth"
                               .format(config.model_dir,cc))
                 
                t0 = time.time()
    def load(self, directory, epoch):
        gen_path = os.path.join(directory, 'gen.pth'.format(epoch))

        self.generator.load_state_dict(torch.load(gen_path))

        dis_path = os.path.join(directory, 'dis.pth'.format(epoch))

        self.discriminator.load_state_dict(torch.load(dis_path))


    def _reset_gradients(self):
        self.generator.zero_grad()
        self.discriminator.zero_grad()
Пример #3
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    if os.path.exists('../temp'):
        shutil.rmtree('../temp')
    os.mkdir('../temp')
    os.mkdir('../temp/img')
    os.mkdir('../temp/motion')
    os.mkdir('../temp/attention')
    pca = torch.FloatTensor( np.load('../basics/U_lrw1.npy')[:,:6]).cuda()
    mean =torch.FloatTensor( np.load('../basics/mean_lrw1.npy')).cuda()
    decoder = VG_net()
    encoder = AT_net()
    if config.cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    state_dict2 = multi2single(config.vg_model, 1)

    # state_dict2 = torch.load(config.video_model, map_location=lambda storage, loc: storage)
    decoder.load_state_dict(state_dict2)

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()
    test_file = config.in_file

    example_image, example_landmark = generator_demo_example_lips( config.person)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
     ])        
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
    example_image = transform(example_image)

    example_landmark =  example_landmark.reshape((1,example_landmark.shape[0]* example_landmark.shape[1]))

    if config.cuda:
        example_image = Variable(example_image.view(1,3,128,128)).cuda()
        example_landmark = Variable(torch.FloatTensor(example_landmark.astype(float)) ).cuda()
    else:
        example_image = Variable(example_image.view(1,3,128,128))
        example_landmark = Variable(torch.FloatTensor(example_landmark.astype(float)))
    # Load speech and extract features
    example_landmark = example_landmark * 5.0
    example_landmark  = example_landmark - mean.expand_as(example_landmark)
    example_landmark = torch.mm(example_landmark,  pca)
    speech, sr = librosa.load(test_file, sr=16000)
    mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)
    speech = np.insert(speech, 0, np.zeros(1920))
    speech = np.append(speech, np.zeros(1920))
    mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)

    sound, _ = librosa.load(test_file, sr=44100)

    print ('=======================================')
    print ('Start to generate images')
    t =time.time()
    ind = 3
    with torch.no_grad(): 
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0]/4) - 4:
            t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc).cuda()
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc,dim = 0)
        input_mfcc = input_mfcc.unsqueeze(0)
        fake_lmark = encoder(example_landmark, input_mfcc)
        fake_lmark = fake_lmark.view(fake_lmark.size(0) *fake_lmark.size(1) , 6)
        example_landmark  = torch.mm( example_landmark, pca.t() ) 
        example_landmark = example_landmark + mean.expand_as(example_landmark)
        fake_lmark[:, 1:6] *= 2*torch.FloatTensor(np.array([1.1, 1.2, 1.3, 1.4, 1.5])).cuda() 
        fake_lmark = torch.mm( fake_lmark, pca.t() )
        fake_lmark = fake_lmark + mean.expand_as(fake_lmark)
    

        fake_lmark = fake_lmark.unsqueeze(0) 

        fake_ims, atts ,ms ,_ = decoder(example_image, fake_lmark, example_landmark )

        for indx in range(fake_ims.size(1)):
            fake_im = fake_ims[:,indx]
            fake_store = fake_im.permute(0,2,3,1).data.cpu().numpy()[0]
            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'img') ,indx ), fake_store)
            m = ms[:,indx]
            att = atts[:,indx]
            m = m.permute(0,2,3,1).data.cpu().numpy()[0]
            att = att.data.cpu().numpy()[0,0]

            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'motion' ) ,indx ), m)
            scipy.misc.imsave("{}/{:05d}.png".format(os.path.join('../', 'temp', 'attention') ,indx ), att)

        print ( 'In total, generate {:d} images, cost time: {:03f} seconds'.format(fake_ims.size(1), time.time() - t) )
            
        fake_lmark = fake_lmark.data.cpu().numpy()
        np.save( os.path.join( config.sample_dir,  'obama_fake.npy'), fake_lmark)
        fake_lmark = np.reshape(fake_lmark, (fake_lmark.shape[1], 68, 2))
        utils.write_video_wpts_wsound(fake_lmark, sound, 44100, config.sample_dir, 'fake', [-1.0, 1.0], [-1.0, 1.0])
        video_name = os.path.join(config.sample_dir , 'results.mp4')
        utils.image_to_video(os.path.join('../', 'temp', 'img'), video_name )
        utils.add_audio(video_name, config.in_file)
        print ('The generated video is: {}'.format(os.path.join(config.sample_dir , 'results.mov')))
Пример #4
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    generator = VG_net()
    if config.cuda:
        generator = generator.cuda()

    state_dict = multi2single(config.model_name, 1)
    generator.load_state_dict(state_dict)
    print('load pretrained [{}]'.format(config.model_name))
    generator.eval()
    dataset = LRWdataset(config.dataset_dir, train='test')
    data_loader = DataLoader(dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=True,
                             drop_last=True)
    data_iter = iter(data_loader)
    data_iter.next()

    if not os.path.exists(config.sample_dir):
        os.mkdir(config.sample_dir)
    if not os.path.exists(os.path.join(config.sample_dir, 'fake')):
        os.mkdir(os.path.join(config.sample_dir, 'fake'))
    if not os.path.exists(os.path.join(config.sample_dir, 'real')):
        os.mkdir(os.path.join(config.sample_dir, 'real'))
    if not os.path.exists(os.path.join(config.sample_dir, 'color')):
        os.mkdir(os.path.join(config.sample_dir, 'color'))
    if not os.path.exists(os.path.join(config.sample_dir, 'single')):
        os.mkdir(os.path.join(config.sample_dir, 'single'))
    generator.eval()
    for step, (example_img, example_landmark, real_im,
               landmarks) in enumerate(data_loader):
        with torch.no_grad():
            if step == 43:
                break
            if config.cuda:
                real_im = Variable(real_im.float()).cuda()
                example_img = Variable(example_img.float()).cuda()
                landmarks = Variable(landmarks.float()).cuda()
                example_landmark = Variable(example_landmark.float()).cuda()
            fake_im, atts, colors, lmark_att = generator(
                example_img, landmarks, example_landmark)

            fake_store = fake_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            real_store = real_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            atts_store = atts.data.contiguous().view(config.batch_size * 16, 1,
                                                     128, 128)
            colors_store = colors.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            lmark_att = bilinear_interpolate_torch_gridsample(lmark_att)
            lmark_att_store = lmark_att.data.contiguous().view(
                config.batch_size * 16, 1, 128, 128)
            torchvision.utils.save_image(atts_store,
                                         "{}color/att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(lmark_att_store,
                                         "{}color/lmark_att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(colors_store,
                                         "{}color/color_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)
            torchvision.utils.save_image(fake_store,
                                         "{}fake/fake_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            torchvision.utils.save_image(real_store,
                                         "{}real/real_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            fake_store = fake_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()
            real_store = real_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()

            print(step)
            print(fake_store.shape)
            for inx in range(config.batch_size * 16):
                scipy.misc.imsave(
                    "{}single/fake_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), fake_store[inx])
                scipy.misc.imsave(
                    "{}single/real_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), real_store[inx])
Пример #5
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids

    result_dir = 'temp/' + config.in_file
    motion_dir = result_dir + '/motion/'

    os.mkdir(result_dir)
    os.mkdir(motion_dir)

    pca = torch.FloatTensor(np.load('basics/pca.npy')[:, :6])
    mean = torch.FloatTensor(np.load('basics/mean.npy'))
    decoder = VG_net()
    encoder = AT_net()

    state_dict2 = multi2single(config.vg_model, 1)

    decoder.load_state_dict(state_dict2)

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()
    test_file = result_dir + "/" + config.in_file + ".wav"
    test_file_old = result_dir + "/old_" + config.in_file + ".wav"
    if config.text_tts == "" and config.news_url != "":
        parse_news_content = get_info(config.news_url)['news_content']
    else:
        parse_news_content = config.text_tts
    tts = TTS(config.name_tts,
              "wav",
              "000000-0000-0000-0000-00000000",
              config.lang_tts,
              emotion="neutral",
              speed=1)
    # test content
    tts.generate(parse_news_content[:1999])
    if config.shift == 1:
        tts.save(test_file_old)
        audio_shift(test_file_old, test_file)
    else:
        tts.save(test_file)

    example_image, example_landmark = generator_demo_example_lips(
        config.person)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
    example_image = transform(example_image)

    example_landmark = example_landmark.reshape(
        (1, example_landmark.shape[0] * example_landmark.shape[1]))

    if config.cuda == True:
        example_image = Variable(example_image.view(1, 3, 128, 128)).cuda()
        example_landmark = Variable(
            torch.FloatTensor(example_landmark.astype(float))).cuda()
    else:
        example_image = Variable(example_image.view(1, 3, 128, 128))
        example_landmark = Variable(
            torch.FloatTensor(example_landmark.astype(float)))
    example_landmark = example_landmark * 5.0
    example_landmark = example_landmark - mean.expand_as(example_landmark)
    example_landmark = torch.mm(example_landmark, pca)
    speech, sr = librosa.load(test_file, sr=16000)
    mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)
    speech = np.insert(speech, 0, np.zeros(1920))
    speech = np.append(speech, np.zeros(1920))
    mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)

    sound, _ = librosa.load(test_file, sr=44100)

    print('=======================================')
    print('Generate images')
    t = time.time()
    ind = 3
    with torch.no_grad():
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0] / 4) - 4:
            t_mfcc = mfcc[(ind - 3) * 4:(ind + 4) * 4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc)
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc, dim=0)
        input_mfcc = input_mfcc.unsqueeze(0)
        fake_lmark = encoder(example_landmark, input_mfcc)
        fake_lmark = fake_lmark.view(
            fake_lmark.size(0) * fake_lmark.size(1), 6)
        example_landmark = torch.mm(example_landmark, pca.t())
        example_landmark = example_landmark + mean.expand_as(example_landmark)
        fake_lmark[:, 1:6] *= 2 * torch.FloatTensor(
            np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
        fake_lmark = torch.mm(fake_lmark, pca.t())
        fake_lmark = fake_lmark + mean.expand_as(fake_lmark)

        fake_lmark = fake_lmark.unsqueeze(0)

        fake_lmark = fake_lmark.data.cpu().numpy()

        file_mark = result_dir + "/" + config.in_file + ".npy"
        file_mp4 = result_dir + "/" + config.in_file  # + ".mp4"
        np.save(file_mark, fake_lmark)
        mark_paint.mark_video(fake_lmark, motion_dir)

        cmd = 'ffmpeg -framerate 25 -i ' + motion_dir + '%d.png  -filter:v scale=512:-1 -c:v libx264 -pix_fmt yuv420p ' + file_mp4 + '.mp4'
        subprocess.call(cmd, shell=True)
        print('video done')

        cmd = 'ffmpeg -i ' + file_mp4 + '.mp4 -i ' + test_file + ' -c:v copy -c:a aac -strict experimental ' + file_mp4 + '_result.mp4'
        subprocess.call(cmd, shell=True)
        print('video+audio done')

        return file_mark
    return False
Пример #6
0
def test():
    data_root = '/home/cxu-serve/p1/common/voxceleb2/unzip/test_video/'
    #data_root = '/home/cxu-serve/p1/common/lrs3/lrs3_v0.4/'
    audios = []
    videos = []
    start_ids = []
    end_ids = []
    with open(
            '/home/cxu-serve/p1/common/degree/degree_store/vox/new_extra_data.csv',
            'r') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            #print (row)
            audios.append(row[1])
            videos.append(row[0])
            start_ids.append(int(row[2]))
            end_ids.append(int(row[3]))
            #audios.append(os.path.join(data_root, 'test',tmp[1], tmp[2] + '.wav'))
            #videos.append(os.path.join(data_root, 'test', tmp[1], tmp[2] + '_crop.mp4'))
    #print (gg)
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    if os.path.exists('../temp'):
        shutil.rmtree('../temp')
    os.mkdir('../temp')
    pca = torch.FloatTensor(np.load('../basics/U_lrw1.npy')[:, :6]).cuda()
    mean = torch.FloatTensor(np.load('../basics/mean_lrw1.npy')).cuda()
    decoder = VG_net()
    encoder = AT_net()
    if config.cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    state_dict2 = multi2single(config.vg_model, 1)

    # state_dict2 = torch.load(config.video_model, map_location=lambda storage, loc: storage)
    decoder.load_state_dict(state_dict2)

    state_dict = multi2single(config.at_model, 1)
    encoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()

    # vox
    # # get file paths
    # path = '/home/cxu-serve/p1/common/experiment/vox_good'
    # files = os.listdir(path)
    # data_root = '/home/cxu-serve/p1/common/voxceleb2/unzip/'
    # audios = []
    # videos = []
    # for f in files:
    #     if f[:7] == 'id00817' :
    #         audios.append(os.path.join(data_root, 'test_audio', f[:7], f[8:-14], '{}.wav'.format(f.split('_')[-2])))
    #         videos.append(os.path.join(data_root, 'test_video', f[:7], f[8:-14], '{}_aligned.mp4'.format(f.split('_')[-2])))

    # for i in range(len(audios)):
    #     audio_file = audios[i]
    #     video_file = videos[i]

    #     test_file = audio_file
    #     image_path = video_file
    #     video_name = image_path.split('/')[-3] + '__'  + image_path.split('/')[-2] +'__' + image_path.split('/')[-1][:-4]

    # get file paths
    #path = '/home/cxu-serve/p1/common/other/lrs_good2'
    #files = os.listdir(path)

    #for f in files:
    #    print (f)
    #    if f[:4] !='test':
    #        continue
    #    tmp = f.split('_')#

    # if f[:7] == 'id00817' :

    for i in range(len(audios)):
        try:
            audio_file = audios[i]
            video_file = videos[i]

            test_file = audio_file
            image_path = video_file
            video_name = video_file.split('/')[-1][:-4] + '_' + str(
                start_ids[i])
            #image_path = os.path.join('../image', video_name + '.jpg')
            #print (video_name, image_path)
            #cap = cv2.VideoCapture(video_file)
            #imgs = []
            #count = 0
            #while(cap.isOpened()):
            #    count += 1
            #    ret, frame = cap.read()
            #    if count != 33:
            #        continue
            #   else:
            #        cv2.imwrite(image_path, frame)
            #    try:
            #        example_image, example_landmark = generator_demo_example_lips(image_path)
            #    except:
            #        continue
            #    break

            example_image, example_landmark = generator_demo_example_lips(
                image_path)

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ])
            example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
            example_image = transform(example_image)

            example_landmark = example_landmark.reshape(
                (1, example_landmark.shape[0] * example_landmark.shape[1]))

            if config.cuda:
                example_image = Variable(example_image.view(1, 3, 128,
                                                            128)).cuda()
                example_landmark = Variable(
                    torch.FloatTensor(example_landmark.astype(float))).cuda()
            else:
                example_image = Variable(example_image.view(1, 3, 128, 128))
                example_landmark = Variable(
                    torch.FloatTensor(example_landmark.astype(float)))
        # Load speech and extract features
            example_landmark = example_landmark * 5.0
            example_landmark = example_landmark - mean.expand_as(
                example_landmark)
            example_landmark = torch.mm(example_landmark, pca)
            speech, sr = librosa.load(audio_file, sr=16000)
            mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)
            speech = np.insert(speech, 0, np.zeros(1920))
            speech = np.append(speech, np.zeros(1920))
            mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01)

            # print (mfcc.shape)

            #sound, _ = librosa.load(test_file, sr=44100)

            print('=======================================')
            print('Start to generate images')
            t = time.time()
            ind = 3
            with torch.no_grad():
                fake_lmark = []
                input_mfcc = []

                while ind <= int(mfcc.shape[0] / 4) - 4:
                    t_mfcc = mfcc[(ind - 3) * 4:(ind + 4) * 4, 1:]
                    t_mfcc = torch.FloatTensor(t_mfcc).cuda()
                    if ind >= start_ids[i] and ind < end_ids[i]:
                        input_mfcc.append(t_mfcc)
                    ind += 1

                input_mfcc = torch.stack(input_mfcc, dim=0)
                input_mfcc = input_mfcc.unsqueeze(0)
                print(input_mfcc.shape)
                fake_lmark = encoder(example_landmark, input_mfcc)
                fake_lmark = fake_lmark.view(
                    fake_lmark.size(0) * fake_lmark.size(1), 6)
                example_landmark = torch.mm(example_landmark, pca.t())
                example_landmark = example_landmark + mean.expand_as(
                    example_landmark)
                fake_lmark[:, 1:6] *= 2 * torch.FloatTensor(
                    np.array([1.1, 1.2, 1.3, 1.4, 1.5])).cuda()
                fake_lmark = torch.mm(fake_lmark, pca.t())
                fake_lmark = fake_lmark + mean.expand_as(fake_lmark)

                fake_lmark = fake_lmark.unsqueeze(0)

                fake_ims, _, _, _ = decoder(example_image, fake_lmark,
                                            example_landmark)
                os.system('rm ../temp/*')
                for indx in range(fake_ims.size(1)):
                    fake_im = fake_ims[:, indx]
                    fake_store = fake_im.permute(0, 2, 3,
                                                 1).data.cpu().numpy()[0]
                    scipy.misc.imsave(
                        "{}/{:05d}.png".format(os.path.join('../', 'temp'),
                                               indx), fake_store)
                print(time.time() - t)
                fake_lmark = fake_lmark.data.cpu().numpy()
                # os.system('rm ../results/*')
                # np.save( os.path.join( config.sample_dir,  'obama_fake.npy'), fake_lmark)
                # fake_lmark = np.reshape(fake_lmark, (fake_lmark.shape[1], 68, 2))
                # utils.write_video_wpts_wsound(fake_lmark, sound, 44100, config.sample_dir, 'fake', [-1.0, 1.0], [-1.0, 1.0])
                video_name = os.path.join(config.sample_dir, video_name)
                # ffmpeg.input('../temp/*.png', pattern_type='glob', framerate=25).output(video_name).run()

                utils.image_to_video(os.path.join('../', 'temp'),
                                     video_name + '.mp4')
                utils.add_audio(video_name + '.mp4', audio_file)
                print('The generated video is: {}'.format(
                    os.path.join(config.sample_dir, video_name + '.mov')))
        except:
            continue