def some_trainable(): params = [] for size in [100, 3, 5, 2, 6, 4]: params.append(torch.rand(size, 1)) # Make sure that the params are trainable, enforces size-based partitioning for p in params[1:]: p.requires_grad = True o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=0.1) assert len(o.param_groups) == 1 o.add_param_group({"params": [torch.rand(3, 1)]}) assert len(o.param_groups) == 2 assert len(o.optim.param_groups) == 2
def all_trainable(): params = [] sizes = [9, 7, 5, 3] sizes_world = sizes * self.world_size for size in sizes_world[:-1]: params.append(torch.rand(size, 1)) # Make sure that the params are trainable, enforces size-based partitioning for p in params: p.requires_grad = True o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=0.1) assert len(o.param_groups) == 1 o.add_param_group({"params": [torch.rand(3, 1)]}) assert len(o.param_groups) == 2 # Verify that added group is added to the correct partition making all have the same elements. assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == sum(sizes) assert len(o.optim.param_groups) == 2