コード例 #1
0
ファイル: train.py プロジェクト: SaiSakethAluru/SeqGen
def train(args):
    if args.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
        print("using gpu: ", torch.cuda.get_device_name(torch.cuda.current_device()))
        
    else:
        device = torch.device('cpu')
        print('using cpu')
    
    if args.dataset_name == 'pubmed':
        LABEL_LIST = PUBMED_LABEL_LIST
    elif args.dataset_name == 'nicta':
        LABEL_LIST = NICTA_LABEL_LIST
    elif args.dataset_name == 'csabstract':
        LABEL_LIST = CSABSTRACT_LABEL_LIST

    train_x,train_labels = load_data(args.train_data, args.max_par_len,LABEL_LIST)
    dev_x,dev_labels = load_data(args.dev_data, args.max_par_len,LABEL_LIST)
    test_x,test_labels = load_data(args.test_data, args.max_par_len,LABEL_LIST)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    train_x = tokenize_and_pad(train_x,tokenizer,args.max_par_len,args.max_seq_len, LABEL_LIST)  ## N, par_len, seq_len
    dev_x = tokenize_and_pad(dev_x,tokenizer,args.max_par_len, args.max_seq_len, LABEL_LIST)
    test_x = tokenize_and_pad(test_x,tokenizer, args.max_par_len, args.max_seq_len, LABEL_LIST)

    training_params = {
        "batch_size": args.batch_size,
        "shuffle": True,
        "drop_last": False
        }
    dev_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }
    test_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }

    print('train.py train_x.shape:',train_x.shape,'train_labels.shape',train_labels.shape)
    training_generator = return_dataloader(inputs=train_x, labels=train_labels, params=training_params)
    dev_generator = return_dataloader(inputs=dev_x, labels=dev_labels, params=dev_params)
    test_generator = return_dataloader(inputs=test_x, labels=test_labels, params=test_params)   

    src_pad_idx = 0
    trg_pad_idx = 0
    model = Transformer(
        label_list=LABEL_LIST,
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        embed_size=args.embed_size,
        num_layers=args.num_layers,   ## debug
        forward_expansion=args.forward_expansion,
        heads=len(LABEL_LIST),
        dropout=0.1,
        device=device,
        max_par_len=args.max_par_len,
        max_seq_len=args.max_seq_len,
        bert_model=args.bert_model
    )
    model = model.to(device).float()
    
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    
    epoch_losses = []
    best_val_loss = float('inf')
    for epoch in range(args.num_epochs):
        model.train()
        print(f"----------------[Epoch {epoch} / {args.num_epochs}]-----------------------")

        losses = []
        for batch_idx,batch in tqdm(enumerate(training_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)

            ## For CRF
            optimizer.zero_grad()

            loss = -model(inp_data.long(),target[:,1:], training=True)       ## directly gives loss when training = True


            losses.append(loss.item())

            loss.backward()

            optimizer.step()
            
        mean_loss = sum(losses)/len(losses)

        print(f"Mean loss for epoch {epoch} is {mean_loss}")
        # Validation
        model.eval()
        val_targets = []
        val_preds = []
        for batch_idx,batch in tqdm(enumerate(dev_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)
            with torch.no_grad():
                output = model(inp_data,target[:,:-1], training=False)      ## directly we get the labels here, instead of logits

            flattened_target = target[:,1:].to('cpu').flatten()
            output = convert_crf_output_to_tensor(output,args.max_par_len)
            flattened_preds = output.to('cpu').flatten()
            for target_i,pred_i in zip(flattened_target,flattened_preds):
                if target_i != 0:
                    val_targets.append(target_i)
                    val_preds.append(pred_i)

        f1 = f1_score(val_targets,val_preds,average='micro')
        
        print(f'------Micro F1 score on dev set: {f1}------')

        if loss < best_val_loss:
            print(f"val loss less than previous best val loss of {best_val_loss}")
            best_val_loss = loss
            if args.save_model:
                dir_name = f"seed_{args.seed}_parlen_{args.max_par_len}_seqlen_{args.max_seq_len}_lr_{args.lr}.pt"
                output_path = os.path.join(args.save_path,dir_name)
                if not os.path.exists(args.save_path):
                    os.makedirs(args.save_path)
                print(f"Saving model to path {output_path}")
                torch.save(model,output_path)

        # Testing
        if epoch % args.test_interval == 0:
            model.eval()
            test_targets = []
            test_preds = []
            for batch_idx, batch in tqdm(enumerate(test_generator)):
                inp_data,target = batch
                inp_data = inp_data.to(device)
                target = target.to(device)
                with torch.no_grad():
                    output = model(inp_data,target[:,:-1],training=False)
                    
                flattened_target = target[:,1:].to('cpu').flatten()
                output = convert_crf_output_to_tensor(output,args.max_par_len)
                flattened_preds = output.to('cpu').flatten()
                for target_i,pred_i in zip(flattened_target,flattened_preds):
                    if target_i!=0:
                        test_targets.append(target_i)
                        test_preds.append(pred_i)
            
            f1 = f1_score(test_targets,test_preds,average='micro')
            print(f"------Micro F1 score on test set: {f1}------")
コード例 #2
0
ファイル: train.py プロジェクト: SaiSakethAluru/BERT-SeqGen
def train(args):
    if args.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
        print("using gpu: ", torch.cuda.get_device_name(torch.cuda.current_device()))
        
    else:
        device = torch.device('cpu')
        print('using cpu')
    
    if args.dataset_name == 'pubmed':
        LABEL_LIST = PUBMED_LABEL_LIST
    elif args.dataset_name == 'nicta':
        LABEL_LIST = NICTA_LABEL_LIST
    elif args.dataset_name == 'csabstract':
        LABEL_LIST = CSABSTRACT_LABEL_LIST

    train_x,train_labels = load_data(args.train_data, args.max_par_len,LABEL_LIST)
    dev_x,dev_labels = load_data(args.dev_data, args.max_par_len,LABEL_LIST)
    test_x,test_labels = load_data(args.test_data, args.max_par_len,LABEL_LIST)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    train_x = tokenize_and_pad(train_x,tokenizer,args.max_par_len,args.max_seq_len, LABEL_LIST)  ## N, par_len, seq_len
    dev_x = tokenize_and_pad(dev_x,tokenizer,args.max_par_len, args.max_seq_len, LABEL_LIST)
    test_x = tokenize_and_pad(test_x,tokenizer, args.max_par_len, args.max_seq_len, LABEL_LIST)

    # print('train_x[0]',train_x[0])
    # print('train_x[0].shape',train_x[0].shape)
    # quit()
    training_params = {
        "batch_size": args.batch_size,
        "shuffle": True,
        "drop_last": False
        }
    dev_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }
    test_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }

    print('train.py train_x.shape:',train_x.shape,'train_labels.shape',train_labels.shape)
    training_generator = return_dataloader(inputs=train_x, labels=train_labels, params=training_params)
    dev_generator = return_dataloader(inputs=dev_x, labels=dev_labels, params=dev_params)
    test_generator = return_dataloader(inputs=test_x, labels=test_labels, params=test_params)   

    src_pad_idx = 0
    trg_pad_idx = 0
    model = Transformer(
        label_list=LABEL_LIST,
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        embed_size=args.embed_size,
        num_layers=args.num_layers,   ## debug
        forward_expansion=args.forward_expansion,
        heads=len(LABEL_LIST),
        dropout=0.1,
        device=device,
        max_par_len=args.max_par_len,
        max_seq_len=args.max_seq_len,
        bert_model=args.bert_model
    )
    model = model.to(device).float()
    # for param in model.parameters():
    #     try:
    #         torch.nn.init.xavier_uniform_(param)
    #     except:
    #         continue
    
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, factor=0.1, patience=10, verbose=True
    # )
    
    epoch_losses = []
    best_val_loss = float('inf')
    for epoch in range(args.num_epochs):
        model.train()
        print(f"----------------[Epoch {epoch} / {args.num_epochs}]-----------------------")

        losses = []
        for batch_idx,batch in tqdm(enumerate(training_generator)):
            # print('batch',batch)
            # print('type of batch',type(batch))
            inp_data,target = batch
            # print('inp_data',inp_data)
            # print('type(inp_data)',type(inp_data))
            # print('target',target)
            # print('type(target)',type(target))
            # print('target.shape',target.shape)
            inp_data = inp_data.to(device)
            # print('inp_data.shape',inp_data.shape)
            target = target.to(device)
            # assert False

            ## For generation
            # output = model(inp_data.long(),target[:,:-1], training=True)       ## N,par_len, label_size
            
            ## For CRF
            optimizer.zero_grad()

            # output = model(inp_data.long(),target[:,1:], training=True)       ## N,par_len, label_size
            loss = -model(inp_data.long(),target[:,1:], training=True)       ## directly gives loss when training = True


            # output = model(inp_data,target[:,:-1])

            # print('model net',make_dot(output))
            # print(make_dot(output))
            # make_arch = make_dot(output)
            # Source(make_arch).render('graph.png')
            # assert False
            ## output - N,par_len, num_labels --> N*par_len, num_labels
            # output = output.reshape(-1,output.shape[2])
            ## target -
            # target = target[:,1:].reshape(-1)

            # print('output.shape',output.shape)
            # print('target.shape',target.shape)
            # print(f'{epoch} model params', list(model.parameters())[-1])
            # print('len params',len(list(model.parameters())))
            # print('trainable params: ',len(list(filter(lambda p: p.requires_grad, model.parameters()))))

            # loss = criterion(output,target)
            # loss.retain_grad()
            losses.append(loss.item())

            # print(f'{epoch} loss grads before', list(loss.grad)[-1])
            loss.backward()
            # print(f'{epoch} loss grads after', loss.grad)
            # print('model params')
            # count = 0
            # for p in model.parameters():
            #     if p.grad is not None:
            #         print(p.grad,p.grad.norm())
            #         count +=1 
            # print(f'non none grads are {count}')
            # torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1)

            optimizer.step()
            # break #NOTE: break is there only for quick checking. Remove this for actual training.
            
        mean_loss = sum(losses)/len(losses)
        # scheduler.step(mean_loss)

        print(f"Mean loss for epoch {epoch} is {mean_loss}")
        # Validation
        model.eval()
        # val_losses = []
        val_targets = []
        val_preds = []
        for batch_idx,batch in tqdm(enumerate(dev_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)
            with torch.no_grad():
                output = model(inp_data,target[:,:-1], training=False)      ## directly we get the labels here, instead of logits
                # reshaped_output = output.reshape(-1,output.shape[2])
                # reshaped_target = target[:,1:].reshape(-1)
                # loss = criterion(reshaped_output,reshaped_target).item()

            # val_losses.append(loss)
            flattened_target = target[:,1:].to('cpu').flatten()
            # print(output)
            output = convert_crf_output_to_tensor(output,args.max_par_len)
            # flattened_preds = torch.softmax(output,dim=-1).argmax(dim=-1).to('cpu').flatten()
            flattened_preds = output.to('cpu').flatten()
            for target_i,pred_i in zip(flattened_target,flattened_preds):
                if target_i != 0:
                    val_targets.append(target_i)
                    val_preds.append(pred_i)
            # val_targets.append(target[:,1:].to('cpu').flatten())
            # output = torch.softmax(output,dim=-1).argmax(dim=-1)
            # val_preds.append(output.to('cpu').flatten())
            # break #NOTE: break is there only for quick checking. Remove this for actual training.

        # loss = sum(val_losses) / len(val_losses)
        # print(f"Validation loss at epoch {epoch} is {loss}")
        # val_targets = torch.cat(val_targets,dim=0)
        # val_preds = torch.cat(val_preds,dim=0)
        f1 = f1_score(val_targets,val_preds,average='micro')
        
        print(f'------Micro F1 score on dev set: {f1}------')

        # if loss < best_val_loss:
        #     print(f"val loss less than previous best val loss of {best_val_loss}")
        #     best_val_loss = loss
        #     if args.save_model:
        #         dir_name = f"seed_{args.seed}_parlen_{args.max_par_len}_seqlen_{args.max_seq_len}_lr_{args.lr}.pt"
        #         output_path = os.path.join(args.save_path,dir_name)
        #         if not os.path.exists(args.save_path):
        #             os.makedirs(args.save_path)
        #         print(f"Saving model to path {output_path}")
        #         torch.save(model,output_path)

        # Testing
        if epoch % args.test_interval == 0:
            model.eval()
            test_targets = []
            test_preds = []
            for batch_idx, batch in tqdm(enumerate(test_generator)):
                inp_data,target = batch
                inp_data = inp_data.to(device)
                target = target.to(device)
                with torch.no_grad():
                    output = model(inp_data,target[:,:-1],training=False)
                    
                # output = torch.softmax(output,dim=-1).argmax(dim=-1)
                flattened_target = target[:,1:].to('cpu').flatten()
                output = convert_crf_output_to_tensor(output,args.max_par_len)
                flattened_preds = output.to('cpu').flatten()
                for target_i,pred_i in zip(flattened_target,flattened_preds):
                    if target_i!=0:
                        test_targets.append(target_i)
                        test_preds.append(pred_i)
                # test_targets.append(target[:,1:].to('cpu').flatten())
                # test_preds.append(output.to('cpu').flatten())
                # break  #NOTE: break is there only for quick checking. Remove this for actual training. 
            
            # test_targets = torch.cat(test_targets,dim=0)
            # test_preds = torch.cat(test_preds,dim=0)
            # f1 = f1_score(target[:,1:].to('cpu').flatten(),output.to('cpu').flatten(),average='macro')
            f1 = f1_score(test_targets,test_preds,average='micro')
            print(f"------Micro F1 score on test set: {f1}------")

    ## Uncomment for generating attention vectors. 
    # Look into src/word_level_labelatt.py for details of computing and storing these attention scores
    # Look into src/selfatt.py for sentence level attention scores
    att_x = train_x[:10,:,:].to(device)
    att_y = train_labels[:10,:].to(device)[:,:-1] 
    model(att_x,att_y,training=False,att_heat_map=True)    
コード例 #3
0
    model.to(device)

    optimizer = configure_optimizer(model.named_parameters(), args.lr)
    steps_per_epoch = len(train_dataset) / (args.batch_size *
                                            args.accumulation)
    scheduler = configure_scheduler(
        optimizer,
        training_steps=(args.epochs * steps_per_epoch),
        warmup=args.warmup * steps_per_epoch,
    )
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    best_val_loss = 1000
    generated_molecules = defaultdict()
    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0
        for batch in tqdm(train_loader, total=len(train_loader), ncols=80):
            optimizer.zero_grad()

            smiles = batch["smiles"].to(device)
            log_p = batch["logP"].to(device)
            mask = batch["mask"].to(device)
            target_sequence_length = batch["seq_len"].to(device)

            output = model(
                output_ids=smiles,
                target_sequence_length=target_sequence_length,
                log_p=log_p,
            )
            loss = torch.mean(
コード例 #4
0
class ModelOperator:
    def __init__(self, args):

        # set up output directory
        self.output_dir = os.path.join(args.experiment_dir, args.run_name)
        if not os.path.exists(args.experiment_dir):
            os.mkdir(args.experiment_dir)
        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
        if not os.path.exists(os.path.join(args.experiment_dir,"runs/")):
            os.mkdir(os.path.join(args.experiment_dir,"runs/"))

        # initialize tensorboard writer
        self.runs_dir = os.path.join(args.experiment_dir,"runs/",args.run_name)
        self.writer = SummaryWriter(self.runs_dir)

        # initialize global steps
        self.train_gs = 0
        self.val_gs = 0

        # initialize model config
        self.config = ModelConfig(args)

        # check if there is a model to load
        if args.old_model_dir is not None:
            self.use_old_model = True
            self.load_dir = args.old_model_dir
            self.config.load_from_file(
                os.path.join(self.load_dir, "config.json"))

            # create vocab
            self.vocab = Vocab()
            self.vocab.load_from_dict(os.path.join(self.load_dir, "vocab.json"))
            self.update_vocab = False
            self.config.min_count=1
        else:
            self.use_old_model = False

            self.vocab = None
            self.update_vocab = True

        # create data sets
        self.dataset_filename = args.dataset_filename

        # train
        self.train_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "train.csv"),
            self.config.history_len,
            self.config.response_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_train = torch.utils.data.DataLoader(
            self.train_dataset, self.config.train_batch_size, shuffle=True)
        self.config.train_len = len(self.train_dataset)

        self.vocab = self.train_dataset.vocab

        # eval
        self.val_dataset = DialogueDataset(
            os.path.join(self.dataset_filename, "val.csv"),
            self.config.history_len,
            self.config.response_len,
            self.vocab,
            self.update_vocab)
        self.data_loader_val = torch.utils.data.DataLoader(
            self.val_dataset, self.config.val_batch_size, shuffle=True)
        self.config.val_len = len(self.val_dataset)

        # update, and save vocab
        self.vocab = self.val_dataset.vocab
        self.train_dataset.vocab = self.vocab
        if (self.config.min_count > 1):
            self.config.old_vocab_size = len(self.vocab)
            self.vocab.prune_vocab(self.config.min_count)
        self.vocab.save_to_dict(os.path.join(self.output_dir, "vocab.json"))
        self.vocab_size = len(self.vocab)
        self.config.vocab_size = self.vocab_size

        # print and save the config file
        self.config.print_config(self.writer)
        self.config.save_config(os.path.join(self.output_dir, "config.json"))

        # set device
        self.device = torch.device('cuda')

        # create model
        self.model = Transformer(
            self.config.vocab_size,
            self.config.vocab_size,
            self.config.history_len,
            self.config.response_len,
            d_word_vec=self.config.embedding_dim,
            d_model=self.config.model_dim,
            d_inner=self.config.inner_dim,
            n_layers=self.config.num_layers,
            n_head=self.config.num_heads,
            d_k=self.config.dim_k,
            d_v=self.config.dim_v,
            dropout=self.config.dropout
        ).to(self.device)

        # create optimizer
        self.optimizer = torch.optim.Adam(
            filter(lambda x: x.requires_grad, self.model.parameters()),
            betas=(0.9, 0.98), eps=1e-09)

        # load old model, optimizer if there is one
        if self.use_old_model:
            self.model, self.optimizer = load_checkpoint(
                os.path.join(self.load_dir, "model.bin"),
                self.model, self.optimizer, self.device)


        # create a sceduled optimizer object
        self.optimizer = ScheduledOptim(
            self.optimizer, self.config.model_dim, self.config.warmup_steps)

        #self.optimizer.optimizer.to(torch.device('cpu'))


    def train(self, num_epochs):
        metrics = {"best_epoch":0, "lowest_loss":99999999999999}

        # output an example
        #self.output_example(0)

        for epoch in range(num_epochs):
           # self.writer.add_graph(self.model)
            #self.writer.add_embedding(
            #    self.model.encoder.src_word_emb.weight, global_step=epoch)

            epoch_metrics = dict()

            # train
            epoch_metrics["train"] = self.execute_phase(epoch, "train")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # validate
            epoch_metrics["val"] = self.execute_phase(epoch, "val")
            # save metrics
            metrics["epoch_{}".format(epoch)] = epoch_metrics
            with open(os.path.join(self.output_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)

            # save checkpoint
            #TODO: fix this b
            #if epoch_metrics["val"]["loss"] < metrics["lowest_loss"]:
            #if epoch_metrics["train"]["loss"] < metrics["lowest_loss"]:
            if epoch % 100 == 0:
                self.save_checkpoint(os.path.join(self.output_dir, "model_{}.bin".format(epoch)))
                #metrics["lowest_loss"] = epoch_metrics["train"]["loss"]
                #metrics["best_epoch"] = epoch

            # record metrics to tensorboard
            self.writer.add_scalar("training loss total",
                epoch_metrics["train"]["loss"], global_step=epoch)
            self.writer.add_scalar("val loss total",
                epoch_metrics["val"]["loss"], global_step=epoch)

            self.writer.add_scalar("training perplexity",
                epoch_metrics["train"]["perplexity"], global_step=epoch)
            self.writer.add_scalar("val perplexity",
                epoch_metrics["val"]["perplexity"], global_step=epoch)

            self.writer.add_scalar("training time",
                epoch_metrics["train"]["time_taken"], global_step=epoch)
            self.writer.add_scalar("val time",
                epoch_metrics["val"]["time_taken"], global_step=epoch)

            self.writer.add_scalar("train_bleu_1",
                epoch_metrics["train"]["bleu_1"], global_step=epoch)
            self.writer.add_scalar("val_bleu_1",
                epoch_metrics["val"]["bleu_1"], global_step=epoch)
            self.writer.add_scalar("train_bleu_2",
                epoch_metrics["train"]["bleu_2"], global_step=epoch)
            self.writer.add_scalar("val_bleu_2",
                epoch_metrics["val"]["bleu_2"], global_step=epoch)

            # output an example
            #self.output_example(epoch+1)

        self.writer.close()

    def execute_phase(self, epoch, phase):
        if phase == "train":
            self.model.train()
            dataloader = self.data_loader_train
            batch_size = self.config.train_batch_size
            train = True
        else:
            self.model.eval()
            dataloader = self.data_loader_val
            batch_size = self.config.val_batch_size
            train = False

        start = time.clock()
        phase_metrics = dict()
        epoch_loss = list()
        epoch_bleu_1 = list()
        epoch_bleu_2 = list()
        average_epoch_loss = None
        n_word_total = 0
        n_correct = 0
        n_word_correct = 0
        for i, batch in enumerate(tqdm(dataloader,
                          mininterval=2, desc=phase, leave=False)):
            # prepare data
            src_seq, src_pos, src_seg, tgt_seq, tgt_pos = map(
                lambda x: x.to(self.device), batch)

            gold = tgt_seq[:, 1:]

            # forward
            if train:
                self.optimizer.zero_grad()
            pred = self.model(src_seq, src_pos, src_seg, tgt_seq, tgt_pos)

            # get loss
            loss, n_correct = cal_performance(pred, gold,
                smoothing=self.config.label_smoothing)
            #average_loss = float(loss)/self.config.val_batch_size
            average_loss = float(loss)
            epoch_loss.append(average_loss)
            average_epoch_loss = np.mean(epoch_loss)

            if train:
                self.writer.add_scalar("train_loss",
                    average_loss, global_step=i + epoch * self.config.train_batch_size)
                # backward
                loss.backward()

                # update parameters
                self.optimizer.step_and_update_lr()

            # get_bleu
            output = torch.argmax(pred.view(-1, self.config.response_len-1, self.vocab_size), dim=2)
            epoch_bleu_1.append(bleu(gold, output, 1))
            epoch_bleu_2.append(bleu(gold, output, 2))

            # get_accuracy
            non_pad_mask = gold.ne(src.transformer.Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct


        phase_metrics["loss"] = average_epoch_loss
        phase_metrics["token_accuracy"] = n_correct / n_word_total

        perplexity = np.exp(average_epoch_loss)
        phase_metrics["perplexity"] = perplexity

        phase_metrics["bleu_1"] = np.mean(epoch_bleu_1)
        phase_metrics["bleu_2"] = np.mean(epoch_bleu_2)

        phase_metrics["time_taken"] = time.clock() - start
        string = ' {} loss: {:.3f} '.format(phase, average_epoch_loss)
        print(string, end='\n')
        return phase_metrics

    def save_checkpoint(self, filename):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.optimizer.state_dict()
        }
        torch.save(state, filename)

    def output_example(self, epoch):
        random_index = random.randint(0, len(self.val_dataset))
        example = self.val_dataset[random_index]

        # prepare data
        src_seq, src_pos, src_seg, tgt_seq, tgt_pos = map(
            lambda x: torch.from_numpy(x).to(self.device).unsqueeze(0), example)

        # take out first token from target for some reason
        gold = tgt_seq[:, 1:]

        # forward
        pred = self.model(src_seq, src_pos, src_seg, tgt_seq, tgt_pos)
        output = torch.argmax(pred, dim=1)

        # get history text
        string = "history: "

        seg = -1
        for i, idx in enumerate(src_seg.squeeze()):
            if seg != idx.item():
                string+="\n"
                seg=idx.item()
            token = self.vocab.id2token[src_seq.squeeze()[i].item()]
            if token != '<blank>':
                string += "{} ".format(token)

        # get target text
        string += "\nTarget:\n"

        for idx in tgt_seq.squeeze():
            token = self.vocab.id2token[idx.item()]
            string += "{} ".format(token)

        # get prediction
        string += "\n\nPrediction:\n"

        for idx in output:
            token = self.vocab.id2token[idx.item()]
            string += "{} ".format(token)

        # print
        print("\n------------------------\n")
        print(string)
        print("\n------------------------\n")

        # add result to tensorboard
        self.writer.add_text("example_output", string, global_step=epoch)
        self.writer.add_histogram("example_vocab_ranking", pred, global_step=epoch)
        self.writer.add_histogram("example_vocab_choice", output,global_step=epoch)