예제 #1
0
def create_data(batch_size, num_elements_min_max):
    """Returns graphs containing the inputs and targets for classification.

  Refer to create_data_dicts_tf and create_linked_list_target for more details.

  Args:
    batch_size: batch size for the `input_graphs`.
    num_elements_min_max: a 2-`tuple` of `int`s which define the [lower, upper)
      range of the number of elements per list.

  Returns:
    inputs: a `graphs.GraphsTuple` which contains the input list as a graph.
    targets: a `graphs.GraphsTuple` which contains the target as a graph.
    sort_indices: a `graphs.GraphsTuple` which contains the sort indices of
      the list elements a graph.
    ranks: a `graphs.GraphsTuple` which contains the ranks of the list
      elements as a graph.
  """
    inputs, sort_indices, ranks = create_graph_dicts_tf(
        batch_size, num_elements_min_max)
    inputs = utils_tf.data_dicts_to_graphs_tuple(inputs)
    sort_indices = utils_tf.data_dicts_to_graphs_tuple(sort_indices)
    ranks = utils_tf.data_dicts_to_graphs_tuple(ranks)

    inputs = utils_tf.fully_connect_graph_dynamic(inputs)
    sort_indices = utils_tf.fully_connect_graph_dynamic(sort_indices)
    ranks = utils_tf.fully_connect_graph_dynamic(ranks)

    targets = create_linked_list_target(batch_size, sort_indices)
    nodes = tf.concat((targets.nodes, 1.0 - targets.nodes), axis=1)
    edges = tf.concat((targets.edges, 1.0 - targets.edges), axis=1)
    targets = targets._replace(nodes=nodes, edges=edges)

    return inputs, targets, sort_indices, ranks  # input node[7,1] edge  target edge[49,2] node[7,2]
예제 #2
0
    def make_graph(self, event, debug=False):
        """
        Convert the event into a graphs_tuple. 
        """
        edge_name = self.edge_name
        n_nodes = event['x'].shape[0]
        n_edges = event[edge_name].shape[1]
        nodes = event['x']
        edges = np.zeros((n_edges, 1), dtype=np.float32)
        senders = event[edge_name][0, :]
        receivers = event[edge_name][1, :]
        edge_target = event[self.truth_name].numpy().astype(np.float32)

        input_datadict = {
            "n_node": n_nodes,
            "n_edge": n_edges,
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": np.array([n_nodes], dtype=np.float32)
        }
        n_edges_target = 1
        target_datadict = {
            "n_node": 1,
            "n_edge": n_edges_target,
            "nodes": np.zeros((1, 1), dtype=np.float32),
            "edges": edge_target,
            "senders": np.zeros((n_edges_target, ), dtype=np.int32),
            "receivers": np.zeros((n_edges_target, ), dtype=np.int32),
            "globals": np.zeros((1, ), dtype=np.float32),
        }
        input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
        target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
        return [(input_graph, target_graph)]
예제 #3
0
    def make_graph(self, event, debug):
        inputs_tr, _ = event

        # apply the GNN model and filter out the edges with a score less than the threshold 0.5.
        outputs_tr = self.model(inputs_tr, self.num_mp)
        output_graph = outputs_tr[-1]

        # calculate similar variables for GNN-based reconstruction
        # method-one, place a threshold on edge score
        edge_predict = np.squeeze(output_graph.edges.numpy())
        edge_passed = edge_predict > self.edge_cut
        nodes_sel = np.unique(np.concatenate([output_graph.receivers.numpy()[edge_passed],\
            output_graph.senders.numpy()[edge_passed]], axis=0))

        n_nodes = nodes_sel.shape[0]
        n_edges = sum(edge_passed)
        nodes = inputs_tr.nodes.numpy()[nodes_sel]
        edges = inputs_tr.edges.numpy()[edge_passed]

        node_dicts = {}
        for idx, val in enumerate(nodes_sel):
            node_dicts[val] = idx

        senders = np.array(
            [node_dicts[x] for x in inputs_tr.senders.numpy()[edge_passed]])
        receivers = np.array(
            [node_dicts[x] for x in inputs_tr.receivers.numpy()[edge_passed]])
        # print("n-nodes:", n_nodes)
        # print("n-edges:", n_edges)
        # print("nodes:", nodes.shape)
        # print("edges:", edges.shape)
        # print("senders:", senders.shape)
        # print("receivers:", receivers.shape)
        # print(senders)
        # print(receivers)

        input_datadict = {
            "n_node": n_nodes,
            "n_edge": n_edges,
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": np.array([n_nodes], dtype=np.float32)
        }

        target_datadict = {
            "n_node": n_nodes,
            "n_edge": n_edges,
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": np.array([float(self.is_signal)], dtype=np.float32)
        }
        input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
        target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
        return [(input_graph, target_graph)]
예제 #4
0
파일: graph.py 프로젝트: rkunnawa/root_gnn
    def make_subgraph(mask):
        hit_id = hits[mask].hit_id.values
        sub_doublets = segments[segments.hit_id_in.isin(hit_id) & segments.hit_id_out.isin(hit_id)]

        # TODO: include all edges, uncomment following lines. <>
        # sub_doublets = segments[segments.hit_id_in.isin(hit_id)]
        # # extend the hits to include the hits used in the sub-doublets.
        # hit_id = hits[mask | hits.hit_id.isin(sub_doublets.hit_id_out.values)].hit_id.values

        n_nodes = hit_id.shape[0]
        n_edges = sub_doublets.shape[0]
        nodes = hits[mask][node_features].values.astype(f_dtype)
        if edge_features is None:
            edges = np.expand_dims(np.array([0.0]*n_edges, dtype=np.float32), axis=1)
        else:
            edges = sub_doublets[edge_features].values.astype(f_dtype)
        # print(nodes.dtype)

        hits_id_dict = dict([(hit_id[idx], idx) for idx in range(n_nodes)])
        in_hit_ids = sub_doublets.hit_id_in.values
        out_hit_ids = sub_doublets.hit_id_out.values
        senders   = [hits_id_dict[in_hit_ids[idx]]  for idx in range(n_edges)]
        receivers = [hits_id_dict[out_hit_ids[idx]] for idx in range(n_edges)]
        if verbose:
            print("\t{} nodes and {} edges".format(n_nodes, n_edges))
        senders = np.array(senders)
        receivers = np.array(receivers)

        input_datadict = {
            "n_node": n_nodes,
            'n_edge': n_edges,
            'nodes': nodes,
            'edges': edges,
            'senders': senders,
            'receivers': receivers,
            'globals': zeros
        }
        target_datadict = {
            "n_node": n_nodes,
            'n_edge': n_edges,
            'nodes': np.zeros((n_nodes, n_node_features), dtype=f_dtype),
            'edges': np.expand_dims(sub_doublets.solution.values.astype(f_dtype), axis=1),
            'senders': senders,
            'receivers': receivers,
            'globals': zeros
        }
        
        input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
        target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])

        if with_pad:
            input_graph = padding(input_graph, N_MAX_NODES, N_MAX_EDGES)
            target_graph = padding(target_graph, N_MAX_NODES, N_MAX_EDGES)

        return [(input_graph, target_graph)]
예제 #5
0
def make_graph(event, debug=False):
    # n_max_nodes = 60
    n_nodes = len(event.jet_pt)
    nodes = np.hstack(
        (event.jet_pt, event.jet_eta, event.jet_phi, event.jet_e))
    nodes = nodes.reshape((n_nodes, 4))
    if debug:
        print(np.array(event.jet_pt).shape)
        print(n_nodes)
        print(nodes)

    # edges
    all_edges = list(itertools.combinations(range(n_nodes), 2))
    senders = np.array([x[0] for x in all_edges])
    receivers = np.array([x[1] for x in all_edges])
    n_edges = len(all_edges)
    edges = np.expand_dims(np.array([0.0] * n_edges, dtype=np.float32), axis=1)
    true_edges = set(list(itertools.combinations(event.reco_triplet_0, 2)) \
        + list(itertools.combinations(event.reco_triplet_1, 2)))
    truth_labels = [int(x in true_edges) for x in all_edges]
    if debug:
        print(all_edges)
        print(event.reco_triplet_0)
        print(event.reco_triplet_1)
        print(truth_labels)
    truth_labels = np.array(truth_labels, dtype=np.float32)
    zeros = np.array([0.0], dtype=np.float32)

    input_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([n_nodes], dtype=np.float32)
    }
    target_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": zeros,
        "edges": truth_labels,
        "senders": senders,
        "receivers": receivers,
        "globals": zeros
    }
    input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
    target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
    return [(input_graph, target_graph)]
예제 #6
0
파일: toptagger.py 프로젝트: xju2/root_gnn
def make_graph(event, debug: Optional[bool] = False):
    n_max_nodes = 200
    n_nodes = 0
    nodes = []
    for inode in range(n_max_nodes):
        E_name = 'E_{}'.format(inode)
        if event[E_name] < 0.1:
            continue

        f_keynames = ['{}_{}'.format(x, inode) for x in features]
        n_nodes += 1
        nodes.append(event[f_keynames].values * scale)
    nodes = np.array(nodes, dtype=np.float32)
    # print(n_nodes, "nodes")
    # print("node features:", nodes.shape)

    # edges 1) fully connected, 2) objects nearby in eta/phi are connected
    # TODO: implement 2). <xju>
    all_edges = list(itertools.combinations(range(n_nodes), 2))
    senders = np.array([x[0] for x in all_edges])
    receivers = np.array([x[1] for x in all_edges])
    n_edges = len(all_edges)
    edges = np.expand_dims(np.array([0.0] * n_edges, dtype=np.float32), axis=1)
    # print(n_edges, "edges")
    # print("senders:", senders)
    # print("receivers:", receivers)

    input_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([n_nodes], dtype=np.float32)
    }
    target_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([event[solution]], dtype=np.float32)
    }
    input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
    target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
    return [(input_graph, target_graph)]
예제 #7
0
 def test_unpack_graphs_tuple(self):
     graph_0 = {
         graphs.N_NODE: 4,
         graphs.N_EDGE: 3,
         graphs.NODES: [[0, 0.1], [1, 1.1], [2, 2.1], [3, 3.1]],
         graphs.EDGES: [[1, 1.2], [2, 2.2], [3, 3.2]],
         graphs.SENDERS: [0, 0, 0],
         graphs.RECEIVERS: [1, 2, 3],
     }
     graph_1 = {
         graphs.N_NODE: 3,
         graphs.N_EDGE: 2,
         graphs.NODES: [[0, 0.1], [1, 1.1], [2, 2.1]],
         graphs.EDGES: [[1, 1.2], [2, 2.2]],
         graphs.SENDERS: [0, 1],
         graphs.RECEIVERS: [1, 2],
     }
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([graph_0, graph_1])
     nodes, edges = graph_networks.GraphNet._unpack_graphs_tuple(
         graphs_tuple=graphs_tuple, max_n_node=5, max_n_edge=6)
     self.assertAllClose(nodes,
                         [[[0, 0.1], [1, 1.1], [2, 2.1], [3, 3.1], [0, 0]],
                          [[0, 0.1], [1, 1.1], [2, 2.1], [0, 0], [0, 0]]])
     self.assertAllClose(
         edges, [[[1, 1.2], [2, 2.2], [3, 3.2], [0, 0], [0, 0], [0, 0]],
                 [[1, 1.2], [2, 2.2], [0, 0], [0, 0], [0, 0], [0, 0]]])
예제 #8
0
 def test_output_values(self, broadcaster, expected):
   """Test the broadcasted output value."""
   input_graph = utils_tf.data_dicts_to_graphs_tuple(
       [SMALL_GRAPH_1, SMALL_GRAPH_2])
   broadcasted_out = broadcaster(input_graph)
   self.assertNDArrayNear(
       np.array(expected, dtype=np.float32), broadcasted_out, err=1e-4)
예제 #9
0
 def test_field_must_not_be_none(self, none_fields):
   """Tests that the model cannot be built if required fields are missing."""
   input_graph = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1])
   input_graph = input_graph.map(lambda _: None, none_fields)
   relation_network = self._get_model()
   with self.assertRaises(ValueError):
     relation_network(input_graph)
예제 #10
0
 def test_field_must_not_be_none(self, none_field):
   """Tests that the model cannot be built if required fields are missing."""
   input_graph = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1])
   input_graph = input_graph.replace(**{none_field: None})
   deep_sets = self._get_model()
   with self.assertRaises(ValueError):
     deep_sets(input_graph)
예제 #11
0
 def _get_input_graph(self, none_fields=None):
   if none_fields is None:
     none_fields = []
   input_graph = utils_tf.data_dicts_to_graphs_tuple(
       [SMALL_GRAPH_1, SMALL_GRAPH_2, SMALL_GRAPH_3, SMALL_GRAPH_4])
   input_graph = input_graph.map(lambda _: None, none_fields)
   return input_graph
예제 #12
0
    def test_gns_dataset(self):
        if sys.platform != "win32":
            self.assertTrue(True)
            return

        experience_dir = os.path.join("../experience", "data")

        case_name = "rte_case5_example"
        agent_name = "agent-mip"

        env_dc = True

        case, collector = load_experience(case_name,
                                          agent_name,
                                          experience_dir,
                                          env_dc=env_dc)
        obses, actions, rewards, dones = collector.aggregate_data()

        n_batch = 16
        max_length = 10 * n_batch + 1
        n_window = 2

        graphs_dict_list = obses_to_lgraphs(obses,
                                            dones,
                                            case,
                                            max_length=max_length,
                                            n_window=n_window)
        cgraphs = lgraphs_to_cgraphs(graphs_dict_list)
        labels = is_do_nothing_action(actions, case.env)

        graph_dims = get_graph_feature_dimensions(cgraphs=cgraphs)
        graph_dataset = tf_batched_graph_dataset(cgraphs,
                                                 n_batch=n_batch,
                                                 **graph_dims)
        label_dataset = tf.data.Dataset.from_tensor_slices(labels).batch(
            n_batch)
        dataset = tf.data.Dataset.zip((graph_dataset, label_dataset))
        dataset = dataset.repeat(1)

        for batch_idx, (graph_batch, label_batch) in enumerate(dataset):
            graph_batch_from_list = utils_tf.data_dicts_to_graphs_tuple(
                graphs_dict_list[(n_batch * batch_idx):(n_batch *
                                                        (batch_idx + 1))])

            check = tf.squeeze(equal_graphs(graph_batch,
                                            graph_batch_from_list)).numpy()

            pprint("Batch:", batch_idx, check)

            if not check:
                for field in [
                        "globals",
                        "nodes",
                        "edges",
                ]:
                    print_matrix(getattr(graph_batch, field))
                    print_matrix(getattr(graph_batch_from_list, field))

            self.assertTrue(check)
예제 #13
0
 def test_raises(self, index, error_type, message, use_constant, use_slice):
   if use_constant:
     index = tf.constant(index)
   if use_slice:
     index = slice(index)
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   with self.assertRaisesRegexp(error_type, message):
     utils_tf.get_graph(graphs_tuple, index)
예제 #14
0
 def test_fill_global_state(self, global_size):
   """Tests for filling the global state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("globals")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   n_graphs = self.reference_graph.n_edge.shape[0]
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size)
   self.assertAllEqual((n_graphs, global_size),
                       graphs_tuple.globals.get_shape().as_list())
예제 #15
0
 def test_fill_edge_state(self, edge_size):
   """Tests for filling the edge state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   n_edges = np.sum(self.reference_graph.n_edge)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size)
   self.assertAllEqual((n_edges, edge_size),
                       graphs_tuple.edges.get_shape().as_list())
예제 #16
0
 def setUp(self):
   super(RunGraphWithNoneInSessionTest, self).setUp()
   self._graph = utils_tf.data_dicts_to_graphs_tuple([{
       "senders": tf.random_uniform([10], maxval=10, dtype=tf.int32),
       "receivers": tf.random_uniform([10], maxval=10, dtype=tf.int32),
       "nodes": tf.random_uniform([5, 7]),
       "edges": tf.random_uniform([10, 6]),
       "globals": tf.random_uniform([1, 8])
   }])
예제 #17
0
 def test_fill_edge_state_with_missing_fields_raises(self):
   """Edge field cannot be filled if receivers or senders are missing."""
   for g in self.graphs_dicts_in:
     g.pop("receivers")
     g.pop("senders")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   with self.assertRaisesRegexp(ValueError, "receivers"):
     graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
예제 #18
0
def get_empty_graph(nodeshape, edgeshape, glblshape, senders, receivers):
    dic = {
        "globals": np.zeros(glblshape, dtype=np.float),
        "nodes": np.zeros(nodeshape, dtype=np.float),
        "edges": np.zeros(edgeshape, dtype=np.float),
        "senders": senders,
        "receivers": receivers
    }
    return utils_tf.data_dicts_to_graphs_tuple([dic])
예제 #19
0
 def setUp(self):
   super(TestNestToNumpy, self).setUp()
   self._graph = utils_tf.data_dicts_to_graphs_tuple([{
       "senders": tf.random.uniform([10], maxval=10, dtype=tf.int32),
       "receivers": tf.random.uniform([10], maxval=10, dtype=tf.int32),
       "nodes": tf.random.uniform([5, 7]),
       "edges": tf.random.uniform([10, 6]),
       "globals": tf.random.uniform([1, 8])
   }])
예제 #20
0
 def setUp(self):
   super(StopGradientsGraphTest, self).setUp()
   self._graph = utils_tf.data_dicts_to_graphs_tuple([{
       "senders": tf.zeros([10], dtype=tf.int32),
       "receivers": tf.zeros([10], dtype=tf.int32),
       "nodes": tf.ones([5, 7]),
       "edges": tf.zeros([10, 6]),
       "globals": tf.zeros([1, 8])
   }])
예제 #21
0
 def test_fill_node_state(self, node_size):
   """Tests for filling the node state with a constant content."""
   for g in self.graphs_dicts_in:
     g["n_node"] = g["nodes"].shape[0]
     g.pop("nodes")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   n_nodes = np.sum(self.reference_graph.n_node)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size)
   self.assertAllEqual((n_nodes, node_size),
                       graphs_tuple.nodes.get_shape().as_list())
    def load_batch(self, mode, repeat=None):
        """Return batch loaded from this dataset"""
        params = self.dataset_params
        opts = self.opts
        assert mode in params.sizes, "Mode {} not supported".format(mode)
        other_keys = list(set(self.features.keys()) - set(GRAPH_KEYS))
        item_keys = GRAPH_KEYS + other_keys
        data_source_name = mode + '-*.tfrecords'
        data_sources = glob.glob(
            os.path.join(self.data_dir, mode, data_source_name))
        if opts.shuffle_data and mode != 'test':
            np.random.shuffle(data_sources)  # Added to help the shuffle
        # Build dataset provider
        keys_to_features = {}
        for k, v in self.features.items():
            keys_to_features.update(v.get_feature_read())
        items_to_descriptions = {
            k: v.description
            for k, v in self.features.items()
        }

        def parser_op(record):
            example = tf.parse_single_example(record, keys_to_features)
            return [
                self.features[k].tensors_to_item(example) for k in item_keys
            ]

        dataset = tf.data.TFRecordDataset(data_sources)
        dataset = dataset.map(parser_op)
        dataset = dataset.repeat(repeat)
        if opts.shuffle_data and mode != 'test':
            dataset = dataset.shuffle(buffer_size=5 * opts.batch_size)
        # dataset = dataset.prefetch(buffer_size=opts.batch_size)

        iterator = dataset.make_one_shot_iterator()
        batch_graphs = []
        batch_other = []
        for b in range(opts.batch_size):
            sample_ = iterator.get_next()
            # Extracting other keys outside of the graph
            sample_other_ = {
                k: sample_[i + len(GRAPH_KEYS)]
                for i, k in enumerate(other_keys)
            }
            batch_other.append(sample_other_)
            # Constructing graph using relevant graph keys
            sample_graph = {k: sample_[i] for i, k in enumerate(GRAPH_KEYS)}
            batch_graphs.append(sample_graph)
        # Constructing output sample using known order of the keys
        sample = {}
        for k in other_keys:
            sample[k] = self.features[k].stack(
                [batch_other[b][k] for b in range(opts.batch_size)])
        sample['graph'] = utils_tf.data_dicts_to_graphs_tuple(batch_graphs)
        return sample
예제 #23
0
 def test_output_values_larger_rank(self, broadcaster, expected):
     """Test the broadcasted output value."""
     input_graph = utils_tf.data_dicts_to_graphs_tuple(
         [SMALL_GRAPH_1, SMALL_GRAPH_2])
     input_graph = input_graph.map(
         lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
     broadcasted_out = broadcaster(input_graph)
     self.assertNDArrayNear(np.reshape(np.array(expected, dtype=np.float32),
                                       [len(expected)] + [2, -1]),
                            broadcasted_out,
                            err=1e-4)
예제 #24
0
 def test_fill_edge_state_dynamic(self, edge_size):
   """Tests for filling the edge state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = graphs_tuple._replace(
       n_edge=tf.constant(
           graphs_tuple.n_edge, shape=graphs_tuple.n_edge.get_shape()))
   n_edges = np.sum(self.reference_graph.n_edge)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size)
   actual_edges = graphs_tuple.edges
   self.assertNDArrayNear(
       np.zeros((n_edges, edge_size)), actual_edges, err=1e-4)
예제 #25
0
 def test_fill_state_user_specified_types(self, dtype):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype)
   self.assertEqual(dtype, graphs_tuple.edges.dtype)
   self.assertEqual(dtype, graphs_tuple.nodes.dtype)
   self.assertEqual(dtype, graphs_tuple.globals.dtype)
def get_batched_graphs (train_set):
    """
    description: converts inputs in each batch to complete graphs
    :param train_set: training set containing tuples (batch_input , batch_target)
    :return:
    """
    for batch_input , batch_target  in train_set:
        input_dict = create_graph_dicts(batch_input)
        targets = batch_target

        input_dict = utils_tf.data_dicts_to_graphs_tuple(input_dict)
        input_dict = utils_tf.fully_connect_graph_dynamic(input_dict)

        yield input_dict , targets
예제 #27
0
 def test_fill_state_default_types(self):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1)
   graphs_tuple = utils_tf.set_zero_global_features(
       graphs_tuple, global_size=1)
   self.assertEqual(tf.float32, graphs_tuple.edges.dtype)
   self.assertEqual(tf.float32, graphs_tuple.nodes.dtype)
   self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
예제 #28
0
 def test_fill_node_state_dynamic(self, node_size):
   """Tests for filling the node state with a constant content."""
   for g in self.graphs_dicts_in:
     g["n_node"] = g["nodes"].shape[0]
     g.pop("nodes")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = graphs_tuple._replace(
       n_node=tf.constant(
           graphs_tuple.n_node, shape=graphs_tuple.n_node.get_shape()))
   n_nodes = np.sum(self.reference_graph.n_node)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size)
   actual_nodes = graphs_tuple.nodes.numpy()
   self.assertNDArrayNear(
       np.zeros((n_nodes, node_size)), actual_nodes, err=1e-4)
예제 #29
0
 def test_fill_global_state_dynamic(self, global_size):
   """Tests for filling the global state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("globals")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   # Hide global shape information
   graphs_tuple = graphs_tuple._replace(
       n_node=tf.placeholder_with_default(graphs_tuple.n_node, shape=[None]))
   n_graphs = self.reference_graph.n_edge.shape[0]
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size)
   with self.test_session() as sess:
     actual_globals = sess.run(graphs_tuple.globals)
   self.assertNDArrayNear(
       np.zeros((n_graphs, global_size)), actual_globals, err=1e-4)
예제 #30
0
def load_graphs(file_path, train_ratio, keep_features, existence_as_vector=True):
    """
    The function extracts the graphs from the given file
    :param file_path: The path to the graph json file
    :param train_ratio: How much of the data should we use for training. The other part is used for testing.
    :param keep_features: Whether to keep all features of the graph. It is advised to do so in case of input graphs.
    :param existence_as_vector: Whether to represent existence feature as a vector with the length of 2 in an individual
                                feature vector. It should be False when processing the input graphs.
    :return: Training and testing GraphTuples
    """
    graph_dicts = []
    with open(file_path) as json_file:
        line = json_file.readline().strip()
        while line != '' and line is not None:
            json_dict = json.loads(line)
            json_dict = process_line(json_dict, keep_features, existence_as_vector)
            is_valid_graph(json_dict)
            graph_dicts.append(json_dict)
            line = json_file.readline().strip()
    train_dicts = graph_dicts[:int(train_ratio * len(graph_dicts))]
    test_dicts = graph_dicts[int(train_ratio * len(graph_dicts)):]
    graphs_tuple_train = utils_tf.data_dicts_to_graphs_tuple(train_dicts)
    graphs_tuple_test = utils_tf.data_dicts_to_graphs_tuple(test_dicts)
    return graphs_tuple_train, graphs_tuple_test