Exemplo n.º 1
0
    def train(self, data_loader):
        print('Training...')
        with torch.autograd.set_detect_anomaly(True):
            self.epoch += 1
            self.G.train()
            self.D.train()
            record_G = utils.Record()
            record_D = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()

                self.D.zero_grad()
                # update D with real images
                real_output = self.D(image)
                err_D_real = self.loss(real_output, self.real_label)
                D_x = real_output.data.mean()
                # update D with reconstructed images
                fake_input, *_ = self.trace2image(trace)
                fake_refine = self.G(fake_input)
                fake_output = self.D(fake_refine.detach())
                err_D_fake = self.loss(fake_output, self.fake_label)
                D_G_z = fake_output.data.mean()

                err_D = err_D_fake + err_D_real
                err_D.backward()
                self.optimizerD.step()

                self.G.zero_grad()
                # update G
                fake_output = self.D(fake_refine)
                err_G = self.loss(fake_output, self.real_label)

                err_G.backward()
                self.optimizerG.step()

                record_D.add(err_D.item())
                record_G.add(err_G.item())
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Epoch: %d' % self.epoch)
            print('Costs time: %.2f s' % (time.time() - start_time))
            print('Loss of G: %f' % (record_G.mean()))
            print('Loss of D: %f' % (record_D.mean()))
            print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' %
                                          (self.args['gan_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/test/tr2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/test/im2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
Exemplo n.º 2
0
    def train(self, data_loader):
        print('Training...')
        with torch.autograd.set_detect_anomaly(True):
            self.epoch += 1
            self.set_train()
            record_trace = utils.Record()
            record_image = utils.Record()
            record_inter = utils.Record()
            record_kld = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()
                self.zero_grad()
                trace_embed = self.TraceEncoder(trace)
                image_embed = self.ImageEncoder(image)
                trace_mu, trace_logvar = trace_embed, trace_embed
                image_mu, image_logvar = image_embed, image_embed
                trace_z = utils.reparameterize(trace_mu, trace_logvar)
                image_z = utils.reparameterize(image_mu, image_logvar)
                trace2image, trace_inter = self.Decoder(trace_z)
                image2image, image_inter = self.Decoder(image_z)

                err_trace = self.l1(trace2image, image)
                err_image = self.l1(image2image, image)
                #err_inter = self.l2(trace_inter, image_inter)
                err_kld = self.kld(image_mu, image_logvar, trace_mu,
                                   trace_logvar)

                #(err_trace + err_image + err_inter + self.args['beta'] * err_kld).backward()
                (err_trace + err_image +
                 self.args['beta'] * err_kld).backward()

                self.optimizer.step()

                record_trace.add(err_trace)
                record_image.add(err_image)
                #record_inter.add(err_inter)
                record_kld.add(err_kld)
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Epoch: %d' % self.epoch)
            print('Costs time: %.2fs' % (time.time() - start_time))
            print('Loss of Trace to Image: %f' % (record_trace.mean()))
            print('Loss of Image to Image: %f' % (record_image.mean()))
            print('Loss of KL-Divergence: %f' % (record_kld.mean()))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/train/target_%03d.jpg' %
                                          (self.args['vae_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/train/tr2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/train/im2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
Exemplo n.º 3
0
def make_gifs(x, idx, name, frame_predictor, posterior, prior, encoder,
              decoder):

    nsample = opt.nsample
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, _, _ = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)

    # prep np.array to be plotted
    TRU = np.zeros(
        [opt.n_eval, opt.batch_size, 1, opt.image_width, opt.image_width])
    GEN = np.zeros([
        nsample, opt.n_eval, opt.batch_size, 1, opt.image_width,
        opt.image_width
    ])
    for i in range(opt.n_eval):
        TRU[i, :, :, :, :] = inv_scaler(x[i].cpu().numpy())
        for k in range(nsample):
            GEN[k, i, :, :, :, :] = inv_scaler(all_gen[k][i].cpu().numpy())
    # plot
    print(" ground truth max:", np.max(TRU), " gen max:", np.max(GEN))
    for j in range(opt.batch_size):
        plot_rainfall(TRU[:, j, 0, :, :], GEN[:, :, j, 0, :, :], opt.log_dir,
                      name + "_sample" + str(j), nsample)

    progress.finish()
    utils.clear_progressbar()
Exemplo n.º 4
0
def run():
    if not os.path.exists(CSV_DIRECTORY):
        os.makedirs(CSV_DIRECTORY)
    number_of_samples = count_xml_dumps()
    print('Creating stations.csv…')
    create_stations()
    print('Creating bikes.csv…')
    create_bikes(number_of_samples)
    clear_progressbar()
    print('Creating bike_positions.csv and bike_movements.csv…')
    create_bike_positions_and_movement(number_of_samples)
    clear_progressbar()
Exemplo n.º 5
0
    def test(self, data_loader):
        print('Testing...')
        with torch.no_grad():
            self.G.eval()
            self.D.eval()
            record_G = utils.Record()
            record_D = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()

                real_output = self.D(image)
                err_D_real = self.loss(real_output, self.real_label)
                D_x = real_output.data.mean()

                fake_input, *_ = self.trace2image(trace)
                fake_refine = self.G(fake_input)
                fake_output = self.D(fake_refine.detach())
                err_D_fake = self.loss(fake_output, self.fake_label)
                D_G_z = fake_output.data.mean()

                err_D = err_D_fake + err_D_real

                fake_output = self.D(fake_refine)
                err_G = self.loss(fake_output, self.real_label)

                record_D.add(err_D.item())
                record_G.add(err_G.item())
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Test at Epoch %d' % self.epoch)
            print('Costs time: %.2f s' % (time.time() - start_time))
            print('Loss of G: %f' % (record_G.mean()))
            print('Loss of D: %f' % (record_D.mean()))
            print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' %
                                          (self.args['gan_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/test/tr2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/test/im2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
Exemplo n.º 6
0
def render_hourly(session):
    date = START_DATE
    total_steps = (END_DATE - START_DATE) / STEP
    i = 0
    while date < END_DATE:
        i += 1
        print_progressbar(i / total_steps)
        graph = nx.MultiDiGraph()
        start = date
        end = (date + STEP)
        result = session.run(
            """
            MATCH (a:Station)-[r:BIKE_MOVED]->(b:Station)
            WHERE {start} <= r.timestamp_start < {end}
            RETURN a, r, b""", {
                'start': start.timestamp(),
                'end': end.timestamp()
            })
        for record in result:
            station_a = record['a']['name'].replace('/', ' /\n')
            station_b = record['b']['name'].replace('/', ' /\n')
            bike_id = record['r']['bike_id']
            start_time = datetime.fromtimestamp(
                record['r']['timestamp_start']).strftime('%H\:%M')
            end_time = datetime.fromtimestamp(
                record['r']['timestamp_end']).strftime('%H\:%M')
            label = f'{start_time} -\n{end_time}'
            color = 'red' if record['r']['transporter'] else '#aaaaaa'
            penwidth = 2 if record['r']['transporter'] else 1
            graph.add_edge(station_a,
                           station_b,
                           label=label,
                           color=color,
                           penwidth=penwidth)
            # graph.add_edge(station_a, station_b, label=bike_id)
        filename = f"{start.strftime('%Y-%m-%d_%H_%M')} - {end.strftime('%Y-%m-%d_%H_%M')}.dot"
        write_dot(graph, os.path.join(OUTPUT_DIRECTORY, filename))
        date = end
    clear_progressbar()
Exemplo n.º 7
0
    encoder.train()
    decoder.train()
    epoch_mse = 0
    epoch_kld = 0

    progress = progressbar.ProgressBar(max_value=opt.epoch_size).start()
    for i in range(opt.epoch_size):
        progress.update(i + 1)
        x = next(training_batch_generator)
        # train frame_predictor
        mse, kld = train(x)
        epoch_mse += mse
        epoch_kld += kld

    progress.finish()
    utils.clear_progressbar()

    print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' %
          (epoch, epoch_mse / opt.epoch_size, epoch_kld / opt.epoch_size,
           epoch * opt.epoch_size * opt.batch_size))

    # plot some stuff
    frame_predictor.eval()
    #encoder.eval()
    #decoder.eval()
    posterior.eval()
    prior.eval()

    x = next(testing_batch_generator)
    plot(x, epoch)
    plot_rec(x, epoch)
Exemplo n.º 8
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _ = posterior(h_target)  # take the mean
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1))
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                _, z_t, _ = posterior(h_target)
            else:
                z_t = torch.cuda.FloatTensor(opt.batch_size,
                                             opt.z_dim).normal_()
            if i < opt.n_past:
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]
        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        rand_sidx = [np.random.randint(nsample) for s in range(3)]
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)
Exemplo n.º 9
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:	
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _= posterior(h_target) # take the mean
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1)) 
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)
  

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim))
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s+1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:	
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, mu, logvar = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)

                var = torch.exp(logvar)  # BxC
                var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()
    
    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))

    best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [ [] for t in range(opt.n_eval) ]
        text = [ [] for t in range(opt.n_eval) ]
        
        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i,:] = ssim[i,ordered[-1],:]
        # best ssim var
        best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i,:] = psnr[i, ordered_p[-1],:]
        # best psnr var
        best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]
        for t in range(opt.n_eval):
            # gt 
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s+1))

        #fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        utils.save_gif_with_text(fname, gifs, text)
    
        # -- generate samples
        to_plot = []
        gts = []
        best_s = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_s.append(all_gen[ordered[-1]][t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_s)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i)
        utils.save_tensors_image(fname, to_plot)
    
    return best_ssim, best_psnr, best_ssim_var, best_psnr_var
Exemplo n.º 10
0
# --------- training loop ------------------------------------
for epoch in range(opt.niter):
    lstm.train()
    epoch_loss = 0
    progress = progressbar.ProgressBar(max_value=opt.epoch_size).start()
    for i in range(opt.epoch_size):
        progress.update(i+1)
        x = next(training_batch_generator)

        # train lstm
        loss = train(x)
        epoch_loss += loss


    progress.finish()
    utils.clear_progressbar()

    lstm.eval()
    # plot some stuff
    x = next(testing_batch_generator)
    plot_gen(x, epoch)
    plot_rec(x, epoch)

    print('[%02d] mse loss: %.6f (%d)' % (epoch, epoch_loss/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size))

    # save the model
    torch.save({
        'lstm': lstm,
        'opt': opt},
        '%s/model.pth' % opt.log_dir)
    print('here')
Exemplo n.º 11
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    # ------------ calculate the content posterior
    xs = []
    for i in range(0, opt.n_past):
        xs.append(x[i])
    #if True:
    random.shuffle(xs)
    #xc = torch.cat(xs, 1)
    mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1))
    mu_c = mu_c.detach()

    for i in range(1, opt.n_eval):
        h_target = pose_encoder(x[i]).detach()
        mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1),
                                            time_step=i - 1)
        z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
        if i < opt.n_past:
            frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1),
                                     time_step=i - 1).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))

    #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1))
    #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1))

    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    '''for i in range(1, opt.n_eval):
        out_gt = discriminator(torch.cat([x[0], x[i]],dim=1))
        ccm_i_gt = out_gt.mean().data.cpu().numpy()
        print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt))
        ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()'''

    hs = []
    for i in range(0, opt.n_past):
        hs.append(pose_encoder(x[i]).detach())

    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)

        h = pose_encoder(x[0]).detach()
        for i in range(1, opt.n_eval):
            h_target = pose_encoder(x[i]).detach()

            if i < opt.n_past:
                mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c],
                                                              1),
                                                    time_step=i - 1)
                z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
                prior(torch.cat([h, mu_c], 1), time_step=i - 1)
                frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
                x_in = x[i]
                all_gen[s].append(x_in)
                h = h_target
            else:
                mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1),
                                             time_step=i - 1)
                z_t = utils.reparameterize(mu_t_pp, logvar_t_pp)
                h_pred = frame_predictor(torch.cat([z_t, mu_c], 1),
                                         time_step=i - 1).detach()
                x_in = decoder([h_pred, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
                h = pose_encoder(x_in).detach()

            #out_pred = discriminator(torch.cat([x[0],x_in],dim=1))
            #ccm_i_pred = out_pred.mean().data.cpu().numpy()
            #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred))
            #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i, :] = ssim[i, ordered[-1], :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i, :] = psnr[i, ordered_p[-1], :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr  #, ccm_pred, ccm_gt
Exemplo n.º 12
0
def make_gifs(x, idx, names):
    all_gt = x.copy()
    if opt.use_action:
        actions = x[1]
        x = x[0]
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _ = posterior(h_target)  # take the mean
        if i < opt.n_past:
            if not opt.use_action:
                frame_predictor(torch.cat([h, z_t], 1))
            else:
                frame_predictor(
                    torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                              1))
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            if not opt.use_action:
                h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            else:
                h_pred = frame_predictor(
                    torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                              1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                if not opt.use_action:
                    frame_predictor(torch.cat([h, z_t], 1))
                else:
                    frame_predictor(
                        torch.cat(
                            [h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, _, _ = prior(h)
                if not opt.use_action:
                    h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                else:
                    h = frame_predictor(
                        torch.cat(
                            [h, z_t, actions[i - 1].repeat(1, opt.a_dim)],
                            1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    ###### ssim ######
    for i in range(opt.batch_size):
        for s in range(nsample):
            imgs = []
            for t in range(opt.n_eval):
                img = all_gen[s][t][i].cpu().transpose(0, 1).transpose(
                    1, 2).clamp(0, 1).numpy()
                img = (img * 255).astype(np.uint8)
                imgs.append(img)

            fname = '%s/%s_%d.gif' % (opt.log_dir, names[i], s)
            utils.save_gif_IROS_2019(fname, imgs)

        # save ground truth
        imgs_gt = []
        for t in range(opt.n_eval):
            img = all_gt[t][i].cpu().transpose(0, 1).transpose(1, 2).clamp(
                0, 1).numpy()
            img = (img * 255).astype(np.uint8)
            imgs_gt.append(img)
        fname = '%s/%s_gt.gif' % (opt.log_dir, names[i])
        utils.save_gif_IROS_2019(fname, imgs_gt)