def eval(): ''' Performs test evaluation on the model. ''' ## Read terminal arguments model_name = FLAGS.model_name checkpoint_path = FLAGS.checkpoint_path output_dir = FLAGS.output_dir data_path = FLAGS.data_path embedding_path = FLAGS.embedding_path assert checkpoint_path is not None, "checkpoint_path is a required argument" assert os.path.isfile(checkpoint_path), "Checkpoint does not exist" assert os.path.isfile(embedding_path), "Embedding does not exist" assert model_name in MODEL_NAMES, "Model name is unknown" # Further process terminal arguments os.makedirs(output_dir, exist_ok=True) # create output directory # Obtain GloVe word embeddings print("Loading GloVe embedding from "+embedding_path) glove_emb = data_utils.EmbeddingGlove(embedding_path) # Build vocabulary vocab = data_utils.Vocabulary() vocab.count_glove(glove_emb) vocab.build() # Obtain SNLI train and dev dataset dataset = {} dataloader = {} dataset = data_utils.DatasetSnli(data_path) dataloader = data_utils.DataLoaderSnli(dataset, vocab) # Load network device_name = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device(device_name) print("Device: "+device_name) if model_name == 'baseline': net = BaselineNet(glove_emb.embedding).to(device) # Load checkpoint print("Initialising model from "+checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device) net.load_state_dict(state_dict) print("Network architecture:\n\t{}".format(str(net))) # Evaluate SNLI per class prem, hyp, label = dataloader.next_batch(len(dataset)) prem = prem.to(device) hyp = hyp.to(device) label = label.to(device) prediction = net.forward(prem, hyp) accuracy_macro, accuracy_micro = accuracy(prediction, label) print("Macro accuracy:\t{}\nMicro accuracy:\t{}".format(accuracy_macro,accuracy_micro))
def infer(): ''' Performs training and evaluation of the model. ''' ## Read terminal arguments model_name = FLAGS.model_name checkpoint_path = FLAGS.checkpoint_path embedding_path = FLAGS.embedding_path assert checkpoint_path is not None, "checkpoint_path is a required argument" assert os.path.isfile(checkpoint_path), "Checkpoint does not exist" assert os.path.isfile(embedding_path), "Embedding does not exist" assert model_name in MODEL_NAMES, "Model name is unknown" # Obtain GloVe word embeddings print("Loading GloVe embedding from " + embedding_path) glove_emb = data_utils.EmbeddingGlove(embedding_path) # Build vocabulary vocab = data_utils.Vocabulary() vocab.count_glove(glove_emb) vocab.build() # Empty dataloader for converting sentence pair to batch dataloader = data_utils.DataLoaderSnli([], vocab) # Load network device_name = 'cpu' device = torch.device(device_name) print("Device: " + device_name) if model_name == 'baseline': net = BaselineNet(glove_emb.embedding).to(device) # Load checkpoint print("Initialising model from " + checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device) net.load_state_dict(state_dict) print("Network architecture:\n\t{}".format(str(net))) # Interactive interface CLASS_TEXT = ['neutral', 'contradiction', 'entailment'] while True: # Obtain input premise_sent = input("Enter the premise:\n\t") hypothesis_sent = input("Enter the hypothesis:\n\t") # Predict prem, hyp, _ = dataloader.prepare_manual(premise_sent, hypothesis_sent) prediction = net.forward(prem, hyp)[0] pred_class = prediction.argmax() pred_prob = prediction[pred_class].exp() / prediction.exp().sum( ) * 100 # softmax probability # Print print("Inference: {}\nScore: {:.2f}%\n".format(CLASS_TEXT[pred_class], pred_prob))
def train(): ''' Performs training and evaluation of the model. ''' start_time = datetime.datetime.now() ## Read terminal arguments model_name = FLAGS.model_name activate_board = FLAGS.activate_board checkpoint_path = FLAGS.checkpoint_path data_train_path = FLAGS.data_train_path data_dev_path = FLAGS.data_dev_path embedding_path = FLAGS.embedding_path # Hyperparameters batch_size = FLAGS.batch_size max_steps = FLAGS.max_steps learning_rate = FLAGS.learning_rate o_dir = FLAGS.output_dir # Further process terminal arguments if model_name not in MODEL_NAMES: raise NotImplementedError now = datetime.datetime.now() time_stamp = "{:02g}{:02g}{:02g}{:02g}".format(now.day, now.hour, now.minute, now.second) output_dir = os.path.join(o_dir,model_name,time_stamp) tensorboard_dir = os.path.join(output_dir, 'tensorboard') checkpoint_dir = os.path.join(output_dir, 'checkpoints') os.makedirs(tensorboard_dir, exist_ok=True) # create output and tensorboard directory os.makedirs(checkpoint_dir, exist_ok=True) # create checkpoint directory if checkpoint_path is not None: if not os.path.isfile(checkpoint_path): print("Could not find checkpoint: "+checkpoint_path) return # Standard hyperparams weight_decay = .01 eval_freq = 50 check_freq = 1000 # Obtain GloVe word embeddings print("Loading GloVe embedding from "+embedding_path) glove_emb = data_utils.EmbeddingGlove(embedding_path) # Build vocabulary vocab = data_utils.Vocabulary() vocab.count_glove(glove_emb) vocab.build() # Obtain SNLI train and dev dataset dataset = {} dataloader = {} for set_name,set_path in [('train',data_train_path),('dev',data_dev_path)]: print("Loading {} data".format(set_name)) dataset[set_name] = data_utils.DatasetSnli(set_path) dataloader[set_name] = data_utils.DataLoaderSnli(dataset[set_name], vocab) # Initialise network device_name = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device(device_name) print("Device: "+device_name) if model_name == 'baseline': net = BaselineNet(glove_emb.embedding).to(device) elif model_name == 'unilstm': net = UniLstmNet(glove_emb.embedding).to(device) # Load checkpoint if checkpoint_path is not None: print("Initialising model from "+checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device) net.load_state_dict(state_dict) loss_fn = F.cross_entropy print("Network architecture:\n\t{}\nLoss module:\n\t{}".format(str(net), str(loss_fn))) # Evaluation vars writer = SummaryWriter(log_dir=tensorboard_dir) if activate_board: call('gnome-terminal -- tensorboard --logdir '+tensorboard_dir, shell=True) # start tensorboard iteration = 0 # Training optimizer = optim.SGD(net.trainable_params(), lr=learning_rate, weight_decay=weight_decay) last_dev_acc = 0 current_dev_accs = [] epoch = 0 train_acc = 0 train_loss = 0 gradient_norm = 0 while True: # Stopping criterion iteration += 1 # Max iterations if max_steps is not None: if iteration > max_steps: print("Training stopped: maximum number of iterations reached") break # Adapt learning rate; early stopping if dataloader['train']._epochs_completed > epoch: epoch = dataloader['train']._epochs_completed print("Epoch {}".format(epoch)) if current_dev_accs == []: current_dev_accs = [0] current_dev_acc = np.mean(current_dev_accs) if current_dev_acc < last_dev_acc: learning_rate /= 5 if learning_rate < 1e-5: print("Training stopped: learning rate dropped below 1e-5") break for g in optimizer.param_groups: g['lr'] = learning_rate print("Learning rate dropped to {}".format(learning_rate)) writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('epoch_dev_acc', current_dev_acc, epoch) last_dev_acc = current_dev_acc current_dev_accs = [] # Sample a mini-batch prem, hyp, label = dataloader['train'].next_batch(batch_size) prem = prem.to(device) hyp = hyp.to(device) label = label.to(device) # Forward propagation prediction = net.forward(prem, hyp) loss = loss_fn(prediction, label) acc = accuracy(prediction, label) train_acc += acc.tolist() / eval_freq train_loss += loss.tolist() / eval_freq # Backprop optimizer.zero_grad() loss.backward() # Weight update in linear modules optimizer.step() with torch.no_grad(): norm = 0 for params in net.sequential.parameters(): norm += params.grad.reshape(-1).pow(2).sum() gradient_norm += norm.reshape(-1).tolist()[0] / eval_freq # Evaluation if iteration % eval_freq == 0 or iteration == max_steps: prem, hyp, label = dataloader['dev'].next_batch(len(dataset['dev'])) prem = prem.to(device) hyp = hyp.to(device) label = label.to(device) prediction = net.forward(prem,hyp) acc = accuracy(prediction, label) test_acc = acc.tolist() current_dev_accs.append(test_acc) writer.add_scalars('accuracy', {'dev': test_acc}, iteration) writer.add_scalars('accuracy', {'train': train_acc}, iteration) writer.add_scalar('train_loss', train_loss, iteration) writer.add_scalar('gradient_norm', gradient_norm, iteration) print("Iteration: {}\t\tTest accuracy: {}\t\tTrain accuracy: {}".format(iteration, test_acc, train_acc)) train_acc = 0 train_loss = 0 gradient_norm = 0 # Checkpoint if iteration % check_freq == 0 or iteration == max_steps: print("Saving checkpoint") torch.save(net.state_dict(), os.path.join(checkpoint_dir, "model_iter_"+str(iteration)+".pt")) writer.close() end_time = datetime.datetime.now() print("Done. Start and End time:\n\t{}\n\t{}".format(start_time, end_time))
embeddings = torch.Tensor(net.encode(batch)).to(device) return embeddings # Set params for SentEval params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': False, 'kfold': 10, 'batch_size': 512} params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4} # Set up logger logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) if __name__ == "__main__": # Obtain GloVe word embeddings print("Loading GloVe embedding from "+embedding_path) glove_emb = data_utils.EmbeddingGlove(embedding_path) # Build vocabulary vocab = data_utils.Vocabulary() vocab.count_glove(glove_emb) vocab.build() # Empty dataloader for converting sentence list to batch dataloader = data_utils.DataLoaderSnli([], vocab) # Load network if model_name == 'baseline': net = BaselineNet(glove_emb.embedding).to(device) # Load checkpoint print("Initialising model from "+checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device)