def save_state(state, output_dir, keep=False): """Save State and optionally gin config.""" # TODO(gilmer, lukaszkaiser): figure out how to use cloudpickle in python3. # Currently the code throws an error when run in python3. if sys.version_info[0] < 3: pkl_module = cloudpickle else: pkl_module = pickle params_file = os.path.join(output_dir, "model.pkl") params = jax.unreplicate(state.params) with gfile.GFile(params_file, "wb") as f: pkl_module.dump((params, state.step, state.history), f) if keep: params_file = os.path.join(output_dir, "model_{}.pkl".format(state.step)) with gfile.GFile(params_file, "wb") as f: pkl_module.dump((params, state.step, state.history), f) log("Model saved to %s" % params_file, stdout=False)
def train(output_dir, model=gin.REQUIRED, loss_fun=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, num_devices=None, random_seed=None, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. num_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ num_devices = num_devices or jax.lib.xla_bridge.device_count() rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs() # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict = model() # Setup state step = state.step or 0 rng, init_key = jax_random.split(rng) params_initializer = \ lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when pmap is stable. opt_state = jax.replicate(opt_state) # jit model_predict and update so they're fast jit_model_predict = backend.jit(model_predict) # for evaluation jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer, lr_fun, num_devices) print() train_stream = inputs.train_stream() epoch_steps = [train_steps ] # Only training if eval_frequency is 0 or None. if eval_frequency: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % num_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fun(params, next(train_stream), model_predict, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, num_devices) rng, subrng = jax_random.split(rng) opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log( step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate if num_devices > 1: # TODO(lukaszkaiser): remove branch when possible. params = trax_opt.get_params(jax.unreplicate(opt_state)) else: params = trax_opt.get_params(opt_state) evaluate_train_and_eval(step=step, inputs=inputs, predict_fun=functools.partial( jit_model_predict, params), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer, lr_fun, num_devices) # Flush summary writers train_sw.writer.flush() eval_sw.writer.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
batches = data_stream() @partial(pmap, axis_name='batch') def spmd_update(params, batch): grads = grad(loss)(params, batch) # We compute the total gradients, summing across the device-mapped axis, # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] # We replicate parameters out across devices. (Check the implementation of # replicate; analogous to device_put, it's a simple wrapper around pmap.) params = replicate(init_random_params(param_scale, layer_sizes)) for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): params = spmd_update(params, next(batches)) epoch_time = time.time() - start_time # We evaluate using the jitted `accuracy` function (not using pmap) by # grabbing just one of the replicated parameter values. train_acc = accuracy(unreplicate(params), (train_images, train_labels)) test_acc = accuracy(unreplicate(params), (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))