def loss_fn(model, state, batch, prng_key): """Loss function used for training.""" with flax.deprecated.nn.stateful(state) as new_state: with flax.deprecated.nn.stochastic(prng_key): logits, model_penalty, summaries = model(batch['image']) ce_loss = train_functions.cross_entropy_loss(logits, batch['label']) # l2_reg not used when using prior no_prior = (config.kernel_prior == 'none' and config.bias_prior == 'none' and config.kernel_prior == 'none') assert l2_reg == 0 or no_prior, 'Either set priors or l2_reg > 0, not both.' weight_penalty_params = jax.tree_leaves(model.params) weight_l2 = sum( [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = l2_reg * 0.5 * weight_l2 kernels = optim.ModelParamTraversal(lambda n, _: n.endswith('kernel')) kernel_prior_penalty = sum(map(kernel_prior(), kernels.iterate(model))) scales = optim.ModelParamTraversal( lambda n, _: n.endswith('scale') or n.endswith('gamma')) scale_prior_penalty = sum(map(scale_prior(), scales.iterate(model))) biases = optim.ModelParamTraversal( lambda n, _: any(map(n.endswith, ['bias', 'rebias', 'beta']))) bias_prior_penalty = sum(map(bias_prior(), biases.iterate(model))) prior_penalty = kernel_prior_penalty + bias_prior_penalty + scale_prior_penalty prior_penalty *= config.prior_reg loss = (ce_loss + weight_penalty + config.std_penalty_mult * model_penalty + prior_penalty) return loss, (new_state, logits, prior_penalty, model_penalty, weight_penalty, summaries)
def create_optimizer(config, params): if config.optimizer == "adam": optimizer_cls = optim.Adam elif config.optimizer == "lamb": optimizer_cls = optim.LAMB else: raise ValueError("Unsupported value for optimizer: {config.optimizer}") common_kwargs = dict( learning_rate=config.learning_rate, beta1=config.adam_beta1, beta2=config.adam_beta2, eps=config.adam_epsilon, ) optimizer_decay_def = optimizer_cls(weight_decay=config.weight_decay, **common_kwargs) optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs) def exclude_from_decay(path, _): return "bias" in path or "layer_norm" in path or "layernorm" in path decay = optim.ModelParamTraversal( lambda *args: not exclude_from_decay(*args)) no_decay = optim.ModelParamTraversal(exclude_from_decay) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(params) return optimizer
def create_optimizer(model, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam Optimizer with weight decay to all parameters except layer_norm and bias and Adam Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam Optimizer to be used for training """ weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.01) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model return optimizer
def test_multi_optimizer_multiple_matches(self): params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = optim.ModelParamTraversal( lambda path, _: path.endswith('/x') or path.endswith('/y')) t_b = optim.ModelParamTraversal(lambda path, value: value.dtype == jnp. int32 or path.endswith('/z')) optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) with self.assertRaisesRegex( ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"): jax.jit(optimizer_def.init_state)(params)
def create_optimizer(model, model_kwargs, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam/LAMB Optimizer with weight decay to all parameters except layer_norm and bias and Adam/LAMB Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to model_kwargs: Bert model config parameter dictionary. learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam/LAMB Optimizer to be used for training """ if FLAGS.use_lamb: weight_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=FLAGS.lamb_weight_decay, num_layers=model_kwargs['num_layers']) no_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=0.0, num_layers=model_kwargs['num_layers']) else: weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=FLAGS.lamb_weight_decay) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key and 'layernorm' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key or 'layernorm' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) del model return optimizer
def create_optimizer(config, model): common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(model) return optimizer
def rescale_layers(flax_module, layer_rescale_factors): """Rescales the model variables by given multiplicative factors. Args: flax_module: A flax module where params is a nested dictionary. layer_rescale_factors: A dictionary mapping flat keys to a multiplicative rescale factor. The corresponding params in the module pytree will be changed from x -> a * x for rescale factor a. The keys of the dictionary must be of the form described in the flatten_keys documentation. Returns: A new flax module with the corresponding params rescaled. """ all_keys = flatten_dict(flax_module.params).keys() logging.info('All keys:') for key in all_keys: logging.info(key) for key in layer_rescale_factors: if key not in all_keys: raise ValueError('Module does not have key: {}'.format(key)) logging.info('Rescaling %s by factor %f', key, layer_rescale_factors[key]) traversal = optim.ModelParamTraversal(lambda path, _: path == key) flax_module = traversal.update( lambda x: x * layer_rescale_factors[key], flax_module) return flax_module
def test_param_selection(self): params = { 'x': { 'kernel': 1, 'bias': 2, 'y': { 'kernel': 3, 'bias': 4, }, }, } names = [] def filter_fn(name, _): names.append(name) # track names passed to filter_fn for testing return 'kernel' in name model = nn.Model(None, params) traversal = optim.ModelParamTraversal(filter_fn) values = list(traversal.iterate(model)) self.assertEqual(values, [1, 3]) self.assertEqual(set(names), set([ '/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias'])) new_model = traversal.update(lambda x: x + x, model) expected_params = { 'x': { 'kernel': 2, 'bias': 2, 'y': { 'kernel': 6, 'bias': 4, }, }, } expected_model = nn.Model(None, expected_params) self.assertEqual(new_model, expected_model)
def create_optimizer(model, learning_rate, weight_decay, layers=None): """Instantiates Adam multi-optimizer.""" if layers is None: assert ( type(learning_rate) == type(weight_decay) == float ), 'Specify float values for moded learning rate and weight decay!' optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) optimizer = optimizer_def.create(model) else: assert ( len(learning_rate) == len(weight_decay) == len(layers) ), 'Number of specified learning rates, weight decays, and layers must be equal!' optimizers = [] for lr, wd, layer in zip(learning_rate, weight_decay, layers): if lr > 0: opt = optim.Adam(learning_rate=lr, weight_decay=wd) filter_fn = functools.partial(path_inclusion_filter_fn, layer=layer) traversal = optim.ModelParamTraversal(filter_fn) traversal_opt = (traversal, opt) optimizers.append(traversal_opt) optimizer_def = optim.MultiOptimizer(*optimizers) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(config, model, initial_params): """Create a model, starting with a pre-trained checkpoint.""" common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) # TODO(marcvanzee): MultiOptimizer triggers double XLA compilation on TPU so # we use Adam here, but we should investigate why this happens. optimizer_def = optim.Adam(learning_rate=config.learning_rate) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # don't keep a copy of the initial model return optimizer
def create_optimizer(config, model): if config.optimizer == 'adam': optimizer_cls = optim.Adam elif config.optimizer == 'lamb': optimizer_cls = optim.LAMB else: raise ValueError('Unsupported value for optimizer: {config.optimizer}') common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optimizer_cls(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(model) return optimizer
def rescale_layers(params, layer_rescale_factors): """Rescales the model variables by given multiplicative factors. Args: params: a dict of trainable model parameters. layer_rescale_factors: A dictionary mapping flat keys to a multiplicative rescale factor. The corresponding params in the module pytree will be changed from x -> a * x for rescale factor a. The keys of the dictionary must be of the form described in the flatten_keys documentation. Returns: A new flax module with the corresponding params rescaled. """ all_keys = flatten_dict(params).keys() logging.info('All keys:') for key in all_keys: logging.info(key) for key in layer_rescale_factors: logging.info('Rescaling %s by factor %f', key, layer_rescale_factors[key]) traversal = optim.ModelParamTraversal(lambda path, _: path == key) params = traversal.update(lambda x: x * layer_rescale_factors[key], params) return params
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if FLAGS.dynamic: train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_buckets=FLAGS.num_data_buckets) if FLAGS.static: weights = np.array([float(w) for w in FLAGS.static.split(',')]) assert len(weights) == FLAGS.num_data_buckets train_ds = train_ds_mgr.sampled_dataset(weights) FLAGS.dynamic = False else: init_dist = np.zeros(FLAGS.num_data_buckets) if FLAGS.data_selection_size < FLAGS.num_data_buckets: init_dist[range(FLAGS.data_selection_size)] = 1.0 train_ds = train_ds_mgr.sampled_dataset(init_dist) else: train_ds = build_split(train_ds_mgr, 1.0) else: train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) if FLAGS.adapter != NONE: adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) if FLAGS.adapter != NONE: learning_rate_fn = common.create_learning_rate_scheduler( factors='constant', base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) else: learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial(eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant p_get_diag_grads = jax.pmap( functools.partial( get_diag_grads, config=eval_config), axis_name='batch') p_get_bucket_score = jax.pmap( functools.partial( get_diag_score, strategy=FLAGS.strategy), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps best_eval_loss = 1000 curr_eval_loss = 1000 with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0): # Dynamic macro: use gradient alignment to score different ratios # of top k vs bottom N-k bins if FLAGS.macro: train_iter = get_macro_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) else: # Use gradient alignment to score bins # take the top k bins and sample uniformly from them. raw_distribution = get_new_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) logging.info(raw_distribution) selected = np.argsort( raw_distribution)[::-1][:FLAGS.data_selection_size] new_distribution = np.zeros(100) new_distribution[selected] = 1.0 logging.info(new_distribution) train_ds = train_ds_mgr.sampled_dataset(new_distribution) train_iter = iter(train_ds) # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length) writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def test_only_works_on_models(self): traversal = optim.ModelParamTraversal(lambda *_: True) with self.assertRaises(ValueError): list(traversal.iterate({}))
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'resnet') mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES) mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') if FLAGS.seed is not None: seed = FLAGS.seed else: seed = np.uint32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) key = random.PRNGKey(seed) batch_size = FLAGS.batch_size if batch_size == -1: if jax.device_count() > 4096: batch_size = 65536 else: batch_size = min(128 * jax.device_count(), 32768) mllogger.event('global_batch_size', batch_size) eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count()) device_batch_size = batch_size // jax.device_count() device_eval_batch_size = int( math.ceil(eval_batch_size / jax.device_count())) model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 num_epochs = FLAGS.num_epochs if num_epochs is None: if batch_size < 32768: num_epochs = 56 elif batch_size < 65536: num_epochs = 64 else: num_epochs = 92 steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size # match TF submission behavior (round steps per loop up) steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop)) # also apply rounding loop up to next step to "epochs" in LR schedule steps_per_epoch *= steps_per_loop / (steps_per_epoch * FLAGS.epochs_per_loop) steps_per_eval = int( math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum if beta is None: if batch_size < 32768: beta = 0.9 elif batch_size < 65536: beta = 0.929 else: beta = 0.9537213777059405 weight_decay = FLAGS.weight_decay if weight_decay is None: weight_decay = 2e-4 if batch_size < 32768 else 1e-4 space_to_depth = FLAGS.space_to_depth if space_to_depth is None: space_to_depth = device_batch_size <= 8 image_format = FLAGS.image_format if image_format is None: if space_to_depth and device_batch_size <= 8: image_format = 'HWNC' else: image_format = 'HWCN' image_size = input_pipeline.IMAGE_SIZE if space_to_depth: train_input_shape = (device_batch_size, image_size // 2, image_size // 2, 12) eval_input_shape = (device_eval_batch_size, image_size // 2, image_size // 2, 12) else: train_input_shape = (device_batch_size, image_size, image_size, 3) eval_input_shape = (device_eval_batch_size, image_size, image_size, 3) if image_format == 'HWCN': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0]) elif image_format == 'HWNC': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3]) model, state = create_model(key, device_batch_size, image_size, model_dtype, space_to_depth) if FLAGS.lars: mllogger.event('opt_name', 'lars') mllogger.event('lars_opt_weight_decay', weight_decay) mllogger.event('lars_opt_momentum', beta) mllogger.event('lars_epsilon', 0) weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: mllogger.event('opt_name', 'sgd') mllogger.event('sgd_opt_momentum', beta) weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) del model # do not keep a copy of the initial model optimizer = broadcast(optimizer) state = broadcast(state) empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0}) p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch') p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch') def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics p_host_loop_train_step = jax.pmap(host_loop_train_step, axis_name='batch', in_axes=(None, 0, 0)) def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics p_host_loop_eval_step = jax.pmap(host_loop_eval_step, axis_name='batch', in_axes=(None, None, 0)) def device_train_loop_cond(args): _, _, _, _, step, loop = args return step // steps_per_loop == loop def device_train_loop_body(args): optimizer, state, metrics, token, step, loop = args batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) step += 1 return optimizer, state, metrics, token, step, loop def device_train_loop(optimizer, state, metrics, step, loop): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, loop)) state = sync_batchnorm_stats(state) metrics = allreduce_metrics(metrics) return optimizer, state, metrics, step p_train_loop = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, None, 0, None, None)) # BEGIN GOOGLE-INTERNAL def maybe_start_xprof(seconds): if jax.host_id() == 0 and FLAGS.xprof: xprof = xprof_session.XprofSession() xprof.start_session('REDACTED', True, 2) def sleep_and_end_xprof(): time.sleep(seconds) logging.info( 'Xprof URL: %s', xprof.end_session_and_get_url( tag='flax resnet, {} devices, batch {} per device'. format(jax.device_count(), device_batch_size))) thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof) # END GOOGLE-INTERNAL if FLAGS.precompile: logging.info('precompiling step/loop functions') if FLAGS.device_loop: # the device training loop condition will immediately be false p_train_loop(unbroadcast(optimizer), unbroadcast(state), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: for device in jax.local_devices(): images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_train_step(unbroadcast(optimizer), state, empty_metrics) p_sync_batchnorm_stats(state) for device in jax.local_devices(): images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), empty_metrics) p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready() logging.info('finished precompiling') # BEGIN GOOGLE-INTERNAL maybe_start_xprof(20) # END GOOGLE-INTERNAL if not FLAGS.fake_data: logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( device_batch_size if train else device_eval_batch_size, dtype=input_dtype, train=train, image_format=image_format, space_to_depth=space_to_depth, cache_uncompressed=jax.device_count() > 64) for train in (True, False) ] logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() host_step, device_step = 0, broadcast(0) mllogger.end('init_stop') mllogger.start('run_start') mllogger.start('block_start', metadata={ 'first_epoch_num': 1, 'epoch_count': FLAGS.epochs_per_loop }) for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2): # BEGIN GOOGLE-INTERNAL if loop == 10: maybe_start_xprof(1) # END GOOGLE-INTERNAL metrics = empty_metrics if FLAGS.device_loop: optimizer, state, metrics, device_step = p_train_loop( unbroadcast(optimizer), unbroadcast(state), metrics, unbroadcast(device_step), loop) while int(host_step // steps_per_loop) == loop: if not FLAGS.device_loop: optimizer, state, metrics = p_host_loop_train_step( unbroadcast(optimizer), state, metrics) # pylint: disable=protected-access while infeed_pool._work_queue.qsize() > 100: time.sleep(0.01) for device in local_devices: if FLAGS.fake_data: images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(train_iter)) assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) host_step += 1 epoch = (loop + 1) * FLAGS.epochs_per_loop if FLAGS.train_metrics: if not FLAGS.device_loop: metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch)) if not FLAGS.device_loop: state = p_sync_batchnorm_stats(state) metrics = empty_metrics for _ in range(steps_per_eval): metrics = p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), metrics) for device in local_devices: if FLAGS.fake_data: images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \ 'images.shape={}'.format(images.shape) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch)) # Wait until computations are done before exiting p_allreduce_metrics(metrics)['accuracy'].block_until_ready() if jax.host_id() == 0: summary_thread.shutdown() if not DONE: mllogger.end('run_stop', metadata={'status': 'aborted'})
def compute_is_scores(filename): """Compute IS scores for training data.""" # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset print('Loading data') logging.info('Initializing dataset.') train_ds, encoder = input_pipeline.get_wmt_is_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, paracrawl_size=FLAGS.paracrawl_size) print('Datasets created') train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. print('data iterators created') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- eval_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=True, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir # When loading a checkpoint trained with adapters (ie. frozen weights) # restoring from the base optimizer fails. We catch this error and create # the optimizer with frozen weights. try: optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) except ValueError: adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) optimizer = checkpoints.restore_checkpoint(model_path, optimizer) start_step = optimizer.state[0].step else: raise RuntimeError('Must restore checkpoint for IS') if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer): adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) p_eval_step = jax.pmap(functools.partial(eval_for_is_step, config=eval_config), axis_name='batch') logging.info('Start scoring loop.') metrics_all = [] t_loop_start = time.time() # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt' length_fp = tf.io.gfile.GFile(save_file, 'w') lengths_writer = csv.writer(length_fp) save_file = FLAGS.is_save_path + '/' + filename + '.txt' with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, eval_batch in enumerate(train_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) eval_batch = jax.tree_map( lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop eval_batch = common_utils.shard(eval_batch) losses, lengths = p_eval_step(optimizer.target, eval_batch) if jax.host_id() == 0: losses = common.tohost(losses) lengths = common.tohost(lengths) if cur_pred_batch_size % n_devices: writer.writerow(losses[:cur_pred_batch_size]) lengths_writer.writerow(lengths[:cur_pred_batch_size]) else: writer.writerow(losses) lengths_writer.writerow(lengths) if batch_idx % 500 == 0: print('Batch', batch_idx) print(time.time() - t_loop_start) length_fp.close()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.num_partial_programs, FLAGS.max_program_length) split_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.num_partial_programs, FLAGS.max_characters) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table, num_partial_programs=FLAGS.num_partial_programs) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5) train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=FLAGS.num_partial_programs, use_expanding_layer=FLAGS.use_expanding_layer) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) adam_opt_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = adam_opt_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if start_step > 0: start_step += 1 # Build Pretraining Model and Optimizer (if specified) # --------------------------------------------------------------------------- pretrain_optimizer = None # Optimizer used for pretrainined split_target = None # Split pretrained model on partial programs. if start_step < FLAGS.num_pretrain_steps: # Load in pretraining optimizer. def filter_fn(path, value): del value if FLAGS.freeze_encoder and path.startswith('/encoder'): return False if FLAGS.freeze_decoder and path.startswith('/decoder'): return False return True trainable_weights = optim.ModelParamTraversal(filter_fn) pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def)) pretrain_optimizer = pretrain_opt_def.create(optimizer.target) if FLAGS.pretrain_checkpoint_format: pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs) if gfile.isdir(checkpoint_dir): # Use the pretrained parameters if no training has occurred yet. if start_step == 0: restore_paths = [] if FLAGS.restore_encoder: restore_paths.append('target/encoder') if FLAGS.restore_decoder: restore_paths.append('target/decoder') pretrain_optimizer = restore_selected_paths( pretrain_optimizer, checkpoint_dir=checkpoint_dir, restore_paths=restore_paths) logging.info('Found model pretrained at %s.', checkpoint_dir) if FLAGS.match_split_encoding: split_model = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=1, use_expanding_layer=False) split_program_shape = (FLAGS.per_device_batch_size, 1, FLAGS.max_program_length) split_initial_variables = jax.jit(split_model.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(split_program_shape, jnp.float32)) split_optimizer = adam_opt_def.create( split_initial_variables['params']) split_optimizer = checkpoints.restore_checkpoint( checkpoint_dir, split_optimizer) split_target = split_optimizer.target else: logging.warn('Could not find model at %s.', checkpoint_dir) if FLAGS.match_split_encoding and (split_target is None): raise RuntimeError('We could not load the pretrained checkpoint, ' 'which is needed to match split embeddings.') learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr) p_pretrain_step = jax.pmap( functools.partial( pretrain_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer, split_params=split_target), axis_name='batch') p_train_step = jax.pmap( functools.partial( train_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, num_partial_programs=FLAGS.num_partial_programs, eos_token=eos_token, config=eval_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, config=predict_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch', static_broadcasted_argnums=(4,)) p_split_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=False, use_split_encoding=True, split_params=split_target), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng # Replicate optimizer. if pretrain_optimizer: pretrain_optimizer = jax_utils.replicate(pretrain_optimizer) optimizer = jax_utils.replicate(optimizer) metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs, split_outputs = ( common_utils.shard(next(train_iter))) if step < FLAGS.num_pretrain_steps: pretrain_optimizer, metrics, train_rngs = p_pretrain_step( pretrain_optimizer, inputs, outputs, programs, split_outputs=split_outputs, pretrain_rng=train_rngs) else: optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1 is_last_step = step == FLAGS.num_train_steps - 1 if is_last_pretrain_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step and not is_last_pretrain_step): continue optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_summary = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc, message = predict_and_compute_score( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, pred_acc) beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search' summary_writer.scalar( 'predict-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush() if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding: pred_acc, message = predict_and_compute_score( p_pred_step=p_split_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' beam_search_or_bfs = ('bfs' if FLAGS.best_first_search else 'beam-search') summary_writer.scalar( 'predict-split-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-split-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(128 * jax.device_count(), 32768) eval_batch_size = 128 * jax.device_count() local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() device_eval_batch_size = eval_batch_size // jax.device_count() device_last_eval_batch_size = (input_pipeline.EVAL_IMAGES % eval_batch_size) // jax.device_count() model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 if FLAGS.transpose_images: train_input_shape = (224, 224, 3, device_batch_size) eval_input_shapes = [(224, 224, 3, bs) for bs in (device_eval_batch_size, device_last_eval_batch_size)] else: train_input_shape = (device_batch_size, 224, 224, 3) eval_input_shapes = [(bs, 224, 224, 3) for bs in (device_eval_batch_size, device_last_eval_batch_size)] num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size logging.info('steps_per_epoch: %f', steps_per_epoch) steps_per_eval = int(np.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) logging.info('steps_per_eval: %d', steps_per_eval) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum weight_decay = FLAGS.weight_decay logging.info('creating and initializing model and optimizer') model, state = create_model(rng, device_batch_size, image_size, model_dtype) state = jax_utils.replicate(state) if FLAGS.lars: weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # do not keep a copy of the initial model p_train_step = jax.pmap(partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, state, metrics, token, step, epoch = args (images, labels), token = lax.infeed( token, shape=(jax.ShapedArray(train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) batch = {'image': images, 'label': labels} optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn) step += 1 return optimizer, state, metrics, token, step, epoch def device_train_loop(optimizer, state, metrics, step, epoch): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, epoch)) return optimizer, state, metrics, step p_train_epoch = jax.pmap(device_train_loop, axis_name='batch') if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false p_train_epoch(optimizer, state, empty_metrics(), jax_utils.replicate(0), jax_utils.replicate(1)) else: batch = { 'image': jnp.zeros((jax.local_device_count(), ) + train_input_shape, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (device_batch_size, ), jnp.int32) } p_train_step(optimizer, state, batch, empty_metrics()) for dbs, eis in zip( [device_eval_batch_size, device_last_eval_batch_size], eval_input_shapes): batch = { 'image': jnp.zeros((jax.local_device_count(), ) + eis, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (dbs, ), jnp.int32) } p_eval_step(optimizer.target, state, batch, empty_metrics()) allreduce_metrics(empty_metrics()) pmean = functools.partial(jax.lax.pmean, axis_name='batch') jax.pmap(pmean, axis_name='batch')(state) logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( local_batch_size if train else local_eval_batch_size, image_size=image_size, dtype=input_dtype, train=train, transpose_images=FLAGS.transpose_images) for train in (True, False) ] # pylint: enable=g-complex-comprehension logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) logging.info('beginning training') host_step, device_step = 0, jax_utils.replicate(0) for epoch in range(num_epochs): device_epoch = jax_utils.replicate(epoch) metrics = empty_metrics() if FLAGS.infeed: optimizer, state, metrics, device_step = p_train_epoch( optimizer, state, metrics, device_step, device_epoch) while int(host_step // steps_per_epoch) == epoch: batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access if FLAGS.infeed: for i, device in enumerate(jax.local_devices()): images, labels = batch['image'][i], batch['label'][i] assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) else: optimizer, state, metrics = p_train_step( optimizer, state, batch, metrics) host_step += 1 if FLAGS.train_metrics: metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch + 1)) if not FLAGS.distributed_batchnorm: # otherwise it's already synced pmean = functools.partial(jax.lax.pmean, axis_name='batch') state = jax.pmap(pmean, axis_name='batch')(state) metrics = empty_metrics() for _ in range(steps_per_eval): batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) # pylint: disable=protected-access metrics = p_eval_step(optimizer.target, state, batch, metrics) metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch + 1)) # TODO(deveci): do something like this from the summary thread: # if summary['accuracy'] > TARGET_ACCURACY: # break if jax.host_id() == 0: summary_thread.shutdown() # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def meta_optimize_scales(loss_fn, fprop, normalized_params, norms, hps, input_shape, output_shape, rng_key, metrics_logger=None, log_every=10): """Implements MetaInit initializer. Args: loss_fn: Loss function. fprop: Forward pass of the model with API fprop(params, inputs) -> outputs. normalized_params: Pytree of model parameters. We assume that all non-bias terms have norm 1, and all bias terms are all 0's. norms: The initial guess of the learned norms, this is the starting point of meta_init. hps: HParam object. Required hparams are meta_learning_rate, meta_batch_size, meta_steps, and epsilon. input_shape: Must agree with batch[0].shape[1:]. output_shape: Must agree with batch[1].shape[1:]. rng_key: jax.PRNGKey, used to seed all randomness. metrics_logger: Supply a utils.MetricsLogger object. log_every: Log the meta loss every k steps. Returns: scales: The model scales after optimizing the meta_init loss. final_loss: The final meta objective value achieved. """ num_outputs = output_shape[-1] if hps.meta_batch_size % jax.device_count() != 0: raise ValueError('meta_bs: {}, n_devices: {}'.format( hps.meta_batch_size, jax.device_count())) def get_batch(rng_key): """Return a fake batch of data.""" meta_input_shape = ( jax.local_device_count(), hps.meta_batch_size // jax.device_count(), ) + input_shape input_key, target_key = jax.random.split(rng_key) inputs = jax.random.normal(input_key, meta_input_shape) targets = jax.random.randint(target_key, ( jax.local_device_count(), hps.meta_batch_size // jax.device_count(), ), 0, num_outputs) targets = jnp.eye(num_outputs)[targets] return (inputs, targets) # We will only optimize the scalars for model parameters with rank >=2. non_bias_and_scalar_keys = _get_non_bias_params(normalized_params) if jax.process_index() == 0: logging.info('MetaInit will optimize the following parameters:') for key in non_bias_and_scalar_keys: logging.info(key) traversal = optimizers.ModelParamTraversal( lambda path, _: path in non_bias_and_scalar_keys) # Non-bias, non-scalar norms. meta_params = traversal.update(lambda x: x, norms) meta_opt_init_fn, meta_opt_update_fn = optax.sgd( learning_rate=hps.meta_learning_rate, momentum=hps.meta_momentum) meta_optimizer_state = meta_opt_init_fn(meta_params) meta_optimizer_state = jax_utils.replicate(meta_optimizer_state) meta_params = jax_utils.replicate(meta_params) # Make a closure over the static variables (model, normalized_params, hps). @functools.partial(jax.pmap, axis_name='batch') def update(meta_params, optimizer_state, inputs, targets): """Update step.""" def params_to_loss(params): return loss_fn(fprop({'params': params}, inputs, train=True), targets) def _meta_loss(params): return meta_loss(params_to_loss, params, normalized_params, hps.epsilon) grad_fn = jax.value_and_grad(_meta_loss, has_aux=False) loss, grads = grad_fn(meta_params) grads = model_utils.cross_device_avg(grads) grads = jax.tree_map(jnp.sign, grads) meta_updates, new_meta_optimizer_state = meta_opt_update_fn( grads, optimizer_state, params=meta_params) new_meta_params = optax.apply_updates(meta_params, meta_updates) return new_meta_params, new_meta_optimizer_state, loss training_curve = [] start = time.perf_counter() for i in range(hps.meta_steps): batch_rng = jax.random.fold_in(rng_key, i) inputs, targets = get_batch(batch_rng) meta_params, meta_optimizer_state, loss_value = update( meta_params, meta_optimizer_state, inputs, targets) training_curve.append(loss_value) if (jax.process_index() == 0 and (i % log_every == 0 or (i + 1) == hps.meta_steps)): end = time.perf_counter() logging.info('Cumulative time (seconds): %d', end-start) logging.info('meta_init step %d, loss: %f', i, float(loss_value[0])) if metrics_logger is not None: metrics_logger.append_scalar_metrics({ 'global_step': i, 'meta_loss': float(loss_value[0]) }) # Create a new model with the learned init. learned_norms = jax_utils.unreplicate(meta_params) return learned_norms, training_curve
def get_optimizer(hps): """Constructs the optimizer from the given HParams.""" if 'weight_decay' in hps.opt_hparams: weight_decay = hps.opt_hparams['weight_decay'] else: weight_decay = 0 if hps.optimizer == 'sgd': return optimizers.GradientDescent(learning_rate=None) elif hps.optimizer == 'nesterov': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=True, weight_decay=weight_decay) elif hps.optimizer == 'momentum': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=False, weight_decay=weight_decay) elif hps.optimizer == 'lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LAMB(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'adam': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.Adam(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'lars': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) elif hps.optimizer == 'mlperf_lars_resnet': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) other_opt_def = optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=0, nesterov=False) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) elif hps.optimizer == 'mlperf_lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LAMB( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['lamb_weight_decay']) other_opt_def = optimizers.Adam( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['adam_weight_decay']) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hps.optimizer))