def test_should_update_bounds(self, frequency, start_step, current_step, should_update): self.assertEqual( train_utils.should_update_bounds( activation_bound_update_freq=frequency, activation_bound_start_step=start_step, step=current_step), should_update)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost: estimate_compute_and_memory_cost( image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info('Writing training HLO and estimating compute/memory costs.') rng = random.PRNGKey(hparams.seed) model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True, is_teacher=hparams.is_teacher) # pylint: disable=g-long-lambda if hparams.teacher_model == 'resnet50-8bit': teacher_config = w8a8auto_paper_config() teacher_hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config) teacher_model, _ = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=teacher_hparams.model_hparams, train=False, is_teacher=True) # teacher model does not need to be trainable # Directory where checkpoints are saved ckpt_model_dir = FLAGS.resnet508b_ckpt_path # will restore to best checkpoint state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None) teacher_variables = {'params': state_load['optimizer']['target']} teacher_variables.update(state_load['model_state']) # create a dictionary for better argument passing teacher = { 'model': lambda var, img, labels: jax.nn.softmax( teacher_model.apply(var, img)), 'variables': teacher_variables, } elif hparams.teacher_model == 'labels': teacher = { 'model': lambda var, img, labels: common_utils.onehot( labels, num_classes=1000), 'variables': {}, # no need of variables in this case } else: raise ValueError('The specified teacher model is not supported.') model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum( beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam( beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler, batch_size) p_train_step = jax.pmap( functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn, teacher=teacher), axis_name='batch', static_broadcasted_argnums=(2, 3, 4)) p_eval_step = jax.pmap( functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch', static_broadcasted_argnums=(2,)) epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() last_log_step = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) # and pass the result bool value to p_train_step # The function should take hparams.weight_quant_start_step as inputs quantize_weights = train_utils.should_quantize_weights( hparams.weight_quant_start_step, step // steps_per_epoch) state, metrics = p_train_step(state, batch, hparams, update_bounds, quantize_weights) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) def should_log(step): epoch_no = step // steps_per_epoch step_in_epoch = step - epoch_no * steps_per_epoch do_log = False do_log = do_log or (step + 1 == num_steps) # log at the end end_of_train = step / num_steps > 0.9 do_log = do_log or ((step_in_epoch % (steps_per_epoch // 4) == 0) and not end_of_train) do_log = do_log or ((step_in_epoch % (steps_per_epoch // 16) == 0) and end_of_train) return do_log if should_log(step): epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start) last_log_step = step t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) if FLAGS.write_summary: summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch, quantize_weights) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[ FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0: estimate_compute_and_memory_cost(image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info( 'Writing training HLO and estimating compute/memory costs.') model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True) model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum(beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam(beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler) p_train_step = jax.pmap(functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn), axis_name='batch', static_broadcasted_argnums=(2, 3)) p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch') epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) state, metrics = p_train_step(state, batch, hparams, update_bounds) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics( state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()