예제 #1
0
    def forward(self, emb, seg):
        """
        Args:
            emb: [batch_size x seq_length x emb_size]
            seg: [batch_size x seq_length]

        Returns:
            hidden: [batch_size x seq_length x hidden_size]
        """

        # Generate mask according to segment indicators.
        # mask: [batch_size x 1 x seq_length x seq_length]
        mask = generate_mask(
            seg, self.mask_mode
        )  # possible values in mask: combinations of (0,1,2,4)
        mask = (mask > 0).float()
        mask = (1.0 - mask) * -10000.0
        hidden = emb
        layers = list(range(self.layers_num))
        for i in layers:
            hidden = self.transformer[i](hidden, mask)

        # TODO:
        # if self.mask_mode == 'crosswise_rel':
        #     mask = self.get_mask_crosswise(seg)
        #     # mask = (1.0 - mask) * -10000.0
        #     hidden = emb
        #     # self.layers_num = 4
        #     layers = list(range(self.layers_num))
        #     for i in layers:
        #         hidden = self.transformer[i](hidden, mask, self.relationEmbedding_K, self.relationEmbedding_V)  # in mask: 0,1,2,3

        return hidden
예제 #2
0
def test_6_compare_row_and_col_wise_fill():
    args = get_args()
    args.seq_len = 128
    tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True)
    tokens_1, seg_1 = generate_seg(args, table_b, row_wise_fill=True)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg)
    import ipdb
    ipdb.set_trace()
    args = get_args()
    args.seq_len = 128
    tokens_0, seg_0 = generate_seg(args, table_a, row_wise_fill=True)
    tokens_1, seg_1 = generate_seg(args, table_b, row_wise_fill=True)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg)
예제 #3
0
def test_2_bigger_table():
    from col_spec_yh.store_utils import test_decode_spider_file
    tab_file = 'demos/samples/sample_file_type0-1.tb'
    tab_cols_list = test_decode_spider_file(tab_file)

    args = get_args()
    seg_list = []
    for tab_col in tab_cols_list:
        _, seg = generate_seg(args, tab_col, row_wise_fill=True)
        seg_list.append(seg)
    seg = torch.LongTensor(seg_list)
    mask = generate_mask(seg)  # mask.shape: torch.Size([10, 1, 64, 64])
    import ipdb
    ipdb.set_trace()
예제 #4
0
def test_3_too_much_empty_values():
    args = get_args()
    args.seq_len = 16
    tokens_0, seg_0 = generate_seg(args,
                                   table_with_empty_values_1,
                                   row_wise_fill=True)
    tokens_1, seg_1 = generate_seg(args,
                                   table_with_empty_values_2,
                                   row_wise_fill=True)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg)
    import ipdb
    ipdb.set_trace()
예제 #5
0
def test_7_additional_ban():
    args = get_args()
    args.row_wise_fill = False
    args.seq_len = 128
    tokens_0, seg_0 = generate_seg(args,
                                   table_a,
                                   row_wise_fill=args.row_wise_fill)
    tokens_1, seg_1 = generate_seg(args,
                                   table_b,
                                   row_wise_fill=args.row_wise_fill)
    seg = torch.LongTensor([seg_0, seg_1])
    check_segs(zip([seg_0, seg_1], [tokens_0, tokens_1]))
    mask = generate_mask(seg, additional_ban=2)
    import ipdb
    ipdb.set_trace()