예제 #1
0
    # global_step = checkpoint['global_step']
    global_step = 0

    for epoch in range(10):
        pbar = tqdm(TextDataLoaderIterator(txt_files, batch_size=16, block_len=64))
        for data_loader in pbar:
            for seq, mask in data_loader:
                seq, mask = seq.to(device), mask.to(device)

                output, *_ = model(seq.masked_fill(mask==0, 0), src_mask=mask)
                # loss = criterion(output[mask==0], seq[mask==0])
                loss = criterion(output.view(-1, 10000), seq.view(-1))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                lr_scheduler.step(loss)
                global_step += 1

                writer.add_scalar('loss', loss.item(), global_step)
                writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step)
                pbar.set_postfix({'loss': loss.item(), 'lr': optimizer.param_groups[0]['lr']})

        torch.save({
            'global_step': global_step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler_state_dict': lr_scheduler.state_dict(),
            }, 'models/lm/latest.pth')
예제 #2
0
    model.eval()

    err = AverageMeter('loss')

    loader = DataLoader(test,
                        pin_memory=True,
                        num_workers=4,
                        batch_size=bptt,
                        drop_last=True)
    progress = tqdm(loader)

    hidden = model.step_init(batch_size)

    with torch.no_grad():
        for inputs, targets in progress:
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            output, hidden = model.step_forward(inputs, hidden)

            loss = criterion(output.view(-1, num_labels), targets.view(-1))

            err.update(loss.item())

            progress.set_description('epoch %d %s' % (epoch + 1, err))

    sys.stderr.write('\n')

    torch.save(model.state_dict(), 'exp/lm.bin')
예제 #3
0
파일: train.py 프로젝트: mzgubic/autothesis
def train(settings, model_dir):

    # training and sampling
    temperature = 0.5
    how_many = 70
    vocab = generate.get_vocab(args.token, small=args.small)

    # create the vocab, model, (and embedding)
    if args.token == 'word':
        emb = generate.get_embedding('word2vec')
        input_size = emb.vectors.shape[1]
        output_size = emb.vectors.shape[0]
    elif args.token == 'character':
        emb = None
        input_size = vocab.size
        output_size = vocab.size

    model = LanguageModel(args.cell, input_size, args.hidden_size, output_size)

    # create criterion and optimiser
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    # create the validation set
    n_valid = 10000
    valid_gen = generate.generate('valid', token=args.token, max_len=args.max_len, small=args.small, batch_size=n_valid)
    for valid_batch, valid_labels in valid_gen:
        # one hot encode
        if args.token == 'character':
            valid_batch = generate.one_hot_encode(valid_batch, vocab)
        # or embed
        elif args.token == 'word':
            valid_batch = generate.w2v_encode(valid_batch, emb, vocab)
        valid_batch, valid_labels = torch.Tensor(valid_batch), torch.Tensor(valid_labels).long()
        break

    # how many epochs do we need?
    batches_per_epoch = generate.get_n_batches_in_epoch('train', args.token, args.batch_size, args.max_len, args.small)

    # training settings
    every_n = int(batches_per_epoch/args.n_saves) if not args.debug else 50
    running_loss = 0
    training_losses = []
    valid_losses = []
    t0 = time.time()
 
    # dump the settings
    pickle.dump(settings, open(model_dir/ 'settings.pkl', 'wb'))
    out_stream = model_dir / 'out_stream.txt'

    # run the training loop
    for epoch in range(1, args.n_epochs+1):

        opening = ['', '#'*20, '# Epoch {} (t={:2.2f}h)'.format(epoch, (time.time() - t0)/3600.), '#'*20, '']
        for txt in opening:
            utils.report(txt, out_stream)

        # create the generator for each epoch
        train_gen = generate.generate('train', token=args.token, max_len=args.max_len,
                                      small=args.small, batch_size=args.batch_size)
        for i, (batch, labels) in enumerate(train_gen):

            # one hot encode
            if args.token == 'character':
                batch = generate.one_hot_encode(batch, vocab)
            # or embed
            elif args.token == 'word':
                batch = generate.w2v_encode(batch, emb, vocab)

            # turn into torch tensors
            batch = torch.Tensor(batch)
            labels = torch.Tensor(labels).long()

            # zero the gradients
            optimizer.zero_grad()

            # forward and backward pass and optimisation step
            outputs = model(batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # monitor the losses
            running_loss += loss
            if i % every_n == (every_n-1):

                # append the training losses
                training_losses.append(float(running_loss/every_n))
                running_loss = 0

                # compute the valid loss
                valid_outputs = model(valid_batch)
                valid_losses.append(float(criterion(valid_outputs, valid_labels)))

                # monitor progress
                monitor = ['\n{}/{} done'.format(i+1, batches_per_epoch)]
                monitor.append(generate.compose(model, vocab, emb, 'The Standard Model of', temperature, how_many))
                for m in monitor:
                    utils.report(m, out_stream)
                
                # save the model
                torch.save(model.state_dict(), model_dir/'checkpoints'/'epoch{}_step_{}.pt'.format(epoch, round(i/every_n)))

            if i >= 1000 and args.debug:
                break
    
    # save information
    dt = (time.time() - t0)
    time_txt = '\ntime taken: {:2.2f}h\n'.format(dt/3600.)
    utils.report(time_txt, out_stream)
    utils.report(str(dt/3600.), model_dir/'time.txt')
        
    loss_dict = {'train':training_losses, 'valid':valid_losses, 'time_taken':dt}
    pickle.dump(loss_dict, open(model_dir/ 'losses.pkl', 'wb'))

    # evaluate
    evaluate.plot_losses(model_dir)
예제 #4
0
    err.summary(writer, epoch)
    grd.summary(writer, epoch)

    err = AverageMeter('Loss/test')

    loader = DataLoaderCuda(test, batch_size=bptt, drop_last=True)

    hidden = model.step_init(batch_size)

    with torch.no_grad():

        for inputs, targets in loader:

            output, hidden = model.step_forward(inputs, hidden)

            loss = criterion(output, targets.view(-1))

            err.update(loss.item())

            loader.set_description('Epoch %d %s' % (epoch, err))

    sys.stderr.write('\n')

    err.summary(writer, epoch)

    writer.flush()

    torch.save(model.state_dict(), writer.log_dir + '/model%d.bin' % epoch)

writer.close()