def problem_setup(net=LinearNet): bf.init() num_epochs = 50 batch_size = 128 num_train_per_node = 1024 num_test_per_node = 128 lr = 0.01 # Setup Problem problem_builder = LinearProblemBuilder() train_dataset = problem_builder.get_dataset(num_train_per_node) train_dataloader = DataLoader(train_dataset, batch_size=batch_size) test_dataset = problem_builder.get_dataset(num_test_per_node) test_dataloader = DataLoader(test_dataset, batch_size=batch_size) # Setup Model model = net(problem_builder.input_dim, problem_builder.output_dim) assert ( num_train_per_node*bf.size() >= model.num_parameters ), "The number of samples is too small making it an underdetermined system." # Setup Optimizer optimizer = optim.Adam(model.parameters(), lr=lr*bf.size()) bf.broadcast_parameters(model.state_dict(), root_rank=0) bf.broadcast_optimizer_state(optimizer, root_rank=0) return problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs
L = len(train_sampler)//args.batch_size elif args.method == "DiffAVRG": L = len(train_sampler) elif args.method == "DiffAVRG_B": L = len(train_sampler)//args.batch_size # The network and optimizer if args.method == "ExactDiff": model = NN_model().cuda() bf.broadcast_parameters(model.state_dict(), root_rank=0) optimizer = ExactDiff(model.parameters(), lr=lr, L=L, communication_type=args.comm) elif args.method == "ATC_SGD": model = NN_model().cuda() bf.broadcast_parameters(model.state_dict(), root_rank=0) optimizer = optim.SGD(model.parameters(), lr=lr) bf.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = bf.DistributedAdaptThenCombineOptimizer(optimizer, model, communication_type=bf.CommunicationType.allreduce if args.comm=="allreduce" else bf.CommunicationType.neighbor_allreduce) else: model_0 = NN_model().cuda() model_i = copy.deepcopy(model_0) bf.broadcast_parameters(model_0.state_dict(), root_rank=0) bf.broadcast_parameters(model_i.state_dict(), root_rank=0) optimizer_0 = DiffAVRG_0(model_0.parameters()) optimizer_i = DiffAVRG(model_i.parameters(), lr=lr, L=L, communication_type=args.comm) n_epoch = args.n_epoch # the number of epochs loss_fn = nn.CrossEntropyLoss() res_list = []