コード例 #1
0
    def get_data_history(self,
                         idx,
                         t_sep,
                         fillings,
                         model_args=None,
                         label=None):
        res = {}
        # max_len_commands = 0
        # len_path = len(t_sep)
        if model_args is None:
            model_args = self.model_args
        pad_len = 0

        t_sep.extend([torch.empty(0, 14)] * pad_len)
        #             print("t_sep",len(t_sep))

        t_grouped = [
            SVGTensor.from_data(torch.cat(t_sep, dim=0),
                                PAD_VAL=self.PAD_VAL).add_eos().add_sos()
        ]
        t_normal = []
        for t in t_sep:
            s = SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL)
            #                 print(1,len(s.commands))
            j = s.add_eos().add_sos().pad(seq_len=20 + 2)
            #                 print(2,len(s.commands))
            t_normal.append(j)

        for arg in set(model_args):
            if "_grouped" in arg:
                arg_ = arg.split("_grouped")[0]
                t_list = t_grouped
            else:
                arg_ = arg
                t_list = t_normal

            if arg_ == "tensor":
                res[arg] = t_list

            if arg_ == "commands":
                res[arg] = torch.stack([t.cmds() for t in t_list])

            if arg_ == "args_rel":
                res[arg] = torch.stack([t.get_relative_args() for t in t_list])
            if arg_ == "args":
                res[arg] = torch.stack([t.args() for t in t_list])

        if "filling" in model_args:
            res["filling"] = torch.stack(
                [torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1)

        if "label" in model_args:
            res["label"] = label
        return res
コード例 #2
0
    def get_data(self, idx, t_sep, fillings, model_args=None, label=None):
        res = {}
        # max_len_commands = 0
        # len_path = len(t_sep)
        if model_args is None:
            model_args = self.model_args
        if len(t_sep) > self.MAX_NUM_GROUPS:
            return None
        pad_len = max(self.MAX_NUM_GROUPS - len(t_sep), 0)

        t_sep.extend([torch.empty(0, 14)] * pad_len)

        t_grouped = [
            SVGTensor.from_data(torch.cat(t_sep, dim=0),
                                PAD_VAL=self.PAD_VAL).add_eos().add_sos().pad(
                                    seq_len=self.MAX_TOTAL_LEN + 2)
        ]
        t_normal = []
        for t in t_sep:
            s = SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL)
            if len(s.commands) > self.MAX_SEQ_LEN:
                return None
            t_normal.append(
                s.add_eos().add_sos().pad(seq_len=self.MAX_SEQ_LEN + 2))

        for arg in set(model_args):
            if "_grouped" in arg:
                arg_ = arg.split("_grouped")[0]
                t_list = t_grouped
            else:
                arg_ = arg
                t_list = t_normal

            if arg_ == "tensor":
                res[arg] = t_list

            if arg_ == "commands":
                res[arg] = torch.stack([t.cmds() for t in t_list])

            if arg_ == "args_rel":
                res[arg] = torch.stack([t.get_relative_args() for t in t_list])
            if arg_ == "args":
                res[arg] = torch.stack([t.args() for t in t_list])

        if "filling" in model_args:
            res["filling"] = torch.stack(
                [torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1)

        if "label" in model_args:
            res["label"] = label
        return res
コード例 #3
0
    def visualize(self, model, output, train_vars, step, epoch, summary_writer,
                  visualization_dir):
        device = next(model.parameters()).device

        # Reconstruction
        for i, data in enumerate(train_vars.x_inputs_train):
            model_args = batchify((data[key] for key in self.model_args),
                                  device)
            commands_y, args_y = model.module.greedy_sample(*model_args)
            tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(),
                                                  args_y[0].cpu())

            try:
                svg_path_sample = SVG.from_tensor(
                    tensor_pred.data, viewbox=Bbox(256),
                    allow_empty=True).normalize().split_paths().set_color(
                        "random")
            except:
                continue

            tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad()
            svg_path_gt = SVG.from_tensor(
                tensor_target.data, viewbox=Bbox(
                    256)).normalize().split_paths().set_color("random")

            img = make_grid([svg_path_sample,
                             svg_path_gt]).draw(do_display=False,
                                                return_png=True,
                                                fill=False,
                                                with_points=False)
            summary_writer.add_image(f"reconstructions_train/{i}",
                                     TF.to_tensor(img), step)
コード例 #4
0
ファイル: interpolate.py プロジェクト: zbxzc35/deepsvg
def decode(z):
    commands_y, args_y = model.greedy_sample(z=z)
    tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
    svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256))

    return svg_path_sample