コード例 #1
0
ファイル: data_loader.py プロジェクト: guxd/deep-code-search
                                                  shuffle=False,
                                                  num_workers=1)
    vocab_api = load_dict(input_dir + 'vocab.apiseq.json')
    vocab_name = load_dict(input_dir + 'vocab.name.json')
    vocab_tokens = load_dict(input_dir + 'vocab.tokens.json')
    vocab_desc = load_dict(input_dir + 'vocab.desc.json')

    print('============ Train Data ================')
    k = 0
    for batch in train_data_loader:
        batch = tuple([t.numpy() for t in batch])
        name, name_len, apiseq, api_len, tokens, tok_len, good_desc, good_desc_len, bad_desc, bad_desc_len = batch
        k += 1
        if k > 20: break
        print('-------------------------------')
        print(indexes2sent(name, vocab_name))
        print(indexes2sent(apiseq, vocab_api))
        print(indexes2sent(tokens, vocab_tokens))
        print(indexes2sent(good_desc, vocab_desc))

    print('\n\n============ Valid Data ================')
    k = 0
    for batch in valid_data_loader:
        batch = tuple([t.numpy() for t in batch])
        name, name_len, apiseq, api_len, tokens, tok_len, good_desc, good_desc_len, bad_desc, bad_desc_len = batch
        k += 1
        if k > 20: break
        print('-------------------------------')
        print(indexes2sent(name, vocab_name))
        print(indexes2sent(apiseq, vocab_api))
        print(indexes2sent(tokens, vocab_tokens))
コード例 #2
0
    train_data_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                    batch_size=5,
                                                    shuffle=False,
                                                    drop_last=False,
                                                    num_workers=1)
    print('number of batch:\n', len(train_data_loader))
    '''
    use_set = CodeSearchDataset(input_dir, 'use.tokens.h5', 30)
    use_data_loader = torch.utils.data.DataLoader(dataset=use_set, batch_size=1, shuffle=False, num_workers=1)
    #print(len(use_data_loader))
    vocab_tokens = load_dict(input_dir+'vocab.tokens.json')
    vocab_desc = load_dict(input_dir+'vocab.desc.json')
    '''
    vocab_desc = load_dict(input_dir + 'vocab.desc.json')
    print('============ Train Data ================')
    k = 0
    for epo in range(0, 3):
        for batch in train_data_loader:
            print("batch[1].size(): ", batch[1].size())
            #batch = tuple([t.numpy() for t in batch])
            cfg_init_input, cfg_adjmat, cfg_node_mask, good_desc, good_desc_len, bad_desc, bad_desc_len = [
                tensor.to(device) for tensor in batch
            ]
            print(cfg_adjmat.dtype)
            #print(batch)
            k += 1
            #if k>0: break
            print('-------------------------------')
            print(indexes2sent(good_desc, vocab_desc))
            #print(indexes2sent(good_desc, vocab_desc))