def create_model(config): """Create a model, starting with a pre-trained checkpoint.""" model_kwargs = dict(config=config.model, ) model_def = modeling.BertForPreTraining.partial(**model_kwargs) if config.init_checkpoint: initial_params = import_weights.load_params( init_checkpoint=config.init_checkpoint, hidden_size=config.model.hidden_size, num_attention_heads=config.model.num_attention_heads, keep_masked_lm_head=True) else: with nn.stochastic(jax.random.PRNGKey(0)): _, initial_params = model_def.init_by_shape( jax.random.PRNGKey(0), [((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_predictions_per_seq), jnp.int32)], deterministic=True) def fixup_for_tpu(x, i=[0]): """HACK to fix incorrect param initialization on TPU.""" if isinstance(x, jax.ShapeDtypeStruct): i[0] += 1 if len(x.shape) == 2: return jnp.zeros(x.shape, x.dtype) else: return nn.linear.default_kernel_init( jax.random.PRNGKey(i[0]), x.shape, x.dtype) else: return x initial_params = jax.tree_map(fixup_for_tpu, initial_params) model = nn.Model(model_def, initial_params) return model
def create_model(config, num_classes=2): """Create a model, starting with a pre-trained checkpoint.""" model_kwargs = dict( config=config.model, n_classes=num_classes, ) model_def = modeling.BertForSequenceClassification.partial(**model_kwargs) if config.init_checkpoint: initial_params = import_weights.load_params( init_checkpoint=config.init_checkpoint, hidden_size=config.model.hidden_size, num_attention_heads=config.model.num_attention_heads, num_classes=num_classes) else: with nn.stochastic(jax.random.PRNGKey(0)): _, initial_params = model_def.init_by_shape( jax.random.PRNGKey(0), [((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, 1), jnp.int32)], deterministic=True) model = nn.Model(model_def, initial_params) return model
def import_pretrained_params(config): return import_weights.load_params( init_checkpoint=config.init_checkpoint, hidden_size=config.model.hidden_size, num_attention_heads=config.model.num_attention_heads, num_classes=config.num_classes)