コード例 #1
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         position_ids,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #2
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        params_rng, dropout_rng = jax.random.split(rng)
        dropout_rng, droppath_rng = jax.random.split(dropout_rng)
        rngs = {
            "params": params_rng,
            "dropout": dropout_rng,
            "droppath": droppath_rng
        }

        random_params = self.module.init(rngs, pixel_values,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #3
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        attention_mask = jnp.ones_like(input_ids)
        head_mask = jnp.ones(
            (self.config.num_hidden_layers, self.config.num_attention_heads))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         token_type_ids,
                                         head_mask,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #4
0
def set_partitions(in_dict):
    rules = _get_partition_rules()
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    assert _unmatched not in result.values(), "Incomplete partition spec."
    return freeze(unflatten_dict(result))
コード例 #5
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            return_dict=False,
        )

        random_params = module_init_outputs["params"]
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #6
0
 def params(self, params: Union[Dict, FrozenDict]):
     if isinstance(params, FrozenDict):
         params = unfreeze(params)
     param_keys = set(flatten_dict(params).keys())
     if len(self.required_params - param_keys) > 0:
         raise ValueError(
             "Some parameters are missing. Make sure that `params` include the following "
             f"parameters {self.required_params - param_keys}")
     self._params = freeze(params)
コード例 #7
0
def sparse_init(loss_fn,
                flax_module,
                params,
                hps,
                input_shape,
                output_shape,
                rng_key,
                metrics_logger=None,
                log_every=10):
    """Implements SparseInit initializer.

  Args:
    loss_fn: Loss function.
    flax_module: Flax nn.Module class.
    params: The dict of model parameters.
    hps: HParam object. Required hparams are meta_learning_rate,
      meta_batch_size, meta_steps, and epsilon.
    input_shape: Must agree with batch[0].shape[1:].
    output_shape: Must agree with batch[1].shape[1:].
    rng_key: jax.PRNGKey, used to seed all randomness.
    metrics_logger: Instance of utils.MetricsLogger
    log_every: Print meta loss every k steps.

  Returns:
    A Flax model with sparse initialization.
  """

    del flax_module, loss_fn, input_shape, output_shape, rng_key, metrics_logger, log_every

    params = unfreeze(params)
    activation_functions = hps.activation_function
    num_hidden_layers = len(hps.hid_sizes)
    if isinstance(hps.activation_function, str):
        activation_functions = [hps.activation_function] * num_hidden_layers
    for i, key in enumerate(params):
        num_units, num_weights = params[key]['kernel'].shape
        mask = np.zeros((num_units, num_weights), dtype=bool)
        for k in range(num_units):
            if num_weights >= hps.non_zero_connection_weights:
                sample = np.random.choice(num_weights,
                                          hps.non_zero_connection_weights,
                                          replace=False)
            else:
                sample = np.random.choice(num_weights,
                                          hps.non_zero_connection_weights)
            mask[k, sample] = True
        params[key]['kernel'] = params[key]['kernel'].at[~mask].set(0.0)
        if i < num_hidden_layers and activation_functions[i] == 'tanh':
            params[key]['bias'] = params[key]['bias'].at[:].set(0.5)
        else:
            params[key]['bias'] = params[key]['bias'].at[:].set(0.0)
    return frozen_dict.freeze(params)
コード例 #8
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        encoder_input_shape, decoder_input_shape = input_shape

        # init input tensors
        input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
        if not decoder_batch_size == batch_size:
            raise ValueError(
                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
            )
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(decoder_sequence_length)[None, :],
            (decoder_batch_size, decoder_sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
コード例 #9
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained config and tokenizer
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models
    # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions
    # GPT2's vocab is odd, we need to resize it for fine-tuning
    model = FlaxAutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
    )

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    optimizer = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    def get_initial_state(params):
        state = optimizer.init(params)
        return tuple(state), params

    # Get PartitionSpec for model params
    param_spec = set_partitions(unfreeze(model.params))

    # Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
    params_shapes = jax.tree_map(lambda x: x.shape, model.params)
    state_shapes = jax.eval_shape(get_initial_state, params_shapes)

    # get PartitionSpec for opt_state, this is very specific to adamw
    # TODO: optax returns different state for different optimizers, how can we handle this generically ?
    # or maybe we don't since in our examples we just use adamw or adafactor
    def get_opt_spec(x):
        if isinstance(x, dict):
            return param_spec
        return None

    opt_state_spec, param_spec = jax.tree_map(
        get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
    )

    # pjit the get_initial_state function to shard params and init
    # optimizer state in sharded way
    p_get_initial_state = pjit(
        get_initial_state,
        in_axis_resources=None,
        out_axis_resources=(opt_state_spec, param_spec),
    )

    # hack: move the inital params to CPU to free up device memory
    # TODO: allow loading weights on CPU in pre-trained model
    model.params = jax.tree_map(lambda x: np.asarray(x), model.params)

    # mesh defination
    mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())

    # actually initialize the opt_state
    with mesh(mesh_devices, ("dp", "mp")):
        opt_state, params = p_get_initial_state(freeze(model.params))

    # cross-entropy with z loss
    def loss_fn(logits, labels, z_loss=0):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]

        shift_labels = onehot(shift_labels, shift_logits.shape[-1])

        shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
        log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
        log_softmax = shift_logits - log_z
        loss = -jnp.sum(shift_labels * log_softmax, axis=-1)

        loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss

        return loss.mean()

    # Define gradient update step fn
    # TODO: try to use TrainState instead of passing params and opt_state individually
    def train_step(params, opt_state, dropout_rng, batch, step):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = loss_fn(logits, labels, z_loss=1.0)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grads = grad_fn(params)

        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)}
        return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1

    # Define eval fn
    def eval_step(input_ids, labels, params):
        logits = model(input_ids=input_ids, params=params, train=False)[0]
        loss = loss_fn(logits, labels)
        # metrics
        return {"loss": loss}

    p_train_step = pjit(
        train_step,
        in_axis_resources=(param_spec, opt_state_spec, None, None, None),
        out_axis_resources=(param_spec, opt_state_spec, None, None, None),
        donate_argnums=(0, 1),
    )

    p_eval_step = pjit(
        eval_step,
        in_axis_resources=(None, None, param_spec),
        out_axis_resources=None,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    global_step = 0
    # we are not doing 2D parallelism (yet!), this just does model parallelism
    with mesh(mesh_devices, ("dp", "mp")):
        for _ in epochs:
            # ======================== Training ================================
            train_start = time.time()

            # Create sampling rng
            rng, input_rng = jax.random.split(rng)

            # Generate an epoch by shuffling sampling indices from the train dataset
            train_metrics = []
            train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
            steps_per_epoch = len(train_dataset) // train_batch_size

            # train
            for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
                batch = next(train_loader)
                params, opt_state, dropout_rng, train_metric, global_step = p_train_step(
                    params,
                    opt_state,
                    dropout_rng,
                    batch,
                    global_step,
                )
                train_metrics.append(train_metric)

                cur_step = global_step

                if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                    # Save metrics
                    train_time += time.time() - train_start
                    if has_tensorboard and jax.process_index() == 0:
                        write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                    epochs.write(
                        f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                    )

                    train_metrics = []

                if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                    # ======================== Evaluating ==============================
                    eval_metrics = []
                    eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
                    eval_steps = len(eval_dataset) // eval_batch_size

                    for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
                        batch = next(eval_loader)
                        metrics = p_eval_step(batch["input_ids"], batch["labels"], params)
                        eval_metrics.append(metrics)

                    # normalize eval metrics
                    eval_metrics = stack_forest(eval_metrics)
                    eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                    try:
                        eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                    except OverflowError:
                        eval_metrics["perplexity"] = float("inf")

                    logger.info(
                        f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
                    )

                if cur_step % training_args.save_steps == 0 and cur_step > 0:
                    # save checkpoint after each epoch and push checkpoint to the hub
                    if jax.process_index() == 0:
                        params = jax.device_get(params)
                        model.save_pretrained(
                            training_args.output_dir,
                            params=params,
                            push_to_hub=training_args.push_to_hub,
                            commit_message=f"Saving weights and logs of step {cur_step}",
                        )
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

    '''
#     WRONG
    @jit
    def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension
        temp = jnp.dot(gH_diab,eigvect_[:,0])
        return jnp.vdot(eigvect_[:,1],temp)
    @jit
    def f_nac_coup(params,x):
        eigval_, eigvect_ = f_adiab(params,x)
        gy_diab = jac_nn_diab(params,x)
        gy_diab = jnp.reshape(gy_diab.T,(12,2,2))
        g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_)
        return g_coup
    '''

    #     --------------------------------------
    #     Validation loss functions

    @jit
    def f_validation(params):
        y_pred = nn_adiab(params, Xt)
        diff_y = y_pred - yt
        z = jnp.linalg.norm(diff_y)
        return z

    @jit
    def f_jac_validation(params):
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt)
        diff_y = gX_pred - gXt
        z = jnp.linalg.norm(diff_y)
        return z

    '''
    @jit
    def f_nac_validation(params):
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt)
        diff_y = g_nac_coup - gXct
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     --------------------------------------
    #    training loss functions
    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        loss = jnp.linalg.norm(diff_y)
        return loss

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        return jnp.linalg.norm(diff_g_X)

    '''    
    @jit
    def f_loss_nac(params,batch):
        X_inputs, _,gXc_inputs,y_true = batch
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,x)
        diff_y = g_nac_coup - gXc_inputs
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     ------
    @jit
    def f_loss(params, batch):
        loss_ad_energy = f_loss_ad_energy(params, batch)
        #         loss_jac_energy = f_loss_jac(params,batch)
        loss = loss_ad_energy  #+ rho_g*loss_jac_energy
        return loss


#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, loss

    optimizer = optim.Adam(learning_rate=lr,
                           weight_decay=w_decay).create(init_params)
    optimizer = jax.device_put(optimizer)

    loss0 = 1E16
    loss0_tot = 1E16
    itercount = itertools.count()
    f_params = init_params
    for epoch in range(n_epochs):
        for _ in range(n_batches):
            optimizer, loss = train_step(optimizer, next(batches))

        params = optimizer.target
        loss_tot = f_validation(params)

        if epoch % 10 == 0:
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss < loss0:
            loss0 = loss
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss_tot < loss0_tot:
            loss0_tot = loss_tot
            f_params = params
            dict_output = serialization.to_state_dict(params)
            jnp.save(f_w_nn, dict_output)  #unfreeze()

    f = open(f_out, 'a+')
    print('---------------------------------', file=f)
    print('Training time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
コード例 #11
0
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dval, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xval, gXval, gXcval, yval = Dval
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
#         print(params)

    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

#     --------------------------------------
#    training loss functions

    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        return jnp.linalg.norm(diff_y, axis=0)

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        # jnp.linalg.norm(diff_g_X,axis=0)

        diff_g_X0 = diff_g_X[:, 0, :]
        diff_g_X1 = diff_g_X[:, 1, :]
        l0 = jnp.linalg.norm(diff_g_X0)
        l1 = jnp.linalg.norm(diff_g_X1)
        return jnp.stack([l0, l1])

#     ------

    @jit
    def f_loss(params, rho_g, batch):
        rho_g = jnp.exp(rho_g)
        loss_ad_energy = f_loss_ad_energy(params, batch)
        loss_jac_energy = f_loss_jac(params, batch)
        loss = jnp.vdot(jnp.ones_like(loss_ad_energy),
                        loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy)
        return loss
#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, rho_g, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, rho_g, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, (loss, grad)

#     @jit

    def train(rho_g, nn_params):
        optimizer = optim.Adam(learning_rate=lr,
                               weight_decay=w_decay).create(nn_params)
        optimizer = jax.device_put(optimizer)

        train_loss = []
        loss0 = 1E16
        loss0_tot = 1E16
        itercount = itertools.count()
        f_params = init_params
        for epoch in range(n_epochs):
            for _ in range(n_batches):
                optimizer, loss_and_grad = train_step(optimizer, rho_g,
                                                      next(batches))
                loss, grad = loss_and_grad

#             f = open(f_out,'a+')
#             print(i,loss,file=f)
#             f.close()

            train_loss.append(loss)
#             params = optimizer.target
#             loss_tot = f_validation(params)

        nn_params = optimizer.target

        return nn_params, loss_and_grad, train_loss

    @jit
    def val_step(optimizer, nn_params):  #, learning_rate_fn, model

        rho_g_prev = optimizer.target
        nn_params, loss_and_grad_train, train_loss_iter = train(
            rho_g_prev, nn_params)
        loss_train, grad_loss_train = loss_and_grad_train

        grad_fn_val = jax.value_and_grad(f_loss, argnums=1)
        loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval)
        optimizer = optimizer.apply_gradient(
            grad_val)  #, {"learning_rate": lr}
        return optimizer, nn_params, (loss_val, loss_train,
                                      train_loss_iter), (grad_loss_train,
                                                         grad_val)

#     Initilialize rho_G

    rng = random.PRNGKey(0)
    rng, subkey = jax.random.split(rng)

    rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025)
    rho_G0 = jnp.log(rho_G0)
    print('Initial lambdas', rho_G0)
    init_G = rho_G0  #

    optimizer_out = optim.Adam(learning_rate=2E-4,
                               weight_decay=0.).create(init_G)
    optimizer_out = jax.device_put(optimizer_out)

    f_params = init_params

    for i in range(50000):
        start_va_time = time.time()
        optimizer_out, f_params, loss_all, grad_all = val_step(
            optimizer_out, f_params)

        rho_g = optimizer_out.target
        loss_val, loss_train, train_loss_iter = loss_all
        grad_loss_train, grad_val = grad_all

        loss0_tot = f_loss(f_params, rho_g, Dt)

        dict_output = serialization.to_state_dict(f_params)
        jnp.save(f_w_nn, dict_output)  #unfreeze()

        f = open(f_out, 'a+')
        #         print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f)
        print(i, loss_val, loss_train, (time.time() - start_va_time), file=f)
        print(jnp.exp(rho_g), file=f)
        print(grad_val, file=f)
        #         print(train_loss_iter ,file=f)
        #         print(grad_val,file=f)
        #         print(grad_loss_train,file=f)
        f.close()


#     --------------------------------------
#     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
コード例 #12
0
    X = sm.add_constant(X)
    glm_binom = sm.GLM(y, X, family=sm.families.Binomial())
    results = glm_binom.fit()
    mu = jnp.array(results.params)

    model = LogisticRegressor()
    init_key, key = split(key)
    variables = model.init(init_key, X)
    output = model.apply(variables, X)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)

    variables = unfreeze(variables)
    variables['params']['Dense_0']['kernel'] = mu.reshape((-1, 1))
    variables = freeze(variables)

    alpha = 1.
    nfeatures = tree_map(lambda x: x.shape[0], variables)
    loglikelihood_fn, logprior_fn = make_fns_for_posterior(model.apply, alpha)

    lambda_best, avg_lower_bounds = ffvb.vb_gauss_chol(key,
                                                       loglikelihood_fn,
                                                       logprior_fn, (X, y),
                                                       optimizer,
                                                       variables,
                                                       lower_triangular=None,
                                                       num_samples=20,
                                                       window_size=10,
                                                       niters=150,
                                                       eps=0.1)