def test_1_small_table(): args = get_args() args.seq_len = 128 tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True) tokens_1, seg_1 = generate_seg(args, table_b, row_wise_fill=True) seg = torch.LongTensor([seg_0, seg_1]) check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1])) mask = generate_mask(seg) import ipdb ipdb.set_trace()
def test_3_too_much_empty_values(): args = get_args() args.seq_len = 16 tokens_0, seg_0 = generate_seg(args, table_with_empty_values_1, row_wise_fill=True) tokens_1, seg_1 = generate_seg(args, table_with_empty_values_2, row_wise_fill=True) seg = torch.LongTensor([seg_0, seg_1]) check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1])) mask = generate_mask(seg) import ipdb ipdb.set_trace()
def test_7_additional_ban(): args = get_args() args.row_wise_fill = False args.seq_len = 128 tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=args.row_wise_fill) tokens_1, seg_1 = generate_seg(args, table_b, row_wise_fill=args.row_wise_fill) seg = torch.LongTensor([seg_0, seg_1]) check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1])) mask = generate_mask(seg, additional_ban=2) import ipdb ipdb.set_trace()
def fn(batch ): # batch: [(<tab-col-id>, <cls-name>, <micro-table-in-cols>), ()] labels = list(map(lambda k: args.labels_map[k], [_[1] for _ in batch])) raw_tab_ids = [_[0] for _ in batch] tab_cols = [_[2] for _ in batch] tokens, segs = list(zip(*[generate_seg(args, _) for _ in tab_cols])) src = torch.LongTensor(tokens) tgt = torch.LongTensor(labels) seg = torch.LongTensor(segs) return src, tgt, seg, raw_tab_ids
def test_2_bigger_table(): from col_spec_yh.store_utils import test_decode_spider_file tab_file = 'demos/samples/sample_file_type0-1.tb' tab_cols_list = test_decode_spider_file(tab_file) args = get_args() seg_list = [] for tab_col in tab_cols_list: _, seg = generate_seg(args, tab_col, row_wise_fill=True) seg_list.append(seg) seg = torch.LongTensor(seg_list) mask = generate_mask(seg) # mask.shape: torch.Size([10, 1, 64, 64]) import ipdb ipdb.set_trace()
import torch from demos.samples.sample_mini_tables import table_a, table_b from utils import get_args, load_or_initialize_parameters from col_spec_yh.model import TabEncoder from col_spec_yh.encode_utils import generate_seg from col_spec_yh.encode_utils import generate_mask args = get_args() # model ta_encoder = TabEncoder(args) load_or_initialize_parameters(args, ta_encoder) # data tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True) tokens_1, seg_1 = generate_seg(args, table_a, row_wise_fill=True) src = torch.LongTensor([tokens_0, tokens_1]) seg = torch.LongTensor([seg_0, seg_1]) # mask = generate_mask_crosswise(seg) _ = ta_encoder(src, seg) import ipdb ipdb.set_trace()