Ejemplo n.º 1
0
    def batchify(self, data_filepaths: List[str], ctx: mx.context.Context):
        data = [self.data_encoder.load_datapoint(i) for i in data_filepaths]

        # Get the size of each graph
        batch_sizes = nd.array([len(dp.node_names) for dp in data],
                               dtype='int32',
                               ctx=ctx)

        combined_node_types = tuple(
            itertools.chain(*[dp.node_types for dp in data]))
        node_types = tuple_of_tuples_to_padded_array(combined_node_types, ctx)
        combined_node_names = tuple(
            itertools.chain(*[dp.node_names for dp in data]))
        node_names = []
        for name in combined_node_names:
            if name == self.data_encoder.internal_node_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_internal=True))
            elif name == self.data_encoder.fill_in_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_special=True))
            else:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        name,
                        embedding_size=self.data_encoder.
                        max_name_encoding_length))
        node_names = nd.array(np.stack(node_names), dtype='float32', ctx=ctx)

        # Combine all the adjacency matrices into one big, disconnected graph
        edges = OrderedDict()
        for edge_type in self.data_encoder.all_edge_types:
            adj_mat = sp.sparse.block_diag(
                [dp.edges[edge_type] for dp in data]).tocsr()
            adj_mat = nd.sparse.csr_matrix(
                (adj_mat.data, adj_mat.indices, adj_mat.indptr),
                shape=adj_mat.shape,
                dtype='float32',
                ctx=ctx)
            edges[edge_type] = adj_mat

        # 1-hot whether a variable should have been indicated or not
        length = 0
        labels = []
        # Relabel the labels to match the indices in the batchified graph
        for dp in data:
            labels += [i + length for i in dp.label]
            length += len(dp.node_types)
        labels = nd.array(labels, dtype='int32', ctx=ctx)
        one_hot_labels = nd.zeros(length, dtype='float32', ctx=ctx)
        one_hot_labels[labels] = 1

        data = self.InputClass(edges, node_types, node_names, batch_sizes, ctx)
        return Batch(data, one_hot_labels)
Ejemplo n.º 2
0
    def batchify(self, data_filepaths: List[str], ctx: mx.context.Context):
        '''
        Returns combined graphs and labels.
        Labels are a (PaddedArray, Tuple[str]) tuple.  The PaddedArray is size (batch x max_name_length) containing integers
        The integer values correspond to the integers in this model's data encoder's all_node_name_subtokens dict
        (i.e. rows in the name_embedding matrix)
        '''
        data = [self.data_encoder.load_datapoint(i) for i in data_filepaths]

        # Get the size of each graph
        batch_sizes = nd.array([len(dp.node_names) for dp in data],
                               dtype='int32',
                               ctx=ctx)

        combined_node_types = tuple(
            itertools.chain(*[dp.node_types for dp in data]))
        node_types = tuple_of_tuples_to_padded_array(combined_node_types, ctx)
        combined_node_names = tuple(
            itertools.chain(*[dp.node_names for dp in data]))
        target_locations = [
            i for i, name in enumerate(combined_node_names)
            if name == self.data_encoder.name_me_flag
        ]
        node_names = []
        for name in combined_node_names:
            if name == self.data_encoder.internal_node_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_internal=True))
            elif name == self.data_encoder.name_me_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_special=True))
            else:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        name,
                        embedding_size=self.data_encoder.
                        max_name_encoding_length))
        node_names = nd.array(np.stack(node_names), dtype='float32', ctx=ctx)

        # Combine all the adjacency matrices into one big, disconnected graph
        edges = OrderedDict()
        for edge_type in self.data_encoder.all_edge_types:
            adj_mat = sp.sparse.block_diag(
                [dp.edges[edge_type] for dp in data]).tocsr()
            adj_mat = nd.sparse.csr_matrix(
                (adj_mat.data, adj_mat.indices, adj_mat.indptr),
                shape=adj_mat.shape,
                dtype='float32',
                ctx=ctx)
            edges[edge_type] = adj_mat

        # Combine the (encoded) real names of variables-to-be-named
        combined_labels = tuple(itertools.chain([dp.label for dp in data]))
        labels = tuple_of_tuples_to_padded_array(
            combined_labels, ctx, pad_amount=self.max_name_length)
        # Combine the (actual) real names of variables-to-be-named
        real_names = tuple([dp.real_variable_name for dp in data])

        data = self.InputClass(edges,
                               node_types,
                               node_names,
                               batch_sizes,
                               ctx,
                               target_locations=target_locations)
        return Batch(data, [labels, real_names])