Exemple #1
0
    def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
        """Build the graphs

        Parameters
        ----------
        mg: MultiDiGraph
            Input graph
        src: Numpy array
            Source nodes
        dst: Numpy array
            Destination nodes
        ntid: Numpy array
            Node types for each node
        etid: Numpy array
            Edge types for each edge
        ntypes: list
            Node types
        etypes: list
            Edge types

        Returns
        -------
        g: DGLGraph
        """
        # create h**o graph
        if self.verbose:
            print('Creating one whole graph ...')
        g = dgl.graph((src, dst))
        g.ndata[dgl.NTYPE] = F.tensor(ntid)
        g.edata[dgl.ETYPE] = F.tensor(etid)
        if self.verbose:
            print('Total #nodes:', g.number_of_nodes())
            print('Total #edges:', g.number_of_edges())

        # rename names such as 'type' so that they an be used as keys
        # to nn.ModuleDict
        etypes = [RENAME_DICT.get(ty, ty) for ty in etypes]
        mg_edges = mg.edges(keys=True)
        mg = nx.MultiDiGraph()
        for sty, dty, ety in mg_edges:
            mg.add_edge(sty, dty, key=RENAME_DICT.get(ety, ety))

        # convert to heterograph
        if self.verbose:
            print('Convert to heterograph ...')
        hg = dgl.to_heterogeneous(g,
                                  ntypes,
                                  etypes,
                                  metagraph=mg)
        if self.verbose:
            print('#Node types:', len(hg.ntypes))
            print('#Canonical edge types:', len(hg.etypes))
            print('#Unique edge type names:', len(set(hg.etypes)))
        return hg
Exemple #2
0
    def process(self):
        """process raw data to graph, labels and masks"""
        data = sio.loadmat(
            os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat'))

        adj = data['graph'].tocoo()
        num_edges = len(adj.row)
        row, col = adj.row[:int(num_edges / 2)], adj.col[:int(num_edges / 2)]

        graph = dgl.graph((np.concatenate(
            (row, col)), np.concatenate((col, row))))
        news_labels = data['label'].squeeze()
        num_news = len(news_labels)

        node_feature = np.load(
            os.path.join(self.raw_path, f'{self.name}_node_feature.npy'))
        edge_feature = np.load(
            os.path.join(self.raw_path,
                         f'{self.name}_edge_feature.npy'))[:int(num_edges / 2)]

        graph.ndata['feat'] = th.tensor(node_feature)
        graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1)))
        pos_news = news_labels.nonzero()[0]

        edge_labels = th.zeros(num_edges)
        edge_labels[graph.in_edges(pos_news, form='eid')] = 1
        edge_labels[graph.out_edges(pos_news, form='eid')] = 1
        graph.edata['label'] = edge_labels

        ntypes = th.ones(graph.num_nodes(), dtype=int)
        etypes = th.ones(graph.num_edges(), dtype=int)

        ntypes[graph.nodes() < num_news] = 0
        etypes[:int(num_edges / 2)] = 0

        graph.ndata['_TYPE'] = ntypes
        graph.edata['_TYPE'] = etypes

        hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward'])
        self._random_split(hg, self.seed, self.train_size, self.val_size)

        self.graph = hg
    def forward(self, **kwargs):
        """
        :param input_ids: shape: [batch_size, max_seq_length (,1)]. e.g. [101 16068 1551 131 11253 10785 7637 3348 113 1286 114 1105 19734 1123 1493 113 1268 114 1112 1131 4927 1123 1159 1113 1103 2037 1437 1114 1123 3235 137 1282 14507 2636 102 1650 3696 9255 153 2591 13360 6258 3048 10069 131 5187 131 3927 142 9272 117 1367 1347 1381 197 19753 11392 12880 2137 131 1367 131 1512 142 9272 117 1367 1347 1381 11253 10785 7637 1144 3090 1131 1110 7805 1123 1148 2027 1114 20497 1389 27891 1667 11247 119 1109 3081 118 1214 118 1385 2851 117 1150 1640 1144 1300 1482 1121 2166 6085 117 1163 1107 1126 3669 1113 1109 4258 157 18963 7317 2737 3237 1115 1131 1110 17278 1106 1129 20028 1330 1901 1106 1123 9304 13465 119 1153 1163 131 112 1284 787 1396 1198 1276 1149 1195 787 1231 1515 170 2963 118 146 787 182 1210 1808 6391 119 146 1138 2094 1105 170 2963 1107 1139 7413 117 1103 1436 2053 1107 1103 1362 117 170 1632 2261 1105 146 787 182 170 1304 6918 1873 119 146 787 182 1304 9473 119 112 137 13426 11253 117 3081 117 1110 1210 1808 6391 1114 1123 3049 2963 137 13426 18662 18284 5208 2483 1163 1131 5115 1176 112 170 1304 6918 1873 112 137 13426 11253 1105 1393 4896 1591 1667 1508 1147 4655 1113 2080 1165 1131 1108 3332 19004 1111 170 1248 1159 1171 1107 1351 102]
        :param attention_mask: [batch_size, max_seq_length(, 1)]. e.g. [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
        :param kwargs (optional input):
            start_positions: [batch_size(,1)]
            end_positions: [batch_size (,1)]
            token_type_ids: [batch_size, max_seq_length(, 1)]. e.g. [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
            wordnet_concept_ids: [batch_size, max_seq_length, max_wn_length]. e.g. [[0,0,0,0,0],[0,1,0,0,0],[92,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]]
            nell_concept_ids: [batch_size, max_seq_length, max_nell_length]. e.g. 0:[] 1:[] 2:[] 3:[] 4:[19, 92, 255] 5:[19, 92, 255] 6:[19, 92, 255] 7:[] 8:[] 9:[] 10:[] 11:[] 12:[] 13:[] 14:[] 15:[] 16:[] 17:[] 18:[] 19:[] 20:[] 21:[] 22:[] 23:[] 24:[] 25:[] 26:[] 27:[] 28:[] 29:[] 30:[] 31:[] 32:[] 33:[] 34:[] 35:[] 36:[] 37:[] 38:[] 39:[] 40:[] 41:[] 42:[] 43:[] 44:[] 45:[] 46:[] 47:[] 48:[] 49:[] 50:[] 51:[] 52:[] 53:[] 54:[] 55:[] 56:[] 57:[] 58:[] 59:[] 60:[] 61:[] 62:[] 63:[] 64:[] 65:[] 66:[] 67:[] 68:[] 69:[19, 92, 255] 70:[19, 92, 255] 71:[19, 92, 255] 72:[] 73:[] 74:[] 75:[] 76:[] 77:[] 78:[] 79:[] 80:[] 81:[] 82:[] 83:[] 84:[] 85:[] 86:[] 87:[] 88:[] 89:[] 90:[] 91:[] 92:[] 93:[] 94:[] 95:[] 96:[] 97:[] 98:[] 99:[] 100:[] 101:[] 102:[] 103:[] 104:[] 105:[] 106:[] 107:[] 108:[] 109:[] 110:[] 111:[] 112:[] 113:[] 114:[] 115:[] 116:[] 117:[] 118:[] 119:[] 120:[] 121:[] 122:[] 123:[] 124:[] 125:[] 126:[] 127:[] 128:[] 129:[] 130:[] 131:[] 132:[] 133:[] 134:[] 135:[] 136:[] 137:[] 138:[] 139:[] 140:[] 141:[] 142:[] 143:[] 144:[] 145:[] 146:[] 147:[] 148:[] 149:[] 150:[] 151:[] 152:[] 153:[] 154:[] 155:[] 156:[] 157:[] 158:[] 159:[] 160:[] 161:[] 162:[] 163:[] 164:[] 165:[] 166:[] 167:[] 168:[] 169:[] 170:[] 171:[] 172:[] 173:[] 174:[] 175:[] 176:[] 177:[] 178:[] 179:[] 180:[] 181:[] 182:[] 183:[] 184:[] 185:[] 186:[] 187:[] 188:[] 189:[] 190:[] 191:[] 192:[50, 239] 193:[] 194:[] 195:[] 196:[] 197:[] 198:[] 199:[] 200:[] 201:[] 202:[] 203:[] 204:[] 205:[] 206:[] 207:[] 208:[] 209:[] 210:[] 211:[] 212:[] 213:[] 214:[] 215:[] 216:[] 217:[] 218:[] 219:[] 220:[] 221:[] 222:[50, 239] 223:[] 224:[] 225:[] 226:[] 227:[138, 91] 228:[] 229:[] 230:[] 231:[] 232:[] 233:[] 234:[] 235:[] 236:[] 237:[] 238:[] 239:[] 240:[] 241:[] 242:[] 243:[] 244:[] 245:[]
        :return:
        """
        # start_forward_time = time()
        label_ids_list = kwargs.get("label_ids")
        input_ids_list = kwargs.get("input_ids")
        # logger.info("rank:{}".format(input_ids.device))
        attention_mask_list = kwargs.get("attention_mask")
        token_type_ids_list = kwargs.get("token_type_ids")
        batch_synset_graphs_id_list = kwargs.get("batch_synset_graphs")
        wn_synset_graphs_list = kwargs.get("wn_synset_graphs")
        choice_score_list = []

        for num in range(2):
            label_ids = label_ids_list
            input_ids = input_ids_list[:, num, :]
            # logger.info("rank:{}".format(input_ids.device))
            attention_mask = attention_mask_list[:, num, :]
            if self.config.text_embed_model == "bert":
                token_type_ids = token_type_ids_list[:, num, :]
            elif self.config.text_embed_model == "roberta" or self.config.text_embed_model == "roberta_base":
                token_type_ids = None
            batch_synset_graphs_id = batch_synset_graphs_id_list
            wn_synset_graphs = wn_synset_graphs_list[num]

            batch_synset_graphs = [wn_synset_graphs[i] for i in batch_synset_graphs_id]
            batch_context_graphs_list = []
            batch_wn_graphs_list = []
            batch_entity2token_graphs_list = []
            batch_entity2token_graphs_nell_list = []

            token_length_list = []

            if self.config.text_embed_model == "bert":
                text_output = self.text_embed_model(
                    input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    output_attentions=self.config.output_attentions,
                    output_hidden_states=self.config.output_hidden_states
                )[0]
            elif self.config.text_embed_model == "roberta" or self.config.text_embed_model == "roberta_base":
                text_output = self.text_embed_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )[0]
            relation_list = self.config.relation_list

            inverse_relation_list = []
            # node_type in origin graph
            id_type_list = []
            context_type_list = []
            for i, relation_type in enumerate(relation_list):
                inverse_relation_list.append("{}_".format(relation_type))

                id_type = "wn{}_id".format(relation_type)
                id_type_list.append(id_type)

                context_type = "wn{}_context".format(relation_type)
                context_type_list.append(context_type)

            # start_time = time()
            for i, g in enumerate(batch_synset_graphs):
                assert (len(g.nodes("token_id")) == torch.sum(attention_mask[i, :]))
                token_length_list.append(len(g.nodes("token_id")))

                # reconstruct context graph
                context_g, wn_g = self.reconstruct_dgl_graph(g, relation_list, inverse_relation_list,
                                                             id_type_list, context_type_list,
                                                             text_output[i, :, :], input_ids.device)

                entity2token_graph, entity2token_graph_nell = self.construct_entity2token_graph(i, g, text_output, input_ids.device)

                batch_entity2token_graphs_list.append(entity2token_graph)
                batch_entity2token_graphs_nell_list.append(entity2token_graph_nell)
                batch_context_graphs_list.append(context_g)
                batch_wn_graphs_list.append(wn_g)

            batch_context_graphs_dgl = dgl.batch(batch_context_graphs_list)
            graph_context_embedding = self.rgcn_context(batch_context_graphs_dgl, batch_context_graphs_dgl.ndata['feature'])
            batch_context_graphs_dgl.nodes["wn_concept_context"].data["feature"] = graph_context_embedding[
                "wn_concept_context"]
            # batch_context_graphs_dgl.nodes["wn_concept_context"].data["feature_project"] = self.bert_projected_token_ids(
            #     graph_context_embedding["wn_concept_context"])
            batch_context_graphs_list = dgl.unbatch(batch_context_graphs_dgl)

            batch_wn_graphs_dgl = dgl.batch(batch_wn_graphs_list)
            graph_wn_embedding = self.rgcn_wn(batch_wn_graphs_dgl, batch_wn_graphs_dgl.ndata['feature'])
            batch_wn_graphs_dgl.nodes["wn_concept_id"].data["feature"] = graph_wn_embedding["wn_concept_id"]
            batch_wn_graphs_list = dgl.unbatch(batch_wn_graphs_dgl)

            memory_output_new = text_output
            # batch_entity2token_graphs_list_homo_s = []
            context_embed_new = torch.zeros(
                (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size),
                dtype=torch.float32, device=input_ids.device)
            concept_embed_new = torch.zeros(
                (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size),
                dtype=torch.float32, device=input_ids.device)

            nell_embed_new = torch.zeros(
                (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size),
                dtype=torch.float32, device=input_ids.device)

            # start_time = time()
            for idx, g_e2t in enumerate(batch_entity2token_graphs_list):
                g_e2t.nodes["wn_concept_id"].data["context_feature"] = batch_context_graphs_list[idx].nodes["wn_concept_context"].data["feature"]
                # logger.info("idx {}: {}".format(idx, g_e2t.nodes["wn_concept_id"].data["context_feature"]))
                g_e2t.nodes["wn_concept_id"].data["id_feature"] = batch_wn_graphs_list[idx].nodes["wn_concept_id"].data["feature"]
                g_e2t.nodes["token_id"].data["id_feature"] = self.projected_token_text(g_e2t.nodes["token_id"].data["context_feature"])
                g_e2t.nodes["sentinel_id"].data["id_feature"] = torch.zeros_like(g_e2t.nodes["token_id"].data["id_feature"], device=input_ids.device)
                g_e2t.nodes["sentinel_id"].data["context_feature"] = torch.zeros_like(g_e2t.nodes["token_id"].data["context_feature"], device=input_ids.device)

                g_e2t_homo = dgl.to_homogeneous(g_e2t, ndata=['id_feature', 'context_feature'])
                g_e2t_homo.ndata['context_feature'] = self.gat_context(g_e2t_homo,
                                 g_e2t_homo.ndata['context_feature'])
                g_e2t_homo.ndata['id_feature'] = self.gat_wn(g_e2t_homo, g_e2t_homo.ndata['id_feature'])
                tmp_graph = dgl.to_heterogeneous(g_e2t_homo, g_e2t.ntypes, g_e2t.etypes)

                tmp_argsort = torch.argsort(tmp_graph.ndata[dgl.NID]["token_id"] - tmp_graph.num_nodes("sentinel_id"))
                concept_embed_new[idx, :tmp_graph.num_nodes("token_id"), :] = tmp_graph.nodes["token_id"].data[
                    "id_feature"].index_select(0, tmp_argsort)
                context_embed_new[idx, :tmp_graph.num_nodes("token_id"), :] = tmp_graph.nodes["token_id"].data[
                    "context_feature"].index_select(0, tmp_argsort)

                g_e2t_nell = batch_entity2token_graphs_nell_list[idx]
                g_e2t_nell_homo = dgl.to_homogeneous(g_e2t_nell, ndata=['id_feature'])
                g_e2t_nell_homo.ndata['id_feature'] = self.gat_nell(g_e2t_nell_homo, g_e2t_nell_homo.ndata['id_feature'])
                tmp_graph_nell = dgl.to_heterogeneous(g_e2t_nell_homo, g_e2t_nell.ntypes, g_e2t_nell.etypes)
                nell_tmp_argsort = torch.argsort(tmp_graph_nell.ndata[dgl.NID]["token_id"] - tmp_graph_nell.num_nodes("sentinel_id") - tmp_graph_nell.num_nodes("nell_concept_id"))
                nell_embed_new[idx, :tmp_graph_nell.num_nodes("token_id"), :] = tmp_graph_nell.nodes["token_id"].data["id_feature"].index_select(0, nell_tmp_argsort)
            # # logger.info("time for one by one: {}".format(time() - start_time))

            if self.use_nell:
                memory_output_new = torch.cat((memory_output_new, nell_embed_new), 2)
            if self.use_context_graph and self.use_wn:
                k_memory = torch.cat((concept_embed_new, context_embed_new), 2)
            elif self.use_wn:
                k_memory = concept_embed_new
            elif self.use_context_graph:
                k_memory = context_embed_new
            if self.use_context_graph or self.use_wn:
                memory_output_new = torch.cat((memory_output_new, k_memory), 2)


            att_output = self.self_matching(memory_output_new,
                                            attention_mask.unsqueeze(2))  # [batch_size, max_seq_length, memory_output_size]
            # 4th layer: output layer
            choice_score = self.qa_kt_outputs(att_output[:,0,:])
            choice_score_list.append(choice_score)

        logits = torch.cat(
            [choice_score.unsqueeze(1).squeeze(-1) for choice_score in choice_score_list], dim=1
        )

        if label_ids[0] != -1:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), label_ids.view(-1))
        else:
            loss = None

        # logger.info("time for forward: {}".format(time()-start_forward_time))
        return loss, logits, label_ids, kwargs.get("qas_ids")
def test_edge_softmax(g, norm_by, idtype):
    print("params", norm_by, idtype)

    g = create_test_heterograph(idtype)

    x1 = F.randn((g.num_edges('plays'), feat_size))
    x2 = F.randn((g.num_edges('follows'), feat_size))
    x3 = F.randn((g.num_edges('develops'), feat_size))
    x4 = F.randn((g.num_edges('wishes'), feat_size))

    F.attach_grad(F.clone(x1))
    F.attach_grad(F.clone(x2))
    F.attach_grad(F.clone(x3))
    F.attach_grad(F.clone(x4))

    g['plays'].edata['eid'] = x1
    g['follows'].edata['eid'] = x2
    g['develops'].edata['eid'] = x3
    g['wishes'].edata['eid'] = x4

    #################################################################
    #  edge_softmax() on homogeneous graph
    #################################################################

    with F.record_grad():
        hm_g = dgl.to_homogeneous(g)
        hm_x = F.cat((x3, x2, x1, x4), 0)
        hm_e = F.attach_grad(F.clone(hm_x))
        score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
        hm_g.edata['score'] = score_hm
        ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
        r1 = ht_g.edata['score'][('user', 'plays', 'game')]
        r2 = ht_g.edata['score'][('user', 'follows', 'user')]
        r3 = ht_g.edata['score'][('developer', 'develops', 'game')]
        r4 = ht_g.edata['score'][('user', 'wishes', 'game')]
        F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
        grad_edata_hm = F.grad(hm_e)

    #################################################################
    #  edge_softmax() on heterogeneous graph
    #################################################################

    e1 = F.attach_grad(F.clone(x1))
    e2 = F.attach_grad(F.clone(x2))
    e3 = F.attach_grad(F.clone(x3))
    e4 = F.attach_grad(F.clone(x4))
    e = {
        ('user', 'follows', 'user'): e2,
        ('user', 'plays', 'game'): e1,
        ('user', 'wishes', 'game'): e4,
        ('developer', 'develops', 'game'): e3
    }
    with F.record_grad():
        score = edge_softmax(g, e, norm_by=norm_by)
        r5 = score[('user', 'plays', 'game')]
        r6 = score[('user', 'follows', 'user')]
        r7 = score[('developer', 'develops', 'game')]
        r8 = score[('user', 'wishes', 'game')]
        F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
        grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)),
                              0)
        # correctness check
        assert F.allclose(r1, r5)
        assert F.allclose(r2, r6)
        assert F.allclose(r3, r7)
        assert F.allclose(r4, r8)
        assert F.allclose(grad_edata_hm, grad_edata_ht)
Exemple #5
0
    def _extract_features(self, graphs, ai2d_ann, image, layers):
        """
        Extracts features from the original AI2D annotation and adds them to the
        AI2D-RST graphs.

        Parameters:
            graphs: A dictionary of NetworkX graphs for AI2D-RST annotation.
            ai2d_ann: A dictionary containing the original AI2D annotation.
            image: An image of the diagram from the original AI2D dataset.
            layers: A string defining annotation layers to include in the
                    updated graphs.

        Returns:
            A dictionary of NetworkX graphs with updated features.
        """
        # To begin with, build the grouping graph, which is provides the layout
        # information on all diagram elements, which can be then picked out in
        # other graphs, if necessary.
        graph = graphs['grouping']

        # Check that a graph exists
        try:

            # Fetch nodes from the graph
            nodes = graph.nodes(data=True)

        except AttributeError:

            return None

        # Begin extracting the features by getting the diagram image shape
        h, w = image.shape[:2]

        # Get the number of pixels in the image
        n_pix = h * w

        # Set up a placeholder dictionaries to hold updated node and edge
        # features
        node_features = {}
        edge_features = {}

        # Loop over the nodes and their features
        for node, features in nodes:

            # Fetch the node type from its features under the key 'kind'
            node_type = features['kind']

            # Parse layout annotation
            layout_feats = self._parse_ai2d_layout(
                ai2d_ann,  # annotation
                h,  # image height
                w,  # image width
                n_pix,  # n of pixels
                node_type,  # elem type
                node  # node id
            )

            # Add layout features to the dictionary of updated node features
            node_features[node] = {
                'features': layout_feats,
                'kind': self.node_dict['grouping'][node_type]
            }

            # Updated node attributes in the grouping graph using layout
            # features
            nx.set_node_attributes(graph, node_features)

        # Calculate features for grouping nodes based on their children. This
        # requires a directed tree graph.
        group_tree = nx.dfs_tree(graph, source="I0")

        # Get a list of grouping nodes and image constants in the graph
        groups = [
            n for n, attr in graph.nodes(data=True) if attr['kind'] in [
                self.node_dict['grouping']['imageConsts'],
                self.node_dict['grouping']['group']
            ]
        ]

        # Iterate over the nodes in the graph
        for n, attr in graph.nodes(data=True):

            # Check if the node type is a group
            if n in groups:

                # Get predecessors of the grouping node
                n_preds = nx.dfs_predecessors(group_tree, n)

                # Remove groups from the list of predecessor;
                # each group will be processed indepedently
                n_preds = [n for n in n_preds.keys() if n not in groups]

                # Create a subgraph consisting of preceding nodes
                n_subgraph = graph.subgraph(n_preds)

                # Get layout features for each node
                n_feats = [
                    ad['features'] for n, ad in n_subgraph.nodes(data=True)
                ]

                # Cast stacked features into a 2D numpy array
                stacked_feats = np.array(n_feats)

                # Get average centre point for group by slicing the array
                x_avg = np.average(stacked_feats[:, 0])
                y_avg = np.average(stacked_feats[:, 1])

                # Add up their area
                a_sum = np.sum(stacked_feats[:, 2])

                # Average the solidity
                s_avg = np.average(stacked_feats[:, 3])

                # Concatenate the features
                layout_feats = np.concatenate(
                    [[x_avg], [y_avg], [a_sum], [s_avg]], axis=0)

                # Update group feature dictionary
                upd_group_feats = {
                    n: {
                        'features': layout_feats,
                        'kind': attr['kind']
                    }
                }

                # Update group features
                nx.set_node_attributes(graph, upd_group_feats)

        # Add edge types to the grouping layer, as these are not defined in the
        # JSON annotation. To do so, get the edges from the grouping graph.
        edges = graph.edges(data=True)

        # Loop over the edges in the graph
        for src, dst, features in edges:

            # Add edge type unde key 'kind' to the edge_features dictionary
            edge_features[src, dst] = {'kind': 'grouping'}

        # Update edge features in the grouping graph
        nx.set_edge_attributes(graph, edge_features)

        # Encode edge features
        self._encode_edges(graph, self.edge_dict['grouping'])

        # Update the grouping graph in the graphs dictionary
        graphs['grouping'] = graph

        # Now that the grouping layer has been created, check which other
        # annotation layers must be included in the graph-based representation.

        # The combination of grouping and connectivity layers is a relatively
        # simple case.
        if layers == "grouping+connectivity":

            # If a connectivity graph exists, merge it with the grouping graph
            if graphs['connectivity'] is not None:

                # Use nx.compose() to combine the grouping and connectivity
                # graphs
                graph = nx.compose(graphs['connectivity'], graphs['grouping'])

            # Encode edge type information using numerical labels
            self._encode_edges(graph, self.edge_dict['connectivity'])

            # Update the grouping graph
            graphs['grouping'] = graph

        # The connectivity layer alone is a bit more complex, as the children of
        # grouping nodes need to be copied over to the connectivity graph.
        if layers == 'connectivity' and graphs['connectivity'] is not None:

            # Get the grouping and connectivity graphs
            conn_graph = graphs['connectivity']
            group_graph = graphs['grouping']

            # Get a list of nodes in the connectivity graph
            conn_nodes = list(conn_graph.nodes(data=True))

            # Get a list of grouping nodes in the connectivity graph
            grouping_nodes = [
                n for n, attr_dict in conn_nodes
                if attr_dict['kind'] == 'group'
            ]

            # If grouping nodes are found, get their children and add them to
            # the graph
            if len(grouping_nodes) > 0:

                # Create a directed tree graph using depth-first search,
                # starting from the image constant I0.
                group_tree = nx.dfs_tree(group_graph, source="I0")

                # Loop over each grouping node
                for gn in grouping_nodes:

                    # Resolve grouping nodes by adding their children to the
                    # connectivity graph
                    self._resolve_grouping_node(gn, group_tree, group_graph,
                                                conn_graph)

            # If the connectivity graph does not include grouping nodes, simply
            # copy the node features from the grouping graph.
            n_subgraph = group_graph.subgraph(conn_graph.nodes)

            # Add these nodes to the connectivity graph
            conn_graph.add_nodes_from(n_subgraph.nodes(data=True))

            # Encode edge type information using numerical labels
            self._encode_edges(conn_graph, self.edge_dict['connectivity'])

            # Update the connectivity graph in the graphs dictionary
            graphs['connectivity'] = conn_graph

        # Start building the discourse graph by getting node features from the
        # grouping graph.
        if layers == 'discourse':

            # Get grouping and discourse graphs
            group_graph = graphs['grouping']
            rst_graph = graphs['discourse']

            # Reverse node type dictionary for the grouping layer
            rev_group_dict = {
                int(v.item()): k
                for k, v in self.node_dict['grouping'].items()
            }

            # Re-encode node types to ensure that node types do not clash with
            # those defined for discourse graph
            upd_node_types = {
                k: rev_group_dict[int(v['kind'].item())]
                for k, v in group_graph.nodes(data=True)
            }

            # Update node attributes for the grouping graph
            nx.set_node_attributes(group_graph, upd_node_types, 'kind')

            # Get the nodes participating in the discourse graph from the
            # grouping graph using the .subgraph() method.
            subgraph = group_graph.subgraph(rst_graph.nodes)

            # Add these nodes back to the discourse graph with their features
            # and numerical labels. These will overwrite the original nodes.
            rst_graph.add_nodes_from(subgraph.nodes(data=True))

            # Check if discourse graph contains groups or split nodes. Split
            # nodes are used to preserve the tree structure in case a diagram
            # element participates in multiple RST relations.
            for n, attr_dict in rst_graph.copy().nodes(data=True):

                # Check if the node is a group
                if 'group' in attr_dict['kind']:

                    # Create a directed tree graph using depth-first search,
                    # starting from the image constant I0.
                    group_tree = nx.dfs_tree(group_graph, source="I0")

                    # Resolve grouping nodes by adding their children to the
                    # discourse graph.
                    self._resolve_grouping_node(n, group_tree, group_graph,
                                                rst_graph)

                # Check node for the copy_of attribute, which contains a
                # reference to the node which has been split.
                if 'copy_of' in attr_dict.keys():

                    # Get the identifier of the node in AI2D layout annotation
                    n_orig_id = attr_dict['copy_of']
                    n_orig_kind = attr_dict['kind']

                    # Fetch node data from the AI2D layout annotation
                    layout_feats = self._parse_ai2d_layout(
                        ai2d_ann, h, w, n_pix, n_orig_kind, n_orig_id)

                    # Add updated features to a dictionary
                    upd_node_feats = {
                        n: {
                            'features': layout_feats,
                            'kind': n_orig_kind
                        }
                    }

                    # Update node features in the graph
                    nx.set_node_attributes(rst_graph, upd_node_feats)

                # Check if the node is a relation
                if 'relation' in attr_dict['kind']:

                    # Get integer label for RST relation
                    rst_int_label = self.node_dict['relations'][
                        attr_dict['rel_name']]

                    # Get node labels and encode using label binarizer
                    rst_label = self._rst_binarizer.transform(rst_int_label)

                    # Check if label smoothing is requested:
                    if self._smooth_labels:

                        # Cast into float for label smoothing
                        rst_label = np.asarray(rst_label, dtype=np.float64)

                        # Smooth the labels by a factor of 0.1
                        rst_label *= (1 - 0.1)
                        rst_label += (0.1 / rst_label.shape[1])

                    # Store encoded information into the updated features dict
                    upd_node_feats = {n: {'features': rst_label.flatten()}}

                    # Set the updated features to nodes in the discourse graph
                    nx.set_node_attributes(rst_graph, upd_node_feats)

            # Check if a NetworkX graph should be returned
            if self._return_nx:

                return rst_graph

            # Convert node identifiers to integers. This needs to be performed
            # before creating a heterograph.
            rst_graph = nx.convert_node_labels_to_integers(rst_graph,
                                                           first_label=0)

            # Get nodes and convert to NumPy array; get unique nodes; get node
            # type index vector
            nodes = np.asarray([
                attr['kind'] for n, attr in rst_graph.nodes(data=True)
            ]).flatten()

            ntypes = np.unique(nodes)

            node_ixs = np.array(
                [np.where(ntypes == n) for n in np.nditer(nodes)],
                dtype=np.int64).flatten()

            # Do the same for edges
            edges = np.asarray([
                attr['kind'] for s, t, attr in rst_graph.edges(data=True)
            ]).flatten()

            etypes = np.unique(edges)

            edge_ixs = np.array(
                [np.where(etypes == e) for e in np.nditer(edges)],
                dtype=np.int64).flatten()

            # Create DGL graph object from the discourse graph
            g = dgl.from_networkx(rst_graph)

            # Assign node and edge types
            g.ndata[dgl.NTYPE] = torch.LongTensor(node_ixs)
            g.edata[dgl.ETYPE] = torch.LongTensor(edge_ixs)

            # Create a DGL heterograph from the DGL graph object
            hg = dgl.to_heterogeneous(g, ntypes, etypes)

            # Loop over node types in the heterograph
            for ntype in hg.ntypes:

                # Get unique node identifiers for this node type; cast to list
                rst_node_ids = hg.nodes[ntype].data[dgl.NID].tolist()

                # Loop over RST node identifiers
                features = np.vstack([
                    rst_graph.nodes[node_id]['features']
                    for node_id in rst_node_ids
                ])

                # Add features to DGL heterograph
                hg.nodes[ntype].data['features'] = torch.from_numpy(features)

            # Update the RST graph
            graphs['discourse'] = hg

        # Return all graphs
        return graphs