def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) params_rng, dropout_rng = jax.random.split(rng) dropout_rng, droppath_rng = jax.random.split(dropout_rng) rngs = { "params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng } random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) attention_mask = jnp.ones_like(input_ids) head_mask = jnp.ones( (self.config.num_hidden_layers, self.config.num_attention_heads)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def set_partitions(in_dict): rules = _get_partition_rules() replace = _replacement_rules(rules) initd = {k: _unmatched for k in flatten_dict(in_dict)} result = {k: replace(k, v) for k, v in initd.items()} assert _unmatched not in result.values(), "Incomplete partition spec." return freeze(unflatten_dict(result))
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, return_dict=False, ) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def params(self, params: Union[Dict, FrozenDict]): if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}") self._params = freeze(params)
def sparse_init(loss_fn, flax_module, params, hps, input_shape, output_shape, rng_key, metrics_logger=None, log_every=10): """Implements SparseInit initializer. Args: loss_fn: Loss function. flax_module: Flax nn.Module class. params: The dict of model parameters. hps: HParam object. Required hparams are meta_learning_rate, meta_batch_size, meta_steps, and epsilon. input_shape: Must agree with batch[0].shape[1:]. output_shape: Must agree with batch[1].shape[1:]. rng_key: jax.PRNGKey, used to seed all randomness. metrics_logger: Instance of utils.MetricsLogger log_every: Print meta loss every k steps. Returns: A Flax model with sparse initialization. """ del flax_module, loss_fn, input_shape, output_shape, rng_key, metrics_logger, log_every params = unfreeze(params) activation_functions = hps.activation_function num_hidden_layers = len(hps.hid_sizes) if isinstance(hps.activation_function, str): activation_functions = [hps.activation_function] * num_hidden_layers for i, key in enumerate(params): num_units, num_weights = params[key]['kernel'].shape mask = np.zeros((num_units, num_weights), dtype=bool) for k in range(num_units): if num_weights >= hps.non_zero_connection_weights: sample = np.random.choice(num_weights, hps.non_zero_connection_weights, replace=False) else: sample = np.random.choice(num_weights, hps.non_zero_connection_weights) mask[k, sample] = True params[key]['kernel'] = params[key]['kernel'].at[~mask].set(0.0) if i < num_hidden_layers and activation_functions[i] == 'tanh': params[key]['bias'] = params[key]['bias'].at[:].set(0.5) else: params[key]['bias'] = params[key]['bias'].at[:].set(0.0) return frozen_dict.freeze(params)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors input_ids = jnp.zeros(encoder_input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape if not decoder_batch_size == batch_size: raise ValueError( f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained config and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions # GPT2's vocab is odd, we need to resize it for fine-tuning model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) def get_initial_state(params): state = optimizer.init(params) return tuple(state), params # Get PartitionSpec for model params param_spec = set_partitions(unfreeze(model.params)) # Get the PyTree for opt_state, we don't actually initialize the opt_state yet. params_shapes = jax.tree_map(lambda x: x.shape, model.params) state_shapes = jax.eval_shape(get_initial_state, params_shapes) # get PartitionSpec for opt_state, this is very specific to adamw # TODO: optax returns different state for different optimizers, how can we handle this generically ? # or maybe we don't since in our examples we just use adamw or adafactor def get_opt_spec(x): if isinstance(x, dict): return param_spec return None opt_state_spec, param_spec = jax.tree_map( get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) ) # pjit the get_initial_state function to shard params and init # optimizer state in sharded way p_get_initial_state = pjit( get_initial_state, in_axis_resources=None, out_axis_resources=(opt_state_spec, param_spec), ) # hack: move the inital params to CPU to free up device memory # TODO: allow loading weights on CPU in pre-trained model model.params = jax.tree_map(lambda x: np.asarray(x), model.params) # mesh defination mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count()) # actually initialize the opt_state with mesh(mesh_devices, ("dp", "mp")): opt_state, params = p_get_initial_state(freeze(model.params)) # cross-entropy with z loss def loss_fn(logits, labels, z_loss=0): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] shift_labels = onehot(shift_labels, shift_logits.shape[-1]) shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True)) log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True)) log_softmax = shift_logits - log_z loss = -jnp.sum(shift_labels * log_softmax, axis=-1) loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss return loss.mean() # Define gradient update step fn # TODO: try to use TrainState instead of passing params and opt_state individually def train_step(params, opt_state, dropout_rng, batch, step): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels, z_loss=1.0) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grads = grad_fn(params) updates, new_opt_state = optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)} return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1 # Define eval fn def eval_step(input_ids, labels, params): logits = model(input_ids=input_ids, params=params, train=False)[0] loss = loss_fn(logits, labels) # metrics return {"loss": loss} p_train_step = pjit( train_step, in_axis_resources=(param_spec, opt_state_spec, None, None, None), out_axis_resources=(param_spec, opt_state_spec, None, None, None), donate_argnums=(0, 1), ) p_eval_step = pjit( eval_step, in_axis_resources=(None, None, param_spec), out_axis_resources=None, ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) global_step = 0 # we are not doing 2D parallelism (yet!), this just does model parallelism with mesh(mesh_devices, ("dp", "mp")): for _ in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_metrics = [] train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) params, opt_state, dropout_rng, train_metric, global_step = p_train_step( params, opt_state, dropout_rng, batch, global_step, ) train_metrics.append(train_metric) cur_step = global_step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): batch = next(eval_loader) metrics = p_eval_step(batch["input_ids"], batch["labels"], params) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = stack_forest(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") logger.info( f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}" ) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(params) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) ''' # WRONG @jit def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension temp = jnp.dot(gH_diab,eigvect_[:,0]) return jnp.vdot(eigvect_[:,1],temp) @jit def f_nac_coup(params,x): eigval_, eigvect_ = f_adiab(params,x) gy_diab = jac_nn_diab(params,x) gy_diab = jnp.reshape(gy_diab.T,(12,2,2)) g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_) return g_coup ''' # -------------------------------------- # Validation loss functions @jit def f_validation(params): y_pred = nn_adiab(params, Xt) diff_y = y_pred - yt z = jnp.linalg.norm(diff_y) return z @jit def f_jac_validation(params): gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt) diff_y = gX_pred - gXt z = jnp.linalg.norm(diff_y) return z ''' @jit def f_nac_validation(params): g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt) diff_y = g_nac_coup - gXct z = jnp.linalg.norm(diff_y) return z ''' # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* loss = jnp.linalg.norm(diff_y) return loss @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs return jnp.linalg.norm(diff_g_X) ''' @jit def f_loss_nac(params,batch): X_inputs, _,gXc_inputs,y_true = batch g_nac_coup = vmap(f_nac_coup,(None,0))(params,x) diff_y = g_nac_coup - gXc_inputs z = jnp.linalg.norm(diff_y) return z ''' # ------ @jit def f_loss(params, batch): loss_ad_energy = f_loss_ad_energy(params, batch) # loss_jac_energy = f_loss_jac(params,batch) loss = loss_ad_energy #+ rho_g*loss_jac_energy return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, loss optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(init_params) optimizer = jax.device_put(optimizer) loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss = train_step(optimizer, next(batches)) params = optimizer.target loss_tot = f_validation(params) if epoch % 10 == 0: f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss < loss0: loss0 = loss f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss_tot < loss0_tot: loss0_tot = loss_tot f_params = params dict_output = serialization.to_state_dict(params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') print('---------------------------------', file=f) print('Training time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dval, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xval, gXval, gXcval, yval = Dval Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) # print(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* return jnp.linalg.norm(diff_y, axis=0) @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs # jnp.linalg.norm(diff_g_X,axis=0) diff_g_X0 = diff_g_X[:, 0, :] diff_g_X1 = diff_g_X[:, 1, :] l0 = jnp.linalg.norm(diff_g_X0) l1 = jnp.linalg.norm(diff_g_X1) return jnp.stack([l0, l1]) # ------ @jit def f_loss(params, rho_g, batch): rho_g = jnp.exp(rho_g) loss_ad_energy = f_loss_ad_energy(params, batch) loss_jac_energy = f_loss_jac(params, batch) loss = jnp.vdot(jnp.ones_like(loss_ad_energy), loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy) return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, rho_g, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, rho_g, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, (loss, grad) # @jit def train(rho_g, nn_params): optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(nn_params) optimizer = jax.device_put(optimizer) train_loss = [] loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss_and_grad = train_step(optimizer, rho_g, next(batches)) loss, grad = loss_and_grad # f = open(f_out,'a+') # print(i,loss,file=f) # f.close() train_loss.append(loss) # params = optimizer.target # loss_tot = f_validation(params) nn_params = optimizer.target return nn_params, loss_and_grad, train_loss @jit def val_step(optimizer, nn_params): #, learning_rate_fn, model rho_g_prev = optimizer.target nn_params, loss_and_grad_train, train_loss_iter = train( rho_g_prev, nn_params) loss_train, grad_loss_train = loss_and_grad_train grad_fn_val = jax.value_and_grad(f_loss, argnums=1) loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval) optimizer = optimizer.apply_gradient( grad_val) #, {"learning_rate": lr} return optimizer, nn_params, (loss_val, loss_train, train_loss_iter), (grad_loss_train, grad_val) # Initilialize rho_G rng = random.PRNGKey(0) rng, subkey = jax.random.split(rng) rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025) rho_G0 = jnp.log(rho_G0) print('Initial lambdas', rho_G0) init_G = rho_G0 # optimizer_out = optim.Adam(learning_rate=2E-4, weight_decay=0.).create(init_G) optimizer_out = jax.device_put(optimizer_out) f_params = init_params for i in range(50000): start_va_time = time.time() optimizer_out, f_params, loss_all, grad_all = val_step( optimizer_out, f_params) rho_g = optimizer_out.target loss_val, loss_train, train_loss_iter = loss_all grad_loss_train, grad_val = grad_all loss0_tot = f_loss(f_params, rho_g, Dt) dict_output = serialization.to_state_dict(f_params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') # print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f) print(i, loss_val, loss_train, (time.time() - start_va_time), file=f) print(jnp.exp(rho_g), file=f) print(grad_val, file=f) # print(train_loss_iter ,file=f) # print(grad_val,file=f) # print(grad_loss_train,file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
X = sm.add_constant(X) glm_binom = sm.GLM(y, X, family=sm.families.Binomial()) results = glm_binom.fit() mu = jnp.array(results.params) model = LogisticRegressor() init_key, key = split(key) variables = model.init(init_key, X) output = model.apply(variables, X) learning_rate = 1e-3 optimizer = optax.adam(learning_rate) variables = unfreeze(variables) variables['params']['Dense_0']['kernel'] = mu.reshape((-1, 1)) variables = freeze(variables) alpha = 1. nfeatures = tree_map(lambda x: x.shape[0], variables) loglikelihood_fn, logprior_fn = make_fns_for_posterior(model.apply, alpha) lambda_best, avg_lower_bounds = ffvb.vb_gauss_chol(key, loglikelihood_fn, logprior_fn, (X, y), optimizer, variables, lower_triangular=None, num_samples=20, window_size=10, niters=150, eps=0.1)