def plot_rec(self, x, actions, epoch): self.frame_predictor.hidden = self.frame_predictor.init_hidden() self.posterior.hidden = self.posterior.init_hidden() gen_seq = [] gen_seq.append(x[:,0]) x_in = x[:,0] h_seq = [self.encoder(x[:,i]) for i in range(params["seq_len"])] for i in range(1, self.params["seq_len"]): h_target = h_seq[i][0].detach() if self.params["last_frame_skip"] or i < self.params["n_past"]: h, skip = h_seq[i - 1] else: h, _ = h_seq[i - 1] h = h.detach() z_t, mu, logvar = self.posterior(h_target) if i < self.params["n_past"]: self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)) gen_seq.append(x[:,i]) else: h = self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)).detach() x_pred = self.decoder([h, skip]).detach() gen_seq.append(x_pred) to_plot = [] nrow = min(self.params["batch_size"], 10) for i in range(nrow): row = [] for t in range(self.params["seq_len"]): row.append(gen_seq[t][i]) to_plot.append(row) check_dir(params["logdir"], "gen") fname = '%s/gen/rec_%d.png' % (self.params["logdir"], epoch) svp_utils.save_tensors_image(fname, to_plot)
def init_vae_model(self): self.vae_dir = os.path.join(self.params["logdir"], 'vae') check_dir(self.vae_dir, 'samples') if not self.params["noreload"]: # and os.path.exists(reload_file): reload_file = os.path.join(self.params["vae_location"], 'best.tar') state = torch.load(reload_file) print("Reloading model at epoch {}" ", with eval error {}".format(state['epoch'], state['precision'])) self.model.load_state_dict(state['state_dict']) self.optimizer.load_state_dict(state['optimizer'])
def init_svg_model(self): self.svg_dir = os.path.join(self.params["logdir"], 'svg') check_dir(self.svg_dir, 'samples')
if __name__ == "__main__": parser = argparse.ArgumentParser(description='SVP_LP') parser.add_argument('--params', default="params/svg_fp_train_params.yaml", metavar='params', help="Path to file containing parameters for training") args = parser.parse_args() with open(args.params, 'r') as stream: try: params = yaml.safe_load(stream) print(params) except yaml.YAMLError as exc: print(exc # Check if directories exists, if not, create them check_dir(params["logdir"]) out_path = params["logdir"] + "/train_params.json" with open(out_path, 'w') as outfile: json.dump(params, outfile) if params["sample"]: check_dir(params["logdir"], 'results') #CHANGE #device = torch.device('cuda:1') #torch.cuda.set_device(device) # Initialize training trainer = SVG_FP_TRAINER(params) trainer.init_svg_model() train = partial(trainer.data_pass, train=True) test = partial(trainer.data_pass, train=False)
import json parser = argparse.ArgumentParser(description='SVP_FP_eval') parser.add_argument('--params', default="params_svg/svg_fp_eval.yaml", metavar='params', help="Path to file containing parameters for training") args = parser.parse_args() with open(args.params, 'r') as stream: try: eval_params = yaml.safe_load(stream) print(eval_params) except yaml.YAMLError as exc: print(exc) check_dir(eval_params["logdir"]) torch.manual_seed(eval_params["seed"]) # ---------------- load the models ---------------- tmp = torch.load(eval_params["model_path"], map_location="cuda:0") params = tmp["params"] frame_predictor = tmp['frame_predictor'] posterior = tmp['posterior'] frame_predictor.eval() posterior.eval() encoder = tmp['encoder'] decoder = tmp['decoder'] encoder.eval() decoder.eval()