Esempio n. 1
0
def train(args, labeled, resume_from, ckpt_file):
    batch_size = args["batch_size"]
    lr = 4.0
    momentum = 0.9
    epochs = args["train_epochs"]

    if not os.path.isdir('./.data'):
        os.mkdir('./.data')

    global train_dataset, test_dataset
    train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
        root='./.data', ngrams=args["N_GRAMS"], vocab=None)

    global VOCAB_SIZE, EMBED_DIM, NUN_CLASS
    VOCAB_SIZE = len(train_dataset.get_vocab())
    EMBED_DIM = args["EMBED_DIM"]
    NUN_CLASS = len(train_dataset.get_labels())

    trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             collate_fn=generate_batch)
    net = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

    if resume_from is not None:
        ckpt = torch.load(os.path.join(args["EXPT_DIR"], resume_from))
        net.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
    else:
        getdatasetstate()

    net.train()
    for epoch in tqdm(range(epochs), desc="Training"):
        running_loss = 0.0
        train_acc = 0
        for i, data in enumerate(trainloader):
            text, offsets, cls = data
            text, offsets, cls = text.to(device), offsets.to(device), cls.to(
                device)
            outputs = net(text, offsets)
            loss = criterion(outputs, cls)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_acc += (outputs.argmax(1) == cls).sum().item()
            running_loss += loss.item()
        scheduler.step()

    print("Finished Training. Saving the model as {}".format(ckpt_file))
    print("Training accuracy: {}".format(
        (train_acc / len(train_dataset) * 100)))
    ckpt = {"model": net.state_dict(), "optimizer": optimizer.state_dict()}
    torch.save(ckpt, os.path.join(args["EXPT_DIR"], ckpt_file))

    return
Esempio n. 2
0
def main():

    device = "gpu" if torch.cuda.is_available() else "cpu"
    train_dataset, test_dataset = get_dataset()
    VOCAB_SIZE = len(train_dataset.get_vocab())
    EMBED_DIM = 32
    NUN_CLASS = len(train_dataset.get_labels())
    model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
    BATCH_SIZE = 16
    N_EPOCHS = 5
    min_valid_loss = float('inf')

    criterion = torch.nn.CrossEntropyLoss().to(
        device)  # mutil-class use the CrossEntropy
    optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

    train_len = int(len(train_dataset) * 0.95)
    sub_train_, sub_valid_ = \
        random_split(train_dataset, [train_len, len(train_dataset) - train_len])
    train_loader = DataLoader(sub_train_,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=generate_batch)
    valid_loader = DataLoader(sub_valid_,
                              batch_size=BATCH_SIZE,
                              collate_fn=generate_batch)
    test_loader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
                             collate_fn=generate_batch)

    for epoch in tqdm(range(N_EPOCHS)):

        start_time = time.time()
        train_loss, train_acc = train_fn(dataLoader=train_loader,
                                         model=model,
                                         optimizer=optimizer,
                                         scheduler=scheduler,
                                         criterion=criterion,
                                         device=device)
        valid_loss, valid_acc = evaluate_fn(dataLoader=valid_loader,
                                            model=model,
                                            criterion=criterion,
                                            device=device)

        secs = int(time.time() - start_time)
        mins = secs / 60
        secs = secs % 60

        print('Epoch: %d' % (epoch + 1),
              " | time in %d minutes, %d seconds" % (mins, secs))
        print(
            f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)'
        )
        print(
            f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)'
        )
        if valid_loss < min_valid_loss:
            torch.save(model.state_dict(),
                       "../weights/text_news{}.pth".format(valid_loss))
            print(min_valid_loss, "--------->>>>>>>>", valid_loss)
            min_valid_loss = valid_loss

    print('Checking the results of test dataset...')
    test_loss, test_acc = evaluate_fn(dataLoader=test_loader,
                                      model=model,
                                      criterion=criterion,
                                      device=device)
    print(
        f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')
Esempio n. 3
0
                                 shuffle=True,
                                 collate_fn=collate_batch)

    for epoch in range(1, num_epochs + 1):
        epoch_start_time = time.time()
        train(train_dataloader, model, optimizer, criterion, epoch)
        accu_val = evaluate(valid_dataloader, model)
        scheduler.step()
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                               time.time() - epoch_start_time,
                                               accu_val))
        print('-' * 59)

    print('Checking the results of test dataset.')
    accu_test = evaluate(test_dataloader, model)
    print('test accuracy {:8.3f}'.format(accu_test))

    if args.save_model_path:
        print("Saving model to {}".format(args.save_model_path))
        torch.save(model.state_dict(), args.save_model_path)
        ''' this shows how to script the model
        sm = torch.jit.script(model)
        sm.save("model_scipted.pt")
        '''

    if args.dictionary is not None:
        print("Save vocab to {}".format(args.dictionary))
        torch.save(vocab, args.dictionary)