def test_swap_is_not_persisted_in_class(self):
        opt = self._opt()
        dictionary = DictionaryAgent(opt)

        CustomFFN = type('CustomFFN', (TransformerFFN, ), {})
        wrapped_class = TransformerGeneratorModel.with_components(
            encoder=TransformerEncoder.with_components(
                layer=TransformerEncoderLayer.with_components(
                    feedforward=CustomFFN)))
        model = wrapped_class(opt=opt, dictionary=dictionary)
        assert (
            model.swappables.encoder.swappables.layer.swappables.feedforward ==
            CustomFFN)  # type: ignore

        another_model = TransformerGeneratorModel(opt, dictionary)
        assert another_model.swappables != model.swappables
        assert issubclass(another_model.swappables.encoder,
                          TransformerEncoder)  # type: ignore

        wrapped_class.swap_components(
            encoder=TransformerEncoder.with_components(
                layer=TransformerEncoderLayer.with_components(
                    feedforward=TransformerFFN)))
        one_more_model = wrapped_class(opt=opt, dictionary=dictionary)
        assert (one_more_model.swappables.encoder.swappables.layer.swappables.
                feedforward == TransformerFFN)  # type: ignore
 def build_model(self, states=None):
     wrapped_class = TransformerGeneratorModel.with_components(
         encoder=TransformerEncoder.with_components(
             layer=TransformerEncoderLayer.with_components(
                 self_attention=MultiHeadAttention,
                 feedforward=TransformerFFN)),
         decoder=TransformerDecoder.with_components(
             layer=TransformerDecoderLayer.with_components(
                 encoder_attention=MultiHeadAttention,
                 self_attention=MultiHeadAttention,
                 feedforward=TransformerFFN,
             )),
     )
     return wrapped_class(opt=self.opt, dictionary=self.dict)
 def test_swap_encoder_attention(self):
     CustomFFN = type('CustomFFN', (TransformerFFN, ), {})
     CustomFFN.forward = MagicMock()
     wrapped_class = TransformerGeneratorModel.with_components(
         encoder=TransformerEncoder.with_components(
             layer=TransformerEncoderLayer.with_components(
                 feedforward=CustomFFN)))
     opt = self._opt()
     CustomFFN.forward.assert_not_called
     model = wrapped_class(opt=opt, dictionary=DictionaryAgent(opt))
     assert isinstance(model, TransformerGeneratorModel)  # type: ignore
     try:
         model(torch.zeros(1, 1).long(),
               ys=torch.zeros(1, 1).long())  # type: ignore
     except TypeError:
         pass
     finally:
         CustomFFN.forward.assert_called