def test_push_to_hub_in_organization(self): config = BertConfig(vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37) model = FlaxBertModel(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained( os.path.join(tmp_dir, "test-model-flax-org"), push_to_hub=True, use_auth_token=self._token, organization="valid_org", ) new_model = FlaxBertModel.from_pretrained( "valid_org/test-model-flax-org") base_params = flatten_dict(unfreeze(model.params)) new_params = flatten_dict(unfreeze(new_model.params)) for key in base_params.keys(): max_diff = (base_params[key] - new_params[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def get_vision_text_model(self, vision_config, text_config): vision_model = FlaxCLIPVisionModel(vision_config) text_model = FlaxBertModel(text_config) return vision_model, text_model
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = FlaxBertModel(config) decoder_model = FlaxGPT2LMHeadModel(decoder_config) return encoder_model, decoder_model
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = FlaxBertModel(config) decoder_model = FlaxBartForCausalLM(decoder_config) return encoder_model, decoder_model