def repack(self, *args):
     edges, node_types, node_names, batch_sizes = args
     for k, adj_mat in zip(self.edges.keys(), edges):
         self.edges[k] = adj_mat
     self.node_types = PaddedArray(*node_types)
     self.node_names = PaddedArray(*node_names)
     self.batch_sizes = batch_sizes
class TestCharCNNInput(unittest.TestCase):
    @given(edges=st.dictionaries(st.characters(),
                                 hpnp.arrays(dtype=np.dtype('float32'),
                                             shape=hpnp.array_shapes()),
                                 dict_class=OrderedDict,
                                 min_size=1),
           node_types=st.builds(
               lambda v, l: PaddedArray(v, l),
               hpnp.arrays(dtype=np.dtype('float32'),
                           shape=hpnp.array_shapes()),
               hpnp.arrays(dtype=np.dtype('float32'),
                           shape=hpnp.array_shapes())),
           node_names=hpnp.arrays(dtype=np.dtype('float32'),
                                  shape=hpnp.array_shapes()),
           batch_sizes=hpnp.arrays(dtype=np.dtype('float32'),
                                   shape=hpnp.array_shapes()))
    def test_unpack_and_repack_are_inverses(self, edges, node_types,
                                            node_names, batch_sizes):
        inp = CharCNNInput(edges, node_types, node_names, batch_sizes,
                           mx.cpu())
        originp = deepcopy(inp)
        inp.repack(*inp.unpack())
        inp.batch_sizes = inp.batch_sizes
        self.assertEqual(inp.edges.keys(), originp.edges.keys())
        for k in inp.edges.keys():
            np.testing.assert_equal(inp.edges[k], originp.edges[k])
        np.testing.assert_equal(inp.node_names, originp.node_names)
        np.testing.assert_equal(inp.node_types.values,
                                originp.node_types.values)
        np.testing.assert_equal(inp.node_types.value_lengths,
                                originp.node_types.value_lengths)
        np.testing.assert_equal(inp.batch_sizes, originp.batch_sizes)
    def batchify(self, data_filepaths: List[str], ctx: mx.context.Context):
        """
        Returns combined graphs and labels.
        Labels are a (PaddedArray, Tuple[str]) tuple.  The PaddedArray is size (batch x max_name_length) containing integers.
        The integer values correspond to the integers in this model's data encoder's all_node_name_subtokens dict,
           or if the integer value is greater than len(all_node_name_subtokens) it corresponds to which subtoken node
           in the graph represents the right subtoken
        """
        data = [self.data_encoder.load_datapoint(i) for i in data_filepaths]

        # Get the size of each graph
        batch_sizes = nd.array([len(dp.node_names) for dp in data],
                               dtype='int32',
                               ctx=ctx)

        combined_node_types = tuple(
            itertools.chain(*[dp.node_types for dp in data]))
        subtoken_node_type_idx = self.data_encoder.all_node_types[
            self.data_encoder.subtoken_flag]
        graph_vocab_node_locations = [
            i for i in range(len(combined_node_types))
            if combined_node_types[i][0] == subtoken_node_type_idx
        ]
        graph_vocab_node_locations = np.array(graph_vocab_node_locations)
        graph_vocab_node_real_names = [
            dp.graph_vocab_node_real_names for dp in data
        ]
        node_types = tuple_of_tuples_to_padded_array(combined_node_types, ctx)
        combined_node_names = tuple(
            itertools.chain(*[dp.node_names for dp in data]))
        target_locations = [
            i for i, name in enumerate(combined_node_names)
            if name == self.data_encoder.name_me_flag
        ]
        node_names = []
        for name in combined_node_names:
            if name == self.data_encoder.internal_node_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_internal=True))
            elif name == self.data_encoder.name_me_flag:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        '',
                        embedding_size=self.data_encoder.
                        max_name_encoding_length,
                        mark_as_special=True))
            else:
                node_names.append(
                    self.data_encoder.name_to_1_hot(
                        name,
                        embedding_size=self.data_encoder.
                        max_name_encoding_length))
        node_names = nd.array(np.stack(node_names), dtype='float32', ctx=ctx)

        # Combine all the adjacency matrices into one big, disconnected graph
        edges = OrderedDict()
        for edge_type in self.data_encoder.all_edge_types:
            adj_mat = sp.sparse.block_diag(
                [dp.edges[edge_type] for dp in data]).tocsr()
            adj_mat = nd.sparse.csr_matrix(
                (adj_mat.data, adj_mat.indices, adj_mat.indptr),
                shape=adj_mat.shape,
                dtype='float32',
                ctx=ctx)
            edges[edge_type] = adj_mat

        # Get the real names of the variables we're supposed to be naming
        combined_closed_vocab_labels = list(
            itertools.chain([dp.label[0] for dp in data]))
        # vocab labels are integers referring to indices in the model's data encoder's all_node_name_subtokens
        vocab_labels = tuple_of_tuples_to_padded_array(
            combined_closed_vocab_labels, ctx, pad_amount=self.max_name_length)
        combined_attn_labels = []
        for dp in data:
            graph_vocab_nodes_in_dp = [
                i for i in range(len(dp.node_types))
                if dp.node_types[i][0] == subtoken_node_type_idx
            ]
            combined_attn_labels.append(
                tuple([
                    graph_vocab_nodes_in_dp.index(i) + 1 if i >= 0 else -1
                    for i in dp.label[1]
                ]))
        # attn labels are integers referring to indices (+1 to avoid confusion with the padding value, which is 0) in the list of attn weights over graph vocab nodes the model will eventually output (or -1 if there's no appropriate node)
        attn_labels = tuple_of_tuples_to_padded_array(
            combined_attn_labels, ctx, pad_amount=self.max_name_length)
        attn_label = attn_labels.values
        subtoken_in_graph = attn_label > 0
        attn_label = len(
            self.data_encoder.all_node_name_subtokens
        ) + attn_label - 1  # -1 because we're done avoiding the padding value
        # If the correct subtoken was in the graph, then pointing to it is the correct output (it will always be in the vocab during training)
        joint_label = PaddedArray(values=nd.where(subtoken_in_graph,
                                                  attn_label,
                                                  vocab_labels.values),
                                  value_lengths=vocab_labels.value_lengths)

        # Combine the (actual) real names of variables-to-be-named
        real_names = tuple([dp.real_variable_name for dp in data])

        data = self.InputClass(
            edges,
            node_types,
            node_names,
            batch_sizes,
            ctx,
            target_locations=target_locations,
            graph_vocab_node_locations=graph_vocab_node_locations,
            graph_vocab_node_real_names=graph_vocab_node_real_names)
        return Batch(data, [joint_label, real_names])