def update_data(self, binding_dict):
        self.graph.update_binding(binding_dict)
        # self.data.update_data(self.graph, self.attr_encoder)

        x = self.attr_encoder.get_embedding(self.graph.get_nodes())
        edge_index, edge_types = self.graph.get_edge_info()
        # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types))
        edge_attr = torch.tensor(edge_types)
        edge_index = torch.tensor(edge_index)
        x = torch.tensor(x)
        batch = torch.zeros(x.shape[0], dtype=torch.int64)

        # print(f"previous edge num: {len(self.data.edge_attr)}")
        self.data = Data(x=x,
                         y=self.data.y,
                         edge_index=edge_index,
                         edge_attr=edge_attr)
        self.data.batch = batch
示例#2
0
    def update_data(self, binding_dict):
        self.graph.update_binding(binding_dict)
        # self.data.update_data(self.graph, self.attr_encoder)

        x = self.attr_encoder.get_embedding(
            [node.name for node in self.graph.nodes])
        edge_index, edge_types = self.graph.get_edge_info()
        # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types))
        edge_attr = torch.tensor(edge_types, dtype=torch.float)
        edge_index = torch.tensor(edge_index)
        x = torch.tensor(x)
        batch = self.data.batch

        # print(f"previous edge num: {len(self.data.edge_attr)}")
        self.data = Data(x=x,
                         y=self.data.y,
                         edge_index=edge_index,
                         edge_attr=edge_attr)
        self.data.batch = batch
示例#3
0

if __name__ == "__main__":
    # load the data
    data_dir = os.path.abspath(__file__ + "../../../data")
    root = os.path.abspath(os.path.join(data_dir, "./processed_dataset"))

    config = get_config()
    attr_encoder = Encoder(config)

    scenes_path = os.path.abspath(
        os.path.join(data_dir,
                     f"./processed_dataset/raw/{cmd_args.scene_file_name}"))
    with open(scenes_path, 'r') as scenes_file:
        scenes = json.load(scenes_file)

    # construct a mini example
    target_id = 0
    graph = Graph(config, scenes[0], target_id)

    x = attr_encoder.get_embedding([node.name for node in graph.nodes])
    edge_index, edge_types = graph.get_edge_info()
    edge_attrs = torch.tensor(attr_encoder.get_embedding(edge_types))
    data_point = Data(x=x,
                      edge_index=edge_index,
                      edge_attr=edge_attrs,
                      y=target_id)

    # construct an env
    env = Env(data_point, graph, config, attr_encoder)
if __name__ == "__main__":
    # load the data
    data_dir = os.path.abspath(__file__ + "../../../../data")
    root = os.path.abspath(os.path.join(data_dir, "./processed_dataset"))

    config = get_config()

    attr_encoder = Encoder(config)

    scenes_path = os.path.abspath(
        os.path.join(data_dir,
                     f"./processed_dataset/raw/{cmd_args.scene_file_name}"))
    with open(scenes_path, 'r') as scenes_file:
        scenes = json.load(scenes_file)

    # construct a mini example
    target_id = 0
    graph = Graph(config, scenes[0], target_id)

    x = attr_encoder.get_embedding(graph.get_nodes())
    edge_index, edge_types = graph.get_edge_info()
    edge_attrs = attr_encoder.get_embedding(
        [f"edge_{tp}" for tp in edge_types])
    data_point = Data(torch.tensor(x), torch.tensor(edge_index),
                      torch.tensor(edge_attrs), graph.target_id)

    # construct an env
    env = Env(data_point, graph, config, attr_encoder)
    # env.reset(graph)