def __init__(self, config): super(SeqSupervisedNetwork, self).__init__() self.base_path = config['base_path'] self.early_stopping = config['early_stopping'] self.lr = config.get('meta_lr', 1e-3) self.weight_decay = config.get('meta_weight_decay', 0.0) if 'seq' in config['learner_model']: self.learner = RNNSequenceModel(config['learner_params']) elif 'mlp' in config['learner_model']: self.learner = MLPModel(config['learner_params']) elif 'bert' in config['learner_model']: self.learner = BERTSequenceModel(config['learner_params']) self.dropout = nn.Dropout(config['learner_params']['dropout_ratio']) self.classifier = nn.Linear( config['learner_params']['embed_dim'], config['learner_params']['num_outputs']['ner']) self.num_outputs = config['learner_params']['num_outputs'] self.vectors = config.get('vectors', 'glove') if self.vectors == 'elmo': self.elmo = Elmo( options_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') self.loss_fn = {} for task in config['learner_params']['num_outputs']: self.loss_fn[task] = nn.CrossEntropyLoss(ignore_index=-1) if config.get('trained_learner', False): self.learner.load_state_dict( torch.load( os.path.join(self.base_path, 'saved_models', config['trained_learner']))) self.classifier.load_state_dict( torch.load( os.path.join(self.base_path, 'saved_models', config['trained_classifier']))) logger.info('Loaded trained learner model {}'.format( config['trained_learner'])) self.device = torch.device(config.get('device', 'cpu')) self.to(self.device) if self.vectors == 'elmo': self.elmo.to(self.device) self.initialize_optimizer_scheduler()
def __init__(self, config): super(SeqMetaModel, self).__init__() self.base_path = config['base_path'] self.learner_lr = config.get('learner_lr', 1e-3) self.output_lr = config.get('output_lr', 0.1) if 'seq' in config['learner_model']: self.learner = RNNSequenceModel(config['learner_params']) elif 'mlp' in config['learner_model']: self.learner = MLPModel(config['learner_params']) elif 'bert' in config['learner_model']: self.learner = BERTSequenceModel(config['learner_params']) self.proto_maml = config.get('proto_maml', False) self.fomaml = config.get('fomaml', False) self.vectors = config.get('vectors', 'glove') if self.vectors == 'elmo': self.elmo = Elmo( options_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') self.learner_loss = {} for task in config['learner_params']['num_outputs']: self.learner_loss[task] = nn.CrossEntropyLoss(ignore_index=-1) self.output_layer_weight = None self.output_layer_bias = None if config.get('trained_learner', False): self.learner.load_state_dict( torch.load( os.path.join(self.base_path, 'saved_models', config['trained_learner']))) logger.info('Loaded trained learner model {}'.format( config['trained_learner'])) self.device = torch.device(config.get('device', 'cpu')) self.to(self.device) if self.proto_maml: logger.info( 'Initialization of output layer weights as per prototypical networks turned on' ) params = [p for p in self.learner.parameters() if p.requires_grad] self.learner_optimizer = optim.SGD(params, lr=self.learner_lr)
def __init__(self, config): self.vectors = config.get('vectors', 'elmo') self.device = torch.device(config.get('device', 'cpu')) if self.vectors == 'elmo': self.elmo = Elmo( options_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) self.elmo.to(self.device) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') self.bert = BERTSequenceModel(config['learner_params']) self.bert.to(self.device) logger.info('Nearest neighbor classifier instantiated')
class SeqPrototypicalNetwork(nn.Module): def __init__(self, config): super(SeqPrototypicalNetwork, self).__init__() self.base_path = config['base_path'] self.early_stopping = config['early_stopping'] self.lr = config.get('meta_lr', 1e-3) self.weight_decay = config.get('meta_weight_decay', 0.0) if 'seq' in config['learner_model']: self.learner = RNNSequenceModel(config['learner_params']) elif 'mlp' in config['learner_model']: self.learner = MLPModel(config['learner_params']) elif 'bert' in config['learner_model']: self.learner = BERTSequenceModel(config['learner_params']) self.num_outputs = config['learner_params']['num_outputs'] self.vectors = config.get('vectors', 'glove') if self.vectors == 'elmo': self.elmo = Elmo(options_file="https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file="https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') self.loss_fn = {} for task in config['learner_params']['num_outputs']: self.loss_fn[task] = nn.CrossEntropyLoss(ignore_index=-1) if config.get('trained_learner', False): self.learner.load_state_dict(torch.load( os.path.join(self.base_path, 'saved_models', config['trained_learner']) )) logger.info('Loaded trained learner model {}'.format(config['trained_learner'])) self.device = torch.device(config.get('device', 'cpu')) self.to(self.device) if self.vectors == 'elmo': self.elmo.to(self.device) self.initialize_optimizer_scheduler() def initialize_optimizer_scheduler(self): learner_params = [p for p in self.learner.parameters() if p.requires_grad] if isinstance(self.learner, BERTSequenceModel): self.optimizer = AdamW(learner_params, lr=self.lr, weight_decay=self.weight_decay) self.lr_scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=100) else: self.optimizer = optim.Adam(learner_params, lr=self.lr, weight_decay=self.weight_decay) self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=500, gamma=0.5) def vectorize(self, batch_x, batch_len, batch_y): with torch.no_grad(): if self.vectors == 'elmo': char_ids = batch_to_ids(batch_x) char_ids = char_ids.to(self.device) batch_x = self.elmo(char_ids)['elmo_representations'][0] elif self.vectors == 'glove': max_batch_len = max(batch_len) vec_batch_x = torch.ones((len(batch_x), max_batch_len, 300)) for i, sent in enumerate(batch_x): sent_emb = self.glove.get_vecs_by_tokens(sent, lower_case_backup=True) vec_batch_x[i, :len(sent_emb)] = sent_emb batch_x = vec_batch_x.to(self.device) elif self.vectors == 'bert': max_batch_len = max(batch_len) + 2 input_ids = torch.zeros((len(batch_x), max_batch_len)).long() for i, sent in enumerate(batch_x): sent_token_ids = self.bert_tokenizer.encode(sent, add_special_tokens=True) input_ids[i, :len(sent_token_ids)] = torch.tensor(sent_token_ids) batch_x = input_ids.to(self.device) batch_len = torch.tensor(batch_len).to(self.device) batch_y = torch.tensor(batch_y).to(self.device) return batch_x, batch_len, batch_y def forward(self, episodes, updates=1, testing=False): query_losses, query_accuracies, query_precisions, query_recalls, query_f1s = [], [], [], [], [] n_episodes = len(episodes) for episode_id, episode in enumerate(episodes): batch_x, batch_len, batch_y = next(iter(episode.support_loader)) batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y) self.train() support_repr, support_label = [], [] batch_x_repr = self.learner(batch_x, batch_len) support_repr.append(batch_x_repr) support_label.append(batch_y) prototypes = self._build_prototypes(support_repr, support_label, episode.n_classes) # Run on query query_loss = 0.0 all_predictions, all_labels = [], [] for module in self.learner.modules(): if isinstance(module, nn.Dropout): module.eval() for n_batch, (batch_x, batch_len, batch_y) in enumerate(episode.query_loader): batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y) batch_x_repr = self.learner(batch_x, batch_len) output = self._normalized_distances(prototypes, batch_x_repr) output = output.view(output.size()[0] * output.size()[1], -1) batch_y = batch_y.view(-1) loss = self.loss_fn[episode.base_task](output, batch_y) query_loss += loss.item() if not testing: self.optimizer.zero_grad() loss.backward(retain_graph=True) self.optimizer.step() self.lr_scheduler.step() relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach() all_predictions.extend(make_prediction(output[relevant_indices]).cpu()) all_labels.extend(batch_y[relevant_indices].cpu()) query_loss /= n_batch + 1 # Calculate metrics accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions, all_labels, binary=False) logger.info('Episode {}/{}, task {} [query set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, ' 'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id, query_loss, accuracy, precision, recall, f1_score)) query_losses.append(query_loss) query_accuracies.append(accuracy) query_precisions.append(precision) query_recalls.append(recall) query_f1s.append(f1_score) return query_losses, query_accuracies, query_precisions, query_recalls, query_f1s def _build_prototypes(self, data_repr, data_label, num_outputs): print (data_repr[0].shape) n_dim = data_repr[0].shape[2] data_repr = torch.cat(tuple([x.view(-1, n_dim) for x in data_repr]), dim=0) data_label = torch.cat(tuple([y.view(-1) for y in data_label]), dim=0) prototypes = torch.zeros((num_outputs, n_dim), device=self.device) for c in range(num_outputs): idx = torch.nonzero(data_label == c).view(-1) if idx.nelement() != 0: prototypes[c] = torch.mean(data_repr[idx], dim=0) return prototypes def _normalized_distances(self, prototypes, q): d = torch.stack( tuple([q.sub(p).pow(2).sum(dim=-1) for p in prototypes]), dim=-1 ) return d.neg()
class SeqMetaModel(nn.Module): def __init__(self, config): super(SeqMetaModel, self).__init__() self.base_path = config['base_path'] self.learner_lr = config.get('learner_lr', 1e-3) self.output_lr = config.get('output_lr', 0.1) if 'seq' in config['learner_model']: self.learner = RNNSequenceModel(config['learner_params']) elif 'mlp' in config['learner_model']: self.learner = MLPModel(config['learner_params']) elif 'bert' in config['learner_model']: self.learner = BERTSequenceModel(config['learner_params']) self.proto_maml = config.get('proto_maml', False) self.fomaml = config.get('fomaml', False) self.vectors = config.get('vectors', 'glove') if self.vectors == 'elmo': self.elmo = Elmo(options_file="https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file="https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') self.learner_loss = {} for task in config['learner_params']['num_outputs']: self.learner_loss[task] = nn.CrossEntropyLoss(ignore_index=-1) self.output_layer_weight = None self.output_layer_bias = None if config.get('trained_learner', False): self.learner.load_state_dict(torch.load( os.path.join(self.base_path, 'saved_models', config['trained_learner']) )) logger.info('Loaded trained learner model {}'.format(config['trained_learner'])) self.device = torch.device(config.get('device', 'cpu')) self.to(self.device) if self.proto_maml: logger.info('Initialization of output layer weights as per prototypical networks turned on') params = [p for p in self.learner.parameters() if p.requires_grad] self.learner_optimizer = optim.SGD(params, lr=self.learner_lr) def vectorize(self, batch_x, batch_len, batch_y): with torch.no_grad(): if self.vectors == 'elmo': char_ids = batch_to_ids(batch_x) char_ids = char_ids.to(self.device) batch_x = self.elmo(char_ids)['elmo_representations'][0] elif self.vectors == 'glove': max_batch_len = max(batch_len) vec_batch_x = torch.ones((len(batch_x), max_batch_len, 300)) for i, sent in enumerate(batch_x): sent_emb = self.glove.get_vecs_by_tokens(sent, lower_case_backup=True) vec_batch_x[i, :len(sent_emb)] = sent_emb batch_x = vec_batch_x.to(self.device) elif self.vectors == 'bert': max_batch_len = max(batch_len) + 2 input_ids = torch.zeros((len(batch_x), max_batch_len)).long() for i, sent in enumerate(batch_x): sent_token_ids = self.bert_tokenizer.encode(sent, add_special_tokens=True) input_ids[i, :len(sent_token_ids)] = torch.tensor(sent_token_ids) batch_x = input_ids.to(self.device) batch_len = torch.tensor(batch_len).to(self.device) batch_y = torch.tensor(batch_y).to(self.device) return batch_x, batch_len, batch_y def forward(self, episodes, updates=1, testing=False): support_losses = [] query_losses, query_accuracies, query_precisions, query_recalls, query_f1s = [], [], [], [], [] n_episodes = len(episodes) for episode_id, episode in enumerate(episodes): self.initialize_output_layer(episode.n_classes) batch_x, batch_len, batch_y = next(iter(episode.support_loader)) batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y) if self.proto_maml: output_repr = self.learner(batch_x, batch_len) init_weights, init_bias = self._initialize_with_proto_weights(output_repr, batch_y, episode.n_classes) else: init_weights, init_bias = 0, 0 with torch.backends.cudnn.flags(enabled=self.fomaml or testing or not isinstance(self.learner, RNNSequenceModel)), \ higher.innerloop_ctx(self.learner, self.learner_optimizer, copy_initial_weights=False, track_higher_grads=(not self.fomaml and not testing)) as (flearner, diffopt): all_predictions, all_labels = [], [] self.train() flearner.train() flearner.zero_grad() for i in range(updates): output = flearner(batch_x, batch_len) output = self.output_layer(output, init_weights, init_bias) output = output.view(output.size()[0] * output.size()[1], -1) batch_y = batch_y.view(-1) loss = self.learner_loss[episode.base_task](output, batch_y) # Update the output layer parameters output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias], retain_graph=True) self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad # Update the shared parameters diffopt.step(loss) relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach() pred = make_prediction(output[relevant_indices].detach()).cpu() all_predictions.extend(pred) all_labels.extend(batch_y[relevant_indices].cpu()) support_loss = loss.item() accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions, all_labels, binary=False) logger.info('Episode {}/{}, task {} [support_set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, ' 'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id, support_loss, accuracy, precision, recall, f1_score)) query_loss = 0.0 all_predictions, all_labels = [], [] # Disable dropout for module in flearner.modules(): if isinstance(module, nn.Dropout): module.eval() for n_batch, (batch_x, batch_len, batch_y) in enumerate(episode.query_loader): batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y) output = flearner(batch_x, batch_len) output = self.output_layer(output, init_weights, init_bias) output = output.view(output.size()[0] * output.size()[1], -1) batch_y = batch_y.view(-1) loss = self.learner_loss[episode.base_task](output, batch_y) if not testing: if self.fomaml: meta_grads = torch.autograd.grad(loss, [p for p in flearner.parameters() if p.requires_grad], retain_graph=self.proto_maml) else: meta_grads = torch.autograd.grad(loss, [p for p in flearner.parameters(time=0) if p.requires_grad], retain_graph=self.proto_maml) if self.proto_maml: proto_grads = torch.autograd.grad(loss, [p for p in self.learner.parameters() if p.requires_grad]) meta_grads = [mg + pg for (mg, pg) in zip(meta_grads, proto_grads)] query_loss += loss.item() relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach() pred = make_prediction(output[relevant_indices].detach()).cpu() all_predictions.extend(pred) all_labels.extend(batch_y[relevant_indices].cpu()) query_loss /= n_batch + 1 accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions, all_labels, binary=False) logger.info('Episode {}/{}, task {} [query set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, ' 'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id, query_loss, accuracy, precision, recall, f1_score)) support_losses.append(support_loss) query_losses.append(query_loss) query_accuracies.append(accuracy) query_precisions.append(precision) query_recalls.append(recall) query_f1s.append(f1_score) if not testing: for param, meta_grad in zip([p for p in self.learner.parameters() if p.requires_grad], meta_grads): if param.grad is not None: param.grad += meta_grad.detach() else: param.grad = meta_grad.detach() # Average the accumulated gradients if not testing: for param in self.learner.parameters(): if param.requires_grad: param.grad /= len(query_accuracies) if testing: return support_losses, query_accuracies, query_precisions, query_recalls, query_f1s else: return query_losses, query_accuracies, query_precisions, query_recalls, query_f1s def initialize_output_layer(self, n_classes): if isinstance(self.learner, RNNSequenceModel): stdv = 1.0 / math.sqrt(self.learner.hidden_size // 4) self.output_layer_weight = -2 * stdv * torch.rand((n_classes, self.learner.hidden_size // 4), device=self.device) + stdv self.output_layer_bias = -2 * stdv * torch.rand(n_classes, device=self.device) + stdv elif isinstance(self.learner, MLPModel) or isinstance(self.learner, BERTSequenceModel): stdv = 1.0 / math.sqrt(self.learner.hidden_size) self.output_layer_weight = -2 * stdv * torch.rand((n_classes, self.learner.hidden_size), device=self.device) + stdv self.output_layer_bias = -2 * stdv * torch.rand(n_classes, device=self.device) + stdv self.output_layer_weight.requires_grad = True self.output_layer_bias.requires_grad = True def _initialize_with_proto_weights(self, support_repr, support_label, n_classes): prototypes = self._build_prototypes(support_repr, support_label, n_classes) weight = 2 * prototypes bias = -torch.norm(prototypes, dim=1)**2 self.output_layer_weight = torch.zeros_like(weight, requires_grad=True) self.output_layer_bias = torch.zeros_like(bias, requires_grad=True) return weight, bias def _build_prototypes(self, data_repr, data_label, num_outputs): n_dim = data_repr.shape[2] data_repr = data_repr.view(-1, n_dim) data_label = data_label.view(-1) prototypes = torch.zeros((num_outputs, n_dim), device=self.device) for c in range(num_outputs): idx = torch.nonzero(data_label == c).view(-1) if idx.nelement() != 0: prototypes[c] = torch.mean(data_repr[idx], dim=0) return prototypes def output_layer(self, input, weight, bias): return F.linear(input, self.output_layer_weight + weight, self.output_layer_bias + bias)
class NearestNeighborClassifier(): def __init__(self, config): self.vectors = config.get('vectors', 'elmo') self.device = torch.device(config.get('device', 'cpu')) if self.vectors == 'elmo': self.elmo = Elmo( options_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) self.elmo.to(self.device) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') self.bert = BERTSequenceModel(config['learner_params']) self.bert.to(self.device) logger.info('Nearest neighbor classifier instantiated') def vectorize(self, batch_x, batch_len, batch_y): with torch.no_grad(): if self.vectors == 'elmo': char_ids = batch_to_ids(batch_x) char_ids = char_ids.to(self.device) batch_x = self.elmo(char_ids)['elmo_representations'][0] elif self.vectors == 'glove': max_batch_len = max(batch_len) vec_batch_x = torch.ones((len(batch_x), max_batch_len, 300)) for i, sent in enumerate(batch_x): sent_emb = self.glove.get_vecs_by_tokens( sent, lower_case_backup=True) vec_batch_x[i, :len(sent_emb)] = sent_emb batch_x = vec_batch_x.to(self.device) elif self.vectors == 'bert': max_batch_len = max(batch_len) + 2 input_ids = torch.zeros((len(batch_x), max_batch_len)).long() for i, sent in enumerate(batch_x): sent_token_ids = self.bert_tokenizer.encode( sent, add_special_tokens=True) input_ids[i, :len(sent_token_ids)] = torch.tensor( sent_token_ids) batch_x = input_ids.to(self.device) batch_x = self.bert(batch_x, max_batch_len) batch_len = torch.tensor(batch_len).to(self.device) batch_y = torch.tensor(batch_y).to(self.device) return batch_x, batch_len, batch_y def training(self, train_episodes, val_episodes): return 0 def testing(self, test_episodes, label_map): self.bert.load_state_dict( torch.load( 'MetaLearningForNER/saved_models/SupervisedLearner-stable.h5')) map_to_label = {v: k for k, v in label_map.items()} # episode_accuracies, episode_precisions, episode_recalls, episode_f1s = [], [], [], [] all_true_labels = [] all_predictions = [] for episode_id, episode in enumerate(tqdm(test_episodes)): batch_x, batch_len, batch_y = next(iter(episode.support_loader)) support_repr, _, support_labels = self.vectorize( batch_x, batch_len, batch_y) support_repr = support_repr.reshape( support_repr.shape[0] * support_repr.shape[1], -1) support_labels = support_labels.view(-1) support_repr = support_repr[support_labels != -1].cpu().numpy() support_labels = support_labels[support_labels != -1].cpu().numpy() batch_x, batch_len, batch_y = next(iter(episode.query_loader)) query_repr, _, true_labels = self.vectorize( batch_x, batch_len, batch_y) query_bs, query_seqlen = query_repr.shape[0], query_repr.shape[1] query_repr = query_repr.reshape( query_repr.shape[0] * query_repr.shape[1], -1) true_labels = true_labels.view(-1) # query_repr = query_repr[true_labels != -1].cpu().numpy() query_repr = query_repr.cpu().numpy() # true_labels = true_labels[true_labels != -1].cpu().numpy() true_labels = true_labels.cpu().numpy() dist = cdist(query_repr, support_repr, metric='cosine') nearest_neighbor = np.argmin(dist, axis=1) predictions = support_labels[nearest_neighbor] true_labels = true_labels.reshape(query_bs, query_seqlen) predictions = predictions.reshape(query_bs, query_seqlen) seq_true_labels, seq_predictions = [], [] for i in range(len(true_labels)): true_i = true_labels[i] pred_i = predictions[i] seq_predictions.append( [map_to_label[val] for val in pred_i[true_i != -1]]) seq_true_labels.append( [map_to_label[val] for val in true_i[true_i != -1]]) all_predictions.extend(seq_predictions) all_true_labels.extend(seq_true_labels) accuracy = accuracy_score(seq_true_labels, seq_predictions) precision = precision_score(seq_true_labels, seq_predictions) recall = recall_score(seq_true_labels, seq_predictions) f1 = f1_score(seq_true_labels, seq_predictions) # logger.info('Episode {}/{}, task {} [query set]: Accuracy = {:.5f}, precision = {:.5f}, ' # 'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, len(test_episodes), episode.task_id, # accuracy, precision, recall, f1)) # episode_accuracies.append(accuracy) # episode_precisions.append(precision) # episode_recalls.append(recall) # episode_f1s.append(f1_score) accuracy = accuracy_score(all_true_labels, all_predictions) precision = precision_score(all_true_labels, all_predictions) recall = recall_score(all_true_labels, all_predictions) f1 = f1_score(all_true_labels, all_predictions) logger.info( 'Avg meta-testing metrics: Accuracy = {:.5f}, precision = {:.5f}, recall = {:.5f}, ' 'F1 score = {:.5f}'.format(accuracy, precision, recall, f1)) return f1
class SeqSupervisedNetwork(nn.Module): def __init__(self, config): super(SeqSupervisedNetwork, self).__init__() self.base_path = config['base_path'] self.early_stopping = config['early_stopping'] self.lr = config.get('meta_lr', 1e-3) self.weight_decay = config.get('meta_weight_decay', 0.0) if 'seq' in config['learner_model']: self.learner = RNNSequenceModel(config['learner_params']) elif 'mlp' in config['learner_model']: self.learner = MLPModel(config['learner_params']) elif 'bert' in config['learner_model']: self.learner = BERTSequenceModel(config['learner_params']) self.dropout = nn.Dropout(config['learner_params']['dropout_ratio']) self.classifier = nn.Linear( config['learner_params']['embed_dim'], config['learner_params']['num_outputs']['ner']) self.num_outputs = config['learner_params']['num_outputs'] self.vectors = config.get('vectors', 'glove') if self.vectors == 'elmo': self.elmo = Elmo( options_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", weight_file= "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", num_output_representations=1, dropout=0, requires_grad=False) elif self.vectors == 'glove': self.glove = torchtext.vocab.GloVe(name='840B', dim=300) elif self.vectors == 'bert': self.bert_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') self.loss_fn = {} for task in config['learner_params']['num_outputs']: self.loss_fn[task] = nn.CrossEntropyLoss(ignore_index=-1) if config.get('trained_learner', False): self.learner.load_state_dict( torch.load( os.path.join(self.base_path, 'saved_models', config['trained_learner']))) self.classifier.load_state_dict( torch.load( os.path.join(self.base_path, 'saved_models', config['trained_classifier']))) logger.info('Loaded trained learner model {}'.format( config['trained_learner'])) self.device = torch.device(config.get('device', 'cpu')) self.to(self.device) if self.vectors == 'elmo': self.elmo.to(self.device) self.initialize_optimizer_scheduler() def initialize_optimizer_scheduler(self): learner_params = [ p for p in self.learner.parameters() if p.requires_grad ] learner_params += self.dropout.parameters() learner_params += self.classifier.parameters() if isinstance(self.learner, BERTSequenceModel): self.optimizer = AdamW(learner_params, lr=self.lr, weight_decay=self.weight_decay) self.lr_scheduler = get_constant_schedule_with_warmup( self.optimizer, num_warmup_steps=100) else: self.optimizer = optim.Adam(learner_params, lr=self.lr, weight_decay=self.weight_decay) self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=500, gamma=0.5) def vectorize(self, batch_x, batch_len, batch_y): with torch.no_grad(): if self.vectors == 'elmo': char_ids = batch_to_ids(batch_x) char_ids = char_ids.to(self.device) batch_x = self.elmo(char_ids)['elmo_representations'][0] elif self.vectors == 'glove': max_batch_len = max(batch_len) vec_batch_x = torch.ones((len(batch_x), max_batch_len, 300)) for i, sent in enumerate(batch_x): sent_emb = self.glove.get_vecs_by_tokens( sent, lower_case_backup=True) vec_batch_x[i, :len(sent_emb)] = sent_emb batch_x = vec_batch_x.to(self.device) elif self.vectors == 'bert': max_batch_len = max(batch_len) + 2 input_ids = torch.zeros((len(batch_x), max_batch_len)).long() for i, sent in enumerate(batch_x): sent_token_ids = self.bert_tokenizer.encode( sent, add_special_tokens=True) input_ids[i, :len(sent_token_ids)] = torch.tensor( sent_token_ids) batch_x = input_ids.to(self.device) batch_len = torch.tensor(batch_len).to(self.device) batch_y = torch.tensor(batch_y).to(self.device) return batch_x, batch_len, batch_y def forward(self, dataloader, tags=None, testing=False, writer=None): if not testing: self.train() else: self.eval() avg_loss = 0 all_predictions, all_labels = [], [] for batch_id, batch in enumerate(dataloader): batch_x, batch_len, batch_y = next(iter(batch)) batch_x, batch_len, batch_y = self.vectorize( batch_x, batch_len, batch_y) batch_x_repr = self.learner(batch_x, batch_len) output = self.dropout(batch_x_repr) output = self.classifier(output) batch_size, seq_len = output.shape[0], output.shape[1] output = output.view(batch_size * seq_len, -1) batch_y = batch_y.view(-1) loss = self.loss_fn['ner'](output, batch_y) avg_loss += loss.item() if not testing: self.optimizer.zero_grad() loss.backward(retain_graph=True) self.optimizer.step() self.lr_scheduler.step() output = output.view(batch_size, seq_len, -1) batch_y = batch_y.view(batch_size, seq_len) predictions, labels = [], [] for bid in range(batch_size): relevant_indices = torch.nonzero( batch_y[bid] != -1).view(-1).detach() predictions.append( list( make_prediction( output[bid] [relevant_indices]).detach().cpu().numpy())) # pdb.set_trace() labels.append( list( batch_y[bid][relevant_indices].detach().cpu().numpy())) accuracy, precision, recall, f1_score = utils.calculate_seqeval_metrics( predictions, labels, tags, binary=False) if writer is not None: writer.add_scalar('Loss/iter', avg_loss / (batch_id + 1), global_step=batch_id + 1) writer.add_scalar('F1/iter', f1_score, global_step=batch_id + 1) if (batch_id + 1) % 100 == 0: logger.info( 'Batch {}/{}, task {} [supervised]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, ' 'recall = {:.5f}, F1 score = {:.5f}'.format( batch_id + 1, len(dataloader), 'ner', loss.item(), accuracy, precision, recall, f1_score)) all_predictions.extend(predictions) all_labels.extend(labels) avg_loss /= len(dataloader) # Calculate metrics accuracy, precision, recall, f1_score = utils.calculate_seqeval_metrics( all_predictions, all_labels, binary=False) return avg_loss, accuracy, precision, recall, f1_score