Esempio n. 1
0
    def __call__(self, module, x, y):
        """
        Save up to the last 'max_samples_to_track' of the dendrite activations and the
        corresponding winning mask.

        :param x: input to an `apply_dendrites` modules; this is tuple
                  of (y, dendrite_activations)
        :param y: dendrite_output named tuple (with values and indices)
        """

        if not self._tracking:
            return

        dendrite_activations = x[1]
        winning_mask = indices_to_mask(y.indices, shape=x[1].shape, dim=2)

        # The `self` tensors were initialized on the cpu which could differ from the
        # values collected during the forward pass.
        device = winning_mask.device
        self.winning_mask = self.winning_mask.to(device)
        self.dendrite_activations = self.dendrite_activations.to(device)

        # MetaCL creates a deepcopy of the model, but this isn't allowed on non-leaf
        # tensors. In detaching it, this will always be the case.
        dendrite_activations = dendrite_activations.detach()

        # Prepend the newest activations and winning masks.
        self.winning_mask = torch.cat((winning_mask, self.winning_mask), dim=0)
        self.dendrite_activations = torch.cat((dendrite_activations,
                                               self.dendrite_activations), dim=0)

        # Keep only the last 'num_samples'.
        self.winning_mask = self.winning_mask[:self.num_samples, ...]
        self.dendrite_activations = self.dendrite_activations[:self.num_samples, ...]
    def test_indices_to_mask_dim_2(self):

        expected_mask = torch.tensor([[[0, 1, 0, 0], [0, 1, 0, 0],
                                       [0, 0, 0, 1]],
                                      [[0, 0, 1, 0], [0, 0, 0, 1],
                                       [0, 0, 0, 1]]])
        indices = self.tensor.max(dim=2).indices
        actual_mask = indices_to_mask(indices, self.tensor.shape, dim=2)
        self.assertTrue(actual_mask.dtype == torch.bool)

        all_equal = (actual_mask == expected_mask).all()
        self.assertTrue(all_equal)
    def update_duty_cycles(self, indices):
        """
        Update the moving average of winning segments.

        :param indices: indices of winning segments; shape batch_size x num_units
        """

        # Figure out which segments won for each unit for each batch.
        batch_size = indices.shape[0]
        shape = (batch_size, self.num_units, self.num_segments)
        winning_mask = indices_to_mask(indices, shape, dim=2)

        # Sum over the batches.
        winning_mask = winning_mask.sum(dim=0)

        # Update the duty cycle.
        self.learning_iterations += batch_size
        period = min(self.duty_cycle_period, self.learning_iterations)
        self.duty_cycles.mul_(period - batch_size)
        self.duty_cycles.add_(winning_mask)
        self.duty_cycles.div_(period)