Пример #1
0
    def plot_rec(self, x, actions, epoch):
        self.frame_predictor.hidden = self.frame_predictor.init_hidden()
        self.posterior.hidden = self.posterior.init_hidden()
        gen_seq = []
        gen_seq.append(x[:,0])
        x_in = x[:,0]
        h_seq = [self.encoder(x[:,i]) for i in range(params["seq_len"])]
        for i in range(1, self.params["seq_len"]):
            h_target = h_seq[i][0].detach()
            if self.params["last_frame_skip"] or i < self.params["n_past"]:
                h, skip = h_seq[i - 1]
            else:
                h, _ = h_seq[i - 1]
            h = h.detach()
            z_t, mu, logvar = self.posterior(h_target)
            if i < self.params["n_past"]:
                self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1))
                gen_seq.append(x[:,i])
            else:
                h = self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)).detach()
                x_pred = self.decoder([h, skip]).detach()
                gen_seq.append(x_pred)

        to_plot = []
        nrow = min(self.params["batch_size"], 10)
        for i in range(nrow):
            row = []
            for t in range(self.params["seq_len"]):
                row.append(gen_seq[t][i])
            to_plot.append(row)
        check_dir(params["logdir"], "gen")
        fname = '%s/gen/rec_%d.png' % (self.params["logdir"], epoch)
        svp_utils.save_tensors_image(fname, to_plot)
Пример #2
0
 def init_vae_model(self):
     self.vae_dir = os.path.join(self.params["logdir"], 'vae')
     check_dir(self.vae_dir, 'samples')
     if not self.params["noreload"]:  # and os.path.exists(reload_file):
         reload_file = os.path.join(self.params["vae_location"], 'best.tar')
         state = torch.load(reload_file)
         print("Reloading model at epoch {}"
               ", with eval error {}".format(state['epoch'],
                                             state['precision']))
         self.model.load_state_dict(state['state_dict'])
         self.optimizer.load_state_dict(state['optimizer'])
Пример #3
0
 def init_svg_model(self):
     self.svg_dir = os.path.join(self.params["logdir"], 'svg')
     check_dir(self.svg_dir, 'samples')
Пример #4
0
if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='SVP_LP')
    parser.add_argument('--params', default="params/svg_fp_train_params.yaml", metavar='params',
                        help="Path to file containing parameters for training")
    args = parser.parse_args()

    with open(args.params, 'r') as stream:
        try:
            params = yaml.safe_load(stream)
            print(params)
        except yaml.YAMLError as exc:
            print(exc

    # Check if directories exists, if not, create them
    check_dir(params["logdir"])
    out_path = params["logdir"] + "/train_params.json"
    with open(out_path, 'w') as outfile:
        json.dump(params, outfile)

    if params["sample"]:
        check_dir(params["logdir"], 'results')
    #CHANGE
    #device = torch.device('cuda:1')
    #torch.cuda.set_device(device)

    # Initialize training
    trainer = SVG_FP_TRAINER(params)
    trainer.init_svg_model()
    train = partial(trainer.data_pass, train=True)
    test = partial(trainer.data_pass, train=False)
Пример #5
0
import json

parser = argparse.ArgumentParser(description='SVP_FP_eval')
parser.add_argument('--params',
                    default="params_svg/svg_fp_eval.yaml",
                    metavar='params',
                    help="Path to file containing parameters for training")
args = parser.parse_args()

with open(args.params, 'r') as stream:
    try:
        eval_params = yaml.safe_load(stream)
        print(eval_params)
    except yaml.YAMLError as exc:
        print(exc)
check_dir(eval_params["logdir"])

torch.manual_seed(eval_params["seed"])

# ---------------- load the models  ----------------

tmp = torch.load(eval_params["model_path"], map_location="cuda:0")
params = tmp["params"]
frame_predictor = tmp['frame_predictor']
posterior = tmp['posterior']
frame_predictor.eval()
posterior.eval()
encoder = tmp['encoder']
decoder = tmp['decoder']
encoder.eval()
decoder.eval()