def test_main(): model = Net() model.load_state_dict(torch.load(model_file)) model.to(device) test_loader = get_test_loader(25) print('=========') print('Test set:') with torch.no_grad(): evaluate(model, test_loader)
def test_main(): print('Reading', model_file) model = Net() model.load_state_dict(torch.load(model_file)) model.to(device) test_loader, test_sampler = get_test_loader(25) print('=========') print('Simple:') with torch.no_grad(): evaluate(model, test_loader, test_sampler)
def test_main(): model = Net() if hvd.rank() == 0: model.load_state_dict(torch.load(model_file)) model.to(device) hvd.broadcast_parameters(model.state_dict(), root_rank=0) test_loader, test_sampler = get_test_loader(25) if hvd.rank() == 0: print('=========') print('Test set:') with torch.no_grad(): evaluate(model, test_loader, test_sampler)
def train_main(): model = Net().to(device) # optimizer = optim.SGD(model.parameters(), lr=0.05) if hvd.rank() == 0: print(model) # Horovod: broadcast parameters. hvd.broadcast_parameters(model.state_dict(), root_rank=0) # Horovod: scale learning rate by the number of GPUs. lr = 0.05 optimizer = optim.SGD(model.parameters(), lr=lr * hvd.size()) # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) criterion = nn.BCELoss() batch_size = 25 train_loader, train_sampler = get_train_loader(batch_size) validation_loader, validation_sampler = get_validation_loader(batch_size) log = get_tensorboard('simple_hvd') epochs = 50 start_time = datetime.now() for epoch in range(1, epochs + 1): train_sampler.set_epoch(epoch) train(model, train_loader, train_sampler, criterion, optimizer, epoch, log) with torch.no_grad(): if hvd.rank() == 0: print('\nValidation for epoch {}:'.format(epoch)) evaluate(model, validation_loader, validation_sampler, criterion, epoch, log) end_time = datetime.now() if hvd.rank() == 0: print('Total training time: {}.'.format(end_time - start_time)) torch.save(model.state_dict(), model_file) print('Wrote model to', model_file)