Beispiel #1
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    print(data.size(0), bsz)
    print("nbatch", nbatch)
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    if args.cuda:
        data = data.cuda()
    return data
Beispiel #2
def train(data_source):
    # Turn on training mode which enables dropout.
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    hidden_lang = model.init_hidden(args.batch_size)
    criterion = nn.CrossEntropyLoss()

    batch_idx = 0

    num_batch = math.ceil(data_source.size(0) / args.bptt)
    # print(num_batch, data_source.size(0), args.bptt)
    indices = np.arange(num_batch)

    for batch, i in enumerate(range(0, data_source.size(0) - 1, args.bptt)):
        # print(">>", batch, i,  indices[batch] * args.bptt)
        data, targets = get_batch(data_source, indices[batch] * args.bptt)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        hidden_lang = repackage_hidden(hidden_lang)

        output, hidden = model(data, hidden)

        loss = criterion(output.view(-1, ntokens), targets)
        batch_idx += data.size(1)

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        opt = optim.SGD(model.parameters(), lr=lr)

        total_loss +=

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time

            log = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | word_loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch,
                len(data_source) // args.bptt,
                lr, elapsed * 1000 / args.log_interval, cur_loss,

            printhelper.print_log(log_file, log)

            total_loss = 0
            start_time = time.time()