示例#1
0
    def test_subtoken_embedding(self):
        fix_seed()
        device = torch.device('cpu')
        h_emb = 5
        token_to_id = {
            'token|name|first': 0,
            'token|second': 1,
            'token|third|name': 2
        }
        g = dgl.DGLGraph()
        g.add_nodes(3, {'token_id': torch.tensor([0, 1, 2])})
        subtoken_embedding = SubTokenNodeEmbedding(token_to_id, {}, h_emb)

        embed_weight = torch.zeros(len(subtoken_embedding.token_to_id), h_emb)
        embed_weight[subtoken_embedding.token_to_id['token'], 0] = 1
        embed_weight[subtoken_embedding.token_to_id['name'], 1] = 1
        embed_weight[subtoken_embedding.token_to_id['first'], 2] = 1
        embed_weight[subtoken_embedding.token_to_id['second'], 3] = 1
        embed_weight[subtoken_embedding.token_to_id['third'], 4] = 1

        subtoken_embedding.subtoken_embedding.weight = torch.nn.Parameter(
            embed_weight, requires_grad=True)

        token_embeds = subtoken_embedding(g)
        true_embeds = torch.tensor(
            [[1, 1, 1, 0, 0], [1, 0, 0, 1, 0], [1, 1, 0, 0, 1]],
            device=device,
            dtype=torch.float)

        self.assertTrue(torch.allclose(true_embeds, token_embeds))
    def test_type_specific_tree_lstm_cell(self):
        device = get_device()
        fix_seed()

        for i, (x_size, h_size, number_of_children) in enumerate(
                zip(self.x_sizes, self.h_sizes, self.numbers_of_children)):
            with self.subTest(i=i):
                g = _gen_node_with_children(number_of_children)
                g.ndata['x'] = torch.rand(number_of_children + 1, x_size)
                g.ndata['type_id'] = torch.tensor(range(0, number_of_children + 1))
                type_relationship = {
                    (0,): [list(range(1, number_of_children // 2))]
                }

                tree_lstm_cell = TypeSpecificTreeLSTMCell(x_size, h_size, type_relationship)

                h_tree_lstm, c_tree_lstm = tree_lstm_cell(g, device)

                tree_lstm_cell_params = tree_lstm_cell.get_params()
                u_f_indices = [
                    tree_lstm_cell.edge_matrix_id.get((0, i), 0) for i in range(1, number_of_children + 1)
                ]
                tree_lstm_cell_params['u_f'] = tree_lstm_cell_params['u_f'][u_f_indices]
                h_calculated, c_calculated = _calculate_nary_tree_lstm_states(g.ndata['x'], **tree_lstm_cell_params)

                self.assertTrue(
                    torch.allclose(h_tree_lstm, h_calculated, atol=ATOL), msg=f"Unequal hidden state tensors"
                )
                self.assertTrue(
                    torch.allclose(c_tree_lstm, c_calculated, atol=ATOL), msg=f"Unequal memory state tensors"
                )
    def test_childsum_tree_lstm_batch(self):
        device = get_device()
        fix_seed()

        x_size = 5
        h_size = 5
        numbers_of_children = [7, 7]

        tree_lstm_types = [EdgeChildSumTreeLSTMCell, NodeChildSumTreeLSTMCell]
        for tree_lstm_type in tree_lstm_types:
            with self.subTest(msg=f"test {tree_lstm_type.__name__} tree lstm cell"):
                tree_lstm_cell = tree_lstm_type(x_size, h_size)

                g1 = _gen_node_with_children(numbers_of_children[0])
                g2 = _gen_node_with_children(numbers_of_children[1])
                g1.ndata['x'] = torch.rand(numbers_of_children[0] + 1, x_size)
                g2.ndata['x'] = torch.rand(numbers_of_children[1] + 1, x_size)
                g = dgl.batch([g1, g2])

                h_tree_lstm, c_tree_lstm = tree_lstm_cell(g, device)

                h1_calculated, c1_calculated = _calculate_childsum_tree_lstm_states(
                    g1.ndata['x'], **tree_lstm_cell.get_params()
                )
                h2_calculated, c2_calculated = _calculate_childsum_tree_lstm_states(
                    g2.ndata['x'], **tree_lstm_cell.get_params()
                )
                h_calculated = torch.cat([h1_calculated, h2_calculated], 0)
                c_calculated = torch.cat([c1_calculated, c2_calculated], 0)

                self.assertTrue(torch.allclose(h_tree_lstm, h_calculated, atol=ATOL), msg=f"Unequal hidden state tensors")
                self.assertTrue(torch.allclose(c_tree_lstm, c_calculated, atol=ATOL), msg=f"Unequal memory state tensors")
def interactive(path_to_function: str, path_to_model: str):
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    # convert function to dot format
    print(f"prepare ast...")
    create_folder(TMP_FOLDER)
    if not build_ast(path_to_function):
        return
    ast_folder = os.path.join(TMP_FOLDER, 'java', 'asts')
    ast = os.listdir(ast_folder)
    if len(ast) == 0:
        print("didn't find any functions in given file")
        return
    if len(ast) > 1:
        print(
            "too many functions in given file, for interactive prediction you need only one"
        )
        return
    dgl_ast = convert_dot_to_dgl(os.path.join(ast_folder, ast[0]))
    ast_desc = pd.read_csv(os.path.join(TMP_FOLDER, 'java', 'description.csv'))
    ast_desc['token'].fillna('NAN', inplace=True)
    with open(vocab_path, 'rb') as pkl_file:
        vocab = pkl_load(pkl_file)
        token_to_id, type_to_id = vocab['token_to_id'], vocab['type_to_id']
    ast_desc = transform_keys(ast_desc, token_to_id, type_to_id)
    batched_graph, labels, paths = prepare_batch(ast_desc, ['ast_0.dot'],
                                                 lambda: [dgl_ast])
    batched_graph = dgl.batch(
        list(
            map(lambda g: dgl.reverse(g, share_ndata=True),
                dgl.unbatch(batched_graph))))

    # load model
    print("loading model..")
    model, _ = load_model(path_to_model, device)
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)
    info = LearningInfo()

    print("forward pass...")
    batch_info, prediction = eval_on_batch(model, criterion, batched_graph,
                                           labels, device)

    info.accumulate_info(batch_info)
    id_to_sublabel = {v: k for k, v in model.decoder.label_to_id.items()}
    label = ''
    for cur_sublabel in prediction:
        if cur_sublabel.item() == model.decoder.label_to_id[EOS]:
            break
        label += '|' + id_to_sublabel[cur_sublabel.item()]
    label = label[1:]
    print(f"Predicted function name is\n{label}")
    print(
        f"Calculated metrics with respect to '{labels[0]}' name\n{info.get_state_dict()}"
    )
 def _test_childsum_tree_lstm_cell(self, tree_lstm_type):
     device = get_device()
     fix_seed()
     for i in range(len(self.x_sizes)):
         x_size, h_size, number_of_children = self.x_sizes[i], self.h_sizes[i], self.numbers_of_children[i]
         with self.subTest(i=i):
             h_equal, c_equal = _test_childsum(
                 tree_lstm_type, x_size, h_size, number_of_children, device
             )
             self.assertTrue(
                 h_equal, msg=f"Unequal hidden state tensors for ({x_size}, {h_size}, {number_of_children}) params"
             )
             self.assertTrue(
                 c_equal, msg=f"Unequal memory state tensors for ({x_size}, {h_size}, {number_of_children}) params"
             )
示例#6
0
def interactive(path_to_function: str, path_to_model: str):
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    # load model
    print("loading model...")
    checkpoint = torch.load(path_to_model, map_location=device)

    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    token_to_id = model.token_to_id
    type_to_id = model.type_to_id
    label_to_id = model.label_to_id
    id_to_label = {v: k for k, v in label_to_id.items()}

    # convert function to dgl format
    print("convert function to dgl format...")
    create_folder(TMP_FOLDER)
    build_asts(path_to_function, TMP_FOLDER, ASTMINER_PATH, *ASTMINER_PARAMS)
    project_folder = os.path.join(TMP_FOLDER, 'java')
    convert_project(project_folder, token_to_id, type_to_id, label_to_id, True,
                    True, 5, 6, False, True, '|')

    # load function
    graph, labels = load_graphs(os.path.join(project_folder, 'converted.dgl'))
    labels = labels['labels']
    assert len(labels) == 1, f"found {len('labels')} functions, instead of 1"
    ast = graph[0].reverse(share_ndata=True)
    ast.ndata['token'] = ast.ndata['token'].to(device)
    ast.ndata['type'] = ast.ndata['type'].to(device)
    labels = labels.t().to(device)
    root_indexes = torch.tensor([0], dtype=torch.long)

    # forward pass
    model.eval()
    with torch.no_grad():
        logits = model(ast, root_indexes, labels, device)
    logits = logits[1:]
    prediction = model.predict(logits).reshape(-1)
    sublabels = [id_to_label[label_id.item()] for label_id in prediction]
    label = '|'.join(takewhile(lambda sl: sl != EOS, sublabels))
    print(f"the predicted label is:\n{label}")
示例#7
0
def evaluate(params: Dict) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    evaluation_set = JavaDataset(params['paths']['evaluate'],
                                 params['batch_size'], True)

    model, _ = load_model(params['paths']['model'], device)

    # define loss function
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)

    # evaluation loop
    print("ok, let's evaluate it")
    eval_epoch_info = evaluate_dataset(evaluation_set, model, criterion,
                                       device)

    print(eval_epoch_info.get_state_dict())
def evaluate(params: Dict) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    checkpoint = torch.load(params['model'], map_location=device)

    print('model initializing...')
    # create model
    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    evaluation_set = TreeDGLDataset(params['dataset'], params['batch_size'], device, True)

    # define loss function
    criterion = nn.CrossEntropyLoss(ignore_index=model.label_to_id[PAD]).to(device)

    # evaluation loop
    print("ok, let's evaluate it")
    eval_epoch_info = evaluate_on_dataset(evaluation_set, model, criterion)

    print(eval_epoch_info.get_state_dict())
示例#9
0
    def test_transformer_encoder_forward_pass(self):
        fix_seed()
        device = torch.device('cpu')

        number_of_children = [3, 5, 128, 256]
        hidden_state = [5, 10, 128, 256]
        n_heads = [1, 2, 16, 32]

        for n_children, h_emb, n_head in zip(number_of_children, hidden_state,
                                             n_heads):
            with self.subTest(
                    f"test transformer encoder with params: {n_children}, {h_emb}, {n_head}"
            ):
                g = generate_node_with_children(n_children)
                x = torch.rand(n_children + 1, h_emb, device=device)
                g.ndata['x'] = x

                my_model = TransformerEncoder(h_emb, h_emb, n_head)
                transformer_layer = torch.nn.TransformerEncoderLayer(
                    h_emb, n_head)
                transformer = torch.nn.TransformerEncoder(transformer_layer, 1)
                my_model.eval()
                transformer.eval()

                state_dict = {}
                for layer_name in transformer.state_dict().keys():
                    state_dict[layer_name] = my_model.state_dict(
                    )[f'transformer.{layer_name}']
                transformer.load_state_dict(state_dict)

                my_result = my_model(g)

                transformer_result = torch.empty_like(my_result)
                transformer_result[1:] = my_model.norm(x[1:])
                h_root = transformer(
                    transformer_result[1:].unsqueeze(1)).transpose(0, 1).sum(1)
                transformer_result[0] = my_model.norm(x[0] + h_root)

                self.assertTrue(transformer_result.allclose(my_result))
示例#10
0
    def test_positional_embedding(self):
        fix_seed()
        device = torch.device('cpu')

        g = generate_tree(3, 3)
        g.ndata['x'] = torch.randn((13, 6), device=device)
        positional_embedding = PositionalEmbedding({}, {}, 6, 3, 2)
        pos_embeds = positional_embedding(g)

        correct_pos_embedding = torch.tensor([[0., 0., 0., 0., 0., 0.],
                                              [1., 0., 0., 0., 0., 0.],
                                              [0., 1., 0., 0., 0., 0.],
                                              [0., 0., 1., 0., 0., 0.],
                                              [1., 0., 0., 1., 0., 0.],
                                              [0., 1., 0., 1., 0., 0.],
                                              [0., 0., 1., 1., 0., 0.],
                                              [1., 0., 0., 0., 1., 0.],
                                              [0., 1., 0., 0., 1., 0.],
                                              [0., 0., 1., 0., 1., 0.],
                                              [1., 0., 0., 0., 0., 1.],
                                              [0., 1., 0., 0., 0., 1.],
                                              [0., 0., 1., 0., 0., 1.]])

        self.assertTrue(torch.allclose(correct_pos_embedding, pos_embeds))
示例#11
0
def train(params: Dict, logging: str) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    training_set = JavaDataset(params['paths']['train'], params['batch_size'],
                               True)
    validation_set = JavaDataset(params['paths']['validate'],
                                 params['batch_size'], True)

    with open(params['paths']['vocabulary'], 'rb') as pkl_file:
        vocabulary = pkl_load(pkl_file)
        token_to_id = vocabulary['token_to_id']
        type_to_id = vocabulary['type_to_id']
        label_to_id = vocabulary['label_to_id']

    print('model initializing...')
    is_resumed = 'resume' in params
    if is_resumed:
        # load model
        model, checkpoint = load_model(params['resume'], device)
        start_batch_id = checkpoint['batch_id'] + 1
        configuration = checkpoint['configuration']
    else:
        # create model
        model_factory = ModelFactory(params['embedding'], params['encoder'],
                                     params['decoder'],
                                     params['hidden_states'], token_to_id,
                                     type_to_id, label_to_id)
        model: Tree2Seq = model_factory.construct_model(device)
        configuration = model_factory.save_configuration()
        start_batch_id = 0

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=params['lr'],
                                 weight_decay=params['weight_decay'])
    # create scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=params['scheduler_step_size'],
        gamma=params['scheduler_gamma'])

    # define loss function
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.decoder.pad_index).to(device)

    # init logging class
    logger = None
    if logging == TerminalLogger.name:
        logger = TerminalLogger(params['checkpoints_folder'])
    elif logging == FileLogger.name:
        logger = FileLogger(params, params['logging_folder'],
                            params['checkpoints_folder'])
    elif logging == WandBLogger.name:
        logger_args = ['treeLSTM', params, model, params['checkpoints_folder']]
        if 'resume_wandb_id' in params:
            logger_args.append(params['resume_wandb_id'])
        logger = WandBLogger(*logger_args)

    # train loop
    print("ok, let's train it")
    for epoch in range(params['n_epochs']):
        train_acc_info = LearningInfo()

        if epoch > 0:
            # specify start batch id only for first epoch
            start_batch_id = 0
        tqdm_batch_iterator = tqdm(range(start_batch_id, len(training_set)),
                                   total=len(training_set))
        tqdm_batch_iterator.update(start_batch_id)
        tqdm_batch_iterator.refresh()

        # iterate over training set
        for batch_id in tqdm_batch_iterator:
            graph, labels = training_set[batch_id]
            graph.ndata['token_id'] = graph.ndata['token_id'].to(device)
            graph.ndata['type_id'] = graph.ndata['type_id'].to(device)
            batch_info = train_on_batch(model, criterion, optimizer, scheduler,
                                        graph, labels, params, device)
            train_acc_info.accumulate_info(batch_info)
            # log current train process
            if is_current_step_match(batch_id, params['logging_step']):
                logger.log(train_acc_info.get_state_dict(), epoch, batch_id)
                train_acc_info = LearningInfo()
            # validate current model
            if is_current_step_match(
                    batch_id, params['evaluation_step']) and batch_id != 0:
                eval_epoch_info = evaluate_dataset(validation_set, model,
                                                   criterion, device)
                logger.log(eval_epoch_info.get_state_dict(), epoch, batch_id,
                           False)
            # save current model
            if is_current_step_match(
                    batch_id, params['checkpoint_step']) and batch_id != 0:
                logger.save_model(model,
                                  f'epoch_{epoch}_batch_{batch_id}.pt',
                                  configuration,
                                  batch_id=batch_id)

        logger.log(train_acc_info.get_state_dict(), epoch, len(training_set))
        eval_epoch_info = evaluate_dataset(validation_set, model, criterion,
                                           device)
        logger.log(eval_epoch_info.get_state_dict(), epoch, len(training_set),
                   False)

        logger.save_model(model, f'epoch_{epoch}.pt', configuration)
示例#12
0
def main(args: Namespace) -> None:
    fix_seed()
    if args.dataset not in known_datasets:
        raise ValueError(f"Unknown dataset: {args.dataset}")
    dataset_info = known_datasets[args.dataset]()
    dataset_path = os.path.join(DATA_FOLDER, dataset_info.name)
    vocabulary_path = os.path.join(dataset_path, VOCABULARY_NAME)
    create_folder(dataset_path, is_clean=False)

    if args.download:
        download_dataset(dataset_info, dataset_path)

    if args.build_ast:
        if not all([
                os.path.exists(os.path.join(dataset_path, holdout))
                for holdout in dataset_info.holdout_folders
        ]):
            raise RuntimeError("download and extract data before building ast")
        if not os.path.exists(ASTMINER_PATH):
            raise RuntimeError(
                f"can't find astminer-cli in this location {ASTMINER_PATH}")
        build_dataset_asts(dataset_info, dataset_path, ASTMINER_PATH)

    if args.collect_vocabulary:
        train_asts = os.path.join(dataset_path,
                                  f'{dataset_info.holdout_folders[0]}_asts')
        if not os.path.exists(train_asts):
            raise RuntimeError(
                "build training asts before collecting vocabulary")
        collect_vocabulary(train_asts, vocabulary_path, args.n_tokens,
                           args.n_types, args.n_labels, args.split_vocabulary,
                           args.wrap_tokens, args.wrap_labels, '|')

    if args.convert:
        if not os.path.exists(vocabulary_path):
            raise RuntimeError(
                "collect vocabulary before converting data to DGL format")
        with open(vocabulary_path, 'rb') as pkl_file:
            vocab = pickle_load(pkl_file)
        token_to_id, type_to_id, label_to_id = vocab['token_to_id'], vocab[
            'type_to_id'], vocab['label_to_id']
        for holdout in dataset_info.holdout_folders:
            ast_folder = os.path.join(dataset_path, f'{holdout}_asts')
            if not os.path.exists(ast_folder):
                raise RuntimeError(
                    f"build asts for {holdout} before converting it to DGL format"
                )
            output_folder = os.path.join(dataset_path,
                                         f'{holdout}_preprocessed')
            create_folder(output_folder)
            convert_holdout(ast_folder, output_folder, args.batch_size,
                            token_to_id, type_to_id, label_to_id,
                            args.tokens_to_leaves, args.split_vocabulary,
                            args.max_token_len, args.max_label_len,
                            args.wrap_tokens, args.wrap_labels, '|', True,
                            args.n_jobs)

    if args.upload:
        if not all([
                os.path.exists(
                    os.path.join(dataset_path, f'{holdout}_preprocessed'))
                for holdout in dataset_info.holdout_folders
        ]):
            raise RuntimeError(
                "preprocess data before uploading it to the cloud")
        upload_dataset(dataset_info, dataset_path, VOCABULARY_NAME, args.store,
                       args.tar_suffix)

    preprocessed_paths = [
        os.path.join(dataset_path, f'{holdout}_preprocessed')
        for holdout in dataset_info.holdout_folders
    ]
    if all([os.path.exists(path) for path in preprocessed_paths]):
        for holdout, path in zip(dataset_info.holdout_folders,
                                 preprocessed_paths):
            number_of_batches = len(os.listdir(path))
            print(f"There are {number_of_batches} batches in {holdout} data")