def inference(self, dev_data): self.eval() batch_size = min(len(dev_data), 16) outputs, output_spans = [], [] for batch_start_id in tqdm(range(0, len(dev_data), batch_size)): mini_batch = dev_data[batch_start_id: batch_start_id + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id) # [batch_size, 2, encoder_seq_len] output, span_extract_output = self.forward(encoder_input_ids, text_masks) outputs.append(output) encoder_seq_len = span_extract_output.size(2) # [batch_size, encoder_seq_len] start_logit = span_extract_output[:, 0, :] end_logit = span_extract_output[:, 1, :] # [batch_size, encoder_seq_len, encoder_seq_len] span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1) valid_span_pos = ops.ones_var_cuda([len(span_logit), encoder_seq_len, encoder_seq_len]).triu() span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT for i in range(len(mini_batch)): span_pos = span_logit[i].argmax() start = int(span_pos / encoder_seq_len) end = int(span_pos % encoder_seq_len) output_spans.append((start, end)) return torch.cat(outputs), output_spans
def get_decoder_input_ids(): if self.training: if self.model_id in [BRIDGE]: X = [exp.program_singleton_field_input_ids for exp in mini_batch] else: X = [exp.program_input_ids for exp in mini_batch] return ops.pad_batch(X, self.mdl.out_vocab.pad_id) else: return None
def inference(self, dev_data): self.eval() batch_size = 32 output_spans = [] for batch_start_id in tqdm(range(0, len(dev_data), batch_size)): mini_batch = dev_data[batch_start_id:batch_start_id + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch( [exp.ptr_input_ids for exp in mini_batch], bu.pad_id) # [batch_size, 2, encoder_seq_len] output = self.forward(encoder_input_ids, text_masks) encoder_seq_len = output.size(2) # [batch_size, encoder_seq_len] start_logit = output[:, 0, :] end_logit = output[:, 1, :] # [batch_size, encoder_seq_len, encoder_seq_len] span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1) valid_span_pos = ops.ones_var_cuda( [batch_size, encoder_seq_len, encoder_seq_len]).triu() span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT for i in range(len(mini_batch)): span_pos = span_logit[i].argmax() start = int(span_pos / encoder_seq_len) end = int(span_pos % encoder_seq_len) output_spans.append((start, end)) # print(start, end) # print(mini_batch[i].text) # confusion_span_size = end - start + 1 # if start > 0 and confusion_span_size < 5: # print(mini_batch[i].text_tokens[start-1:end]) # print() return output_spans
def train(train_data, dev_data): # Model model_dir = get_model_dir(args) if not os.path.exists(model_dir): os.mkdir(model_dir) trans_checker = TranslatabilityChecker(args) trans_checker.cuda() ops.initialize_module(trans_checker, 'xavier') wandb.init(project='translatability-prediction', name=get_wandb_tag(args)) wandb.watch(trans_checker) # Hyperparameters batch_size = 16 num_peek_epochs = 1 # Loss function # -100 is a dummy padding value since all output spans will be of length 2 loss_fun = MaskedCrossEntropyLoss(-100) # Optimizer optimizer = optim.Adam([{ 'params': [ p for n, p in trans_checker.named_parameters() if not 'trans_parameters' in n and p.requires_grad ] }, { 'params': [ p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad ], 'lr': args.bert_finetune_rate }], lr=args.learning_rate) lr_scheduler = lrs.LinearScheduler( optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps], args.num_steps) best_dev_metrics = 0 for epoch_id in range(args.num_epochs): random.shuffle(train_data) trans_checker.train() optimizer.zero_grad() epoch_losses = [] for i in tqdm(range(0, len(train_data), batch_size)): wandb.log({ 'learning_rate/{}'.format(args.dataset_name): optimizer.param_groups[0]['lr'] }) wandb.log({ 'fine_tuning_rate/{}'.format(args.dataset_name): optimizer.param_groups[1]['lr'] }) mini_batch = train_data[i:i + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch( [exp.ptr_input_ids for exp in mini_batch], bu.pad_id) target_span_ids, _ = ops.pad_batch( [exp.span_ids for exp in mini_batch], bu.pad_id) output = trans_checker(encoder_input_ids, text_masks) loss = loss_fun(output, target_span_ids) loss.backward() epoch_losses.append(float(loss)) if args.grad_norm > 0: nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm) lr_scheduler.step() optimizer.step() optimizer.zero_grad() if args.num_epochs % num_peek_epochs == 0: stdout_msg = 'Epoch {}: average training loss = {}'.format( epoch_id, np.mean(epoch_losses)) print(stdout_msg) wandb.log({ 'cross_entropy_loss/{}'.format(args.dataset_name): np.mean(epoch_losses) }) pred_spans = trans_checker.inference(dev_data) target_spans = [exp.span_ids for exp in dev_data] trans_acc = translatablity_eval(pred_spans, target_spans) print('Dev translatability accuracy = {}'.format(trans_acc)) if trans_acc > best_dev_metrics: model_path = os.path.join(model_dir, 'model-best.tar') trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path) best_dev_metrics = trans_acc span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans) print('Dev span accuracy = {}'.format(span_acc)) print('Dev span precision = {}'.format(prec)) print('Dev span recall = {}'.format(recall)) print('Dev span F1 = {}'.format(f1)) wandb.log({ 'translatability_accuracy/{}'.format(args.dataset_name): trans_acc }) wandb.log({'span_accuracy/{}'.format(args.dataset_name): span_acc}) wandb.log({'span_f1/{}'.format(args.dataset_name): f1})
def format_batch(self, mini_batch): def get_decoder_input_ids(): if self.training: if self.model_id in [BRIDGE]: X = [ exp.program_singleton_field_input_ids for exp in mini_batch ] else: X = [exp.program_input_ids for exp in mini_batch] return ops.pad_batch(X, self.mdl.out_vocab.pad_id) else: return None def get_encoder_attn_mask(table_names, table_masks): schema_pos = [ schema_graph.get_schema_pos(table_name) for table_name in table_names ] encoder_attn_mask = [1 for _ in range(exp.num_text_tokens)] # asterisk marker encoder_attn_mask.append(1) is_selected_table = False for j in range(1, len(table_masks)): if j in schema_pos: encoder_attn_mask.append(1) is_selected_table = True elif table_masks[j] == 1: # mask current table encoder_attn_mask.append(0) is_selected_table = False else: if is_selected_table: encoder_attn_mask.append(1) else: encoder_attn_mask.append(0) return encoder_attn_mask super().format_batch(mini_batch) encoder_input_ids = ops.pad_batch([exp.text_ids for exp in mini_batch], self.mdl.in_vocab.pad_id) decoder_input_ids = get_decoder_input_ids() table_samples = [] if self.model_id == SEQ2SEQ: return encoder_input_ids, decoder_input_ids elif self.model_id in [BRIDGE]: encoder_ptr_input_ids, encoder_ptr_value_ids, decoder_ptr_value_ids = [], [], [] primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, \ field_table_pos, transformer_output_value_masks, schema_memory_masks = [], [], [], [], [], [], [], [], [] for exp in mini_batch: schema_graph = self.schema_graphs.get_schema(exp.db_id) # exp.pretty_print(example_id=0, # schema=schema_graph, # de_vectorize_ptr=vec.de_vectorize_ptr, # de_vectorize_field_ptr=vec.de_vectorize_field_ptr, # rev_vocab=self.out_vocab, # post_process=self.output_post_process, # use_table_aware_te=(self.model_id in [BRIDGE])) # import pdb # pdb.set_trace() if self.training: # Compute schema layout if exp.gt_table_names_list: gt_tables = set([ schema_graph.get_table_id(t_name) for t_name in exp.gt_table_names ]) else: gt_table_names = [ token for token, t in zip( exp.program_singleton_field_tokens, exp.program_singleton_field_token_types) if t == 0 ] gt_tables = set([ schema_graph.get_table_id(t_name) for t_name in gt_table_names ]) # [Hack] Baseball database has a complex schema which does not fit the input size of BERT. We select # the ground truth tables and randomly add a few other tables for training. if schema_graph.name.startswith('baseball'): tables = list(gt_tables) tables += random.sample( [ i for i in range(schema_graph.num_tables) if i not in gt_tables ], k=min(random.randint(1, 7), schema_graph.num_tables - len(gt_tables))) else: tables = list(range(schema_graph.num_tables)) if self.args.table_shuffling: table_to_drop = random.choice(tables) if table_to_drop not in gt_tables: if random.uniform(0, 1) < 0.3: tables = [ x for x in tables if x != table_to_drop ] table_po, field_po = schema_graph.get_schema_perceived_order( tables, random_table_order=True, random_field_order=self.args.random_field_order) else: table_po, field_po = schema_graph.get_schema_perceived_order( tables, random_table_order=False, random_field_order=self.args.random_field_order) # Schema feature extraction question_encoding = exp.text if self.args.use_picklist else None schema_features, matched_values = schema_graph.get_serialization( self.tu, flatten_features=True, table_po=table_po, field_po=field_po, use_typed_field_markers=self.args. use_typed_field_markers, use_graph_encoding=self.args.use_graph_encoding, question_encoding=question_encoding, top_k_matches=self.args.top_k_picklist_matches, num_values_per_field=self.args.num_values_per_field, no_anchor_text=self.args.no_anchor_text, verbose=False) ptr_input_tokens, ptr_input_values, num_excluded_tables, num_excluded_fields = \ get_table_aware_transformer_encoder_inputs( exp.text_ptr_values, exp.text_tokens, schema_features, self.tu) assert (len(ptr_input_tokens) <= self.tu.tokenizer.max_len) if num_excluded_fields > 0: print('Warning: training input truncated') num_included_nodes = schema_graph.get_num_perceived_nodes(tables) + 1 \ - num_excluded_tables - num_excluded_fields encoder_ptr_input_ids.append( self.tu.tokenizer.convert_tokens_to_ids( ptr_input_tokens)) if self.args.read_picklist: exp.transformer_output_value_mask, value_features, value_tokens = \ get_transformer_output_value_mask(ptr_input_tokens, matched_values, self.tu) transformer_output_value_masks.append( exp.transformer_output_value_mask) primary_key_ids.append( schema_graph.get_primary_key_ids( num_included_nodes, table_po, field_po)) foreign_key_ids.append( schema_graph.get_foreign_key_ids( num_included_nodes, table_po, field_po)) field_type_ids.append( schema_graph.get_field_type_ids( num_included_nodes, table_po, field_po)) table_masks.append( schema_graph.get_table_masks(num_included_nodes, table_po, field_po)) # Value copy feature extraction if self.args.read_picklist: constant_memory_features = exp.text_tokens + value_features constant_memory = exp.text_ptr_values + value_tokens exp.text_ptr_values = constant_memory else: constant_memory_features = exp.text_tokens constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in( constant_memory_features, self.out_vocab) encoder_ptr_value_ids.append(constant_ptr_value_ids + [ self.out_vocab.size + len(constant_memory_features) + x for x in range(num_included_nodes) ]) program_field_ptr_value_ids = \ vec.vectorize_field_ptr_out(exp.program_singleton_field_tokens, exp.program_singleton_field_token_types, self.out_vocab, constant_unique_input_ids, max_memory_size=len(constant_memory_features), schema=schema_graph, num_included_nodes=num_included_nodes) decoder_ptr_value_ids.append(program_field_ptr_value_ids) else: encoder_ptr_input_ids = [ exp.ptr_input_ids for exp in mini_batch ] encoder_ptr_value_ids = [ exp.ptr_value_ids for exp in mini_batch ] decoder_ptr_value_ids = [exp.program_text_and_field_ptr_value_ids for exp in mini_batch] \ if self.training else None primary_key_ids = [ exp.primary_key_ids for exp in mini_batch ] foreign_key_ids = [ exp.foreign_key_ids for exp in mini_batch ] field_type_ids = [exp.field_type_ids for exp in mini_batch] table_masks = [exp.table_masks for exp in mini_batch] # TODO: here we assume that all nodes in the schema graph are included table_pos, table_field_scope = schema_graph.get_table_scopes( schema_graph.num_nodes) table_positions.append(table_pos) table_field_scopes.append(table_field_scope) if self.args.read_picklist: transformer_output_value_masks.append( exp.transformer_output_value_mask) encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids, self.mdl.in_vocab.pad_id) encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids, self.mdl.in_vocab.pad_id) schema_memory_masks = ops.pad_batch(schema_memory_masks, pad_id=0) \ if (self.args.use_pred_tables and not self.training) else (None, None) decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) \ if self.training else None primary_key_ids = ops.pad_batch(primary_key_ids, self.mdl.in_vocab.pad_id) foreign_key_ids = ops.pad_batch(foreign_key_ids, self.mdl.in_vocab.pad_id) field_type_ids = ops.pad_batch(field_type_ids, self.mdl.in_vocab.pad_id) table_masks = ops.pad_batch(table_masks, pad_id=0) transformer_output_value_masks = ops.pad_batch(transformer_output_value_masks, pad_id=0, dtype=torch.uint8) \ if self.args.read_picklist else (None, None) if not self.training: table_positions = ops.pad_batch(table_positions, pad_id=-1) \ if self.args.process_sql_in_execution_order else (None, None) table_field_scopes = ops.pad_batch_2D(table_field_scopes, pad_id=0) \ if self.args.process_sql_in_execution_order else (None, None) graphs = None return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \ decoder_ptr_value_ids, transformer_output_value_masks, schema_memory_masks, graphs, \ (primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, field_table_pos), table_samples elif self.model_id in [SEQ2SEQ_PG]: encoder_ptr_input_ids = [exp.ptr_input_ids for exp in mini_batch] encoder_ptr_value_ids = [exp.ptr_value_ids for exp in mini_batch] decoder_ptr_value_ids = [ exp.program_text_ptr_value_ids for exp in mini_batch ] encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids, self.mdl.in_vocab.pad_id) encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids, self.mdl.in_vocab.pad_id) decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \ decoder_ptr_value_ids else: raise NotImplementedError
def train(train_data, dev_data): # Model model_dir = get_model_dir(args) if not os.path.exists(model_dir): os.mkdir(model_dir) trans_checker = TranslatabilityChecker(args) trans_checker.cuda() ops.initialize_module(trans_checker, 'xavier') # Hyperparameters batch_size = min(len(train_data), 12) num_peek_epochs = 1 # Loss function loss_fun = nn.BCELoss() span_extract_pad_id = -100 span_extract_loss_fun = MaskedCrossEntropyLoss(span_extract_pad_id) # Optimizer optimizer = optim.Adam( [{'params': [p for n, p in trans_checker.named_parameters() if not 'trans_parameters' in n and p.requires_grad]}, {'params': [p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad], 'lr': args.bert_finetune_rate}], lr=args.learning_rate) lr_scheduler = lrs.LinearScheduler( optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps], args.num_steps) best_dev_metrics = 0 for epoch_id in range(args.num_epochs): random.shuffle(train_data) trans_checker.train() optimizer.zero_grad() epoch_losses = [] for i in tqdm(range(0, len(train_data), batch_size)): mini_batch = train_data[i: i + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id) target_ids = ops.int_var_cuda([1 if exp.span_ids[0] == 0 else 0 for exp in mini_batch]) target_span_ids, _ = ops.pad_batch([exp.span_ids for exp in mini_batch], bu.pad_id) target_span_ids = target_span_ids * (1 - target_ids.unsqueeze(1)) + \ target_ids.unsqueeze(1).expand_as(target_span_ids) * span_extract_pad_id output, span_extract_output = trans_checker(encoder_input_ids, text_masks) loss = loss_fun(output, target_ids.unsqueeze(1).float()) span_extract_loss = span_extract_loss_fun(span_extract_output, target_span_ids) loss += span_extract_loss loss.backward() epoch_losses.append(float(loss)) if args.grad_norm > 0: nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm) lr_scheduler.step() optimizer.step() optimizer.zero_grad() with torch.no_grad(): if args.num_epochs % num_peek_epochs == 0: stdout_msg = 'Epoch {}: average training loss = {}'.format(epoch_id, np.mean(epoch_losses)) print(stdout_msg) pred_trans, pred_spans = trans_checker.inference(dev_data) targets = [1 if exp.span_ids[0] == 0 else 0 for exp in dev_data] target_spans = [exp.span_ids for exp in dev_data] trans_acc = translatablity_eval(pred_trans, targets) print('Dev translatability accuracy = {}'.format(trans_acc)) if trans_acc > best_dev_metrics: model_path = os.path.join(model_dir, 'model-best.tar') trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path) best_dev_metrics = trans_acc span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans) print('Dev span accuracy = {}'.format(span_acc)) print('Dev span precision = {}'.format(prec)) print('Dev span recall = {}'.format(recall)) print('Dev span F1 = {}'.format(f1))