def create_train_state(rng, config: ml_collections.ConfigDict, model, image_size): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = optim.DynamicScale() else: dynamic_scale = None params, model_state = initialized(rng, image_size, model) optimizer = optim.Momentum( beta=config.momentum, nesterov=True).create(params) state = TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) return state
def test_dynamic_scale(self): def loss_fn(p): return jnp.asarray(p, jnp.float16) ** 2 p = jnp.array(1., jnp.float32) dyn_scale = optim.DynamicScale(growth_interval=2) step = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) inf = float('inf') nan = float('nan') expected_values = [ (False, nan, 32768.0, inf), (False, 1.0, 16384.0, inf), (True, 1.0, 16384.0, 2.0), (True, 1.0, 16384.0, 2.0), (True, 1.0, 32768.0, 2.0), (False, 1.0, 16384.0, inf), ] for expected in expected_values: dyn_scale, is_fin, loss, grad = step(dyn_scale, p) values = onp.array((is_fin, loss, dyn_scale.scale, grad)) onp.testing.assert_allclose(values, expected)
def create_train_state(rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = optim.DynamicScale() else: dynamic_scale = None params, batch_stats = initialized(rng, image_size, model) tx = optax.sgd( learning_rate=learning_rate_fn, momentum=config.momentum, nesterov=True, ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, dynamic_scale=dynamic_scale) return state
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost: estimate_compute_and_memory_cost( image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info('Writing training HLO and estimating compute/memory costs.') rng = random.PRNGKey(hparams.seed) model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True, is_teacher=hparams.is_teacher) # pylint: disable=g-long-lambda if hparams.teacher_model == 'resnet50-8bit': teacher_config = w8a8auto_paper_config() teacher_hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config) teacher_model, _ = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=teacher_hparams.model_hparams, train=False, is_teacher=True) # teacher model does not need to be trainable # Directory where checkpoints are saved ckpt_model_dir = FLAGS.resnet508b_ckpt_path # will restore to best checkpoint state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None) teacher_variables = {'params': state_load['optimizer']['target']} teacher_variables.update(state_load['model_state']) # create a dictionary for better argument passing teacher = { 'model': lambda var, img, labels: jax.nn.softmax( teacher_model.apply(var, img)), 'variables': teacher_variables, } elif hparams.teacher_model == 'labels': teacher = { 'model': lambda var, img, labels: common_utils.onehot( labels, num_classes=1000), 'variables': {}, # no need of variables in this case } else: raise ValueError('The specified teacher model is not supported.') model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum( beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam( beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler, batch_size) p_train_step = jax.pmap( functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn, teacher=teacher), axis_name='batch', static_broadcasted_argnums=(2, 3, 4)) p_eval_step = jax.pmap( functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch', static_broadcasted_argnums=(2,)) epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() last_log_step = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) # and pass the result bool value to p_train_step # The function should take hparams.weight_quant_start_step as inputs quantize_weights = train_utils.should_quantize_weights( hparams.weight_quant_start_step, step // steps_per_epoch) state, metrics = p_train_step(state, batch, hparams, update_bounds, quantize_weights) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) def should_log(step): epoch_no = step // steps_per_epoch step_in_epoch = step - epoch_no * steps_per_epoch do_log = False do_log = do_log or (step + 1 == num_steps) # log at the end end_of_train = step / num_steps > 0.9 do_log = do_log or ((step_in_epoch % (steps_per_epoch // 4) == 0) and not end_of_train) do_log = do_log or ((step_in_epoch % (steps_per_epoch // 16) == 0) and end_of_train) return do_log if should_log(step): epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start) last_log_step = step t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) if FLAGS.write_summary: summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch, quantize_weights) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) dynamic_scale = optim.DynamicScale() optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer, dynamic_scale=dynamic_scale) del optimizer, variables learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap(functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2, )) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3, ), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean( np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean( np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map( lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod( jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print( f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec" ) summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[ FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0: estimate_compute_and_memory_cost(image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info( 'Writing training HLO and estimating compute/memory costs.') model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True) model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum(beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam(beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler) p_train_step = jax.pmap(functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn), axis_name='batch', static_broadcasted_argnums=(2, 3)) p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch') epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) state, metrics = p_train_step(state, batch, hparams, update_bounds) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics( state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def train_and_evaluate(model_dir: str, batch_size: int, num_epochs: int, learning_rate: float, momentum: float, cache: bool, half_precision: bool, num_train_steps: int = -1, num_eval_steps: int = -1): """Runs model training and evaluation loop. Args: model_dir: Directory where the checkpoints and tensorboard summaries should be written to. batch_size: Batch size of the input. num_epochs: Number of epochs to cycle through before stopping. learning_rate: The learning rate in case you have batch size 256. The effective learning rate is scaled linearly to the batch size. momentum: Momentum value for the momentum optimizer. cache: Determines whether the dataset should be cached. half_precision: Determines whether bfloat16/float16 should be used instead of float32. num_train_steps: Number of trainings steps to be executed in a single epoch. Default = -1 signifies using the entire TRAIN split. num_eval_steps: Number of evaluation steps to be executed in a single epoch. Default = -1 signifies using the entire VALIDATION split. """ if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(model_dir) rng = random.PRNGKey(0) image_size = 224 if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 dataset_builder = tfds.builder('imagenet2012:5.*.*') train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=cache) if num_train_steps == -1: steps_per_epoch = \ dataset_builder.info.splits['train'].num_examples // batch_size else: steps_per_epoch = num_train_steps if num_eval_steps == -1: steps_per_eval = \ dataset_builder.info.splits['validation'].num_examples // batch_size else: steps_per_eval = num_eval_steps steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = learning_rate * batch_size / 256. model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(model_dir, state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(model_dir, state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def train_and_evaluate(config: ml_collections.ConfigDict, model_dir: str): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. model_dir: Directory where the tensorboard summaries are written to. """ if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(model_dir) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.host_count() platform = jax.local_devices()[0].platform dynamic_scale = None if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: input_dtype = tf.float32 dataset_builder = tfds.builder('imagenet2012:5.*.*') train_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = ( dataset_builder.info.splits['train'].num_examples // config.batch_size ) if config.num_train_steps == -1: num_steps = steps_per_epoch * config.num_epochs else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits['train'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. variables = initialized(rng, image_size, config.half_precision) optimizer = optim.Momentum( beta=config.momentum, nesterov=True).create(variables['params']) state = TrainState( step=0, optimizer=optimizer, batch_stats=variables['batch_stats'], dynamic_scale=dynamic_scale) state = restore_checkpoint(state, model_dir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn( base_learning_rate, steps_per_epoch, config.num_epochs) p_train_step = jax.pmap( functools.partial(train_step, half_precision=config.half_precision, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch, config.half_precision) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, model_dir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()