예제 #1
0
def plot(x, epoch):
    nsample = 5 
    gen_seq = [[] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    mean_ssim, mean_psnr = [], []
    for s in range(nsample):
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        gen_seq[s].append(x[0])
        x_in = x[0]

        tempgt_seq, temppred_seq = [], []

        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:	
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])
                h_target = h_target[0].detach()
                z_t, _, _ = posterior(h_target)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                gen_seq[s].append(x_in)
            else:
                z_t, _, _ = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq[s].append(x_in)

                try:
                    temppred_seq.append(x_in.data.cpu().numpy())
                    tempgt_seq.append(x[i].data.cpu().numpy())
                except:
                    print("here")
                    pass

        _, ssim, psnr = utils.eval_seq(tempgt_seq, temppred_seq)
        mean_ssim.append(np.mean(ssim, axis=0)[-1])
        mean_psnr.append(np.mean(psnr, axis=0)[-1])
    return np.mean(mean_ssim), np.mean(mean_psnr)
    '''
예제 #2
0
def plot_rec(x, epoch, _type):
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    gen_seq, gt_seq, pred_seq = [], [], []
    gen_seq.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_past+opt.n_future):
        h = encoder(x[i-1])
        h_target = encoder(x[i])
        if opt.last_frame_skip or i < opt.n_past:	
            h, skip = h
        else:
            h, _ = h
        h_target, _ = h_target
        h = h.detach()
        h_target = h_target.detach()
        z_t, _, _= posterior(h_target)
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1)) 
            gen_seq.append(x[i])
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1))
            x_pred = decoder([h_pred, skip]).detach()
            gen_seq.append(x_pred)
            pred_seq.append(x_pred.data.cpu().numpy())
            gt_seq.append(x[i].data.cpu().numpy())

    _, ssim, psnr = utils.eval_seq(gt_seq, pred_seq)
    ''' 
    to_plot = []
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i]) 
        to_plot.append(row)
    fname = '%s/gen/%s_rec_%d.png' % (opt.log_dir, _type, epoch) 
    utils.save_tensors_image(fname, to_plot)
    ''' 
    
    return np.mean(ssim, axis=0), np.mean(psnr, axis=0)
예제 #3
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _ = posterior(h_target)  # take the mean
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1))
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                _, z_t, _ = posterior(h_target)
            else:
                z_t = torch.cuda.FloatTensor(opt.batch_size,
                                             opt.z_dim).normal_()
            if i < opt.n_past:
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]
        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        rand_sidx = [np.random.randint(nsample) for s in range(3)]
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)
예제 #4
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:	
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _= posterior(h_target) # take the mean
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1)) 
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)
  

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim))
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s+1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:	
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, mu, logvar = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)

                var = torch.exp(logvar)  # BxC
                var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()
    
    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))

    best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [ [] for t in range(opt.n_eval) ]
        text = [ [] for t in range(opt.n_eval) ]
        
        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i,:] = ssim[i,ordered[-1],:]
        # best ssim var
        best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i,:] = psnr[i, ordered_p[-1],:]
        # best psnr var
        best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]
        for t in range(opt.n_eval):
            # gt 
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s+1))

        #fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        utils.save_gif_with_text(fname, gifs, text)
    
        # -- generate samples
        to_plot = []
        gts = []
        best_s = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_s.append(all_gen[ordered[-1]][t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_s)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i)
        utils.save_tensors_image(fname, to_plot)
    
    return best_ssim, best_psnr, best_ssim_var, best_psnr_var
예제 #5
0
def make_gifs(x, idx, name):
    # sample from approx posterior
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    
    # rec
    hsvg_net.init_states(x[0])
    for i in range(1, opt.n_eval):
        if i < opt.n_past:
            hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
            hsvg_net.skips = feats
            posterior_gen.append(x[i])
        else:
            hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
            x_rec = hsvg_net.decoding(hs_rec)
            posterior_gen.append(x_rec)
    
    # sample from prior
    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    # variance
    var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim))
    all_gen = []
    gt_seq = [x[i].data.cpu().numpy() for i in range(opt.n_past, opt.n_eval)]

    for s in tqdm(range(nsample)):        
        gen_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        
        hsvg_net.init_states(x[0])
        for i in range(1, opt.n_eval):
            if i < opt.n_past:
                hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
                hsvg_net.skips = feats
                all_gen[s].append(x[i])
            else:
                x_pred = hsvg_net.inference()
                gen_seq.append(x_pred.data.cpu().numpy())
                all_gen[s].append(x_pred)
                
                logvar = torch.cat(hsvg_net.logvars_prior, -1)
                var = torch.exp(logvar)  # BxC
                var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)
    
    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))

    best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [ [] for t in range(opt.n_eval) ]
        text = [ [] for t in range(opt.n_eval) ]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i,:] = ssim[i,ordered[-1],:]
        # best ssim var
        best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i,:] = psnr[i, ordered_p[-1],:]
        # best psnr var
        best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt 
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s+1))

        fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_s = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_s.append(all_gen[ordered[-1]][t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_s)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr, best_ssim_var, best_psnr_var
예제 #6
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    # ------------ calculate the content posterior
    xs = []
    for i in range(0, opt.n_past):
        xs.append(x[i])
    #if True:
    random.shuffle(xs)
    #xc = torch.cat(xs, 1)
    mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1))
    mu_c = mu_c.detach()

    for i in range(1, opt.n_eval):
        h_target = pose_encoder(x[i]).detach()
        mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1),
                                            time_step=i - 1)
        z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
        if i < opt.n_past:
            frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1),
                                     time_step=i - 1).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))

    #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1))
    #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1))

    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    '''for i in range(1, opt.n_eval):
        out_gt = discriminator(torch.cat([x[0], x[i]],dim=1))
        ccm_i_gt = out_gt.mean().data.cpu().numpy()
        print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt))
        ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()'''

    hs = []
    for i in range(0, opt.n_past):
        hs.append(pose_encoder(x[i]).detach())

    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)

        h = pose_encoder(x[0]).detach()
        for i in range(1, opt.n_eval):
            h_target = pose_encoder(x[i]).detach()

            if i < opt.n_past:
                mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c],
                                                              1),
                                                    time_step=i - 1)
                z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
                prior(torch.cat([h, mu_c], 1), time_step=i - 1)
                frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
                x_in = x[i]
                all_gen[s].append(x_in)
                h = h_target
            else:
                mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1),
                                             time_step=i - 1)
                z_t = utils.reparameterize(mu_t_pp, logvar_t_pp)
                h_pred = frame_predictor(torch.cat([z_t, mu_c], 1),
                                         time_step=i - 1).detach()
                x_in = decoder([h_pred, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
                h = pose_encoder(x_in).detach()

            #out_pred = discriminator(torch.cat([x[0],x_in],dim=1))
            #ccm_i_pred = out_pred.mean().data.cpu().numpy()
            #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred))
            #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i, :] = ssim[i, ordered[-1], :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i, :] = psnr[i, ordered_p[-1], :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr  #, ccm_pred, ccm_gt
예제 #7
0
def make_gifs(x, idx, names):
    all_gt = x.copy()
    if opt.use_action:
        actions = x[1]
        x = x[0]
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _ = posterior(h_target)  # take the mean
        if i < opt.n_past:
            if not opt.use_action:
                frame_predictor(torch.cat([h, z_t], 1))
            else:
                frame_predictor(
                    torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                              1))
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            if not opt.use_action:
                h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            else:
                h_pred = frame_predictor(
                    torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                              1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                if not opt.use_action:
                    frame_predictor(torch.cat([h, z_t], 1))
                else:
                    frame_predictor(
                        torch.cat(
                            [h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, _, _ = prior(h)
                if not opt.use_action:
                    h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                else:
                    h = frame_predictor(
                        torch.cat(
                            [h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                            1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    ###### ssim ######
    for i in range(opt.batch_size):
        for s in range(nsample):
            imgs = []
            for t in range(opt.n_eval):
                img = all_gen[s][t][i].cpu().transpose(0, 1).transpose(
                    1, 2).clamp(0, 1).numpy()
                img = (img * 255).astype(np.uint8)
                imgs.append(img)

            fname = '%s/%s_%d.gif' % (opt.log_dir, names[i], s)
            utils.save_gif_IROS_2019(fname, imgs)

        # save ground truth
        imgs_gt = []
        for t in range(opt.n_eval):
            img = all_gt[t][i].cpu().transpose(0, 1).transpose(1, 2).clamp(
                0, 1).numpy()
            img = (img * 255).astype(np.uint8)
            imgs_gt.append(img)
        fname = '%s/%s_gt.gif' % (opt.log_dir, names[i])
        utils.save_gif_IROS_2019(fname, imgs_gt)