def test_main():
    model = Net()
    model.load_state_dict(torch.load(model_file))
    model.to(device)

    test_loader = get_test_loader(25)

    print('=========')
    print('Test set:')
    with torch.no_grad():
        evaluate(model, test_loader)
def test_main():
    print('Reading', model_file)
    model = Net()
    model.load_state_dict(torch.load(model_file))
    model.to(device)

    test_loader = get_test_loader(25)

    print('=========')
    print('Simple:')
    with torch.no_grad():
        evaluate(model, test_loader)
def main():
    model = Net()

    batch_size = 25
    train_loader = get_train_loader(batch_size)
    validation_loader = get_validation_loader(batch_size)

    trainer = pl.Trainer(gpus=-1, max_epochs=50, accelerator='ddp')
    # trainer = pl.Trainer(gpus=1, max_epochs=50, accelerator='horovod', checkpoint_callback=False)

    start_time = datetime.now()
    trainer.fit(model, train_loader, validation_loader)
    end_time = datetime.now()
    print('Total training time: {}.'.format(end_time - start_time))

    # torch.save(model.state_dict(), model_file)
    # print('Wrote model to', model_file)

    test_loader = get_test_loader(batch_size)
    trainer.test(test_dataloaders=test_loader)
Example #4
0
def test_main():
    model = PretrainedNet()
    model.load_state_dict(torch.load(model_file))
    model.to(device)

    test_loader = get_test_loader(25)

    print('=========')
    print('Pretrained:')
    with torch.no_grad():
        evaluate(model, test_loader)

    model = PretrainedNet()
    model.load_state_dict(torch.load(model_file_ft))
    model.to(device)

    print('=========')
    print('Finetuned:')
    with torch.no_grad():
        evaluate(model, test_loader)