def test_dirichlet_shape(self): dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) self.assertEqual(dist._event_shape, torch.Size((2,))) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def select_action(self, obs): concentration, value = self.forward(obs) m = Dirichlet(concentration) action = m.sample() self.saved_actions.append(SavedAction(m.log_prob(action), value)) return list(action.cpu().numpy())
def test_forward(self): source_dirichlet = Dirichlet(torch.ones(10)) batch_size = 12 observed_data = source_dirichlet.sample((batch_size, )) # Check that forward function returns _, membership_probs = self.dmm(observed_data) # Ensure membership probabilities sum to one for prob in membership_probs.sum(dim=1): self.assertAlmostEqual(prob.item(), 1, places=5)
def test_dirichlet_log_prob(self): num_samples = 10 alpha = torch.exp(torch.randn(5)) dist = Dirichlet(alpha) x = dist.sample((num_samples,)) actual_log_prob = dist.log_prob(x) for i in range(num_samples): expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy()) self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
def test_backward(self): source_dirichlet = Dirichlet(torch.ones(10)) batch_size = 12 observed_data = source_dirichlet.sample((batch_size, )) # Obtain loss nll, _ = self.dmm(observed_data) loss = nll.sum() # Check that gradient is non-zero for params loss.backward() for param in self.dmm.parameters(): self.assertIsNotNone(param.grad)
def sample_l1_sphere(device, shape): '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope. Inputs: device: 'cpu' | 'cuda' | other torch devices shape: a pair (batchsize, dim) Outputs: matrix of shape `shape` such that each row is a sample. ''' batchsize, dim = shape dirdist = Dirichlet(concentration=torch.ones(dim, device=device)) noises = dirdist.sample([batchsize]) signs = torch.sign(torch.rand_like(noises) - 0.5) return noises * signs
def backward(ctx, grad_output): grad_input=None if not ctx.train: raise RuntimeError('Running backward on shake when train is False') if ctx.needs_input_grad[0]: dist = Dirichlet(th.full((ctx.xsh[ctx.dim],),ctx.concentration)) beta = dist.sample(sample_shape=th.Size([ctx.xsh[ctx.batchdim]])) beta = beta.to(th.device(ctx.dev)) sh = [1 for _ in range(len(ctx.xsh))] sh[ctx.batchdim], sh[ctx.dim] = ctx.xsh[ctx.batchdim], ctx.xsh[ctx.dim] beta = beta.view(*sh) grad_output = grad_output.unsqueeze(ctx.dim).expand(*ctx.xsh) grad_input = grad_output * beta return grad_input, None, None, None, None, None
def forward(ctx, x, dim, batchdim, concentration, train): ctx.dim = dim ctx.batchdim = batchdim ctx.concentration = concentration xsh, ctx.xsh = [x.shape]*2 ctx.dev = x.device ctx.train = True if train: # Randomly sample from Dirichlet distribution dist = Dirichlet(th.full((xsh[dim],), concentration)) alpha = dist.sample(sample_shape=th.Size([xsh[batchdim]])) alpha = alpha.to(th.device(x.device)) sh = [1 for _ in range(len(xsh))] sh[batchdim], sh[dim] = xsh[batchdim], xsh[dim] alpha = alpha.view(*sh) y = (x * alpha).sum(dim) else: y = x.mean(dim) return y
def make(self, mode: int, num_clients: int, show_plots: bool = False, **kwargs) -> None: if os.path.exists(self.root_dir / "client_data"): shutil.rmtree(self.root_dir / "client_data") client_data_path = Path(self.root_dir / "client_data") client_data_path.mkdir() if not isinstance(self.test_data.targets, torch.Tensor): self.test_data.targets = torch.tensor(self.test_data.targets) test_data = [self.test_data[j] for j in range(len(self.test_data))] torch.save(test_data, client_data_path / "test_data.pth") if mode == 0: # IID # Shuffle data data_ids = torch.randperm(self.num_train_data, dtype=torch.int32) num_data_per_client = self.num_train_data // num_clients if not isinstance(self.train_data.targets, torch.Tensor): self.train_data.targets = torch.tensor(self.train_data.targets) pbar = tqdm(range(num_clients), desc=f"{self.dataset_name} IID: ") for i in pbar: client_path = Path(client_data_path / str(i)) client_path.mkdir() # TODO: Make this parallel for large number of clients & large datasets (Maybe not required) train_data = [ self.train_data[j] for j in data_ids[i * num_data_per_client:(i + 1) * num_data_per_client] ] pbar.set_postfix({'# data / Client': num_data_per_client}) if show_plots: self._plot(train_data, title=f"Client {i+1} Data Distribution") # Split data equally and send to the client torch.save(train_data, client_data_path / str(i) / "data.pth") elif mode == 1: # Non IID Balanced num_data_per_client = self.num_train_data // num_clients classs_sampler = Dirichlet( torch.empty(self.num_classes).fill_(kwargs.get('dir_alpha'))) # print(torch.empty(self.num_classes).fill_(2.0)) if not isinstance(self.train_data.targets, torch.Tensor): self.train_data.targets = torch.tensor(self.train_data.targets) assigned_ids = [] pbar = tqdm(range(num_clients), desc=f"{self.dataset_name} Non-IID Balanced: ") for i in pbar: client_path = Path(client_data_path / str(i)) client_path.mkdir() # Compute class prior probabilities for each client p_ij = classs_sampler.sample( ) # Share of jth class for ith client (always sums to 1) # print(p_ij) weights = torch.zeros(self.num_train_data) # print(torch.nonzero(self.train_data.targets == 9)) for c_id in range(self.num_classes): weights[self.train_data.targets == c_id] = p_ij[c_id] weights[ assigned_ids] = 0.0 # So that previously assigned data are not sampled again # Sample each data point uniformly without replacement based on # the sampling probability assigned based on its class data_ids = torch.multinomial(weights, num_data_per_client, replacement=False) train_data = [self.train_data[j] for j in data_ids] # print(f"Client {i} has {len(train_data)} data points.") pbar.set_postfix({'# data / Client': len(train_data)}) assigned_ids += data_ids.tolist() torch.save(train_data, client_data_path / str(i) / "data.pth") if show_plots: self._plot(train_data, title=f"Client {i+1} Data Distribution") elif mode == 2: # Non IID Unbalanced num_data_per_client = self.num_train_data // num_clients num_data_per_class = self.num_train_data / (self.num_classes * num_clients) classs_sampler = Dirichlet( torch.empty(self.num_classes).fill_(kwargs.get('dir_alpha'))) assigned_ids = [] pbar = tqdm(range(num_clients), desc=f"{self.dataset_name} Non-IID Unbalanced: ") if not isinstance(self.train_data.targets, torch.Tensor): self.train_data.targets = torch.tensor(self.train_data.targets) for i in pbar: train_data = [] client_path = Path(client_data_path / str(i)) client_path.mkdir() # Compute class prior probabilities for each client p_ij = classs_sampler.sample( ) # Share of jth class for ith client (always sums to 1) c_sampler = Categorical(p_ij) data_sampler = LogNormal( torch.tensor(num_data_per_class).log(), kwargs.get('lognorm_std')) while (True): num_data_left = num_data_per_client - len(train_data) c = c_sampler.sample() num_data_c = int(data_sampler.sample()) # print(c, num_data_c, len(train_data)) data_ids = torch.nonzero( self.train_data.targets == c.item()).flatten() # data_ids = [x for x in data_ids if x not in assigned_ids] # Remove duplicated ids # print(data_ids.shape) num_data_c = min(num_data_c, data_ids.shape[0]) if num_data_c >= num_data_left: train_data += [ self.train_data[j] for j in data_ids[:num_data_left] ] break else: train_data += [ self.train_data[j] for j in data_ids[:num_data_c] ] assigned_ids += data_ids[:num_data_c].tolist() pbar.set_postfix({'# data / Client': len(train_data)}) torch.save(train_data, client_data_path / str(i) / "data.pth") if show_plots: self._plot(train_data, title=f"Client {i+1} Data Distribution") else: raise ValueError("Unknown mode. Mode must be {0,1}")