예제 #1
0
def test_1_small_table():
    args = get_args()
    args.seq_len = 128
    tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True)
    tokens_1, seg_1 = generate_seg(args, table_b, row_wise_fill=True)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg)
    import ipdb
    ipdb.set_trace()
예제 #2
0
def test_3_too_much_empty_values():
    args = get_args()
    args.seq_len = 16
    tokens_0, seg_0 = generate_seg(args,
                                   table_with_empty_values_1,
                                   row_wise_fill=True)
    tokens_1, seg_1 = generate_seg(args,
                                   table_with_empty_values_2,
                                   row_wise_fill=True)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg)
    import ipdb
    ipdb.set_trace()
예제 #3
0
def test_7_additional_ban():
    args = get_args()
    args.row_wise_fill = False
    args.seq_len = 128
    tokens_0, seg_0 = generate_seg(args,
                                   table_a,
                                   row_wise_fill=args.row_wise_fill)
    tokens_1, seg_1 = generate_seg(args,
                                   table_b,
                                   row_wise_fill=args.row_wise_fill)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg, additional_ban=2)
    import ipdb
    ipdb.set_trace()
예제 #4
0
 def fn(batch
        ):  # batch: [(<tab-col-id>, <cls-name>, <micro-table-in-cols>), ()]
     labels = list(map(lambda k: args.labels_map[k], [_[1] for _ in batch]))
     raw_tab_ids = [_[0] for _ in batch]
     tab_cols = [_[2] for _ in batch]
     tokens, segs = list(zip(*[generate_seg(args, _) for _ in tab_cols]))
     src = torch.LongTensor(tokens)
     tgt = torch.LongTensor(labels)
     seg = torch.LongTensor(segs)
     return src, tgt, seg, raw_tab_ids
예제 #5
0
def test_2_bigger_table():
    from col_spec_yh.store_utils import test_decode_spider_file
    tab_file = 'demos/samples/sample_file_type0-1.tb'
    tab_cols_list = test_decode_spider_file(tab_file)

    args = get_args()
    seg_list = []
    for tab_col in tab_cols_list:
        _, seg = generate_seg(args, tab_col, row_wise_fill=True)
        seg_list.append(seg)
    seg = torch.LongTensor(seg_list)
    mask = generate_mask(seg)  # mask.shape: torch.Size([10, 1, 64, 64])
    import ipdb
    ipdb.set_trace()
예제 #6
0
import torch
from demos.samples.sample_mini_tables import table_a, table_b
from utils import get_args, load_or_initialize_parameters
from col_spec_yh.model import TabEncoder
from col_spec_yh.encode_utils import generate_seg
from col_spec_yh.encode_utils import generate_mask

args = get_args()

# model
ta_encoder = TabEncoder(args)
load_or_initialize_parameters(args, ta_encoder)

# data
tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True)
tokens_1, seg_1 = generate_seg(args, table_a, row_wise_fill=True)
src = torch.LongTensor([tokens_0, tokens_1])
seg = torch.LongTensor([seg_0, seg_1])
# mask = generate_mask_crosswise(seg)

_ = ta_encoder(src, seg)
import ipdb
ipdb.set_trace()