def get_summary_writers(workdir): current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = workdir + '/log/' + current_time train_log_dir = log_dir + '/train' eval_log_dir = log_dir + '/eval' train_summary_writer = tensorboard.SummaryWriter(train_log_dir) eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir) return train_summary_writer, eval_summary_writer
def main(_): master = jax.host_id() == 0 # make sure TF does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # load configs from a config json string hparams = FLAGS.config logging.info('=========== Hyperparameters ============') logging.info(hparams) if hparams.get('debug'): logging.warning('DEBUG MODE IS ENABLED!') # set tensorflow random seed tf.random.set_seed(jax.host_id() + hparams.rng_seed) experiment_dir = FLAGS.experiment_dir logging.info('Experiment directory: %s', experiment_dir) summary_writer = None if master and hparams.write_summary: tensorboard_dir = os.path.join(experiment_dir, 'tb_summaries') gfile.makedirs(tensorboard_dir) summary_writer = tensorboard.SummaryWriter(tensorboard_dir) run(hparams, experiment_dir, summary_writer) pool.close() pool.join()
def train(train_ds, test_ds): """Train MNIST to completion.""" rng = random.PRNGKey(0) batch_size = FLAGS.batch_size num_epochs = FLAGS.num_epochs model_dir = FLAGS.model_dir summary_writer = tensorboard.SummaryWriter(model_dir) model = create_model(rng) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.momentum) input_rng = onp.random.RandomState(0) for epoch in range(1, num_epochs + 1): optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng) loss, accuracy = eval_model(optimizer.target, test_ds) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss, accuracy * 100) summary_writer.scalar('train_loss', train_metrics['loss'], epoch) summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('eval_loss', loss, epoch) summary_writer.scalar('eval_accuracy', accuracy, epoch) return optimizer
def eval_once(run_configuration, checkpoint_path, optimizer=None): """Evaluates a single checkpoint on a single epoch of data.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info eval_name = config.eval_name or 'eval' log_dir = os.path.join(run_dir, eval_name) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) step = int(optimizer.state.step) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) eval_step = adapter.make_eval_step() eval_step_parallel = jax.pmap(eval_step, axis_name='batch') # Perform evaluation tick = time.time() metrics_all = [] example = None dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) for unused_eval_step, example in zip(range(config.eval_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = eval_step_parallel(optimizer.target, train_inputs) metrics_all.append(metrics) # Write results. metrics_all = common_utils.get_metrics(metrics_all) metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = len(metrics_all) / (tock - tick) examples_per_sec = denominator / (tock - tick) summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush()
def __init__(self, base_dir, create_agent_fn, create_environment_fn=gym_lib.create_gym_environment, checkpoint_file_prefix='ckpt', logging_file_prefix='log', log_every_n=1, num_iterations=200, training_steps=250000, evaluation_steps=125000, max_steps_per_episode=1000, clip_rewards=False): """Initialize the Runner object in charge of running a full experiment. Args: base_dir: str, the base directory to host all required sub-directories. create_agent_fn: A function that takes as argument an environment, and returns an agent. create_environment_fn: A function which receives a problem name and creates a Gym environment for that problem (e.g. an Atari 2600 game). checkpoint_file_prefix: str, the prefix to use for checkpoint files. logging_file_prefix: str, prefix to use for the log files. log_every_n: int, the frequency for writing logs. num_iterations: int, the iteration number threshold (must be greater than start_iteration). training_steps: int, the number of training steps to perform. evaluation_steps: int, the number of evaluation steps to perform. max_steps_per_episode: int, maximum number of steps after which an episode terminates. clip_rewards: bool, whether to clip rewards in [-1, 1]. This constructor will take the following actions: - Initialize an environment. - Initialize a logger. - Initialize an agent. - Reload from the latest checkpoint, if available, and initialize the Checkpointer object. """ assert base_dir is not None self._logging_file_prefix = logging_file_prefix self._log_every_n = log_every_n self._num_iterations = num_iterations self._training_steps = training_steps self._evaluation_steps = evaluation_steps self._max_steps_per_episode = max_steps_per_episode self._base_dir = base_dir self._clip_rewards = clip_rewards self._create_directories() self._summary_writer = tensorboard.SummaryWriter(base_dir) self._environment = create_environment_fn() self._agent = create_agent_fn(self._environment, summary_writer=self._summary_writer) self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
def train(module: models.ActorCritic, optimizer: flax.optim.base.Optimizer, config: ml_collections.ConfigDict, model_dir: str): """Main training loop. Args: module: the actor-critic model optimizer: optimizer for the actor-critic model config: object holding hyperparameters and the training information model_dir: path to dictionary where checkpoints and logging info are stored Returns: optimizer: the trained optimizer """ game = config.game + 'NoFrameskip-v4' simulators = [ agent.RemoteSimulator(game) for _ in range(config.num_agents) ] summary_writer = tensorboard.SummaryWriter(model_dir) summary_writer.hparams(dict(config)) loop_steps = config.total_frames // (config.num_agents * config.actor_steps) log_frequency = 40 checkpoint_frequency = 500 for s in range(loop_steps): # Bookkeeping and testing. if s % log_frequency == 0: score = test_episodes.policy_test(1, module, optimizer.target, game) frames = s * config.num_agents * config.actor_steps summary_writer.scalar('game_score', score, frames) print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n') if s % checkpoint_frequency == 0: checkpoints.save_checkpoint(model_dir, optimizer, s) # Core training code. alpha = 1. - s / loop_steps if config.decaying_lr_and_clip_param else 1. all_experiences = get_experience(optimizer.target, module, simulators, config.actor_steps) trajectories = process_experience(all_experiences, config.actor_steps, config.num_agents, config.gamma, config.lambda_) lr = config.learning_rate * alpha clip_param = config.clip_param * alpha for e in range(config.num_epochs): permutation = onp.random.permutation(config.num_agents * config.actor_steps) trajectories = tuple(map(lambda x: x[permutation], trajectories)) optimizer, loss = train_step(module, optimizer, trajectories, clip_param, config.vf_coeff, config.entropy_coeff, lr, config.batch_size) return optimizer
def train_agent(iterations, modeldir, logdir): """Train and convert the model.""" summary_writer = tensorboard.SummaryWriter(logdir) rng = random.PRNGKey(0) rng, init_rng = random.split(rng) policygradient = PolicyGradient() params = policygradient.init( init_rng, jnp.ones([1, common.BOARD_SIZE, common.BOARD_SIZE]))['params'] optimizer = create_optimizer(model_params=params, learning_rate=LEARNING_RATE) # Main training loop progress_bar = tf.keras.utils.Progbar(iterations) for i in range(iterations): predict_fn = functools.partial(run_inference, optimizer.target) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) summary_writer.scalar('game_length', len(board_log), i) optimizer = train_step(optimizer, board_log, action_log, rewards) summary_writer.flush() progress_bar.add(1) summary_writer.close() # Convert to tflite model model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': optimizer.target}, input) tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec(shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False) converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict) tflite_model = converter.convert() # Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model) print('TFLite model generated!')
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ ## get random seed rng = jax.random.PRNGKey(0) ## Get data train_ds, test_ds = get_datasets("cifar10") ## Initializing model and infering dimensions of layers from one example batch model = models.ResNet18(num_classes=10) init_params = model.init( rng, jnp.ones((1, 32, 32, 3)) ) # figure this shape out automatically ? params = init_params solver, solver_param_name = get_solver( FLAGS, config, loss_fun, losses=loss_fun) # losses is not defined yet! params, state = solver.init(params) ## Path to dump results dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10") summary_writer = tensorboard.SummaryWriter(dumpath) summary_writer.hparams(dict(config)) for epoch in range(1, config.num_epochs + 1): rng, _ = jax.random.split(rng) params, state = train_epoch( config, solver, params, state, train_ds, rng ) test_loss, test_accuracy = eval_model(params, test_ds) train_loss, train_accuracy = eval_model(params, train_ds) print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch, train_loss, train_accuracy * 100) logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) summary_writer.scalar("train_loss", train_loss, epoch) summary_writer.scalar("test_loss", loss, epoch) summary_writer.scalar("train_accuracy", train_accuracy, epoch) summary_writer.scalar("test_accuracy", test_accuracy, epoch) summary_writer.flush()
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ train_ds, test_ds = get_datasets("mnist") # Get solver solver, solver_param_name = get_solver(FLAGS, config, loss_fun, losses) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) init_params = CNN().init(init_rng, jnp.ones([1, 28, 28, 1]))["params"] params, state = solver.init(init_params) # Full path to dump resultss dumpath = create_dumpfile(config, solver_param_name, workdir, "mnist") summary_writer = tensorboard.SummaryWriter(dumpath) summary_writer.hparams(dict(config)) # Run solver. for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) params, state, train_metrics = train_epoch(config, solver, params, state, train_ds, epoch, input_rng) test_loss, test_accuracy = eval_model(params, test_ds) print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) summary_writer.scalar("train_loss", train_metrics["loss"], epoch) summary_writer.scalar("train_accuracy", train_metrics["accuracy"], epoch) summary_writer.scalar("eval_loss", test_loss, epoch) summary_writer.scalar("eval_accuracy", test_accuracy, epoch) summary_writer.flush()
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ train_ds, test_ds = get_datasets() rng = jax.random.PRNGKey(0) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng, init_rng = jax.random.split(rng) cnn = CNN() params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_metrics = train_epoch(state, train_ds, config.batch_size, epoch, input_rng) loss, accuracy = eval_model(state.params, test_ds) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss, accuracy * 100) summary_writer.scalar('train_loss', train_metrics['loss'], epoch) summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('eval_loss', loss, epoch) summary_writer.scalar('eval_accuracy', accuracy, epoch) summary_writer.flush() return state
def setUp(self): super().setUp() self._batch_size = 128 # Note: Tests are run on GPU/TPU. self._batch_size_test = 128 self._shuffle_buffer_size = 1024 self._rng = jax.random.PRNGKey(42) self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._num_classes = 10 self._num_epochs = 1 self._learning_rate_fn = lambda _: 0.01 self._weight_decay = 0.0001 self._momentum = 0.9 self._rng = jax.random.PRNGKey(42) self._min_loss = jnp.finfo(float).eps self._max_loss = 2.0 * math.log(self._num_classes) self._dataset_name = 'MNIST' self._model_name = 'MNIST_CNN' self._summarywriter = tensorboard.SummaryWriter('/tmp/') self._dataset = dataset_factory.create_dataset( self._dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) self._model, self._state = model_factory.create_model( self._model_name, self._rng, (self._input_shape, ), num_classes=self._num_classes) self._optimizer = flax.optim.Momentum( learning_rate=self._learning_rate_fn(0), beta=self._momentum, weight_decay=self._weight_decay)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ train_ds, test_ds = get_datasets() rng = jax.random.PRNGKey(0) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, config) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, config.batch_size, input_rng) _, test_loss, test_accuracy = apply_model(state, test_ds['image'], test_ds['label']) print( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)) summary_writer.scalar('train_loss', train_loss, epoch) summary_writer.scalar('train_accuracy', train_accuracy, epoch) summary_writer.scalar('test_loss', test_loss, epoch) summary_writer.scalar('test_accuracy', test_accuracy, epoch) summary_writer.flush() return state
def train_and_evaluate(model_dir: str, num_epochs: int, batch_size: int, learning_rate: float, momentum: float): """Execute model training and evaluation loop. Args: model_dir: Directory where the tensorboard summaries are written to. num_epochs: Number of epochs to cycle through the dataset before stopping. batch_size: Batch size of the input. learning_rate: Learning rate for the momentum optimizer. momentum: Momentum value for the momentum optimizer. Returns: The trained optimizer. """ train_ds, test_ds = get_datasets() rng = random.PRNGKey(0) summary_writer = tensorboard.SummaryWriter(model_dir) rng, init_rng = random.split(rng) model = create_model(init_rng) optimizer = create_optimizer(model, learning_rate, momentum) for epoch in range(1, num_epochs + 1): rng, input_rng = random.split(rng) optimizer, train_metrics = train_epoch( optimizer, train_ds, batch_size, epoch, input_rng) loss, accuracy = eval_model(optimizer.target, test_ds) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss, accuracy * 100) summary_writer.scalar('train_loss', train_metrics['loss'], epoch) summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('eval_loss', loss, epoch) summary_writer.scalar('eval_accuracy', accuracy, epoch) summary_writer.flush() return optimizer
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The trained optimizer. """ train_ds, test_ds = get_datasets() rng = jax.random.PRNGKey(0) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng, init_rng = jax.random.split(rng) params = get_initial_params(init_rng) optimizer = create_optimizer(params, config.learning_rate, config.momentum) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) optimizer, train_metrics = train_epoch(optimizer, train_ds, config.batch_size, epoch, input_rng) loss, accuracy = eval_model(optimizer.target, test_ds) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss, accuracy * 100) summary_writer.scalar('train_loss', train_metrics['loss'], epoch) summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('eval_loss', loss, epoch) summary_writer.scalar('eval_accuracy', accuracy, epoch) summary_writer.flush() return optimizer
def run_train(run_configuration): """Runs the training workflow.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter log_dir = os.path.join(run_dir, 'train') checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rngs = jax.random.split(rng, jax.local_device_count()) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) summary_freq = config.logging.summary_freq metrics_all = [] tick = time.time() for step, example in zip(range(start_step, num_train_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rngs, logits, state = train_step( optimizer, train_inputs, dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint( checkpoint_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if summary_freq and step % summary_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = summary_freq / (tock - tick) examples_per_sec = denominator / (tock - tick) tick = tock summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = []
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. base_learning_rate = base_learning_rate / FLAGS.loss_scaling model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = "tpu_driver" jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_token = 2 # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def train( model: models.ActorCritic, config: ml_collections.ConfigDict, model_dir: str): """Main training loop. Args: model: the actor-critic model config: object holding hyperparameters and the training information model_dir: path to dictionary where checkpoints and logging info are stored Returns: optimizer: the trained optimizer """ game = config.game + 'NoFrameskip-v4' simulators = [agent.RemoteSimulator(game) for _ in range(config.num_agents)] summary_writer = tensorboard.SummaryWriter(model_dir) summary_writer.hparams(dict(config)) loop_steps = config.total_frames // (config.num_agents * config.actor_steps) log_frequency = 40 checkpoint_frequency = 500 # train_step does multiple steps per call for better performance # compute number of steps per call here to convert between the number of # train steps and the inner number of optimizer steps iterations_per_step = (config.num_agents * config.actor_steps // config.batch_size) initial_params = get_initial_params(jax.random.PRNGKey(0), model) state = create_train_state(initial_params, model, config, loop_steps * config.num_epochs * iterations_per_step) del initial_params state = checkpoints.restore_checkpoint(model_dir, state) # number of train iterations done by each train_step start_step = int(state.step) // config.num_epochs // iterations_per_step logging.info('Start training from step: %s', start_step) for step in range(start_step, loop_steps): # Bookkeeping and testing. if step % log_frequency == 0: score = test_episodes.policy_test(1, state.apply_fn, state.params, game) frames = step * config.num_agents * config.actor_steps summary_writer.scalar('game_score', score, frames) logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score) # Core training code. alpha = 1. - step / loop_steps if config.decaying_lr_and_clip_param else 1. all_experiences = get_experience( state, simulators, config.actor_steps) trajectories = process_experience( all_experiences, config.actor_steps, config.num_agents, config.gamma, config.lambda_) clip_param = config.clip_param * alpha for _ in range(config.num_epochs): permutation = np.random.permutation( config.num_agents * config.actor_steps) trajectories = tuple(x[permutation] for x in trajectories) state, _ = train_step( state, trajectories, config.batch_size, clip_param=clip_param, vf_coeff=config.vf_coeff, entropy_coeff=config.entropy_coeff) if (step + 1) % checkpoint_frequency == 0: checkpoints.save_checkpoint(model_dir, state, step + 1) return train_state
def train(config, workdir): """Runs a training loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs tf.io.gfile.makedirs(workdir) sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) rng = jax.random.PRNGKey(config.seed) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) if jax.host_id() == 0: writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) # Whether the generative model is conditioned on class labels class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = (jax.local_device_count(), config.data.image_size, config.data.image_size, 3) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. # Create checkpoints directory and the initial checkpoint checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = utils.Checkpoint( checkpoint_dir, max_to_keep=None) ckpt.restore_or_initialize(state) # Save intermediate checkpoints to resume training automatically checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") ckpt_meta = utils.Checkpoint( checkpoint_meta_dir, max_to_keep=1) state = ckpt_meta.restore_or_initialize(state) initial_step = int(state.step) rng = state.rng # Build input pipeline. rng, ds_rng = jax.random.split(rng) train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types scaler = datasets.get_data_scaler(config) # data normalizer inverse_scaler = datasets.get_data_inverse_scaler(config) # Distribute training. optimize_fn = losses.optimization_manager(config) if config.training.loss.lower() == "ddpm": # Use score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=True, optimize_fn=optimize_fn) eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) # Whether to use a continuous distribution of noise levels continuous = "continuous" in config.training.loss.lower() train_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=True, optimize_fn=optimize_fn, anneal_power=config.training.anneal_power) eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=False, anneal_power=config.training.anneal_power) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch") state = flax_utils.replicate(state) num_train_steps = config.training.n_iters logging.info("Starting training loop at step %d.", initial_step) rng = jax.random.fold_in(rng, jax.host_id()) for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. # Convert data to JAX arrays. Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) loss, state = p_train_step(next_rng, state, batch) loss = flax.jax_utils.unreplicate(loss) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) if jax.host_id() == 0 and step % 50 == 0: logging.info("step: %d, training_loss: %.5e", step, loss) writer.scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption. if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id( ) == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt_meta.save(saved_state) # Report the loss on an evaluation dataset. if step % 100 == 0: rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, state, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) if jax.host_id() == 0: logging.info("step: %d, eval_loss: %.5e", step, eval_loss) writer.scalar("eval_loss", eval_loss, step) # Save a checkpoint periodically and generate samples. if (step + 1) % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. if jax.host_id() == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt.save(saved_state) # Generate and save samples if config.training.snapshot_sampling: rng, sample_rng = jax.random.split(rng) init_shape = tuple(train_ds.element_spec["image"].shape) samples = sampling.get_samples(sample_rng, config, flax_utils.unreplicate(state), init_shape, scaler, inverse_scaler, class_conditional=class_conditional) this_sample_dir = os.path.join( sample_dir, "iter_{}_host_{}".format(step, jax.host_id())) tf.io.gfile.makedirs(this_sample_dir) if config.sampling.final_only: # Do not save intermediate samples sample = samples[-1] image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2) else: # Save all intermediate samples produced during sampling. for i, sample in enumerate(samples): image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.np".format(i)), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.png".format(i)), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) else: summary_writer = None if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') logging.info('Training on %s', FLAGS.task_name) if model_type in ['wideresnet', 'resnet', 'simple_cnn']: normalize = True else: # transformer-based models normalize = False (train_ds, eval_ds, test_ds, num_classes, vocab_size, input_shape) = task_registry.TASK_DATA_DICT[FLAGS.task_name]( n_devices=jax.local_device_count(), batch_size=batch_size, normalize=normalize) train_iter = iter(train_ds) model_kwargs = {} flatten_input = True if model_type in ['wideresnet', 'resnet', 'simple_cnn']: model_kwargs.update({ 'num_classes': num_classes, }) flatten_input = False else: # transformer models # we will flatten the input bs, h, w, c = input_shape assert c == 1 input_shape = (bs, h * w * c) model_kwargs.update({ 'vocab_size': vocab_size, 'max_len': input_shape[1], 'classifier': True, 'num_classes': num_classes, }) model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model, state = get_model(init_rng, input_shape, model_type, model_kwargs) optimizer = create_optimizer(model, learning_rate, config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer, state = checkpoints.restore_checkpoint( FLAGS.model_dir, (optimizer, state)) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer and state optimizer = jax_utils.replicate(optimizer) state = jax_utils.replicate(state) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup, steps_per_cycle=config.get('steps_per_cycle', None), ) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, num_classes=num_classes, grad_clip_norm=config.get('grad_clip_norm', None), flatten_input=flatten_input), axis_name='batch') p_eval_step = jax.pmap( functools.partial(eval_step, num_classes=num_classes, flatten_input=flatten_input), axis_name='batch', ) optimizer, state, step = train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps, num_train_steps, optimizer, state, p_eval_step, p_train_step, start_step, train_iter, summary_writer) logging.info('Starting testing') logging.info('====================') test(optimizer, state, p_eval_step, step, test_ds, summary_writer, FLAGS.model_dir)
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 is_regression_task = config.dataset_name == "glue/stsb" num_classes = (1 if is_regression_task else ds_info.features["label"].num_classes) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=num_classes) params = _init_params(model, init_rng, config) optimizer = _create_adam_optimizer(config.learning_rate, params) # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if (start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir): optimizer = _restore_pretrained_model(optimizer, params, config) # We access model state only from optimizer via optimizer.target. del params optimizer = jax_utils.replicate(optimizer) if is_regression_task: compute_stats = functools.partial(_compute_regression_stats, model=model, pad_id=tokenizer.pad_id()) else: compute_stats = functools.partial(_compute_classification_stats, model=model, pad_id=tokenizer.pad_id()) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) glue_inputs = functools.partial(input_pipeline.glue_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = glue_inputs(split=tfds.Split.TRAIN, batch_size=per_process_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics, model=model, pad_id=tokenizer.pad_id()) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn), axis_name="batch") p_eval_step = jax.pmap(functools.partial(train_utils.eval_step, metric_fn=compute_stats), axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name, is_regression_task) train_metrics = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceAnnotation("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) optimizer, train_step_metrics, rngs = p_train_step(optimizer, train_batch, rng=rngs) train_metrics.append(train_step_metrics) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1) and jax.process_index() == 0: # Save un-replicated optimizer and model state. checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = { "loss": jnp.sum(train_metrics["loss"]) / jnp.sum(train_metrics["num_labels"]), "learning_rate": learning_rate_fn(step) } if not is_regression_task: train_summary["accuracy"] = jnp.sum( train_metrics["correct_predictions"]) / jnp.sum( train_metrics["num_labels"]) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. train_metrics = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix, batch_size=per_process_eval_batch_size, training=False) all_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): all_stats.append( _evaluate(p_eval_step, optimizer.target, eval_batch, n_devices)) flat_stats = {} for k in all_stats[ 0]: # All batches of output stats are the same size flat_stats[k] = np.concatenate([stat[k] for stat in all_stats], axis=0) eval_summary = eval_metrics_fn(flat_stats) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer. with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.PreTrainingModel(config=config) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, decay_steps=config.num_train_steps - config.num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config) train_iter = iter(train_ds) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") eval_step = functools.partial(_compute_eval_stats, model=model) p_eval_step = jax.pmap(eval_step, axis_name="batch") seconds = 0. train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) tick = time.time() state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) if config.measure_step_speed: jax.tree_map(lambda opt: opt.block_until_ready(), state) tock = time.time() seconds += tock - tick train_stats.append(train_step_stats) if (step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_state # Only used for checkpointing. # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_pretraining_metrics(train_metrics) train_summary["learning_rate"] = learning_rate_fn(step) if config.measure_step_speed: train_summary["steps_per_sec"] = (step - start_step + 1) / seconds if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering evaluation metrics at step: %d", step) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_batch = common_utils.shard(eval_batch) eval_stats.append(p_eval_step(state.params, eval_batch)) eval_metrics = train_utils.collect_metrics(eval_stats) eval_summary = train_utils.compute_pretraining_metrics( eval_metrics, record_grad_norm=False) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap(functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2, )) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3, ), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean( np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean( np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map( lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod( jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print( f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec" ) summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) test_render_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. # pylint: disable=g-long-lambda lambda key_0, key_1, model, rays: jax.lax.all_gather( model(key_0, key_1, *rays), axis_name="batch"), in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=3, axis_name="batch", ) rng, key = random.split(rng) init_model, init_state = models.get_model(key, dataset.peek(), FLAGS) optimizer_def = optim.Adam(FLAGS.lr_init) optimizer = optimizer_def.create(init_model) state = model_utils.TrainState(step=0, optimizer=optimizer, model_state=init_state) if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) offset = state.step + 1 state = jax_utils.replicate(state) del init_model, init_state if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) t_loop_start = time.time() learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) ptrain_step = jax.pmap(train_step, axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=2) # Prefetch_buffer_size = 3 x batch_size pdataset = jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset): lr = learning_rate_fn(step) state, stats, keys = ptrain_step(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats[0]) if step % FLAGS.gc_every == 0: gc.collect() # --- Train logs start --- # Put the training time visualization before the host_id check as in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state)) test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( state_to_eval, test_case["rays"], test_render_fn, keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) summary_writer.scalar("test_psnr", psnr, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if jax.host_id() != 0: # Only log via host 0. continue if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats[0].loss[0], step) summary_writer.scalar("train_psnr", stats[0].psnr[0], step) if len(stats) > 1: summary_writer.scalar("train_loss_coarse", stats[1].loss[0], step) summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) t_loop_start = time.time() rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("steps_per_sec", steps_per_sec, step) summary_writer.scalar("rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats[0].loss[0]:0.5f}, " + f"avg_loss={avg_loss:0.5f}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.3f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, state_to_save.step, keep=100) # --- Train logs end --- if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(state.step), keep=100)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps eval_freq = FLAGS.eval_frequency random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig( vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), max_len=FLAGS.max_length) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, repeat=1) model = models.Transformer(config) rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables init_variables = initialize_variables(init_rng) optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(init_variables['params']) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn), axis_name='batch') def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = model.apply({'params': params}, inputs=inputs, train=False) return compute_metrics(logits, targets, weights) p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] # reset metric accumulation for next evaluation cycle. eval_metrics = [] eval_iter = iter(eval_ds) for eval_batch in eval_iter: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.process_index() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables lpips_model = tf_hub.load(LPIPS_TFHUB_PATH) # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] lpips_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) lpips = compute_lpips(pred_color, batch["pixels"], lpips_model) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) lpips_values.append(float(lpips)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img(pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in lpips_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f: f.write("{}".format(np.mean(np.array(lpips_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps num_eval_steps = FLAGS.num_eval_steps eval_freq = FLAGS.eval_frequency max_length = FLAGS.max_length random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params') if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length, repeat=1) train_iter = iter(train_ds) bs = device_batch_size * jax.device_count() rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) input_shape = (bs, max_length) transformer_kwargs = { 'vocab_size': len(vocabs['forms']), 'output_vocab_size': len(vocabs['xpos']), 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max_length, } model = create_model(init_rng, tuple(input_shape), transformer_kwargs) optimizer = create_optimizer(model, learning_rate) del model # don't keep a copy of the initial model learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # reset metric accumulation for next evaluation cycle. metrics_all = [] eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: logging.info('Uneven batch size %d.', cur_pred_batch_size) eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip( jnp.exp(eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) # BOS special attention only makes sense if we are using relative attention # and it's not the baseline. if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention or FLAGS.attention_mask_type == 'baseline'): raise ValueError( "bos_special_attention doesn't work when use_relative_attention={} and " 'attention_mask_type={}'.format(FLAGS.use_relative_attention, FLAGS.attention_mask_type)) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program.tolist(), id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:]) logging.info('padded_shapes: %s', padded_shapes) dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:], program_shape[1:]) logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.test_dataset_filepattern, token_id_table, char_id_table) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- default_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size) base_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=min(FLAGS.max_distance, default_config.max_input_distance), num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=min(FLAGS.max_distance, default_config.max_output_distance), num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=min( FLAGS.max_distance, default_config.max_input_cross_output_distance), num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=min(FLAGS.max_distance, default_config.max_program_distance), num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=min( FLAGS.max_distance, default_config.max_program_cross_embed_distance), bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.max_characters, FLAGS.max_program_length, FLAGS.predict_max_characters)), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, programs, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 pred_successes = collections.defaultdict(int) pred_denominators = collections.defaultdict(int) ios, targets, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard( pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map( tohost, (inputs, outputs, programs)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=decode_program) # Split by length of program. program = programs[i] num_expressions = len( decode_program(program).expressions) pred_denominators[num_expressions] += 1 total_denominator += 1 if p_score >= len(inps): pred_successes[num_expressions] += 1 total_successes += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append( decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('Did not compile') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append( decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('Did not compile') logging.info('predicted beam: %s', '\n'.join(beams_log)) top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): try: decoded_program = decode_program( beam).to_string() except: # pylint: disable=bare-except decoded_program = 'Did not compile' top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decoded_program, beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) all_pred_successes, all_pred_denominators = per_host_sum_pmap( jax.tree_map(np.array, (pred_successes, pred_denominators))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) for length in sorted(all_pred_successes.keys()): this_length_accuracy = ( 100 * all_pred_successes[length] / all_pred_denominators[length]) logging.info( ' accuracy for length %s: %s / %s = %.2f%%', length, all_pred_successes[length], all_pred_denominators[length], this_length_accuracy) summary_writer.scalar( '{}-by-length/beam-size-{}-length-{}'.format( predict_or_test, beam_size, length), this_length_accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush()
from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),),
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()