def test_encode(self):
        de = VarNamingGSCVocabDataEncoder(self.task.graphs_and_instances,
                                                excluded_edge_types=frozenset(),
                                                instance_to_datapoints_kwargs=dict(),
                                                max_name_encoding_length=self.max_name_encoding_length)
        for graph, instances in self.task.graphs_and_instances:
            VarNamingGSCVocab.fix_up_edges(graph, instances, frozenset())
            VarNamingGSCVocab.extra_graph_processing(graph, instances, de)
            for instance in tqdm(instances):
                dporig = VarNamingGSCVocab.instance_to_datapoint(graph, instance, de, max_nodes_per_graph=50)
                dp = deepcopy(dporig)
                de.encode(dp)
                self.assertCountEqual(list(all_edge_types) + [de.subtoken_edge_type, de.subtoken_reverse_edge_type],
                                      dp.edges.keys())
                self.assertEqual(list(dp.edges.keys()), sorted(list(de.all_edge_types)),
                                 "Not all adjacency matrices were created")
                for edge_type, adj_mat in dp.edges.items():
                    np.testing.assert_equal(adj_mat.todense(),
                                            dporig.subgraph.get_adjacency_matrix(edge_type).todense())
                    self.assertIsInstance(adj_mat, sp.sparse.coo_matrix,
                                          "Encoding produces adjacency matrix of wrong type")

                self.assertEqual(len(dporig.node_types), len(dp.node_types),
                                 "Type for some node got lost during encoding")
                self.assertEqual([len(i) for i in dporig.node_types], [len(i) for i in dp.node_types],
                                 "Some type for some node got lost during encoding")
                for i in range(len(dp.node_types)):
                    for j in range(len(dp.node_types[i])):
                        self.assertEqual(dp.node_types[i][j], de.all_node_types[dporig.node_types[i][j]],
                                         "Some node type got encoded wrong")

                orig_subtoken_nodes = [i for i, data in dporig.subgraph.nodes if data['type'] == de.subtoken_flag]
                dp_subtoken_nodes = [i for i in range(len(dp.node_types)) if
                                     dp.node_types[i] == (de.all_node_types[de.subtoken_flag],)]
                self.assertEqual(len(orig_subtoken_nodes), len(dp_subtoken_nodes), "Some subtoken nodes got lost")
                for i in dp_subtoken_nodes:
                    self.assertEqual(dp.node_names[i], dporig.subgraph[i]['identifier'],
                                     "Some subtoken node got the wrong name")

                self.assertEqual(tuple(dporig.node_names), dp.node_names, "Some node names got lost")

                self.assertEqual(len(dp.label[0]), len(dp.label[1]), "Vocab and Attn labels should be the same length")
                self.assertEqual(len(dporig.label), len(dp.label[0]),
                                 "Some vocab label got lost")
                for i in range(len(dp.label[0])):
                    self.assertEqual(dp.label[0][i], de.all_node_name_subtokens[dporig.label[i]],
                                     "Some vocab label got encoded wrong")
                self.assertEqual(len(dporig.label), len(dp.label[1]),
                                 "Some attn label got list")
                for i, sbtk in enumerate(dporig.label):
                    if dp.label[1][i] == -1:
                        self.assertNotIn(sbtk, dp.node_names)
                    else:
                        self.assertEqual(sbtk, dporig.subgraph[dp.label[1][i]]['identifier'],
                                         "An attn label is indicating the wrong node")
                        self.assertEqual(sbtk, dp.node_names[dp.label[1][i]],
                                         "An attn label is indicating the wrong node")
    def test_preprocess_task_for_model_no_subtoken_edges(self):
        task = VarNamingTask.from_gml_files(self.test_gml_files)
        task_filepath = os.path.join(self.output_dataset_dir, 'VarNamingTask.pkl')
        task.save(task_filepath)
        VarNamingGSCVocab.preprocess_task(task=task,
                                                output_dir=self.output_dataset_dir,
                                                n_jobs=30,
                                                data_encoder='new',
                                                data_encoder_kwargs=dict(
                                                    max_name_encoding_length=self.max_name_encoding_length,
                                                    add_edges=False),
                                                instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        self.assertNotIn('jobs.txt', os.listdir(self.output_dataset_dir),
                         "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted")
        self.assertTrue(all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
                        "Hacky check for if pickled jobs didn't get deleted")
        reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
        os.mkdir(reencoding_dir)
        data_encoder = VarNamingGSCVocab.DataEncoder.load(os.path.join(self.output_dataset_dir,
                                                                             'VarNamingGSCVocabDataEncoder.pkl'))
        self.assertCountEqual(data_encoder.all_edge_types,
                              list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                              "DataEncoder found weird edge types")
        VarNamingGSCVocab.preprocess_task(task=task,
                                                output_dir=reencoding_dir,
                                                n_jobs=30,
                                                data_encoder=data_encoder)
        orig_datapoints = []
        for file in os.listdir(self.output_dataset_dir):
            if file not in ['VarNamingGSCVocabDataEncoder.pkl', 'VarNamingTask.pkl', 're-encoding']:
                with open(os.path.join(self.output_dataset_dir, file), 'rb') as f:
                    dp = pickle.load(f)
                    self.assertNotIn('SUBTOKEN_USE', dp.edges.keys())
                    self.assertNotIn('reverse_SUBTOKEN_USE', dp.edges.keys())
                    self.assertCountEqual(dp.edges.keys(),
                                          list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                                          'We lost some edge types')
                    orig_datapoints.append(
                        (dp.real_variable_name, dp.origin_file, dp.encoder_hash, dp.edges.keys()))

        reencoded_datapoints = []
        for file in os.listdir(reencoding_dir):
            with open(os.path.join(reencoding_dir, file), 'rb') as f:
                dp = pickle.load(f)
                self.assertNotIn('SUBTOKEN_USE', dp.edges.keys())
                self.assertNotIn('reverse_SUBTOKEN_USE', dp.edges.keys())
                reencoded_datapoints.append(
                    (dp.real_variable_name, dp.origin_file, dp.encoder_hash, dp.edges.keys()))
        self.assertEqual(len(orig_datapoints), len(reencoded_datapoints))
        self.assertCountEqual(orig_datapoints, reencoded_datapoints)
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     VarNamingGSCVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(
             max_name_encoding_length=self.max_name_encoding_length),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingGSCVocabDataEncoder.load(
         os.path.join(
             self.output_dataset_dir,
             '{}.pkl'.format(VarNamingGSCVocabDataEncoder.__name__)))
     VarNamingGSCVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         data_encoder_kwargs=dict(
             excluded_edge_types=syntax_only_excluded_edge_types,
             max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingGSCVocab.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             data_encoder_kwargs=dict(
                 excluded_edge_types=syntax_only_excluded_edge_types,
                 max_name_encoding_length=self.max_name_encoding_length))
    def test_batchify_and_unbatchify_are_inverses(self):
        VarNamingGSCVocab.preprocess_task(
            self.task,
            output_dir=self.output_dataset_dir,
            n_jobs=30,
            data_encoder='new',
            data_encoder_kwargs=dict(
                max_name_encoding_length=self.max_name_encoding_length),
            instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        with open(
                os.path.join(
                    self.output_dataset_dir,
                    '{}.pkl'.format(VarNamingGSCVocab.DataEncoder.__name__)),
                'rb') as f:
            de = pickle.load(f)
        model = VarNamingGSCVocabGGNN(data_encoder=de,
                                      hidden_size=17,
                                      type_emb_size=5,
                                      name_emb_size=7,
                                      n_msg_pass_iters=1,
                                      max_name_length=8)
        model.collect_params().initialize('Xavier', ctx=mx.cpu())
        datapoints = [
            os.path.join(self.output_dataset_dir, i)
            for i in os.listdir(self.output_dataset_dir)
            if 'Encoder.pkl' not in i
        ]
        batch_size = 64
        for b in tqdm(range(int(math.ceil(len(datapoints) / batch_size)))):
            batchdpspaths = datapoints[batch_size * b:batch_size * (b + 1)]
            batchdps = [de.load_datapoint(b) for b in batchdpspaths]
            batchified = model.batchify(batchdpspaths, ctx=mx.cpu())
            self.assertTrue(
                batchified.data.graph_vocab_node_locations is not None)
            self.assertEqual(
                len(batchified.data.graph_vocab_node_locations),
                sum(batchified.data.node_types.values[:, 0] ==
                    de.all_node_types[de.subtoken_flag]))
            self.assertEqual(type(batchified.label[0]), PaddedArray)
            self.assertEqual(batchified.label[0].values.shape,
                             (len(batchified.data.batch_sizes), 8))
            for dp, b_label in zip(batchdps, batchified.label[0].values):
                vocab_label, attn_label = dp.label
                real_variable_name = de.name_to_subtokens(
                    dp.real_variable_name)
                subtoken_nodes_this_dp = [
                    i for i in range(len(dp.node_types))
                    if dp.node_types[i] == (
                        de.all_node_types[de.subtoken_flag], )
                ]
                for i in range(len(attn_label)):
                    if attn_label[i] != -1:
                        self.assertEqual(
                            b_label[i].asscalar(),
                            subtoken_nodes_this_dp.index(attn_label[i]) +
                            len(de.all_node_name_subtokens),
                            "Batch label for subtoken node is off")
                    else:
                        self.assertLess(b_label[i].asscalar(),
                                        len(de.all_node_name_subtokens),
                                        "Batch label for vocab word is off")
                        self.assertEqual(
                            de.rev_all_node_name_subtokens[
                                b_label[i].asscalar()], real_variable_name[i])

            model_output = model(batchified.data)
            self.assertEqual(len(model_output.shape), 3,
                             "model_output is the wrong size")
            self.assertEqual(model_output.shape[0],
                             len(batchified.data.batch_sizes),
                             "model_output has wrong batch dimension")
            self.assertEqual(model_output.shape[1], model.max_name_length,
                             "model_output is outputting wrong length names")
            self.assertGreaterEqual(model_output.shape[2],
                                    len(de.all_node_name_subtokens),
                                    "model_output's output dimension is off")
            graph_vocab_nodes_per_batch_element = []
            length = 0
            for l in batchified.data.batch_sizes.asnumpy():
                graph_vocab_nodes_this_element = [
                    loc for loc in batchified.data.graph_vocab_node_locations
                    if length <= loc < length + l
                ]
                graph_vocab_nodes_per_batch_element.append(
                    len(graph_vocab_nodes_this_element))
                length += l
            graph_vocab_nodes_per_batch_element = mx.nd.array(
                graph_vocab_nodes_per_batch_element,
                dtype='float32',
                ctx=mx.cpu())
            masked_model_output = mx.nd.SequenceMask(
                model_output.exp().swapaxes(1, 2),
                use_sequence_length=True,
                sequence_length=len(de.all_node_name_subtokens) +
                graph_vocab_nodes_per_batch_element,
                axis=1)
            self.assertAlmostEqual(
                (masked_model_output.sum(axis=1) - 1).sum().asscalar(), 0, 3,
                "Probabilities aren't summing to 1")
            unbatchified = model.unbatchify(batchified, model_output)
            self.assertEqual(len(batchdps), len(unbatchified),
                             "We lost some datapoints somewhere")
            self.assertEqual(sum(len(dp.node_names) for dp in batchdps),
                             sum(batchified.data.batch_sizes).asscalar())
            self.assertEqual(sum(len(dp.node_types) for dp in batchdps),
                             sum(batchified.data.batch_sizes).asscalar())
            self.assertEqual(
                len(batchified.data.target_locations),
                sum([dp.node_names.count('__NAME_ME!__') for dp in batchdps]),
                "Some target location went missing")
            for adj_mat in batchified.data.edges.values():
                self.assertEqual(adj_mat.shape,
                                 (sum(len(dp.node_names) for dp in batchdps),
                                  sum(len(dp.node_names) for dp in batchdps)),
                                 "Batchified adjacency matrix is wrong size")
            for i, (dp, (prediction,
                         label)) in enumerate(zip(batchdps, unbatchified)):
                for p in prediction:
                    self.assertIn(
                        p, de.all_node_name_subtokens.keys(),
                        "Some word in the prediction wasn't in the model's vocab (normally that's the point, but this is the training set)"
                    )
                self.assertEqual(
                    len(dp.node_types), len(dp.node_names),
                    "node_types and node_names arrays are different lengths")
                self.assertEqual(
                    len(dp.node_types), batchified.data.batch_sizes[i],
                    "batch_sizes doesn't match datapoint's array size")
                self.assertEqual(de.name_to_subtokens(dp.real_variable_name),
                                 label, "Something got labeled wrong")
    def test_instance_to_datapoint(self):
        for excluded_edge_types in [
                syntax_only_excluded_edge_types,
                frozenset()
        ]:
            de = VarNamingGSCVocab.DataEncoder(
                self.task.graphs_and_instances,
                excluded_edge_types=excluded_edge_types,
                instance_to_datapoints_kwargs=dict(),
                max_name_encoding_length=self.max_name_encoding_length)
            for graph, instances in tqdm(self.task.graphs_and_instances):
                VarNamingGSCVocab.fix_up_edges(graph, instances,
                                               excluded_edge_types)
                VarNamingGSCVocab.extra_graph_processing(graph, instances, de)
                node_names = []
                for _, data in graph.nodes_that_represent_variables:
                    node_names += de.name_to_subtokens(data['identifier'])
                node_names = set(node_names)
                subtoken_nodes = [
                    i for i, data in graph.nodes
                    if data['type'] == de.subtoken_flag
                ]
                self.assertCountEqual(
                    node_names,
                    set([graph[i]['identifier'] for i in subtoken_nodes]),
                    "There isn't a subtoken node for each word in the graph")
                for node in subtoken_nodes:
                    self.assertFalse(
                        graph.is_variable_node(node),
                        "Subtoken node got flagged as a variable node")
                    self.assertEqual(graph[node]['type'], de.subtoken_flag,
                                     "Subtoken node got the wrong type")
                for node, data in graph.nodes:
                    if graph.is_variable_node(node):
                        node_names = de.name_to_subtokens(data['identifier'])
                        subtoken_nodes = graph.successors(
                            node, of_type=frozenset([de.subtoken_edge_type]))
                        back_subtoken_nodes = graph.predecessors(
                            node,
                            of_type=frozenset(
                                ['reverse_' + de.subtoken_edge_type]))
                        self.assertCountEqual(
                            subtoken_nodes, back_subtoken_nodes,
                            "Same forward and reverse subtoken nodes aren't present"
                        )
                        self.assertCountEqual(
                            set(node_names), [
                                graph.nodes[d]['identifier']
                                for d in subtoken_nodes
                            ],
                            "Node wasn't connected to all the right subtoken nodes"
                        )
                for instance in instances:
                    dp = VarNamingGSCVocab.instance_to_datapoint(
                        graph, instance, de, max_nodes_per_graph=100)
                    self.assertEqual(type(dp), VarNamingGSCVocabDataPoint)
                    self.assertEqual(len(dp.subgraph.nodes),
                                     len(dp.node_types))
                    self.assertEqual(len(dp.subgraph.nodes),
                                     len(dp.node_names))

                    name_me_nodes = [
                        i for i in dp.subgraph.nodes_that_represent_variables
                        if i[1]['identifier'] == de.name_me_flag
                    ]
                    self.assertTrue(
                        all(
                            dp.subgraph.is_variable_node(i[0])
                            for i in name_me_nodes),
                        "Some non-variable got masked")
                    self.assertEqual(
                        len([i[0] for i in name_me_nodes]), len(instance[1]),
                        "Wrong number of variables got their names masked")
                    self.assertEqual(
                        1, len(set([i[1]['text'] for i in name_me_nodes])),
                        "Not all name-masked nodes contain the same name")
                    self.assertTrue(
                        all([
                            i[1]['text'] == dp.real_variable_name
                            for i in name_me_nodes
                        ]), "Some nodes have the wrong name")

                    for node, _ in name_me_nodes:
                        for et in too_useful_edge_types:
                            self.assertNotIn(et, [
                                e[3]['type']
                                for e in dp.subgraph.all_adjacent_edges(node)
                            ])

                    for i, (name, types) in enumerate(
                            zip(dp.node_names, dp.node_types)):
                        self.assertEqual(type(name), str)
                        self.assertGreater(len(name), 0)
                        self.assertEqual(type(types), list)
                        self.assertGreaterEqual(len(types), 1)
                        if dp.subgraph.is_variable_node(i):
                            self.assertCountEqual(
                                set(
                                    re.split(r'[,.]',
                                             dp.subgraph[i]['reference'])),
                                types)
                            self.assertEqual(name,
                                             dp.subgraph[i]['identifier'])
                        else:
                            if types == [de.subtoken_flag]:
                                self.assertEqual(dp.subgraph[i]['identifier'],
                                                 name)
                            else:
                                self.assertEqual(name, de.internal_node_flag)
                                self.assertEqual(len(types), 1)

                    self.assertEqual(
                        dp.label,
                        de.name_to_subtokens(name_me_nodes[0][1]['text']),
                        "Label is wrong")

                    de.encode(dp)
                    self.assertIn('AST', dp.edges.keys())
                    self.assertIn('NEXT_TOKEN', dp.edges.keys())
                    de.save_datapoint(dp, self.output_dataset_dir)
 def test_preprocess_task_type_check_basic_functionality(self):
     task = Task
     with self.assertRaises(AssertionError):
         VarNamingGSCVocab.preprocess_task(task)