Esempio n. 1
0
    def plot_rec(self, x, actions, epoch):
        self.frame_predictor.hidden = self.frame_predictor.init_hidden()
        self.posterior.hidden = self.posterior.init_hidden()
        gen_seq = []
        gen_seq.append(x[:,0])
        x_in = x[:,0]
        h_seq = [self.encoder(x[:,i]) for i in range(params["seq_len"])]
        for i in range(1, self.params["seq_len"]):
            h_target = h_seq[i][0].detach()
            if self.params["last_frame_skip"] or i < self.params["n_past"]:
                h, skip = h_seq[i - 1]
            else:
                h, _ = h_seq[i - 1]
            h = h.detach()
            z_t, mu, logvar = self.posterior(h_target)
            if i < self.params["n_past"]:
                self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1))
                gen_seq.append(x[:,i])
            else:
                h = self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)).detach()
                x_pred = self.decoder([h, skip]).detach()
                gen_seq.append(x_pred)

        to_plot = []
        nrow = min(self.params["batch_size"], 10)
        for i in range(nrow):
            row = []
            for t in range(self.params["seq_len"]):
                row.append(gen_seq[t][i])
            to_plot.append(row)
        check_dir(params["logdir"], "gen")
        fname = '%s/gen/rec_%d.png' % (self.params["logdir"], epoch)
        svp_utils.save_tensors_image(fname, to_plot)
Esempio n. 2
0
    def plot_samples(self, x, actions, epoch):
        nsample = 5
        gen_seq = [[] for i in range(nsample)]

        gt_seq = [x[:,i] for i in range(x.shape[1])]

        #h_seq = [self.encoder(x[:,i]) for i in range(params["n_past"])]
        for s in range(nsample):
            self.frame_predictor.hidden = self.frame_predictor.init_hidden()

            gen_seq[s].append(x[:,0])
            x_in = x[:,0]
            for i in range(1, self.params["n_eval"]):
                h = self.encoder(x_in)
                if self.params["last_frame_skip"] or i < self.params["n_past"]:
                    h, skip = h
                    h = h.detach()
                else:
                    h, _ = h
                    h = h.detach()
                if i < self.params["n_past"]:
                    h_target = self.encoder(x[:, i])[0].detach()
                    self.frame_predictor(torch.cat([h, actions[:,i-1]], 1))
                    x_in = x[:,i]

                    gen_seq[s].append(x_in)
                else:
                    h = self.frame_predictor(torch.cat([h, actions[:,i-1]], 1)).detach()
                    x_in = self.decoder([h, skip]).detach()
                    gen_seq[s].append(x_in)

        to_plot = []
        gifs = [[] for t in range(self.params["n_eval"])]
        nrow = min(self.params["batch_size"], 10)
        for i in range(nrow):
            # ground truth sequence
            row = []
            for t in range(self.params["n_eval"]):
                row.append(gt_seq[t][i])
            to_plot.append(row)

            for s in range(nsample):
                row = []
                for t in range(self.params["n_eval"]):
                    row.append(gen_seq[s][t][i])
                to_plot.append(row)
            for t in range(self.params["n_eval"]):
                row = []
                row.append(gt_seq[t][i])
                for s in range(nsample):
                    row.append(gen_seq[s][t][i])
                gifs[t].append(row)

        fname = '%s/gen/sample_%d.png' % (self.params["logdir"], epoch)
        svp_utils.save_tensors_image(fname, to_plot)

        fname = '%s/gen/sample_%d.gif' % (self.params["logdir"], epoch)
        svp_utils.save_gif(fname, gifs)