def test_fit_on_two_layer_perceptron_is_correct_when_used_with_a_loss(self):
        # Given
        model = create_2l_perceptron()

        maml = MAML(model, loss=tf.keras.losses.MSE)

        data_x = np.ones((1,))
        data_y = np.zeros((1, 1))

        expected_weights = [
            -3 * np.ones((1, 2)),
            -3 * np.ones((2, 1)),
        ]

        # When
        eval_model = maml.fit(data_x, data_y)
        weight_set = [layer.kernel for layer in eval_model.layers]

        # Then
        for i_weight, weights in enumerate(weight_set):
            self.assertTrue(np.all(weights == expected_weights[i_weight]))
    def test_fit_on_two_layer_perceptron_is_correct_when_passed_a_value_for_alpha(self):
        # Given
        model = create_2l_perceptron()

        maml = MAML(model, loss=lambda y, p: p)

        data_x = np.ones((1,))
        data_y = np.zeros((1, 1))

        expected_weights = [
            0.5 * np.ones((1, 2)),
            0.5 * np.ones((2, 1)),
        ]

        # When
        eval_model = maml.fit(data_x, data_y, alpha=0.5)
        weight_set = [layer.kernel for layer in eval_model.layers]

        # Then
        for i_weight, weights in enumerate(weight_set):
            self.assertTrue(np.all(weights == expected_weights[i_weight]))