예제 #1
0
 def test_dirichlet_shape(self):
     dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
     self.assertEqual(dist._batch_shape, torch.Size((3,)))
     self.assertEqual(dist._event_shape, torch.Size((2,)))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
예제 #2
0
    def select_action(self, obs):
        concentration, value = self.forward(obs)

        m = Dirichlet(concentration)

        action = m.sample()
        self.saved_actions.append(SavedAction(m.log_prob(action), value))
        return list(action.cpu().numpy())
예제 #3
0
파일: test_dmm.py 프로젝트: rloganiv/dmm
 def test_forward(self):
     source_dirichlet = Dirichlet(torch.ones(10))
     batch_size = 12
     observed_data = source_dirichlet.sample((batch_size, ))
     # Check that forward function returns
     _, membership_probs = self.dmm(observed_data)
     # Ensure membership probabilities sum to one
     for prob in membership_probs.sum(dim=1):
         self.assertAlmostEqual(prob.item(), 1, places=5)
예제 #4
0
 def test_dirichlet_log_prob(self):
     num_samples = 10
     alpha = torch.exp(torch.randn(5))
     dist = Dirichlet(alpha)
     x = dist.sample((num_samples,))
     actual_log_prob = dist.log_prob(x)
     for i in range(num_samples):
         expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
         self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
예제 #5
0
파일: test_dmm.py 프로젝트: rloganiv/dmm
 def test_backward(self):
     source_dirichlet = Dirichlet(torch.ones(10))
     batch_size = 12
     observed_data = source_dirichlet.sample((batch_size, ))
     # Obtain loss
     nll, _ = self.dmm(observed_data)
     loss = nll.sum()
     # Check that gradient is non-zero for params
     loss.backward()
     for param in self.dmm.parameters():
         self.assertIsNotNone(param.grad)
예제 #6
0
파일: utils.py 프로젝트: tonyduan/rs4a
def sample_l1_sphere(device, shape):
    '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope.
    Inputs:
        device: 'cpu' | 'cuda' | other torch devices
        shape: a pair (batchsize, dim)
    Outputs:
        matrix of shape `shape` such that each row is a sample.
    '''
    batchsize, dim = shape
    dirdist = Dirichlet(concentration=torch.ones(dim, device=device))
    noises = dirdist.sample([batchsize])
    signs = torch.sign(torch.rand_like(noises) - 0.5)
    return noises * signs
예제 #7
0
    def backward(ctx, grad_output):
        grad_input=None

        if not ctx.train:
            raise RuntimeError('Running backward on shake when train is False')

        if ctx.needs_input_grad[0]:
            dist = Dirichlet(th.full((ctx.xsh[ctx.dim],),ctx.concentration))
            beta = dist.sample(sample_shape=th.Size([ctx.xsh[ctx.batchdim]]))
            beta = beta.to(th.device(ctx.dev))
            sh = [1 for _ in range(len(ctx.xsh))]
            sh[ctx.batchdim], sh[ctx.dim] = ctx.xsh[ctx.batchdim], ctx.xsh[ctx.dim]
            beta = beta.view(*sh)
            grad_output = grad_output.unsqueeze(ctx.dim).expand(*ctx.xsh)

            grad_input = grad_output * beta

        return grad_input, None, None, None, None, None
예제 #8
0
    def forward(ctx, x, dim, batchdim, concentration, train):
        ctx.dim = dim
        ctx.batchdim = batchdim
        ctx.concentration = concentration
        xsh, ctx.xsh = [x.shape]*2
        ctx.dev = x.device
        ctx.train = True

        if train:
            # Randomly sample from Dirichlet distribution
            dist = Dirichlet(th.full((xsh[dim],), concentration))
            alpha = dist.sample(sample_shape=th.Size([xsh[batchdim]]))
            alpha = alpha.to(th.device(x.device))
            sh = [1 for _ in range(len(xsh))]
            sh[batchdim], sh[dim] = xsh[batchdim], xsh[dim]
            alpha = alpha.view(*sh)


            y = (x * alpha).sum(dim)
        else:
            y = x.mean(dim)

        return y
예제 #9
0
    def make(self,
             mode: int,
             num_clients: int,
             show_plots: bool = False,
             **kwargs) -> None:

        if os.path.exists(self.root_dir / "client_data"):
            shutil.rmtree(self.root_dir / "client_data")
        client_data_path = Path(self.root_dir / "client_data")
        client_data_path.mkdir()

        if not isinstance(self.test_data.targets, torch.Tensor):
            self.test_data.targets = torch.tensor(self.test_data.targets)
        test_data = [self.test_data[j] for j in range(len(self.test_data))]
        torch.save(test_data, client_data_path / "test_data.pth")

        if mode == 0:  # IID
            # Shuffle data
            data_ids = torch.randperm(self.num_train_data, dtype=torch.int32)
            num_data_per_client = self.num_train_data // num_clients

            if not isinstance(self.train_data.targets, torch.Tensor):
                self.train_data.targets = torch.tensor(self.train_data.targets)

            pbar = tqdm(range(num_clients), desc=f"{self.dataset_name} IID: ")
            for i in pbar:
                client_path = Path(client_data_path / str(i))
                client_path.mkdir()

                # TODO: Make this parallel for large number of clients & large datasets (Maybe not required)
                train_data = [
                    self.train_data[j]
                    for j in data_ids[i * num_data_per_client:(i + 1) *
                                      num_data_per_client]
                ]

                pbar.set_postfix({'# data / Client': num_data_per_client})

                if show_plots:
                    self._plot(train_data,
                               title=f"Client {i+1} Data Distribution")

                # Split data equally and send to the client
                torch.save(train_data, client_data_path / str(i) / "data.pth")
        elif mode == 1:  # Non IID Balanced
            num_data_per_client = self.num_train_data // num_clients
            classs_sampler = Dirichlet(
                torch.empty(self.num_classes).fill_(kwargs.get('dir_alpha')))
            # print(torch.empty(self.num_classes).fill_(2.0))
            if not isinstance(self.train_data.targets, torch.Tensor):
                self.train_data.targets = torch.tensor(self.train_data.targets)

            assigned_ids = []
            pbar = tqdm(range(num_clients),
                        desc=f"{self.dataset_name} Non-IID Balanced: ")
            for i in pbar:

                client_path = Path(client_data_path / str(i))
                client_path.mkdir()
                # Compute class prior probabilities for each client
                p_ij = classs_sampler.sample(
                )  # Share of jth class for ith client (always sums to 1)
                # print(p_ij)
                weights = torch.zeros(self.num_train_data)
                # print(torch.nonzero(self.train_data.targets == 9))
                for c_id in range(self.num_classes):
                    weights[self.train_data.targets == c_id] = p_ij[c_id]
                weights[
                    assigned_ids] = 0.0  # So that previously assigned data are not sampled again

                # Sample each data point uniformly without replacement based on
                # the sampling probability assigned based on its class
                data_ids = torch.multinomial(weights,
                                             num_data_per_client,
                                             replacement=False)

                train_data = [self.train_data[j] for j in data_ids]
                # print(f"Client {i} has {len(train_data)} data points.")
                pbar.set_postfix({'# data / Client': len(train_data)})

                assigned_ids += data_ids.tolist()

                torch.save(train_data, client_data_path / str(i) / "data.pth")

                if show_plots:
                    self._plot(train_data,
                               title=f"Client {i+1} Data Distribution")
        elif mode == 2:  # Non IID Unbalanced
            num_data_per_client = self.num_train_data // num_clients
            num_data_per_class = self.num_train_data / (self.num_classes *
                                                        num_clients)
            classs_sampler = Dirichlet(
                torch.empty(self.num_classes).fill_(kwargs.get('dir_alpha')))

            assigned_ids = []
            pbar = tqdm(range(num_clients),
                        desc=f"{self.dataset_name} Non-IID Unbalanced: ")

            if not isinstance(self.train_data.targets, torch.Tensor):
                self.train_data.targets = torch.tensor(self.train_data.targets)

            for i in pbar:
                train_data = []
                client_path = Path(client_data_path / str(i))
                client_path.mkdir()
                # Compute class prior probabilities for each client
                p_ij = classs_sampler.sample(
                )  # Share of jth class for ith client (always sums to 1)
                c_sampler = Categorical(p_ij)
                data_sampler = LogNormal(
                    torch.tensor(num_data_per_class).log(),
                    kwargs.get('lognorm_std'))

                while (True):
                    num_data_left = num_data_per_client - len(train_data)
                    c = c_sampler.sample()
                    num_data_c = int(data_sampler.sample())
                    # print(c, num_data_c, len(train_data))
                    data_ids = torch.nonzero(
                        self.train_data.targets == c.item()).flatten()
                    # data_ids = [x for x in data_ids if x not in assigned_ids] # Remove duplicated ids
                    # print(data_ids.shape)
                    num_data_c = min(num_data_c, data_ids.shape[0])
                    if num_data_c >= num_data_left:
                        train_data += [
                            self.train_data[j]
                            for j in data_ids[:num_data_left]
                        ]
                        break
                    else:
                        train_data += [
                            self.train_data[j] for j in data_ids[:num_data_c]
                        ]
                        assigned_ids += data_ids[:num_data_c].tolist()

                pbar.set_postfix({'# data / Client': len(train_data)})
                torch.save(train_data, client_data_path / str(i) / "data.pth")
                if show_plots:
                    self._plot(train_data,
                               title=f"Client {i+1} Data Distribution")

        else:
            raise ValueError("Unknown mode. Mode must be {0,1}")