コード例 #1
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         position_ids,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #2
0
    def test_save_load_to_base(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(
                unfreeze(model.params[model.base_model_prefix]))

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] -
                                base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff,
                                         1e-3,
                                         msg=f"{key} not identical")
コード例 #3
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        params_rng, dropout_rng = jax.random.split(rng)
        dropout_rng, droppath_rng = jax.random.split(dropout_rng)
        rngs = {
            "params": params_rng,
            "dropout": dropout_rng,
            "droppath": droppath_rng
        }

        random_params = self.module.init(rngs, pixel_values,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #4
0
    def test_save_load_to_base_pt(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))

            # convert Flax model to PyTorch model
            pt_model_class = getattr(transformers, model_class.__name__[4:])  # Skip the "Flax" at the beginning
            pt_model = pt_model_class(config).eval()
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
コード例 #5
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            return_dict=False,
        )

        random_params = module_init_outputs["params"]
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #6
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        attention_mask = jnp.ones_like(input_ids)
        head_mask = jnp.ones(
            (self.config.num_hidden_layers, self.config.num_attention_heads))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         token_type_ids,
                                         head_mask,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #7
0
    def test_push_to_hub_in_organization(self):
        config = BertConfig(vocab_size=99,
                            hidden_size=32,
                            num_hidden_layers=5,
                            num_attention_heads=4,
                            intermediate_size=37)
        model = FlaxBertModel(config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(
                os.path.join(tmp_dir, "test-model-flax-org"),
                push_to_hub=True,
                use_auth_token=self._token,
                organization="valid_org",
            )

            new_model = FlaxBertModel.from_pretrained(
                "valid_org/test-model-flax-org")

            base_params = flatten_dict(unfreeze(model.params))
            new_params = flatten_dict(unfreeze(new_model.params))

            for key in base_params.keys():
                max_diff = (base_params[key] - new_params[key]).sum().item()
                self.assertLessEqual(max_diff,
                                     1e-3,
                                     msg=f"{key} not identical")
コード例 #8
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        encoder_input_shape, decoder_input_shape = input_shape

        # init input tensors
        input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
        if not decoder_batch_size == batch_size:
            raise ValueError(
                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
            )
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(decoder_sequence_length)[None, :],
            (decoder_batch_size, decoder_sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #9
0
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
    ):
        if config is None:
            raise ValueError("config cannot be None")

        if module is None:
            raise ValueError("module cannot be None")

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
        self.dtype = dtype

        # randomely initialized parameters
        random_params = self.init(self.key, input_shape)

        # save required_params as set
        self._required_params = set(
            flatten_dict(unfreeze(random_params)).keys())
        self.params = random_params
コード例 #10
0
 def params(self, params: Union[Dict, FrozenDict]):
     if isinstance(params, FrozenDict):
         params = unfreeze(params)
     param_keys = set(flatten_dict(params).keys())
     if len(self.required_params - param_keys) > 0:
         raise ValueError(
             "Some parameters are missing. Make sure that `params` include the following "
             f"parameters {self.required_params - param_keys}")
     self._params = freeze(params)
コード例 #11
0
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
    ):
        if config is None:
            raise ValueError("config cannot be None")

        if module is None:
            raise ValueError("module cannot be None")

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
        self.dtype = dtype
        self.input_shape = input_shape

        # To check if the model was intialized automatically.
        self._is_initialized = _do_init

        if _do_init:
            # randomly initialized parameters
            random_params = self.init_weights(self.key, input_shape)
            params_shape_tree = jax.eval_shape(lambda params: params,
                                               random_params)
        else:
            init_fn = partial(self.init_weights, input_shape=input_shape)
            params_shape_tree = jax.eval_shape(init_fn, self.key)

            logger.info(
                "Model weights are not initialized as `_do_init` is set to `False`. "
                f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
            )

        # get the shape of the parameters
        self._params_shape_tree = params_shape_tree

        # save required_params as set
        self._required_params = set(
            flatten_dict(unfreeze(params_shape_tree)).keys())

        # initialize the parameters
        if _do_init:
            self.params = random_params
コード例 #12
0
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
    # convert pytorch tensor to numpy
    pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

    random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
    flax_state_dict = {}

    remove_base_model_prefix = (
        flax_model.base_model_prefix
        not in flax_model.params) and (flax_model.base_model_prefix in set(
            [k.split(".")[0] for k in pt_state_dict.keys()]))
    add_base_model_prefix = (
        flax_model.base_model_prefix
        in flax_model.params) and (flax_model.base_model_prefix not in set(
            [k.split(".")[0] for k in pt_state_dict.keys()]))

    # Need to change some parameters name to match Flax names so that we don't have to fork any layer
    for pt_key, pt_tensor in pt_state_dict.items():

        pt_tuple_key = tuple(pt_key.split("."))

        has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
        require_base_model_prefix = (flax_model.base_model_prefix,
                                     ) + pt_tuple_key in random_flax_state_dict

        if remove_base_model_prefix and has_base_model_prefix:
            pt_tuple_key = pt_tuple_key[1:]
        elif add_base_model_prefix and require_base_model_prefix:
            pt_tuple_key = (flax_model.base_model_prefix, ) + pt_tuple_key

        if pt_tuple_key[
                -1] == "weight" and pt_tuple_key not in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("kernel", )
            pt_tensor = pt_tensor.T
        elif pt_tuple_key[-1] == "gamma":
            pt_tuple_key = pt_tuple_key[:-1] + ("weight", )
        elif pt_tuple_key[-1] == "beta":
            pt_tuple_key = pt_tuple_key[:-1] + ("bias", )

        if pt_tuple_key in random_flax_state_dict:
            if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
                raise ValueError(
                    "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
                )

        # add unexpected weight so that warning is thrown
        flax_state_dict[pt_tuple_key] = pt_tensor

    return unflatten_dict(flax_state_dict)
コード例 #13
0
    def params(self, params: Union[Dict, FrozenDict]):
        # don't set params if the model is not initialized
        if not self._is_initialized:
            raise ValueError(
                "`params` cannot be set from model when the model is created with `_do_init=False`. "
                "You store the params outside of the model.")

        if isinstance(params, FrozenDict):
            params = unfreeze(params)
        param_keys = set(flatten_dict(params).keys())
        if len(self.required_params - param_keys) > 0:
            raise ValueError(
                "Some parameters are missing. Make sure that `params` include the following "
                f"parameters {self.required_params - param_keys}")
        self._params = params
コード例 #14
0
    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
        """
        # init input variables to retrieve cache
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return unfreeze(init_variables["cache"])
コード例 #15
0
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        """
        # init input variables to retrieve cache
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
            decoder_input_ids.shape)

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask,
                             decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                position_ids=decoder_position_ids,
                **kwargs,
            )

        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=
            _decoder_forward,  # we only need to call the decoder to init the cache
        )
        return unfreeze(init_variables["cache"])
コード例 #16
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if encoder_hidden_states is not None and encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        # prepare encoder inputs
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
        # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
        # changed by FlaxXGLMAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs
コード例 #17
0
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dval, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xval, gXval, gXcval, yval = Dval
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
#         print(params)

    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

#     --------------------------------------
#    training loss functions

    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        return jnp.linalg.norm(diff_y, axis=0)

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        # jnp.linalg.norm(diff_g_X,axis=0)

        diff_g_X0 = diff_g_X[:, 0, :]
        diff_g_X1 = diff_g_X[:, 1, :]
        l0 = jnp.linalg.norm(diff_g_X0)
        l1 = jnp.linalg.norm(diff_g_X1)
        return jnp.stack([l0, l1])

#     ------

    @jit
    def f_loss(params, rho_g, batch):
        rho_g = jnp.exp(rho_g)
        loss_ad_energy = f_loss_ad_energy(params, batch)
        loss_jac_energy = f_loss_jac(params, batch)
        loss = jnp.vdot(jnp.ones_like(loss_ad_energy),
                        loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy)
        return loss
#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, rho_g, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, rho_g, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, (loss, grad)

#     @jit

    def train(rho_g, nn_params):
        optimizer = optim.Adam(learning_rate=lr,
                               weight_decay=w_decay).create(nn_params)
        optimizer = jax.device_put(optimizer)

        train_loss = []
        loss0 = 1E16
        loss0_tot = 1E16
        itercount = itertools.count()
        f_params = init_params
        for epoch in range(n_epochs):
            for _ in range(n_batches):
                optimizer, loss_and_grad = train_step(optimizer, rho_g,
                                                      next(batches))
                loss, grad = loss_and_grad

#             f = open(f_out,'a+')
#             print(i,loss,file=f)
#             f.close()

            train_loss.append(loss)
#             params = optimizer.target
#             loss_tot = f_validation(params)

        nn_params = optimizer.target

        return nn_params, loss_and_grad, train_loss

    @jit
    def val_step(optimizer, nn_params):  #, learning_rate_fn, model

        rho_g_prev = optimizer.target
        nn_params, loss_and_grad_train, train_loss_iter = train(
            rho_g_prev, nn_params)
        loss_train, grad_loss_train = loss_and_grad_train

        grad_fn_val = jax.value_and_grad(f_loss, argnums=1)
        loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval)
        optimizer = optimizer.apply_gradient(
            grad_val)  #, {"learning_rate": lr}
        return optimizer, nn_params, (loss_val, loss_train,
                                      train_loss_iter), (grad_loss_train,
                                                         grad_val)

#     Initilialize rho_G

    rng = random.PRNGKey(0)
    rng, subkey = jax.random.split(rng)

    rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025)
    rho_G0 = jnp.log(rho_G0)
    print('Initial lambdas', rho_G0)
    init_G = rho_G0  #

    optimizer_out = optim.Adam(learning_rate=2E-4,
                               weight_decay=0.).create(init_G)
    optimizer_out = jax.device_put(optimizer_out)

    f_params = init_params

    for i in range(50000):
        start_va_time = time.time()
        optimizer_out, f_params, loss_all, grad_all = val_step(
            optimizer_out, f_params)

        rho_g = optimizer_out.target
        loss_val, loss_train, train_loss_iter = loss_all
        grad_loss_train, grad_val = grad_all

        loss0_tot = f_loss(f_params, rho_g, Dt)

        dict_output = serialization.to_state_dict(f_params)
        jnp.save(f_w_nn, dict_output)  #unfreeze()

        f = open(f_out, 'a+')
        #         print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f)
        print(i, loss_val, loss_train, (time.time() - start_va_time), file=f)
        print(jnp.exp(rho_g), file=f)
        print(grad_val, file=f)
        #         print(train_loss_iter ,file=f)
        #         print(grad_val,file=f)
        #         print(grad_loss_train,file=f)
        f.close()


#     --------------------------------------
#     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
コード例 #18
0
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example::

            >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
            >>> import jax.numpy as jnp

            >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
            >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')

            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

            >>> text = "My friends are cool but they eat too many carbs."
            >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors='np')
            >>> encoder_outputs = model.encode(input_ids)

            >>> decoder_start_token_id = model.config.decoder.bos_token_id
            >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

            >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
            >>> logits = outputs.logits

        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0]
        if encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask,
                             decoder_position_ids, encoder_hidden_states,
                             **kwargs):

            projection_module = module._get_projection_module()
            decoder_module = module._get_decoder_module()

            # optionally project encoder_hidden_states
            if projection_module is not None:
                encoder_hidden_states = projection_module(
                    encoder_hidden_states)

            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                encoder_hidden_states,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask,
                                             dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:]

        return outputs
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

    '''
#     WRONG
    @jit
    def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension
        temp = jnp.dot(gH_diab,eigvect_[:,0])
        return jnp.vdot(eigvect_[:,1],temp)
    @jit
    def f_nac_coup(params,x):
        eigval_, eigvect_ = f_adiab(params,x)
        gy_diab = jac_nn_diab(params,x)
        gy_diab = jnp.reshape(gy_diab.T,(12,2,2))
        g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_)
        return g_coup
    '''

    #     --------------------------------------
    #     Validation loss functions

    @jit
    def f_validation(params):
        y_pred = nn_adiab(params, Xt)
        diff_y = y_pred - yt
        z = jnp.linalg.norm(diff_y)
        return z

    @jit
    def f_jac_validation(params):
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt)
        diff_y = gX_pred - gXt
        z = jnp.linalg.norm(diff_y)
        return z

    '''
    @jit
    def f_nac_validation(params):
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt)
        diff_y = g_nac_coup - gXct
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     --------------------------------------
    #    training loss functions
    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        loss = jnp.linalg.norm(diff_y)
        return loss

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        return jnp.linalg.norm(diff_g_X)

    '''    
    @jit
    def f_loss_nac(params,batch):
        X_inputs, _,gXc_inputs,y_true = batch
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,x)
        diff_y = g_nac_coup - gXc_inputs
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     ------
    @jit
    def f_loss(params, batch):
        loss_ad_energy = f_loss_ad_energy(params, batch)
        #         loss_jac_energy = f_loss_jac(params,batch)
        loss = loss_ad_energy  #+ rho_g*loss_jac_energy
        return loss


#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, loss

    optimizer = optim.Adam(learning_rate=lr,
                           weight_decay=w_decay).create(init_params)
    optimizer = jax.device_put(optimizer)

    loss0 = 1E16
    loss0_tot = 1E16
    itercount = itertools.count()
    f_params = init_params
    for epoch in range(n_epochs):
        for _ in range(n_batches):
            optimizer, loss = train_step(optimizer, next(batches))

        params = optimizer.target
        loss_tot = f_validation(params)

        if epoch % 10 == 0:
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss < loss0:
            loss0 = loss
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss_tot < loss0_tot:
            loss0_tot = loss_tot
            f_params = params
            dict_output = serialization.to_state_dict(params)
            jnp.save(f_w_nn, dict_output)  #unfreeze()

    f = open(f_out, 'a+')
    print('---------------------------------', file=f)
    print('Training time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
コード例 #20
0
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
                      case, ``from_pt`` should be set to :obj:`True`.
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `model id` string of a pretrained
                      model).
                    - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            >>> from transformers import BertConfig, FlaxBertModel
            >>> # Download model and configuration from huggingface.co and cache.
            >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
            >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/config.json')
            >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
        """
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {
            "file_type": "model",
            "framework": "flax",
            "from_auto_class": from_auto_class
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
                        f"{pretrained_model_name_or_path} or `from_pt` set to False"
                    )
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # init random models
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(
                model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
                except UnpicklingError:
                    raise EnvironmentError(
                        f"Unable to convert {archive_file} to Flax deserializable object. "
                    )

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # flatten dicts
        state = flatten_dict(state)
        random_state = flatten_dict(unfreeze(model.params))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        else:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )

        # set correct parameters
        model.params = unflatten_dict(state)
        return model
コード例 #21
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        params: dict = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        dropout_rng: PRNGKey = None,
        deterministic: bool = True,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        if position_ids is None:
            position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
        # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
        # changed by FlaxOPTAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs
コード例 #22
0
ファイル: attention_simple.py プロジェクト: voicedm/flax
        attn = Attn(attn_module=self.attn_module,
                    qkv_features=qkv_features // self.num_heads,
                    out_features=out_features)

        # evaluate multi-headed-attention.
        y = attn(inputs_q, inputs_kv, bias)
        return y.mean(axis=-2)


# run it.

if __name__ == '__main__':

    inputs = jnp.ones((8, 97, 256))
    rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
    model = MultiHeadDotProductAttention(
        broadcast_dropout=False,
        qkv_features=256,
        out_features=256,
        attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1),
        num_heads=8,
        batch_axes=(0, ),
    )

    y, params = model.init_with_output(rngs, inputs, inputs)

    print('input shape: ', inputs.shape)
    print('parameter shapes:')
    pprint(jax.tree_map(jnp.shape, unfreeze(params)))
    print('output shape: ', y.shape)
コード例 #23
0
def block_hessians(params,
                   loss_fn,
                   param_partition_fn,
                   batches_gen,
                   rng_key,
                   num_lanczos_steps=20):
    """Computes the loss hessian with respect to subsets of model parameters.

  Subsets are determined by the param_partition_fn, which maps the flattened
  model parameters to a dict mapping keys to subset of the flattened model
  parameters. For example, if the flattened param tree is

  {('a', 'b'): 1.0,
  ('a', 'c'): 1.0,
  ('d', 'b'): 2.0}

  And we partition on the outer key (see partition_tree.outer_key), then
  the output is
  {'a': {('a', 'b'): 1.0, ('a', 'c'): 1.0}
  'd': {('d', 'b'): 2.0}}

  Args:
    params: Replicated pytree of model parameters.
    loss_fn: non pmapped function with API
       loss_fn(unreplicated_params, unreplicated_batch) -> scalar loss.
    param_partition_fn: Maps a flattened pytree to a partitioned dict
       as described above.
    batches_gen: Should yield replicated batch so we can call
      jax.pmap(loss_fn)(params, batch).
    rng_key: Unreplicated jax PRNGKey.
    num_lanczos_steps: How many lanczos iterates to do.

  Returns:
    A dictionary of results, with a key for each partition of
      the model parameters (as determined by param_partition_fn). The key
      maps to another dict with the following key value pairs:
      max_eig_hess -> The hessian max eigenvalue with respect to the sub params.
      tridiag_hess -> The tridiagonal matrix output by lanczos.
      param_names -> The flattened parameter names in this partitian.
  """
    unrep_params = flax.jax_utils.unreplicate(params)
    flat_dict = flax.traverse_util.flatten_dict(unrep_params)

    sub_param_groups = param_partition_fn(flat_dict)

    sub_results = {}

    # I believe this will basically store a copy of the unreplicated
    # parameters in the function definition, which may be costly when
    # len(sub_param_groups) is large? Unclear how garbage colllection will handle
    # this in jax.
    def sub_loss(sub_params, batch_rng):
        new_dict = flat_dict.copy()
        for tup in sub_params:
            new_dict[tup] = sub_params[tup]

        new_params = flax.traverse_util.unflatten_dict(new_dict)
        return loss_fn(new_params, batch_rng)

    for key in sub_param_groups:

        logging.info('Block Hessian eval on %s', key)
        sub_params = unfreeze(jax_utils.replicate(sub_param_groups[key]))

        hvp_fn, _, n_params = hessian_computation.get_hvp_fn(sub_loss,
                                                             sub_params,
                                                             batches_gen,
                                                             use_pmap=True)

        # This was needed to avoid the lint [cell-var-from-loop] error. Not sure
        # it's needed but to avoid any python weirdness with defining functions in
        # loop went ahead and implemented this.
        hvp_cl = functools.partial(hvp_fn, unfreeze(sub_params))
        row = {}

        row['tridiag_hess'], _ = lanczos.lanczos_np(hvp_cl,
                                                    n_params,
                                                    num_lanczos_steps,
                                                    0,
                                                    rng_key,
                                                    verbose=True)
        evs = np.linalg.eigvalsh(row['tridiag_hess'])
        row['max_eig_hess'] = np.max(evs)
        row['param_names'] = [list(sub_param_groups[key].keys())]

        # The flattened keys are tuples, which doesn't work with the serialization,
        # so to save this dict the keys need to be strings.
        sub_results[str(key)] = row

    return sub_results
コード例 #24
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained config and tokenizer
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models
    # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions
    # GPT2's vocab is odd, we need to resize it for fine-tuning
    model = FlaxAutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
    )

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    optimizer = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    def get_initial_state(params):
        state = optimizer.init(params)
        return tuple(state), params

    # Get PartitionSpec for model params
    param_spec = set_partitions(unfreeze(model.params))

    # Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
    params_shapes = jax.tree_map(lambda x: x.shape, model.params)
    state_shapes = jax.eval_shape(get_initial_state, params_shapes)

    # get PartitionSpec for opt_state, this is very specific to adamw
    # TODO: optax returns different state for different optimizers, how can we handle this generically ?
    # or maybe we don't since in our examples we just use adamw or adafactor
    def get_opt_spec(x):
        if isinstance(x, dict):
            return param_spec
        return None

    opt_state_spec, param_spec = jax.tree_map(
        get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
    )

    # pjit the get_initial_state function to shard params and init
    # optimizer state in sharded way
    p_get_initial_state = pjit(
        get_initial_state,
        in_axis_resources=None,
        out_axis_resources=(opt_state_spec, param_spec),
    )

    # hack: move the inital params to CPU to free up device memory
    # TODO: allow loading weights on CPU in pre-trained model
    model.params = jax.tree_map(lambda x: np.asarray(x), model.params)

    # mesh defination
    mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())

    # actually initialize the opt_state
    with mesh(mesh_devices, ("dp", "mp")):
        opt_state, params = p_get_initial_state(freeze(model.params))

    # cross-entropy with z loss
    def loss_fn(logits, labels, z_loss=0):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]

        shift_labels = onehot(shift_labels, shift_logits.shape[-1])

        shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
        log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
        log_softmax = shift_logits - log_z
        loss = -jnp.sum(shift_labels * log_softmax, axis=-1)

        loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss

        return loss.mean()

    # Define gradient update step fn
    # TODO: try to use TrainState instead of passing params and opt_state individually
    def train_step(params, opt_state, dropout_rng, batch, step):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = loss_fn(logits, labels, z_loss=1.0)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grads = grad_fn(params)

        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)}
        return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1

    # Define eval fn
    def eval_step(input_ids, labels, params):
        logits = model(input_ids=input_ids, params=params, train=False)[0]
        loss = loss_fn(logits, labels)
        # metrics
        return {"loss": loss}

    p_train_step = pjit(
        train_step,
        in_axis_resources=(param_spec, opt_state_spec, None, None, None),
        out_axis_resources=(param_spec, opt_state_spec, None, None, None),
        donate_argnums=(0, 1),
    )

    p_eval_step = pjit(
        eval_step,
        in_axis_resources=(None, None, param_spec),
        out_axis_resources=None,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    global_step = 0
    # we are not doing 2D parallelism (yet!), this just does model parallelism
    with mesh(mesh_devices, ("dp", "mp")):
        for _ in epochs:
            # ======================== Training ================================
            train_start = time.time()

            # Create sampling rng
            rng, input_rng = jax.random.split(rng)

            # Generate an epoch by shuffling sampling indices from the train dataset
            train_metrics = []
            train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
            steps_per_epoch = len(train_dataset) // train_batch_size

            # train
            for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
                batch = next(train_loader)
                params, opt_state, dropout_rng, train_metric, global_step = p_train_step(
                    params,
                    opt_state,
                    dropout_rng,
                    batch,
                    global_step,
                )
                train_metrics.append(train_metric)

                cur_step = global_step

                if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                    # Save metrics
                    train_time += time.time() - train_start
                    if has_tensorboard and jax.process_index() == 0:
                        write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                    epochs.write(
                        f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                    )

                    train_metrics = []

                if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                    # ======================== Evaluating ==============================
                    eval_metrics = []
                    eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
                    eval_steps = len(eval_dataset) // eval_batch_size

                    for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
                        batch = next(eval_loader)
                        metrics = p_eval_step(batch["input_ids"], batch["labels"], params)
                        eval_metrics.append(metrics)

                    # normalize eval metrics
                    eval_metrics = stack_forest(eval_metrics)
                    eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                    try:
                        eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                    except OverflowError:
                        eval_metrics["perplexity"] = float("inf")

                    logger.info(
                        f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
                    )

                if cur_step % training_args.save_steps == 0 and cur_step > 0:
                    # save checkpoint after each epoch and push checkpoint to the hub
                    if jax.process_index() == 0:
                        params = jax.device_get(params)
                        model.save_pretrained(
                            training_args.output_dir,
                            params=params,
                            push_to_hub=training_args.push_to_hub,
                            commit_message=f"Saving weights and logs of step {cur_step}",
                        )
コード例 #25
0
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
        >>> import jax.numpy as jnp
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
        >>> encoder_outputs = model.encode(pixel_values)

        >>> decoder_start_token_id = model.config.decoder.bos_token_id
        >>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> logits = outputs.logits
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0]

        batch_size, sequence_length = encoder_hidden_states.shape[:2]
        encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask,
                             decoder_position_ids, encoder_hidden_states,
                             **kwargs):

            projection_module = module._get_projection_module()
            decoder_module = module._get_decoder_module()

            # optionally project encoder_hidden_states
            if projection_module is not None:
                encoder_hidden_states = projection_module(
                    encoder_hidden_states)

            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                encoder_hidden_states,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask,
                                             dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:]

        return outputs
コード例 #26
0
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing model weights saved using
                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
                      `from_pt` should be set to `True`.
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
                specified all the computation will be performed with the given `dtype`.

                **Note that this only specifies the dtype of the computation and does not influence the dtype of model
                parameters.**

                If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
                [`~FlaxPreTrainedModel.to_bf16`].
            model_args (sequence of positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
                Can be either:

                    - an instance of a class derived from [`PretrainedConfig`],
                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].

                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
                      model).
                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
                      save directory.
                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
                      configuration JSON file named *config.json* is found in the directory.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (`bool`, *optional*, defaults to `False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                `pretrained_model_name_or_path` argument).
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(`bool`, *optional*, defaults to `False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision (`str`, *optional*, defaults to `"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
                automatically loaded:

                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
                      corresponds to a configuration attribute will be used to override said attribute with the
                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
                      will be passed to the underlying model's `__init__` function.

        Examples:

        ```python
        >>> from transformers import BertConfig, FlaxBertModel

        >>> # Download model and configuration from huggingface.co and cache.
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
        >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
        >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
        >>> config = BertConfig.from_json_file("./pt_model/config.json")
        >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
        ```"""
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)
        _do_init = kwargs.pop("_do_init", True)

        user_agent = {
            "file_type": "model",
            "framework": "flax",
            "from_auto_class": from_auto_class
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                # At this stage we don't have a weight file so we will raise an error.
                elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
                    raise EnvironmentError(
                        f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
                        "weights.")
                else:
                    raise EnvironmentError(
                        f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
                        f"{pretrained_model_name_or_path}.")
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=filename,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )

            except RepositoryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
                    "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
                    "login` and pass `use_auth_token=True`.")
            except RevisionNotFoundError:
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
            except EntryNotFoundError:
                if filename == FLAX_WEIGHTS_NAME:
                    has_file_kwargs = {
                        "revision": revision,
                        "proxies": proxies,
                        "use_auth_token": use_auth_token
                    }
                    if has_file(pretrained_model_name_or_path, WEIGHTS_NAME,
                                **has_file_kwargs):
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
                            "those weights.")
                    else:
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            f"or {WEIGHTS_NAME}.")
                else:
                    raise EnvironmentError(
                        f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
                    )
            except HTTPError as err:
                raise EnvironmentError(
                    f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
                    f"{err}")
            except ValueError:
                raise EnvironmentError(
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached "
                    f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
                    "Checkout your internet connection or see how to run the library in offline mode at "
                    "'https://huggingface.co/docs/transformers/installation#offline-mode'."
                )
            except EnvironmentError:
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
                )

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # init random models
        model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(
                model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
                except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                    try:
                        with open(resolved_archive_file) as f:
                            if f.read().startswith("version"):
                                raise OSError(
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                                    "you cloned.")
                            else:
                                raise ValueError from e
                    except (UnicodeDecodeError, ValueError):
                        raise EnvironmentError(
                            f"Unable to convert {archive_file} to Flax deserializable object. "
                        )
            # make sure all arrays are stored as jnp.arrays
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
            if _do_init:
                state = jax.tree_util.tree_map(jnp.array, state)
            else:
                # keep the params on CPU if we don't want to initialize
                state = jax.tree_util.tree_map(
                    lambda x: jax.device_put(x,
                                             jax.devices("cpu")[0]), state)

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params_shape_tree) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # if model is head model and we are loading weights from base model
        # we initialize new params dict with base_model_prefix
        if cls.base_model_prefix in dict(
                model.params_shape_tree
        ) and cls.base_model_prefix not in state:
            state = {cls.base_model_prefix: state}

        # flatten dicts
        state = flatten_dict(state)

        random_state = flatten_dict(
            unfreeze(model.params if _do_init else model.params_shape_tree))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        if missing_keys and not _do_init:
            logger.warn(
                f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
                f"Make sure to call model.init_weights to initialize the missing weights."
            )
            cls._missing_keys = missing_keys

        # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
        # matching the weights in the model.
        mismatched_keys = []
        for key in state.keys():
            if key in random_state and state[key].shape != random_state[
                    key].shape:
                if ignore_mismatched_sizes:
                    mismatched_keys.append(
                        (key, state[key].shape, random_state[key].shape))
                    state[key] = random_state[key]
                else:
                    raise ValueError(
                        f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                        f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
                        "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
                        "model.")

        # add missing keys as random parameters if we are initializing
        if missing_keys and _do_init:
            for missing_key in missing_keys:
                state[missing_key] = random_state[missing_key]

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        elif len(mismatched_keys) == 0:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join([
                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                for key, shape1, shape2 in mismatched_keys
            ])
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )

        # dictionary of key: dtypes for the model params
        param_dtypes = jax.tree_map(lambda x: x.dtype, state)
        # extract keys of parameters not in jnp.float32
        fp16_params = [
            k for k in param_dtypes if param_dtypes[k] == jnp.float16
        ]
        bf16_params = [
            k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16
        ]

        # raise a warning if any of the parameters are not in jnp.float32
        if len(fp16_params) > 0:
            logger.warning(
                f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
                f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
                "You should probably UPCAST the model weights to float32 if this was not intended. "
                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
            )

        if len(bf16_params) > 0:
            logger.warning(
                f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
                f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
                "You should probably UPCAST the model weights to float32 if this was not intended. "
                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
            )

        if _do_init:
            # set correct parameters
            model.params = unflatten_dict(state)
            return model
        else:
            return model, unflatten_dict(state)
コード例 #27
0
    y = data[:, 0]

    X = sm.add_constant(X)
    glm_binom = sm.GLM(y, X, family=sm.families.Binomial())
    results = glm_binom.fit()
    mu = jnp.array(results.params)

    model = LogisticRegressor()
    init_key, key = split(key)
    variables = model.init(init_key, X)
    output = model.apply(variables, X)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)

    variables = unfreeze(variables)
    variables['params']['Dense_0']['kernel'] = mu.reshape((-1, 1))
    variables = freeze(variables)

    alpha = 1.
    nfeatures = tree_map(lambda x: x.shape[0], variables)
    loglikelihood_fn, logprior_fn = make_fns_for_posterior(model.apply, alpha)

    lambda_best, avg_lower_bounds = ffvb.vb_gauss_chol(key,
                                                       loglikelihood_fn,
                                                       logprior_fn, (X, y),
                                                       optimizer,
                                                       variables,
                                                       lower_triangular=None,
                                                       num_samples=20,
                                                       window_size=10,
コード例 #28
0
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`."
                )

            position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs
コード例 #29
0
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None and input_ids.shape[-1] == 1:
                # if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
                cache_shift = flatten_dict(
                    unfreeze(past_key_values))[self._attn_layer_name +
                                               ("cache_index", )]
                position_ids = jnp.broadcast_to(
                    jnp.arange(self.config.max_position_embeddings)[None, :],
                    (batch_size, self.config.max_position_embeddings),
                )
                position_ids = lax.dynamic_slice(position_ids,
                                                 (0, cache_shift),
                                                 (batch_size, 1))
            else:
                position_ids = jnp.broadcast_to(
                    jnp.arange(sequence_length)[None, :],
                    (batch_size, sequence_length))

        if attention_mask is None:
            # if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
            if past_key_values is not None:
                cache_length = flatten_dict(
                    unfreeze(past_key_values))[self._attn_layer_name +
                                               ("cached_key", )].shape[1]
            else:
                cache_length = sequence_length

            # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation
            attention_mask = jnp.ones((batch_size, cache_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs