コード例 #1
0
    def inference(self, dev_data):
        self.eval()
        batch_size = min(len(dev_data), 16)

        outputs, output_spans = [], []
        for batch_start_id in tqdm(range(0, len(dev_data), batch_size)):
            mini_batch = dev_data[batch_start_id: batch_start_id + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id)
            encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            # [batch_size, 2, encoder_seq_len]
            output, span_extract_output = self.forward(encoder_input_ids, text_masks)
            outputs.append(output)
            encoder_seq_len = span_extract_output.size(2)
            # [batch_size, encoder_seq_len]
            start_logit = span_extract_output[:, 0, :]
            end_logit = span_extract_output[:, 1, :]
            # [batch_size, encoder_seq_len, encoder_seq_len]
            span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1)
            valid_span_pos = ops.ones_var_cuda([len(span_logit), encoder_seq_len, encoder_seq_len]).triu()
            span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT

            for i in range(len(mini_batch)):
                span_pos = span_logit[i].argmax()
                start = int(span_pos / encoder_seq_len)
                end = int(span_pos % encoder_seq_len)
                output_spans.append((start, end))
        return torch.cat(outputs), output_spans
コード例 #2
0
 def get_decoder_input_ids():
     if self.training:
         if self.model_id in [BRIDGE]:
             X = [exp.program_singleton_field_input_ids for exp in mini_batch]
         else:
             X = [exp.program_input_ids for exp in mini_batch]
         return ops.pad_batch(X, self.mdl.out_vocab.pad_id)
     else:
         return None
コード例 #3
0
    def inference(self, dev_data):
        self.eval()
        batch_size = 32

        output_spans = []
        for batch_start_id in tqdm(range(0, len(dev_data), batch_size)):
            mini_batch = dev_data[batch_start_id:batch_start_id + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          bu.pad_id)
            encoder_input_ids = ops.pad_batch(
                [exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            # [batch_size, 2, encoder_seq_len]
            output = self.forward(encoder_input_ids, text_masks)
            encoder_seq_len = output.size(2)
            # [batch_size, encoder_seq_len]
            start_logit = output[:, 0, :]
            end_logit = output[:, 1, :]
            # [batch_size, encoder_seq_len, encoder_seq_len]
            span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1)
            valid_span_pos = ops.ones_var_cuda(
                [batch_size, encoder_seq_len, encoder_seq_len]).triu()
            span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT

            for i in range(len(mini_batch)):
                span_pos = span_logit[i].argmax()
                start = int(span_pos / encoder_seq_len)
                end = int(span_pos % encoder_seq_len)
                output_spans.append((start, end))
                # print(start, end)
                # print(mini_batch[i].text)
                # confusion_span_size = end - start + 1
                # if start > 0 and confusion_span_size < 5:
                #     print(mini_batch[i].text_tokens[start-1:end])
                #     print()

        return output_spans
コード例 #4
0
def train(train_data, dev_data):
    # Model
    model_dir = get_model_dir(args)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    trans_checker = TranslatabilityChecker(args)
    trans_checker.cuda()
    ops.initialize_module(trans_checker, 'xavier')

    wandb.init(project='translatability-prediction', name=get_wandb_tag(args))
    wandb.watch(trans_checker)

    # Hyperparameters
    batch_size = 16
    num_peek_epochs = 1

    # Loss function
    # -100 is a dummy padding value since all output spans will be of length 2
    loss_fun = MaskedCrossEntropyLoss(-100)

    # Optimizer
    optimizer = optim.Adam([{
        'params': [
            p for n, p in trans_checker.named_parameters()
            if not 'trans_parameters' in n and p.requires_grad
        ]
    }, {
        'params': [
            p for n, p in trans_checker.named_parameters()
            if 'trans_parameters' in n and p.requires_grad
        ],
        'lr':
        args.bert_finetune_rate
    }],
                           lr=args.learning_rate)
    lr_scheduler = lrs.LinearScheduler(
        optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr],
        [args.num_warmup_steps, args.num_warmup_steps], args.num_steps)

    best_dev_metrics = 0
    for epoch_id in range(args.num_epochs):
        random.shuffle(train_data)
        trans_checker.train()
        optimizer.zero_grad()

        epoch_losses = []

        for i in tqdm(range(0, len(train_data), batch_size)):
            wandb.log({
                'learning_rate/{}'.format(args.dataset_name):
                optimizer.param_groups[0]['lr']
            })
            wandb.log({
                'fine_tuning_rate/{}'.format(args.dataset_name):
                optimizer.param_groups[1]['lr']
            })
            mini_batch = train_data[i:i + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          bu.pad_id)
            encoder_input_ids = ops.pad_batch(
                [exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            target_span_ids, _ = ops.pad_batch(
                [exp.span_ids for exp in mini_batch], bu.pad_id)
            output = trans_checker(encoder_input_ids, text_masks)
            loss = loss_fun(output, target_span_ids)
            loss.backward()
            epoch_losses.append(float(loss))

            if args.grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_checker.parameters(),
                                         args.grad_norm)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        if args.num_epochs % num_peek_epochs == 0:
            stdout_msg = 'Epoch {}: average training loss = {}'.format(
                epoch_id, np.mean(epoch_losses))
            print(stdout_msg)
            wandb.log({
                'cross_entropy_loss/{}'.format(args.dataset_name):
                np.mean(epoch_losses)
            })
            pred_spans = trans_checker.inference(dev_data)
            target_spans = [exp.span_ids for exp in dev_data]
            trans_acc = translatablity_eval(pred_spans, target_spans)
            print('Dev translatability accuracy = {}'.format(trans_acc))
            if trans_acc > best_dev_metrics:
                model_path = os.path.join(model_dir, 'model-best.tar')
                trans_checker.save_checkpoint(optimizer, lr_scheduler,
                                              model_path)
                best_dev_metrics = trans_acc

            span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans)
            print('Dev span accuracy = {}'.format(span_acc))
            print('Dev span precision = {}'.format(prec))
            print('Dev span recall = {}'.format(recall))
            print('Dev span F1 = {}'.format(f1))
            wandb.log({
                'translatability_accuracy/{}'.format(args.dataset_name):
                trans_acc
            })
            wandb.log({'span_accuracy/{}'.format(args.dataset_name): span_acc})
            wandb.log({'span_f1/{}'.format(args.dataset_name): f1})
コード例 #5
0
    def format_batch(self, mini_batch):
        def get_decoder_input_ids():
            if self.training:
                if self.model_id in [BRIDGE]:
                    X = [
                        exp.program_singleton_field_input_ids
                        for exp in mini_batch
                    ]
                else:
                    X = [exp.program_input_ids for exp in mini_batch]
                return ops.pad_batch(X, self.mdl.out_vocab.pad_id)
            else:
                return None

        def get_encoder_attn_mask(table_names, table_masks):
            schema_pos = [
                schema_graph.get_schema_pos(table_name)
                for table_name in table_names
            ]
            encoder_attn_mask = [1 for _ in range(exp.num_text_tokens)]
            # asterisk marker
            encoder_attn_mask.append(1)
            is_selected_table = False
            for j in range(1, len(table_masks)):
                if j in schema_pos:
                    encoder_attn_mask.append(1)
                    is_selected_table = True
                elif table_masks[j] == 1:
                    # mask current table
                    encoder_attn_mask.append(0)
                    is_selected_table = False
                else:
                    if is_selected_table:
                        encoder_attn_mask.append(1)
                    else:
                        encoder_attn_mask.append(0)
            return encoder_attn_mask

        super().format_batch(mini_batch)
        encoder_input_ids = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          self.mdl.in_vocab.pad_id)
        decoder_input_ids = get_decoder_input_ids()

        table_samples = []

        if self.model_id == SEQ2SEQ:
            return encoder_input_ids, decoder_input_ids
        elif self.model_id in [BRIDGE]:
            encoder_ptr_input_ids, encoder_ptr_value_ids, decoder_ptr_value_ids = [], [], []
            primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, \
                field_table_pos, transformer_output_value_masks, schema_memory_masks = [], [], [], [], [], [], [], [], []
            for exp in mini_batch:
                schema_graph = self.schema_graphs.get_schema(exp.db_id)
                # exp.pretty_print(example_id=0,
                #                  schema=schema_graph,
                #                  de_vectorize_ptr=vec.de_vectorize_ptr,
                #                  de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                #                  rev_vocab=self.out_vocab,
                #                  post_process=self.output_post_process,
                #                  use_table_aware_te=(self.model_id in [BRIDGE]))
                # import pdb
                # pdb.set_trace()
                if self.training:
                    # Compute schema layout
                    if exp.gt_table_names_list:
                        gt_tables = set([
                            schema_graph.get_table_id(t_name)
                            for t_name in exp.gt_table_names
                        ])
                    else:
                        gt_table_names = [
                            token for token, t in zip(
                                exp.program_singleton_field_tokens,
                                exp.program_singleton_field_token_types)
                            if t == 0
                        ]
                        gt_tables = set([
                            schema_graph.get_table_id(t_name)
                            for t_name in gt_table_names
                        ])
                    # [Hack] Baseball database has a complex schema which does not fit the input size of BERT. We select
                    # the ground truth tables and randomly add a few other tables for training.
                    if schema_graph.name.startswith('baseball'):
                        tables = list(gt_tables)
                        tables += random.sample(
                            [
                                i for i in range(schema_graph.num_tables)
                                if i not in gt_tables
                            ],
                            k=min(random.randint(1, 7),
                                  schema_graph.num_tables - len(gt_tables)))
                    else:
                        tables = list(range(schema_graph.num_tables))
                    if self.args.table_shuffling:
                        table_to_drop = random.choice(tables)
                        if table_to_drop not in gt_tables:
                            if random.uniform(0, 1) < 0.3:
                                tables = [
                                    x for x in tables if x != table_to_drop
                                ]
                        table_po, field_po = schema_graph.get_schema_perceived_order(
                            tables,
                            random_table_order=True,
                            random_field_order=self.args.random_field_order)
                    else:
                        table_po, field_po = schema_graph.get_schema_perceived_order(
                            tables,
                            random_table_order=False,
                            random_field_order=self.args.random_field_order)

                    # Schema feature extraction
                    question_encoding = exp.text if self.args.use_picklist else None
                    schema_features, matched_values = schema_graph.get_serialization(
                        self.tu,
                        flatten_features=True,
                        table_po=table_po,
                        field_po=field_po,
                        use_typed_field_markers=self.args.
                        use_typed_field_markers,
                        use_graph_encoding=self.args.use_graph_encoding,
                        question_encoding=question_encoding,
                        top_k_matches=self.args.top_k_picklist_matches,
                        num_values_per_field=self.args.num_values_per_field,
                        no_anchor_text=self.args.no_anchor_text,
                        verbose=False)
                    ptr_input_tokens, ptr_input_values, num_excluded_tables, num_excluded_fields = \
                        get_table_aware_transformer_encoder_inputs(
                            exp.text_ptr_values, exp.text_tokens, schema_features, self.tu)
                    assert (len(ptr_input_tokens) <= self.tu.tokenizer.max_len)
                    if num_excluded_fields > 0:
                        print('Warning: training input truncated')
                    num_included_nodes = schema_graph.get_num_perceived_nodes(tables) + 1 \
                                         - num_excluded_tables - num_excluded_fields
                    encoder_ptr_input_ids.append(
                        self.tu.tokenizer.convert_tokens_to_ids(
                            ptr_input_tokens))
                    if self.args.read_picklist:
                        exp.transformer_output_value_mask, value_features, value_tokens = \
                            get_transformer_output_value_mask(ptr_input_tokens, matched_values, self.tu)
                        transformer_output_value_masks.append(
                            exp.transformer_output_value_mask)
                    primary_key_ids.append(
                        schema_graph.get_primary_key_ids(
                            num_included_nodes, table_po, field_po))
                    foreign_key_ids.append(
                        schema_graph.get_foreign_key_ids(
                            num_included_nodes, table_po, field_po))
                    field_type_ids.append(
                        schema_graph.get_field_type_ids(
                            num_included_nodes, table_po, field_po))
                    table_masks.append(
                        schema_graph.get_table_masks(num_included_nodes,
                                                     table_po, field_po))

                    # Value copy feature extraction
                    if self.args.read_picklist:
                        constant_memory_features = exp.text_tokens + value_features
                        constant_memory = exp.text_ptr_values + value_tokens
                        exp.text_ptr_values = constant_memory
                    else:
                        constant_memory_features = exp.text_tokens
                    constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in(
                        constant_memory_features, self.out_vocab)
                    encoder_ptr_value_ids.append(constant_ptr_value_ids + [
                        self.out_vocab.size + len(constant_memory_features) + x
                        for x in range(num_included_nodes)
                    ])
                    program_field_ptr_value_ids = \
                        vec.vectorize_field_ptr_out(exp.program_singleton_field_tokens,
                                                    exp.program_singleton_field_token_types,
                                                    self.out_vocab, constant_unique_input_ids,
                                                    max_memory_size=len(constant_memory_features),
                                                    schema=schema_graph,
                                                    num_included_nodes=num_included_nodes)
                    decoder_ptr_value_ids.append(program_field_ptr_value_ids)
                else:
                    encoder_ptr_input_ids = [
                        exp.ptr_input_ids for exp in mini_batch
                    ]
                    encoder_ptr_value_ids = [
                        exp.ptr_value_ids for exp in mini_batch
                    ]
                    decoder_ptr_value_ids = [exp.program_text_and_field_ptr_value_ids for exp in mini_batch] \
                        if self.training else None
                    primary_key_ids = [
                        exp.primary_key_ids for exp in mini_batch
                    ]
                    foreign_key_ids = [
                        exp.foreign_key_ids for exp in mini_batch
                    ]
                    field_type_ids = [exp.field_type_ids for exp in mini_batch]
                    table_masks = [exp.table_masks for exp in mini_batch]
                    # TODO: here we assume that all nodes in the schema graph are included
                    table_pos, table_field_scope = schema_graph.get_table_scopes(
                        schema_graph.num_nodes)
                    table_positions.append(table_pos)
                    table_field_scopes.append(table_field_scope)
                    if self.args.read_picklist:
                        transformer_output_value_masks.append(
                            exp.transformer_output_value_mask)

            encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids,
                                                  self.mdl.in_vocab.pad_id)
            encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids,
                                                  self.mdl.in_vocab.pad_id)
            schema_memory_masks = ops.pad_batch(schema_memory_masks, pad_id=0) \
                if (self.args.use_pred_tables and not self.training) else (None, None)
            decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) \
                if self.training else None
            primary_key_ids = ops.pad_batch(primary_key_ids,
                                            self.mdl.in_vocab.pad_id)
            foreign_key_ids = ops.pad_batch(foreign_key_ids,
                                            self.mdl.in_vocab.pad_id)
            field_type_ids = ops.pad_batch(field_type_ids,
                                           self.mdl.in_vocab.pad_id)
            table_masks = ops.pad_batch(table_masks, pad_id=0)
            transformer_output_value_masks = ops.pad_batch(transformer_output_value_masks, pad_id=0, dtype=torch.uint8) \
                if self.args.read_picklist else (None, None)
            if not self.training:
                table_positions = ops.pad_batch(table_positions, pad_id=-1) \
                    if self.args.process_sql_in_execution_order else (None, None)
                table_field_scopes = ops.pad_batch_2D(table_field_scopes, pad_id=0) \
                    if self.args.process_sql_in_execution_order else (None, None)
            graphs = None
            return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \
                   decoder_ptr_value_ids, transformer_output_value_masks, schema_memory_masks, graphs, \
                   (primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions,
                    table_field_scopes, field_table_pos), table_samples
        elif self.model_id in [SEQ2SEQ_PG]:
            encoder_ptr_input_ids = [exp.ptr_input_ids for exp in mini_batch]
            encoder_ptr_value_ids = [exp.ptr_value_ids for exp in mini_batch]
            decoder_ptr_value_ids = [
                exp.program_text_ptr_value_ids for exp in mini_batch
            ]
            encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids,
                                                  self.mdl.in_vocab.pad_id)
            encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids,
                                                  self.mdl.in_vocab.pad_id)
            decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids,
                                                  self.mdl.out_vocab.pad_id)
            return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \
                   decoder_ptr_value_ids
        else:
            raise NotImplementedError
コード例 #6
0
def train(train_data, dev_data):
    # Model
    model_dir = get_model_dir(args)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    trans_checker = TranslatabilityChecker(args)
    trans_checker.cuda()
    ops.initialize_module(trans_checker, 'xavier')

    # Hyperparameters
    batch_size = min(len(train_data), 12)
    num_peek_epochs = 1

    # Loss function
    loss_fun = nn.BCELoss()
    span_extract_pad_id = -100
    span_extract_loss_fun = MaskedCrossEntropyLoss(span_extract_pad_id)

    # Optimizer
    optimizer = optim.Adam(
        [{'params': [p for n, p in trans_checker.named_parameters() if
                     not 'trans_parameters' in n and p.requires_grad]},
         {'params': [p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad],
          'lr': args.bert_finetune_rate}],
        lr=args.learning_rate)
    lr_scheduler = lrs.LinearScheduler(
        optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps],
        args.num_steps)

    best_dev_metrics = 0
    for epoch_id in range(args.num_epochs):
        random.shuffle(train_data)
        trans_checker.train()
        optimizer.zero_grad()

        epoch_losses = []

        for i in tqdm(range(0, len(train_data), batch_size)):
            mini_batch = train_data[i: i + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id)
            encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            target_ids = ops.int_var_cuda([1 if exp.span_ids[0] == 0 else 0 for exp in mini_batch])
            target_span_ids, _ = ops.pad_batch([exp.span_ids for exp in mini_batch], bu.pad_id)
            target_span_ids = target_span_ids * (1 - target_ids.unsqueeze(1)) + \
                              target_ids.unsqueeze(1).expand_as(target_span_ids) * span_extract_pad_id
            output, span_extract_output = trans_checker(encoder_input_ids, text_masks)
            loss = loss_fun(output, target_ids.unsqueeze(1).float())
            span_extract_loss = span_extract_loss_fun(span_extract_output, target_span_ids)
            loss += span_extract_loss
            loss.backward()
            epoch_losses.append(float(loss))

            if args.grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        with torch.no_grad():
            if args.num_epochs % num_peek_epochs == 0:
                stdout_msg = 'Epoch {}: average training loss = {}'.format(epoch_id, np.mean(epoch_losses))
                print(stdout_msg)
                pred_trans, pred_spans = trans_checker.inference(dev_data)
                targets = [1 if exp.span_ids[0] == 0 else 0 for exp in dev_data]
                target_spans = [exp.span_ids for exp in dev_data]
                trans_acc = translatablity_eval(pred_trans, targets)
                print('Dev translatability accuracy = {}'.format(trans_acc))
                if trans_acc > best_dev_metrics:
                    model_path = os.path.join(model_dir, 'model-best.tar')
                    trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path)
                    best_dev_metrics = trans_acc
                span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans)
                print('Dev span accuracy = {}'.format(span_acc))
                print('Dev span precision = {}'.format(prec))
                print('Dev span recall = {}'.format(recall))
                print('Dev span F1 = {}'.format(f1))