class COPAModel(nn.Module): def __init__(self, model='bert', model_size='base', pool_method='avg', emb_size=10, **kwargs): super(COPAModel, self).__init__() self.pool_method = pool_method self.encoder = Encoder(model=model, model_size=model_size, fine_tune=False, cased=False) self.rel_type_emb = nn.Embedding(2, emb_size) sent_repr_size = get_repr_size(self.encoder.hidden_size, method=pool_method) self.label_net = nn.Sequential( nn.Linear(3 * sent_repr_size + emb_size, 1), nn.Sigmoid() ) self.training_criterion = nn.BCELoss() def get_other_params(self): core_encoder_param_names = set() for name, param in self.encoder.model.named_parameters(): if param.requires_grad: core_encoder_param_names.add(name) other_params = [] print("\nParams outside core transformer params:\n") for name, param in self.named_parameters(): if param.requires_grad and name not in core_encoder_param_names: print(name, param.data.size()) other_params.append(param) print("\n") return other_params def get_core_params(self): return self.encoder.model.parameters() def forward(self, batch_data): event, event_lens = batch_data.event event_emb = self.encoder.get_sentence_repr( self.encoder(event.cuda()), event_lens.cuda(), method=self.pool_method) hyp_event, hyp_event_lens = batch_data.hyp_event hyp_event_emb = self.encoder.get_sentence_repr( self.encoder(hyp_event.cuda()), hyp_event_lens.cuda(), method=self.pool_method) rel_type_emb = self.rel_type_emb(batch_data.type_event.cuda()) pred_label = self.label_net(torch.cat([event_emb, hyp_event_emb, event_emb * hyp_event_emb, rel_type_emb], dim=-1)) pred_label = torch.squeeze(pred_label, dim=-1) loss = self.training_criterion(pred_label, batch_data.label.cuda().float()) if self.training: return loss else: return loss, pred_label
def __init__(self, model='bert', model_size='base', pool_method='avg', emb_size=10, **kwargs): super(COPAModel, self).__init__() self.pool_method = pool_method self.encoder = Encoder(model=model, model_size=model_size, fine_tune=False, cased=False) self.rel_type_emb = nn.Embedding(2, emb_size) sent_repr_size = get_repr_size(self.encoder.hidden_size, method=pool_method) self.label_net = nn.Sequential( nn.Linear(3 * sent_repr_size + emb_size, 1), nn.Sigmoid() ) self.training_criterion = nn.BCELoss()
def __init__(self, model='bert', model_size='base', just_last_layer=False, span_dim=256, pool_method='avg', fine_tune=False, num_spans=1, **kwargs): super(CorefModel, self).__init__() self.pool_method = pool_method self.num_spans = num_spans self.just_last_layer = just_last_layer self.encoder = Encoder(model=model, model_size=model_size, fine_tune=True, cased=True) self.span_net = nn.ModuleDict() self.span_net['0'] = get_span_module( method=pool_method, input_dim=self.encoder.hidden_size, use_proj=True, proj_dim=span_dim) self.pooled_dim = self.span_net['0'].get_output_dim() self.label_net = nn.Sequential( nn.Linear(2 * self.pooled_dim, span_dim), nn.Tanh(), nn.LayerNorm(span_dim), nn.Dropout(0.2), nn.Linear(span_dim, 1), nn.Sigmoid()) self.training_criterion = nn.BCELoss()
val_iter, test_iter = data.BucketIterator.splits( (val, test), batch_size=eval_batch_size, sort_within_batch=True, shuffle=False, repeat=False) label_field.build_vocab(train) num_labels = len(label_field.vocab.itos) return (train_iter, val_iter, test_iter, num_labels) if __name__ == '__main__': from encoders.pretrained_transformers import Encoder encoder = Encoder(cased=True) path = "/home/shtoshni/Research/hackathon_2019/tasks/srl/data" train_iter, val_iter, test_iter, num_labels = SRLDataset.iters( path, encoder, train_frac=1.0) print("Train size:", len(train_iter.data())) print("Val size:", len(val_iter.data())) print("Test size:", len(test_iter.data())) for batch_data in train_iter: print(batch_data.text[0].shape) print(batch_data.span1.shape) print(batch_data.span2.shape) print(batch_data.label.shape) text, text_len = batch_data.text
def main(): hp = parse_args() # Setup model directories model_name = get_model_name(hp) model_path = path.join(hp.model_dir, model_name) best_model_path = path.join(model_path, 'best_models') if not path.exists(model_path): os.makedirs(model_path) if not path.exists(best_model_path): os.makedirs(best_model_path) # Set random seed torch.manual_seed(hp.seed) # Hacky way of assigning the number of labels. encoder = Encoder(model=hp.model, model_size=hp.model_size, fine_tune=hp.fine_tune, cased=True) # Load data logging.info("Loading data") train_iter, val_iter, test_iter, num_labels = SRLDataset.iters( hp.data_dir, encoder, batch_size=hp.batch_size, eval_batch_size=hp.eval_batch_size, train_frac=hp.train_frac) logging.info("Data loaded") # Initialize the model model = SRLModel(encoder, num_labels=num_labels, **vars(hp)).cuda() sys.stdout.flush() if not hp.fine_tune: optimizer = torch.optim.Adam(model.get_other_params(), lr=hp.lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.5, verbose=True) steps_done = 0 max_f1 = 0 init_num_stuck_evals = 0 num_steps = (hp.n_epochs * len(train_iter.data())) // hp.real_batch_size # Quantize the number of training steps to eval steps num_steps = (num_steps // hp.eval_steps) * hp.eval_steps logging.info("Total training steps: %d" % num_steps) location = path.join(model_path, "model.pt") if path.exists(location): logging.info("Loading previous checkpoint") checkpoint = torch.load(location) model.encoder.weighing_params = checkpoint['weighing_params'] model.span_net.load_state_dict(checkpoint['span_net']) model.label_net.load_state_dict(checkpoint['label_net']) if hp.fine_tune: model.encoder.model.load_state_dict(checkpoint['encoder']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) steps_done = checkpoint['steps_done'] init_num_stuck_evals = checkpoint['num_stuck_evals'] max_f1 = checkpoint['max_f1'] torch.set_rng_state(checkpoint['rng_state']) logging.info("Steps done: %d, Max F1: %.3f" % (steps_done, max_f1)) if not hp.eval: train(hp, model, train_iter, val_iter, optimizer, scheduler, model_path, best_model_path, init_steps=steps_done, max_f1=max_f1, eval_steps=hp.eval_steps, num_steps=num_steps, init_num_stuck_evals=init_num_stuck_evals) val_f1, test_f1 = final_eval(hp, model, best_model_path, val_iter, test_iter) perf_dir = path.join(hp.model_dir, "perf") if not path.exists(perf_dir): os.makedirs(perf_dir) if hp.slurm_id: perf_file = path.join(perf_dir, hp.slurm_id + ".txt") else: perf_file = path.join(model_path, "perf.txt") with open(perf_file, "w") as f: f.write("%s\n" % (model_path)) f.write("%s\t%.4f\n" % ("Valid", val_f1)) f.write("%s\t%.4f\n" % ("Test", test_f1))
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler = logging.FileHandler( os.path.join(args.model_path, args.model_name + '.log'), 'a') handler.setLevel(logging.INFO) handler.setFormatter(formatter) logger.addHandler(handler) console = logging.StreamHandler() console.setLevel(logging.INFO) console.setFormatter(formatter) logger.addHandler(console) logger.propagate = False # create data sets, tokenizers, and data loaders encoder = Encoder(args.model_type, args.model_size, args.cased, use_proj=args.use_proj, proj_dim=args.proj_dim) data_loader_path = os.path.join(args.model_path, args.model_name + '.loader.pt') if os.path.exists(data_loader_path): logger.info('Loading datasets.') data_info = torch.load(data_loader_path) data_loader = data_info['data_loader'] ConstituentDataset.label_dict = data_info['label_dict'] ConstituentDataset.encoder = encoder else: logger.info('Creating datasets.') data_set = dict() data_loader = dict() for split in ['train', 'development', 'test']:
shuffle=True, repeat=False) val_iter, test_iter = data.BucketIterator.splits( (val, test), batch_size=eval_batch_size, sort_within_batch=True, shuffle=False, repeat=False) return (train_iter, val_iter, test_iter) if __name__ == '__main__': from encoders.pretrained_transformers import Encoder encoder = Encoder(cased=False) path = "/home/shtoshni/Research/hackathon_2019/tasks/coref/data" train_iter, val_iter, test_iter = CorefDataset.iters(path, encoder, train_frac=1.0) print("Train size:", len(train_iter.data())) print("Val size:", len(val_iter.data())) print("Test size:", len(test_iter.data())) for batch_data in train_iter: print(batch_data.text[0].shape) print(batch_data.span1.shape) print(batch_data.span2.shape) print(batch_data.label.shape)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler = logging.FileHandler( os.path.join(args.model_path, args.model_name + '.log'), 'a') handler.setLevel(logging.INFO) handler.setFormatter(formatter) logger.addHandler(handler) console = logging.StreamHandler() console.setLevel(logging.INFO) console.setFormatter(formatter) logger.addHandler(console) logger.propagate = False # create data sets, tokenizers, and data loaders encoder = Encoder(args.model_type, args.model_size, args.cased, use_proj=False, fine_tune=args.fine_tune) data_loader_path = os.path.join(args.model_path, args.model_name + '.loader.pt') if os.path.exists(data_loader_path): logger.info('Loading datasets.') data_info = torch.load(data_loader_path) data_loader = data_info['data_loader'] ### To be removed for split in ['train', 'development', 'test']: data_loader[split] = DataLoader(data_loader[split].dataset, args.batch_size, collate_fn=collate_fn, shuffle=(split == 'train')) ### End to be removed
def collate_fn(data): sents, spans, labels = list(zip(*data)) max_length = max(item.shape[1] for item in sents) pad_id = ConstituentDataset.encoder.tokenizer.pad_token_id batch_size = len(sents) padded_sents = pad_id * torch.ones(batch_size, max_length).long() for i, sent in enumerate(sents): padded_sents[i, :sent.shape[1]] = sent[0, :] spans = torch.cat(spans, dim=0) labels = torch.tensor(labels) return padded_sents, spans, labels # unit test if __name__ == '__main__': from torch.utils.data import DataLoader from encoders.pretrained_transformers import Encoder encoder = Encoder('bert', 'base', True) for split in ['train', 'development', 'test']: dataset = ConstituentDataset( f'tasks/constclass/data/debug/{split}.json', encoder) data_loader = DataLoader(dataset, 64, collate_fn=collate_fn) for sents, spans, labels in data_loader: pass print(f'Split "{split}" has passed the unit test ' f'with {len(dataset)} instances.') from IPython import embed embed(using=False)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler = logging.FileHandler( os.path.join(args.model_path, args.model_name + '.log'), 'a') handler.setLevel(logging.INFO) handler.setFormatter(formatter) logger.addHandler(handler) console = logging.StreamHandler() console.setLevel(logging.INFO) console.setFormatter(formatter) logger.addHandler(console) logger.propagate = False # create data sets, tokenizers, and data loaders encoder = Encoder(args.model_type, args.model_size, args.cased, use_proj=args.use_proj, proj_dim=args.proj_dim) data_loader_path = os.path.join(args.model_path, args.model_name + '.loader.pt') if os.path.exists(data_loader_path): logger.info('Loading datasets.') data_info = torch.load(data_loader_path) data_loader = data_info['data_loader'] ConstituentDataset.label_dict = data_info['label_dict'] ConstituentDataset.encoder = encoder else: logger.info('Creating datasets.') data_set = dict() data_loader = dict()