def create_optimizer(model, learning_rate, weight_decay, layers=None): """Instantiates Adam multi-optimizer.""" if layers is None: assert ( type(learning_rate) == type(weight_decay) == float ), 'Specify float values for moded learning rate and weight decay!' optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) optimizer = optimizer_def.create(model) else: assert ( len(learning_rate) == len(weight_decay) == len(layers) ), 'Number of specified learning rates, weight decays, and layers must be equal!' optimizers = [] for lr, wd, layer in zip(learning_rate, weight_decay, layers): if lr > 0: opt = optim.Adam(learning_rate=lr, weight_decay=wd) filter_fn = functools.partial(path_inclusion_filter_fn, layer=layer) traversal = optim.ModelParamTraversal(filter_fn) traversal_opt = (traversal, opt) optimizers.append(traversal_opt) optimizer_def = optim.MultiOptimizer(*optimizers) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(model, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam Optimizer with weight decay to all parameters except layer_norm and bias and Adam Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam Optimizer to be used for training """ weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.01) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model return optimizer
def create_optimizer(config, params): if config.optimizer == "adam": optimizer_cls = optim.Adam elif config.optimizer == "lamb": optimizer_cls = optim.LAMB else: raise ValueError("Unsupported value for optimizer: {config.optimizer}") common_kwargs = dict( learning_rate=config.learning_rate, beta1=config.adam_beta1, beta2=config.adam_beta2, eps=config.adam_epsilon, ) optimizer_decay_def = optimizer_cls(weight_decay=config.weight_decay, **common_kwargs) optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs) def exclude_from_decay(path, _): return "bias" in path or "layer_norm" in path or "layernorm" in path decay = optim.ModelParamTraversal( lambda *args: not exclude_from_decay(*args)) no_decay = optim.ModelParamTraversal(exclude_from_decay) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(params) return optimizer
def test_multi_optimizer(self): params = {'a': 0., 'b': 0.} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = traverse_util.t_identity['a'] t_b = traverse_util.t_identity['b'] optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) state = optimizer_def.init_state(params) expected_hyper_params = [ _GradientDescentHyperParams(1.), _GradientDescentHyperParams(10.) ] self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = [optim.OptimizerState(0, [()])] * 2 self.assertEqual(state, expected_state) grads = {'a': -1., 'b': -2.} new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_params = {'a': 1., 'b': 20.} expected_state = [optim.OptimizerState(1, [()])] * 2 self.assertEqual(new_state, expected_state) self.assertEqual(new_params, expected_params) # override learning_rate hp = optimizer_def.update_hyper_params(learning_rate=2.) new_params, new_state = optimizer_def.apply_gradient( hp, params, state, grads) expected_params = {'a': 2., 'b': 4.} self.assertEqual(new_params, expected_params)
def test_multi_optimizer_multiple_matches(self): params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = optim.ModelParamTraversal( lambda path, _: path.endswith('/x') or path.endswith('/y')) t_b = optim.ModelParamTraversal(lambda path, value: value.dtype == jnp. int32 or path.endswith('/z')) optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) with self.assertRaisesRegex( ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"): jax.jit(optimizer_def.init_state)(params)
def create_optimizer(model, model_kwargs, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam/LAMB Optimizer with weight decay to all parameters except layer_norm and bias and Adam/LAMB Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to model_kwargs: Bert model config parameter dictionary. learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam/LAMB Optimizer to be used for training """ if FLAGS.use_lamb: weight_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=FLAGS.lamb_weight_decay, num_layers=model_kwargs['num_layers']) no_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=0.0, num_layers=model_kwargs['num_layers']) else: weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=FLAGS.lamb_weight_decay) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key and 'layernorm' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key or 'layernorm' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) del model return optimizer
def create_optimizer(config, model): common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(config, model, initial_params): """Create a model, starting with a pre-trained checkpoint.""" common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) # TODO(marcvanzee): MultiOptimizer triggers double XLA compilation on TPU so # we use Adam here, but we should investigate why this happens. optimizer_def = optim.Adam(learning_rate=config.learning_rate) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # don't keep a copy of the initial model return optimizer
def create_optimizer(config, model): if config.optimizer == 'adam': optimizer_cls = optim.Adam elif config.optimizer == 'lamb': optimizer_cls = optim.LAMB else: raise ValueError('Unsupported value for optimizer: {config.optimizer}') common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optimizer_cls(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(model) return optimizer
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(128 * jax.device_count(), 32768) eval_batch_size = 128 * jax.device_count() local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() device_eval_batch_size = eval_batch_size // jax.device_count() device_last_eval_batch_size = (input_pipeline.EVAL_IMAGES % eval_batch_size) // jax.device_count() model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 if FLAGS.transpose_images: train_input_shape = (224, 224, 3, device_batch_size) eval_input_shapes = [(224, 224, 3, bs) for bs in (device_eval_batch_size, device_last_eval_batch_size)] else: train_input_shape = (device_batch_size, 224, 224, 3) eval_input_shapes = [(bs, 224, 224, 3) for bs in (device_eval_batch_size, device_last_eval_batch_size)] num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size logging.info('steps_per_epoch: %f', steps_per_epoch) steps_per_eval = int(np.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) logging.info('steps_per_eval: %d', steps_per_eval) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum weight_decay = FLAGS.weight_decay logging.info('creating and initializing model and optimizer') model, state = create_model(rng, device_batch_size, image_size, model_dtype) state = jax_utils.replicate(state) if FLAGS.lars: weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # do not keep a copy of the initial model p_train_step = jax.pmap(partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, state, metrics, token, step, epoch = args (images, labels), token = lax.infeed( token, shape=(jax.ShapedArray(train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) batch = {'image': images, 'label': labels} optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn) step += 1 return optimizer, state, metrics, token, step, epoch def device_train_loop(optimizer, state, metrics, step, epoch): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, epoch)) return optimizer, state, metrics, step p_train_epoch = jax.pmap(device_train_loop, axis_name='batch') if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false p_train_epoch(optimizer, state, empty_metrics(), jax_utils.replicate(0), jax_utils.replicate(1)) else: batch = { 'image': jnp.zeros((jax.local_device_count(), ) + train_input_shape, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (device_batch_size, ), jnp.int32) } p_train_step(optimizer, state, batch, empty_metrics()) for dbs, eis in zip( [device_eval_batch_size, device_last_eval_batch_size], eval_input_shapes): batch = { 'image': jnp.zeros((jax.local_device_count(), ) + eis, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (dbs, ), jnp.int32) } p_eval_step(optimizer.target, state, batch, empty_metrics()) allreduce_metrics(empty_metrics()) pmean = functools.partial(jax.lax.pmean, axis_name='batch') jax.pmap(pmean, axis_name='batch')(state) logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( local_batch_size if train else local_eval_batch_size, image_size=image_size, dtype=input_dtype, train=train, transpose_images=FLAGS.transpose_images) for train in (True, False) ] # pylint: enable=g-complex-comprehension logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) logging.info('beginning training') host_step, device_step = 0, jax_utils.replicate(0) for epoch in range(num_epochs): device_epoch = jax_utils.replicate(epoch) metrics = empty_metrics() if FLAGS.infeed: optimizer, state, metrics, device_step = p_train_epoch( optimizer, state, metrics, device_step, device_epoch) while int(host_step // steps_per_epoch) == epoch: batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access if FLAGS.infeed: for i, device in enumerate(jax.local_devices()): images, labels = batch['image'][i], batch['label'][i] assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) else: optimizer, state, metrics = p_train_step( optimizer, state, batch, metrics) host_step += 1 if FLAGS.train_metrics: metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch + 1)) if not FLAGS.distributed_batchnorm: # otherwise it's already synced pmean = functools.partial(jax.lax.pmean, axis_name='batch') state = jax.pmap(pmean, axis_name='batch')(state) metrics = empty_metrics() for _ in range(steps_per_eval): batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) # pylint: disable=protected-access metrics = p_eval_step(optimizer.target, state, batch, metrics) metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch + 1)) # TODO(deveci): do something like this from the summary thread: # if summary['accuracy'] > TARGET_ACCURACY: # break if jax.host_id() == 0: summary_thread.shutdown() # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.num_partial_programs, FLAGS.max_program_length) split_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.num_partial_programs, FLAGS.max_characters) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table, num_partial_programs=FLAGS.num_partial_programs) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5) train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=FLAGS.num_partial_programs, use_expanding_layer=FLAGS.use_expanding_layer) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) adam_opt_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = adam_opt_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if start_step > 0: start_step += 1 # Build Pretraining Model and Optimizer (if specified) # --------------------------------------------------------------------------- pretrain_optimizer = None # Optimizer used for pretrainined split_target = None # Split pretrained model on partial programs. if start_step < FLAGS.num_pretrain_steps: # Load in pretraining optimizer. def filter_fn(path, value): del value if FLAGS.freeze_encoder and path.startswith('/encoder'): return False if FLAGS.freeze_decoder and path.startswith('/decoder'): return False return True trainable_weights = optim.ModelParamTraversal(filter_fn) pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def)) pretrain_optimizer = pretrain_opt_def.create(optimizer.target) if FLAGS.pretrain_checkpoint_format: pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs) if gfile.isdir(checkpoint_dir): # Use the pretrained parameters if no training has occurred yet. if start_step == 0: restore_paths = [] if FLAGS.restore_encoder: restore_paths.append('target/encoder') if FLAGS.restore_decoder: restore_paths.append('target/decoder') pretrain_optimizer = restore_selected_paths( pretrain_optimizer, checkpoint_dir=checkpoint_dir, restore_paths=restore_paths) logging.info('Found model pretrained at %s.', checkpoint_dir) if FLAGS.match_split_encoding: split_model = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=1, use_expanding_layer=False) split_program_shape = (FLAGS.per_device_batch_size, 1, FLAGS.max_program_length) split_initial_variables = jax.jit(split_model.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(split_program_shape, jnp.float32)) split_optimizer = adam_opt_def.create( split_initial_variables['params']) split_optimizer = checkpoints.restore_checkpoint( checkpoint_dir, split_optimizer) split_target = split_optimizer.target else: logging.warn('Could not find model at %s.', checkpoint_dir) if FLAGS.match_split_encoding and (split_target is None): raise RuntimeError('We could not load the pretrained checkpoint, ' 'which is needed to match split embeddings.') learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr) p_pretrain_step = jax.pmap( functools.partial( pretrain_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer, split_params=split_target), axis_name='batch') p_train_step = jax.pmap( functools.partial( train_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, num_partial_programs=FLAGS.num_partial_programs, eos_token=eos_token, config=eval_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, config=predict_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch', static_broadcasted_argnums=(4,)) p_split_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=False, use_split_encoding=True, split_params=split_target), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng # Replicate optimizer. if pretrain_optimizer: pretrain_optimizer = jax_utils.replicate(pretrain_optimizer) optimizer = jax_utils.replicate(optimizer) metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs, split_outputs = ( common_utils.shard(next(train_iter))) if step < FLAGS.num_pretrain_steps: pretrain_optimizer, metrics, train_rngs = p_pretrain_step( pretrain_optimizer, inputs, outputs, programs, split_outputs=split_outputs, pretrain_rng=train_rngs) else: optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1 is_last_step = step == FLAGS.num_train_steps - 1 if is_last_pretrain_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step and not is_last_pretrain_step): continue optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_summary = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc, message = predict_and_compute_score( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, pred_acc) beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search' summary_writer.scalar( 'predict-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush() if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding: pred_acc, message = predict_and_compute_score( p_pred_step=p_split_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' beam_search_or_bfs = ('bfs' if FLAGS.best_first_search else 'beam-search') summary_writer.scalar( 'predict-split-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-split-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'resnet') mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES) mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') if FLAGS.seed is not None: seed = FLAGS.seed else: seed = np.uint32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) key = random.PRNGKey(seed) batch_size = FLAGS.batch_size if batch_size == -1: if jax.device_count() > 4096: batch_size = 65536 else: batch_size = min(128 * jax.device_count(), 32768) mllogger.event('global_batch_size', batch_size) eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count()) device_batch_size = batch_size // jax.device_count() device_eval_batch_size = int( math.ceil(eval_batch_size / jax.device_count())) model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 num_epochs = FLAGS.num_epochs if num_epochs is None: if batch_size < 32768: num_epochs = 56 elif batch_size < 65536: num_epochs = 64 else: num_epochs = 92 steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size # match TF submission behavior (round steps per loop up) steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop)) # also apply rounding loop up to next step to "epochs" in LR schedule steps_per_epoch *= steps_per_loop / (steps_per_epoch * FLAGS.epochs_per_loop) steps_per_eval = int( math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum if beta is None: if batch_size < 32768: beta = 0.9 elif batch_size < 65536: beta = 0.929 else: beta = 0.9537213777059405 weight_decay = FLAGS.weight_decay if weight_decay is None: weight_decay = 2e-4 if batch_size < 32768 else 1e-4 space_to_depth = FLAGS.space_to_depth if space_to_depth is None: space_to_depth = device_batch_size <= 8 image_format = FLAGS.image_format if image_format is None: if space_to_depth and device_batch_size <= 8: image_format = 'HWNC' else: image_format = 'HWCN' image_size = input_pipeline.IMAGE_SIZE if space_to_depth: train_input_shape = (device_batch_size, image_size // 2, image_size // 2, 12) eval_input_shape = (device_eval_batch_size, image_size // 2, image_size // 2, 12) else: train_input_shape = (device_batch_size, image_size, image_size, 3) eval_input_shape = (device_eval_batch_size, image_size, image_size, 3) if image_format == 'HWCN': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0]) elif image_format == 'HWNC': train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3]) eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3]) model, state = create_model(key, device_batch_size, image_size, model_dtype, space_to_depth) if FLAGS.lars: mllogger.event('opt_name', 'lars') mllogger.event('lars_opt_weight_decay', weight_decay) mllogger.event('lars_opt_momentum', beta) mllogger.event('lars_epsilon', 0) weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: mllogger.event('opt_name', 'sgd') mllogger.event('sgd_opt_momentum', beta) weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) del model # do not keep a copy of the initial model optimizer = broadcast(optimizer) state = broadcast(state) empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0}) p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch') p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch') def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics p_host_loop_train_step = jax.pmap(host_loop_train_step, axis_name='batch', in_axes=(None, 0, 0)) def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics p_host_loop_eval_step = jax.pmap(host_loop_eval_step, axis_name='batch', in_axes=(None, None, 0)) def device_train_loop_cond(args): _, _, _, _, step, loop = args return step // steps_per_loop == loop def device_train_loop_body(args): optimizer, state, metrics, token, step, loop = args batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) step += 1 return optimizer, state, metrics, token, step, loop def device_train_loop(optimizer, state, metrics, step, loop): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, loop)) state = sync_batchnorm_stats(state) metrics = allreduce_metrics(metrics) return optimizer, state, metrics, step p_train_loop = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, None, 0, None, None)) # BEGIN GOOGLE-INTERNAL def maybe_start_xprof(seconds): if jax.host_id() == 0 and FLAGS.xprof: xprof = xprof_session.XprofSession() xprof.start_session('REDACTED', True, 2) def sleep_and_end_xprof(): time.sleep(seconds) logging.info( 'Xprof URL: %s', xprof.end_session_and_get_url( tag='flax resnet, {} devices, batch {} per device'. format(jax.device_count(), device_batch_size))) thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof) # END GOOGLE-INTERNAL if FLAGS.precompile: logging.info('precompiling step/loop functions') if FLAGS.device_loop: # the device training loop condition will immediately be false p_train_loop(unbroadcast(optimizer), unbroadcast(state), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: for device in jax.local_devices(): images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_train_step(unbroadcast(optimizer), state, empty_metrics) p_sync_batchnorm_stats(state) for device in jax.local_devices(): images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), empty_metrics) p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready() logging.info('finished precompiling') # BEGIN GOOGLE-INTERNAL maybe_start_xprof(20) # END GOOGLE-INTERNAL if not FLAGS.fake_data: logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( device_batch_size if train else device_eval_batch_size, dtype=input_dtype, train=train, image_format=image_format, space_to_depth=space_to_depth, cache_uncompressed=jax.device_count() > 64) for train in (True, False) ] logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() host_step, device_step = 0, broadcast(0) mllogger.end('init_stop') mllogger.start('run_start') mllogger.start('block_start', metadata={ 'first_epoch_num': 1, 'epoch_count': FLAGS.epochs_per_loop }) for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2): # BEGIN GOOGLE-INTERNAL if loop == 10: maybe_start_xprof(1) # END GOOGLE-INTERNAL metrics = empty_metrics if FLAGS.device_loop: optimizer, state, metrics, device_step = p_train_loop( unbroadcast(optimizer), unbroadcast(state), metrics, unbroadcast(device_step), loop) while int(host_step // steps_per_loop) == loop: if not FLAGS.device_loop: optimizer, state, metrics = p_host_loop_train_step( unbroadcast(optimizer), state, metrics) # pylint: disable=protected-access while infeed_pool._work_queue.qsize() > 100: time.sleep(0.01) for device in local_devices: if FLAGS.fake_data: images = np.zeros(train_input_shape, model_dtype) labels = np.zeros((device_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(train_iter)) assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) host_step += 1 epoch = (loop + 1) * FLAGS.epochs_per_loop if FLAGS.train_metrics: if not FLAGS.device_loop: metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch)) if not FLAGS.device_loop: state = p_sync_batchnorm_stats(state) metrics = empty_metrics for _ in range(steps_per_eval): metrics = p_host_loop_eval_step(unbroadcast(optimizer.target), unbroadcast(state), metrics) for device in local_devices: if FLAGS.fake_data: images = np.zeros(eval_input_shape, model_dtype) labels = np.zeros((device_eval_batch_size, ), np.int32) else: # pylint: disable=protected-access images, labels = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \ 'images.shape={}'.format(images.shape) infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) metrics = p_allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch)) # Wait until computations are done before exiting p_allreduce_metrics(metrics)['accuracy'].block_until_ready() if jax.host_id() == 0: summary_thread.shutdown() if not DONE: mllogger.end('run_stop', metadata={'status': 'aborted'})
def get_optimizer(hps): """Constructs the optimizer from the given HParams.""" if 'weight_decay' in hps.opt_hparams: weight_decay = hps.opt_hparams['weight_decay'] else: weight_decay = 0 if hps.optimizer == 'sgd': return optimizers.GradientDescent(learning_rate=None) elif hps.optimizer == 'nesterov': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=True, weight_decay=weight_decay) elif hps.optimizer == 'momentum': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=False, weight_decay=weight_decay) elif hps.optimizer == 'lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LAMB(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'adam': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.Adam(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'lars': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) elif hps.optimizer == 'mlperf_lars_resnet': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) other_opt_def = optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=0, nesterov=False) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) elif hps.optimizer == 'mlperf_lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LAMB( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['lamb_weight_decay']) other_opt_def = optimizers.Adam( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['adam_weight_decay']) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hps.optimizer))