Beispiel #1
0
    def test_average_reweighting_by_pred_confidence(self):
        """Test average reweighting by pred confidence (zeros in ind head)."""
        batch_size = 4
        h_dim = 20
        num_classes = 2

        outputs = {
            "task_slice:a_ind_head": torch.zeros(batch_size, 2),
            "task_slice:a_pred_transform": torch.ones(batch_size, h_dim) * 4,
            "task_slice:a_pred_head": torch.ones(batch_size, num_classes) * 5,
            "task_slice:b_ind_head": torch.zeros(batch_size, 2),
            "task_slice:b_pred_transform": torch.ones(batch_size, h_dim) * 2,
            "task_slice:b_pred_head": torch.ones(batch_size, num_classes) * 5,
        }
        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertTrue(
            torch.all(combined_rep == torch.ones(batch_size, h_dim) * 3))

        # Changing sign of prediction shouldn't change confidence
        outputs = {
            "task_slice:a_ind_head": torch.zeros(batch_size, 2),
            "task_slice:a_pred_transform": torch.ones(batch_size, h_dim) * 4,
            "task_slice:a_pred_head": torch.ones(batch_size, num_classes) * -5,
            "task_slice:b_ind_head": torch.zeros(batch_size, 2),
            "task_slice:b_pred_transform": torch.ones(batch_size, h_dim) * 2,
            "task_slice:b_pred_head": torch.ones(batch_size, num_classes) * 5,
        }
        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertTrue(
            torch.allclose(combined_rep,
                           torch.ones(batch_size, h_dim) * 3))
Beispiel #2
0
    def test_forward_shape(self):
        """Test that the reweight representation shape matches expected feature size."""

        # common case
        batch_size = 4
        h_dim = 20
        num_classes = 2

        outputs = {
            # Always 2-dim binary outputs
            "task_slice:base_ind_head":
            torch.FloatTensor(batch_size, 2).uniform_(0, 1),
            "task_slice:base_pred_transform":
            torch.FloatTensor(batch_size, h_dim).uniform_(0, 1),
            "task_slice:base_pred_head":
            torch.FloatTensor(batch_size, num_classes).uniform_(0, 1),
        }
        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertEqual(tuple(combined_rep.shape), (batch_size, h_dim))

        # edge case (some ones)
        batch_size = 1
        h_dim = 1
        num_classes = 2

        outputs = {
            # Always 2-dim binary outputs
            "task_slice:base_ind_head":
            torch.FloatTensor(batch_size, 2).uniform_(0, 1),
            "task_slice:base_pred_transform":
            torch.FloatTensor(batch_size, h_dim).uniform_(0, 1),
            "task_slice:base_pred_head":
            torch.FloatTensor(batch_size, num_classes).uniform_(0, 1),
        }
        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertEqual(tuple(combined_rep.shape), (batch_size, h_dim))

        # num_classes = 1
        batch_size = 1
        h_dim = 1
        num_classes = 1

        outputs = {
            # Always 2-dim binary outputs
            "task_slice:base_ind_head":
            torch.FloatTensor(batch_size, 2).uniform_(0, 1),
            "task_slice:base_pred_transform":
            torch.FloatTensor(batch_size, h_dim).uniform_(0, 1),
            "task_slice:base_pred_head":
            torch.FloatTensor(batch_size, num_classes).uniform_(0, 1),
        }
        combiner_module = SliceCombinerModule()
        with self.assertRaisesRegex(NotImplementedError,
                                    "requires output shape"):
            combined_rep = combiner_module(outputs)
Beispiel #3
0
    def test_temperature(self):
        """Test temperature parameter for attention weights."""
        batch_size = 4
        h_dim = 20
        num_classes = 2

        # Add noise to each set of inputs to attention weights
        epsilon = 1e-5
        outputs = {
            "task_slice:a_ind_head":
            torch.ones(batch_size, 2) * 10.0 +
            torch.FloatTensor(batch_size, 2).normal_(0.0, epsilon),
            "task_slice:a_pred_transform":
            torch.ones(batch_size, h_dim) * 4 +
            torch.FloatTensor(batch_size, h_dim).normal_(0.0, epsilon),
            "task_slice:a_pred_head":
            torch.ones(batch_size, num_classes) * 10.0 +
            torch.FloatTensor(batch_size, num_classes).normal_(0.0, epsilon),
            "task_slice:b_ind_head":
            torch.ones(batch_size, 2) * -10.0 +
            torch.FloatTensor(batch_size, 2).normal_(0.0, epsilon),
            "task_slice:b_pred_transform":
            torch.ones(batch_size, h_dim) * 2 +
            torch.FloatTensor(batch_size, h_dim).normal_(0.0, epsilon),
            "task_slice:b_pred_head":
            torch.ones(batch_size, num_classes) * -10.0 +
            torch.FloatTensor(batch_size, num_classes).normal_(0.0, epsilon),
        }

        # With larger temperature, attention is smoother and reweighted rep is closer
        # closer to true average, despite noise
        combiner_module = SliceCombinerModule(temperature=1e5)
        combined_rep = combiner_module(outputs)
        self.assertTrue(
            torch.allclose(combined_rep,
                           torch.ones(batch_size, h_dim) * 3))

        # With (impractically) small temperature, attention peaks/biases towards in
        # some direction of noisy attention weights
        combiner_module = SliceCombinerModule(temperature=1e-15)
        combined_rep = combiner_module(outputs)

        # Check number of elements that match either of the original weights
        # Every example should either match 2 or 4
        isclose_four = torch.isclose(combined_rep,
                                     torch.ones(batch_size, h_dim) * 2,
                                     atol=1e-4)
        isclose_two = torch.isclose(combined_rep,
                                    torch.ones(batch_size, h_dim) * 4,
                                    atol=1e-4)
        num_matching_original = torch.sum(isclose_four) + torch.sum(
            isclose_two)
        self.assertEqual(num_matching_original, batch_size * h_dim)
Beispiel #4
0
    def test_many_slices(self):
        """Test combiner on 100 synthetic generated slices."""

        batch_size = 4
        h_dim = 20
        num_classes = 2

        outputs = {}
        # Generate 100 slices, half voting in one direction (index i) and
        # half in another (index i%2), resulting in averaged weights
        for i in range(100):
            if i % 2 == 0:
                outputs[f"task_slice:{i}_ind_head"] = torch.ones(
                    batch_size, 2) * 20.0
                outputs[f"task_slice:{i}_pred_transform"] = (
                    torch.ones(batch_size, h_dim) * 4)
                outputs[f"task_slice:{i}_pred_head"] = (
                    torch.ones(batch_size, num_classes) * 20.0)
            else:
                outputs[f"task_slice:{i}_ind_head"] = torch.ones(
                    batch_size, 2) * -20.0
                outputs[f"task_slice:{i}_pred_transform"] = (
                    torch.ones(batch_size, h_dim) * 2)
                outputs[f"task_slice:{i}_pred_head"] = (
                    torch.ones(batch_size, num_classes) * -20.0)

        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertTrue(
            torch.allclose(combined_rep,
                           torch.ones(batch_size, h_dim) * 3))
Beispiel #5
0
    def test_combiner_multiclass(self):
        """Test combiner in multiclass setting."""
        batch_size = 4
        h_dim = 20
        num_classes = 10

        # For each class, make one class (at random) the maximum score
        # For everything else, set to "logit" between [-5, 5]
        max_score_indexes_a = [
            random.randint(0, num_classes) for _ in range(batch_size)
        ]
        pred_outputs_a = torch.FloatTensor(batch_size,
                                           num_classes).uniform_(-5, 5)
        pred_outputs_a = pred_outputs_a.scatter_(
            # scatter max score (10.0) at indexes
            1,
            torch.tensor(max_score_indexes_a).unsqueeze(1),
            10.0,
        )

        max_score_indexes_b = [
            random.randint(0, num_classes) for _ in range(batch_size)
        ]
        pred_outputs_b = torch.FloatTensor(batch_size,
                                           num_classes).uniform_(-5, 5)
        pred_outputs_b = pred_outputs_b.scatter_(
            # scatter max score (-10.0) at indexes
            1,
            torch.tensor(max_score_indexes_b).unsqueeze(1),
            10.0,
        )

        outputs = {
            "task_slice:a_ind_head": torch.ones(batch_size, 2) * -10.0,
            "task_slice:a_pred_transform": torch.ones(batch_size, h_dim) * 4,
            "task_slice:a_pred_head": pred_outputs_a,
            "task_slice:b_ind_head": torch.ones(batch_size, 2) * 10.0,
            "task_slice:b_pred_transform": torch.ones(batch_size, h_dim) * 2,
            "task_slice:b_pred_head": pred_outputs_b,
        }

        # Ensure smoother temperature for multi-class
        combiner_module = SliceCombinerModule()
        with self.assertRaisesRegex(NotImplementedError,
                                    "more than 2 classes"):
            combiner_module(outputs)
Beispiel #6
0
    def test_average_reweighting_by_ind(self):
        """Test average reweighting by ind (zeros in pred head)."""
        batch_size = 4
        h_dim = 20
        num_classes = 2

        outputs = {
            "task_slice:a_ind_head": torch.ones(batch_size, 2) * 10.0,
            "task_slice:a_pred_transform": torch.ones(batch_size, h_dim) * 4,
            "task_slice:a_pred_head": torch.zeros(batch_size, num_classes),
            "task_slice:b_ind_head": torch.ones(batch_size, 2) * -10.0,
            "task_slice:b_pred_transform": torch.ones(batch_size, h_dim) * 2,
            "task_slice:b_pred_head": torch.zeros(batch_size, num_classes),
        }
        combiner_module = SliceCombinerModule()
        combined_rep = combiner_module(outputs)
        self.assertTrue(
            torch.allclose(combined_rep,
                           torch.ones(batch_size, h_dim) * 3))