コード例 #1
0
def problem_setup(net=LinearNet):
    bf.init()
    num_epochs = 50
    batch_size = 128
    num_train_per_node = 1024
    num_test_per_node = 128
    lr = 0.01

    # Setup Problem
    problem_builder = LinearProblemBuilder()
    train_dataset = problem_builder.get_dataset(num_train_per_node)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
    test_dataset = problem_builder.get_dataset(num_test_per_node)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    # Setup Model
    model = net(problem_builder.input_dim, problem_builder.output_dim)
    assert (
        num_train_per_node*bf.size() >= model.num_parameters
    ), "The number of samples is too small making it an underdetermined system."
    # Setup Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr*bf.size())
    bf.broadcast_parameters(model.state_dict(), root_rank=0)
    bf.broadcast_optimizer_state(optimizer, root_rank=0)
    return problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs
コード例 #2
0
    L = len(train_sampler)//args.batch_size
elif args.method == "DiffAVRG":
    L = len(train_sampler)
elif args.method == "DiffAVRG_B":
    L = len(train_sampler)//args.batch_size

# The network and optimizer
if args.method == "ExactDiff":
    model = NN_model().cuda()
    bf.broadcast_parameters(model.state_dict(), root_rank=0)
    optimizer = ExactDiff(model.parameters(), lr=lr, L=L, communication_type=args.comm)
elif args.method == "ATC_SGD":
    model = NN_model().cuda()
    bf.broadcast_parameters(model.state_dict(), root_rank=0)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    bf.broadcast_optimizer_state(optimizer, root_rank=0)
    optimizer = bf.DistributedAdaptThenCombineOptimizer(optimizer, model,
        communication_type=bf.CommunicationType.allreduce if args.comm=="allreduce" else
                           bf.CommunicationType.neighbor_allreduce)
else:
    model_0 = NN_model().cuda()
    model_i = copy.deepcopy(model_0)
    bf.broadcast_parameters(model_0.state_dict(), root_rank=0)
    bf.broadcast_parameters(model_i.state_dict(), root_rank=0)
    optimizer_0 = DiffAVRG_0(model_0.parameters())
    optimizer_i = DiffAVRG(model_i.parameters(), lr=lr, L=L, communication_type=args.comm)

n_epoch = args.n_epoch  # the number of epochs
loss_fn = nn.CrossEntropyLoss()

res_list = []