def get_data_history(self, idx, t_sep, fillings, model_args=None, label=None): res = {} # max_len_commands = 0 # len_path = len(t_sep) if model_args is None: model_args = self.model_args pad_len = 0 t_sep.extend([torch.empty(0, 14)] * pad_len) # print("t_sep",len(t_sep)) t_grouped = [ SVGTensor.from_data(torch.cat(t_sep, dim=0), PAD_VAL=self.PAD_VAL).add_eos().add_sos() ] t_normal = [] for t in t_sep: s = SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL) # print(1,len(s.commands)) j = s.add_eos().add_sos().pad(seq_len=20 + 2) # print(2,len(s.commands)) t_normal.append(j) for arg in set(model_args): if "_grouped" in arg: arg_ = arg.split("_grouped")[0] t_list = t_grouped else: arg_ = arg t_list = t_normal if arg_ == "tensor": res[arg] = t_list if arg_ == "commands": res[arg] = torch.stack([t.cmds() for t in t_list]) if arg_ == "args_rel": res[arg] = torch.stack([t.get_relative_args() for t in t_list]) if arg_ == "args": res[arg] = torch.stack([t.args() for t in t_list]) if "filling" in model_args: res["filling"] = torch.stack( [torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1) if "label" in model_args: res["label"] = label return res
def get_data(self, idx, t_sep, fillings, model_args=None, label=None): res = {} # max_len_commands = 0 # len_path = len(t_sep) if model_args is None: model_args = self.model_args if len(t_sep) > self.MAX_NUM_GROUPS: return None pad_len = max(self.MAX_NUM_GROUPS - len(t_sep), 0) t_sep.extend([torch.empty(0, 14)] * pad_len) t_grouped = [ SVGTensor.from_data(torch.cat(t_sep, dim=0), PAD_VAL=self.PAD_VAL).add_eos().add_sos().pad( seq_len=self.MAX_TOTAL_LEN + 2) ] t_normal = [] for t in t_sep: s = SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL) if len(s.commands) > self.MAX_SEQ_LEN: return None t_normal.append( s.add_eos().add_sos().pad(seq_len=self.MAX_SEQ_LEN + 2)) for arg in set(model_args): if "_grouped" in arg: arg_ = arg.split("_grouped")[0] t_list = t_grouped else: arg_ = arg t_list = t_normal if arg_ == "tensor": res[arg] = t_list if arg_ == "commands": res[arg] = torch.stack([t.cmds() for t in t_list]) if arg_ == "args_rel": res[arg] = torch.stack([t.get_relative_args() for t in t_list]) if arg_ == "args": res[arg] = torch.stack([t.args() for t in t_list]) if "filling" in model_args: res["filling"] = torch.stack( [torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1) if "label" in model_args: res["label"] = label return res
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir): device = next(model.parameters()).device # Reconstruction for i, data in enumerate(train_vars.x_inputs_train): model_args = batchify((data[key] for key in self.model_args), device) commands_y, args_y = model.module.greedy_sample(*model_args) tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) try: svg_path_sample = SVG.from_tensor( tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color( "random") except: continue tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad() svg_path_gt = SVG.from_tensor( tensor_target.data, viewbox=Bbox( 256)).normalize().split_paths().set_color("random") img = make_grid([svg_path_sample, svg_path_gt]).draw(do_display=False, return_png=True, fill=False, with_points=False) summary_writer.add_image(f"reconstructions_train/{i}", TF.to_tensor(img), step)
def decode(z): commands_y, args_y = model.greedy_sample(z=z) tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256)) return svg_path_sample