def train_epoch(data, model, optimizer, batch_size, device):
    """ Trains a single epoch of the given model. """
    torch.cuda.empty_cache()
    model.train()
    log_timer = LogTimer(100)
    for batch_ind, sents in enumerate(batches(data, batch_size)):
        model.zero_grad()
        optimizer.zero_grad()
        torch.cuda.empty_cache()
        out, loss, y = step(model, sents, device)
        torch.cuda.empty_cache()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
        try:
            optimizer.step()
        except Exception as e:
            logging.fatal(e, exc_info=True)
        if log_timer() or batch_ind == 0:
            # Calculate perplexity.
            prob = out.exp()[
                torch.arange(0, y.data.shape[0], dtype=torch.int64), y.data]
            perplexity = 2**prob.log2().neg().mean().item()
            print("\tBatch %d, loss %.3f, perplexity %.2f", batch_ind,
                  loss.item(), perplexity)
            logging.info("\tBatch %d, loss %.3f, perplexity %.2f", batch_ind,
                         loss.item(), perplexity)
예제 #2
0
def train_epoch(data, model, optimizer, args, device):
    """ Trains a single epoch of the given model. """
    model.train()
    log_timer = LogTimer(5)
    for batch_ind, sents in enumerate(batches(data, args.batch_size)):
        model.zero_grad()
        out, loss, y = step(model, sents, device)
        loss.backward()
        optimizer.step()
        if log_timer() or batch_ind == 0:
            # Calculate perplexity.
            prob = out.exp()[
                torch.arange(0, y.data.shape[0], dtype=torch.int64), y.data]
            perplexity = 2**prob.log2().neg().mean().item()
            logging.info("\tBatch %d, loss %.3f, perplexity %.2f", batch_ind,
                         loss.item(), perplexity)
예제 #3
0
def train_epoch(data, model, optimizer, args, device, id2word2):
    """ Trains a single epoch of the given model. """
    model.train()
    log_timer = LogTimer(5)
    for batch_ind, batched_uts in enumerate(batches(data, args.batch_size)):
        model.zero_grad()        
        hidden_init = model.init_hidden(len(batched_uts)).to(device)
        
        for index in batched_uts[0]:
            print(id2word2[index.item()], end= ' ') # print an uttrance
            
        out, loss, y, h_n = step(model, hidden_init, batched_uts, device)
        loss.backward()
        optimizer.step()
        if log_timer() or batch_ind == 0:
            # Calculate perplexity.
            prob = out.exp()[
                torch.arange(0, y.data.shape[0], dtype=torch.int64), y.data]
            perplexity = 2 ** prob.log2().neg().mean().item()
            logging.info("\tBatch %d, loss %.3f, perplexity %.2f",
                         batch_ind, loss.item(), perplexity)