def test_train_one_step(self): """Tests training loop over one step.""" iterator = self._dataset.get_train() batch = next(iterator) state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) self._rng, step_key = jax.random.split(self._rng) batch = training._shard_batch(batch) sharded_keys = common_utils.shard_prng_key(step_key) p_train_step = jax.pmap(functools.partial( training.train_step, learning_rate_fn=self._learning_rate_fn), axis_name='batch') _, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys, state) loss = jnp.mean(loss) gradient_norm = jax_utils.unreplicate(gradient_norm) with self.subTest(name='test_loss_range'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_gradient_norm'): self.assertGreaterEqual(gradient_norm, 0)
def train_for_one_epoch( dataset_source: dataset_source_lib.DatasetSource, optimizer: flax.optim.Optimizer, state: flax.nn.Collection, prng_key: jnp.ndarray, pmapped_train_step: _TrainStep, pmapped_update_ema: Optional[_EMAUpdateStep], moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage], summary_writer: tensorboard.SummaryWriter ) -> Tuple[flax.optim.Optimizer, flax.nn.Collection, Optional[efficientnet_optim.ExponentialMovingAverage]]: """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). pmapped_update_ema: Function to update the parameter moving average. Can be None if we don't use EMA. moving_averages: Parameters moving average if used. summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ start_time = time.time() cnt = 0 train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key = jax.random.fold_in(prng_key, optimizer.state.step[0]) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) cnt += 1 if moving_averages is not None: moving_averages = pmapped_update_ema(optimizer, state, moving_averages) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) info = 'Whole training step done in {} ({} steps)'.format( time.time()-start_time, cnt) logging.info(info) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, moving_averages
def initial_state(self): return TrainState( history=self, rng=common_utils.shard_prng_key( jax.random.PRNGKey(np.random.randint(2 ** 16)) ), step=None, metrics=None, )
def update_preconditioner(config, optimizer, p_update_grad_vars, rng, state, train_iter): """Computes preconditioner state using samples from dataloader.""" # TODO(basv): support multiple hosts. values = jax.tree_map(jnp.zeros_like, optimizer.target) eps = config.precon_est_eps n_batches = config.precon_est_batches for _ in range(n_batches): rng, est_key = jax.random.split(rng) batch = next(train_iter) batch = input_pipeline.load_and_shard_tf_batch(config, batch) if not config.debug_run: # Shard the step PRNG key sharded_keys = common_utils.shard_prng_key(est_key) else: sharded_keys = est_key values = p_update_grad_vars(optimizer, state, batch, sharded_keys, values) stds = jax.tree_map( lambda v: jnp.sqrt(eps + (1 / n_batches) * jnp.mean(v)), values) std_min = jnp.min(jnp.asarray(jax.tree_leaves(stds))) # TODO(basv): verify preconditioner estimate. new_precon = jax.tree_map(lambda s, x: jnp.ones_like(x) * (s / std_min), stds, optimizer.target) def convert_momentum( new_precon, state, ): """Converts momenta to new preconditioner.""" if config.weight_norm == 'learned': state = state.direction_state old_precon = state.preconditioner momentum = state.momentum m_c = jnp.power(old_precon, -.5) * momentum m = jnp.power(new_precon, .5) * m_c return m # TODO(basv): verify momentum convert. new_momentum = jax.tree_map(convert_momentum, new_precon, optimizer.state.param_states) # TODO(basv): verify this is replaced correctly, check replicated. optimizer = replace_param_state(config, optimizer, preconditioner=new_precon, momentum=new_momentum) return optimizer, rng
def train_for_one_epoch( dataset_source, optimizer, state, prng_key, pmapped_train_step, summary_writer ): """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key, prng_key = jax.random.split(prng_key) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, prng_key
async def dalle(self, ctx: commands.Context, *, prompt: str): prompts = [ "sunset over a lake in the mountains", "the Eiffel tower landing on the moon" ] tokenized_prompts = processor(prompts) tokenized_prompt = replicate(tokenized_prompts) # generate images images = [] for i in trange(max(n_predictions // jax.device_count(), 1)): # get a new key key, subkey = jax.random.split(key) # generate images encoded_images = p_generate( tokenized_prompt, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale, ) # remove BOS encoded_images = encoded_images.sequences[..., 1:] # decode images decoded_images = p_decode(encoded_images, vqgan_params) decoded_images = decoded_images.clip(0.0, 1.0).reshape( (-1, 256, 256, 3)) for decoded_img in decoded_images: img = Image.fromarray( np.asarray(decoded_img * 255, dtype=np.uint8)) images.append(img) # display(img) filename = f"{random.randrange(100, 999)}@{datetime.now()}" print(f"Saving picture '{filename}'") with open(Path(self.cache_dir, filename), 'wb') as image_file: shutil.copyfileobj(img, image_file)
def train(): """Train model.""" batch_size = FLAGS.batch_size n_devices = jax.device_count() if jax.host_count() > 1: raise ValueError('PixelCNN++ example should not be run on more than 1 host' ' (for now)') if batch_size % n_devices > 0: raise ValueError('Batch size must be divisible by the number of devices') train_summary_writer, eval_summary_writer = get_summary_writers() # Load dataset data_source = input_pipeline.DataSource( train_batch_size=batch_size, eval_batch_size=batch_size) train_ds = data_source.train_ds eval_ds = data_source.eval_ds # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Compute steps per epoch and nb of eval steps steps_per_epoch = data_source.TRAIN_IMAGES // batch_size steps_per_eval = data_source.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * FLAGS.num_epochs # Create the model using data-dependent initialization. Don't shard the init # batch. assert FLAGS.init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:FLAGS.init_batch_size] rng = random.PRNGKey(FLAGS.rng) rng, init_rng = random.split(rng) rng, dropout_rng = random.split(rng) initial_variables = model().init({ 'params': init_rng, 'dropout': dropout_rng }, init_batch)['params'] optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) optimizer, ema = restore_checkpoint(optimizer, initial_variables) ema = initial_variables step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: FLAGS.learning_rate * FLAGS.lr_decay ** step # pmap the train and eval functions p_train_step = jax.pmap( partial(train_step, learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_steps), train_iter): # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Generate a PRNG key that will be rolled into the batch. rng, step_rng = random.split(rng) sharded_rngs = common_utils.shard_prng_key(step_rng) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) train_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: save_checkpoint(optimizer, ema, step)
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in [ "float32", "float64" ] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained( args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [ name for name in raw_datasets["train"].column_names if name != "label" ] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if (model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression): # Some have all caps in their config, some don't. label_name_to_id = { k.lower(): v for k, v in model.config.label2id.items() } if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!") label_to_id = { i: label_name_to_id[label_list[i]] for i in range(num_labels) } else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ((examples[sentence1_key], ) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer summary_writer = tensorboard.SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) train_batch_size = args.per_device_train_batch_size * jax.local_device_count( ) eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count( ) learning_rate_fn = create_learning_rate_fn(len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate) state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") # make sure weights are replicated on each device state = replicate(state) train_start = time.time() train_metrics = [] rng, input_rng, dropout_rng = jax.random.split(rng, 3) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): dropout_rngs = shard_prng_key(dropout_rng) state, metrics = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") rng, input_rng = jax.random.split(rng) # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # put weights on single device state = unreplicate(state) # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(state, batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(train_metrics, eval_metric, train_time, cur_step) # save last checkpoint if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params)
def train(pcnn_module, model_dir, batch_size, init_batch_size, num_epochs, learning_rate, decay_rate, run_seed=0): """Train model.""" if jax.host_count() > 1: raise ValueError( 'PixelCNN++ example should not be run on more than 1 host' ' (for now)') current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_dir = model_dir + '/log/' + current_time train_log_dir = log_dir + '/train' eval_log_dir = log_dir + '/eval' train_summary_writer = tensorboard.SummaryWriter(train_log_dir) eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir) rng = random.PRNGKey(run_seed) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') # Load dataset data_source = input_pipeline.DataSource(train_batch_size=batch_size, eval_batch_size=batch_size) train_ds = data_source.train_ds eval_ds = data_source.eval_ds # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Compute steps per epoch and nb of eval steps steps_per_epoch = data_source.TRAIN_IMAGES // batch_size steps_per_eval = data_source.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = learning_rate # Create the model using data-dependent initialization. Don't shard the init # batch. assert init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:init_batch_size] model = create_model(rng, init_batch, pcnn_module) ema = model.params optimizer = create_optimizer(model, base_learning_rate) del model # don't keep a copy of the initial model optimizer, ema = restore_checkpoint(optimizer, ema) step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: base_learning_rate * decay_rate**step # pmap the train and eval functions p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_steps), train_iter): # Generate a PRNG key that will be rolled into the batch rng, step_key = jax.random.split(rng) # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Shard the step PRNG key sharded_keys = common_utils.shard_prng_key(step_key) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_keys) train_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation model_ema = optimizer.target.replace(params=ema) eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(model_ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: save_checkpoint(optimizer, ema)
def initial_state(self): return TrainState( rng=common_utils.shard_prng_key(jax.random.PRNGKey(0)), step=None, metrics=None, history=self)
def train(model_def, model_dir, batch_size, num_epochs, learning_rate, sgd_momentum, make_lr_fun=None, l2_reg=0.0005, run_seed=0): """Train model.""" if jax.host_count() > 1: raise ValueError('CIFAR-10 example should not be run on ' 'more than 1 host (for now)') if make_lr_fun is None: # No learning rate function provided # Default to stepped LR schedule for CIFAR-10 and Wide ResNet def make_lr_fun(base_lr, steps_per_epoch): # pylint: disable=function-redefined return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, [[60, 0.2], [120, 0.04], [160, 0.008]]) summary_writer = tensorboard.SummaryWriter(model_dir) rng = random.PRNGKey(run_seed) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # Load dataset data_source = input_pipeline.CIFAR10DataSource(train_batch_size=batch_size, eval_batch_size=batch_size) train_ds = data_source.train_ds eval_ds = data_source.eval_ds # Compute steps per epoch and nb of eval steps steps_per_epoch = data_source.TRAIN_IMAGES // batch_size steps_per_eval = data_source.EVAL_IMAGES // batch_size num_steps = steps_per_epoch * num_epochs base_learning_rate = learning_rate # Create the model image_size = 32 model, state = create_model(rng, device_batch_size, image_size, model_def) state = jax_utils.replicate(state) optimizer = create_optimizer(model, base_learning_rate, sgd_momentum) del model # don't keep a copy of the initial model # Learning rate schedule learning_rate_fn = make_lr_fun(base_learning_rate, steps_per_epoch) # pmap the train and eval functions p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Gather metrics train_metrics = [] epoch = 1 for step, batch in zip(range(num_steps), train_iter): # Generate a PRNG key that will be rolled into the batch rng, step_key = jax.random.split(rng) # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Shard the step PRNG key sharded_keys = common_utils.shard_prng_key(step_key) # Train step optimizer, state, metrics = p_train_step(optimizer, state, batch, sharded_keys) train_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info( 'Epoch %d: TRAIN loss=%.6f, err=%.2f, EVAL loss=%.6f, err=%.2f', epoch, train_summary['loss'], train_summary['error_rate'] * 100.0, eval_summary['loss'], eval_summary['error_rate'] * 100.0) summary_writer.scalar('eval_loss', eval_summary['loss'], epoch) summary_writer.scalar('eval_error_rate', eval_summary['error_rate'], epoch) summary_writer.flush() epoch += 1
def main(_): if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax_config.update("jax_xla_backend", "tpu_driver") jax_config.update("jax_backend_target", FLAGS.jax_backend_target) logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) # summary_writer.hparams(dict(FLAGS.config)) rng = random.PRNGKey(FLAGS.seed) rng, init_rng_coarse, init_rng_fine = random.split(rng, 3) n_devices = jax.device_count() ### Load dataset and data values if FLAGS.config.dataset_type == "blender": images, poses, render_poses, hwf, counts = load_blender.load_data( FLAGS.data_dir, half_res=FLAGS.config.half_res, testskip=FLAGS.config.testskip, ) logging.info("Loaded blender, total images: %d", images.shape[0]) near = 2.0 far = 6.0 if FLAGS.config.white_bkgd: images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) else: images = images[..., :3] elif FLAGS.config.dataset_type == "deepvoxels": images, poses, render_poses, hwf, counts = load_deepvoxels.load_dv_data( FLAGS.data_dir, scene=FLAGS.config.shape, testskip=FLAGS.config.testskip, ) hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1.0 far = hemi_R + 1.0 logging.info( "Loaded deepvoxels (%s), total images: %d", FLAGS.config.shape, images.shape[0], ) else: raise ValueError(f"Dataset '{FLAGS.config.dataset_type}' is not available.") img_h, img_w, focal = hwf logging.info("Images splits: %s", counts) logging.info("Render poses: %s", render_poses.shape) logging.info("Image height: %d, image width: %d, focal: %.5f", img_h, img_w, focal) train_imgs, val_imgs, test_imgs, *_ = np.split(images, np.cumsum(counts)) train_poses, val_poses, test_poses, *_ = np.split(poses, np.cumsum(counts)) if FLAGS.config.render_factor > 0: # render downsampled for speed r_img_h = img_h // FLAGS.config.render_factor r_img_w = img_w // FLAGS.config.render_factor r_focal = focal / FLAGS.config.render_factor r_hwf = r_img_h, r_img_w, r_focal else: r_hwf = hwf to_np = lambda x, h=img_h, w=img_w: np.reshape(x, [h, w, -1]).astype(np.float32) psnr_fn = lambda x: -10.0 * np.log(x) / np.log(10.0) ### Pre-compute rays @functools.partial(jax.jit, static_argnums=(0,)) def prep_rays(hwf, c2w, c2w_sc=None): if c2w_sc is not None: c2w_sc = c2w_sc[:3, :4] return prepare_rays(None, hwf, FLAGS.config, near, far, c2w[:3, :4], c2w_sc) rays_render = lax.map(lambda x: prep_rays(r_hwf, x), render_poses) render_shape = [-1, n_devices, r_hwf[1], rays_render.shape[-1]] rays_render = jnp.reshape(rays_render, render_shape) logging.info("Render rays shape: %s", rays_render.shape) if FLAGS.config.use_viewdirs: rays_render_vdirs = lax.map( lambda x: prep_rays(r_hwf, x, render_poses[0]), render_poses ).reshape(render_shape) if FLAGS.config.batching: train_rays = lax.map(lambda pose: prep_rays(hwf, pose), train_poses) train_rays = jnp.reshape(train_rays, [-1, train_rays.shape[-1]]) train_imgs = jnp.reshape(train_imgs, [-1, 3]) logging.info("Batched rays shape: %s", train_rays.shape) val_rays = lax.map(lambda pose: prep_rays(hwf, pose), val_poses) test_rays = lax.map(lambda pose: prep_rays(r_hwf, pose), test_poses) test_rays = jnp.reshape(test_rays, render_shape) ### Init model parameters and optimizer input_pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) input_views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized( init_rng_coarse, input_pts_shape, input_views_shape, FLAGS.config.model ) optimizer = optim.Adam() state = TrainState( step=0, optimizer_coarse=optimizer.create(params_coarse), optimizer_fine=None ) model_fn = (model_coarse.apply, None) if FLAGS.config.num_importance > 0: input_pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized( init_rng_fine, input_pts_shape, input_views_shape, FLAGS.config.model_fine ) state = state.replace(optimizer_fine=optimizer.create(params_fine)) model_fn = (model_coarse.apply, model_fine.apply) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) state = jax_utils.replicate(state) ### Build 'pmapped' functions for distributed training learning_rate_fn = create_learning_rate_scheduler( factors=FLAGS.config.lr_schedule, base_learning_rate=FLAGS.config.learning_rate, decay_factor=FLAGS.config.decay_factor, steps_per_decay=FLAGS.config.lr_decay * 1000, ) p_train_step = jax.pmap( functools.partial( train_step, model_fn, FLAGS.config, learning_rate_fn, (hwf, near, far), ), axis_name="batch", donate_argnums=(0,), ) p_eval_step = jax.pmap( functools.partial(eval_step, model_fn, FLAGS.config), axis_name="batch", ) t = time.time() train_metrics = [] for step in range(start_step, FLAGS.config.num_steps + 1): rng, sample_rng, step_rng, test_rng = random.split(rng, 4) sharded_rngs = common_utils.shard_prng_key(step_rng) coords = None if FLAGS.config.batching: select_idx = random.randint( sample_rng, [n_devices * FLAGS.config.num_rand], minval=0, maxval=train_rays.shape[0], ) inputs = train_rays[select_idx, ...] inputs = jnp.reshape(inputs, [n_devices, FLAGS.config.num_rand, -1]) target = train_imgs[select_idx, ...] target = jnp.reshape(target, [n_devices, FLAGS.config.num_rand, 3]) else: img_idx = random.randint( sample_rng, [n_devices], minval=0, maxval=counts[0] ) inputs = train_poses[img_idx, ...] # [n_devices, 4, 4] target = train_imgs[img_idx, ...] # [n_devices, img_h, img_w, 3] if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jax_utils.replicate( jnp.stack(coords, axis=-1).reshape([-1, 2]) ) state, metrics, coarse_res, fine_res = p_train_step( state, (inputs, target), coords, rng=sharded_rngs ) train_metrics.append(metrics) ### Write summaries to TB if step % FLAGS.config.i_print == 0 and step > 0: steps_per_sec = time.time() - t train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) if jax.host_id() == 0: logging.info( "Step: %6d, %.3f s/step, loss %.5f, psnr %6.3f", step, steps_per_sec, train_summary["loss"], train_summary["psnr"], ) for key, val in train_summary.items(): summary_writer.scalar(f"train/{key}", val, step) summary_writer.scalar("steps per second", steps_per_sec, step) summary_writer.histogram("raw_c", np.array(coarse_res["raw"]), step) if FLAGS.config.num_importance > 0: summary_writer.histogram("raw_f", np.array(fine_res["raw"]), step) train_metrics = [] ### Eval a random validation image and plot it in TB if step % FLAGS.config.i_img == 0: val_idx = random.randint(test_rng, [1], minval=0, maxval=counts[1]) if FLAGS.config.batching: inputs = val_rays[tuple(val_idx)].reshape(render_shape) else: inputs = prep_rays(hwf, val_poses[tuple(val_idx)]) inputs = jnp.reshape(inputs, render_shape) target = val_imgs[tuple(val_idx)] preds, preds_c, z_std = lax.map(lambda x: p_eval_step(state, x), inputs) rgb = to_np(preds["rgb"]) loss = np.mean((rgb - target) ** 2) summary_writer.scalar(f"val/loss", loss, step) summary_writer.scalar(f"val/psnr", psnr_fn(loss), step) rgb = 255 * np.clip(rgb, 0, 1) summary_writer.image("val/rgb", rgb.astype(np.uint8), step) summary_writer.image("val/target", target, step) summary_writer.image("val/disp", to_np(preds["disp"]), step) summary_writer.image("val/acc", to_np(preds["acc"]), step) if FLAGS.config.num_importance > 0: rgb = 255 * np.clip(to_np(preds_c["rgb"]), 0, 1) summary_writer.image("val/rgb_c", rgb.astype(np.uint8), step) summary_writer.image("val/disp_c", to_np(preds_c["disp"]), step) summary_writer.image("val/z_std", to_np(z_std), step) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: logging.info("Rendering video at step %d", step) t = time.time() preds, *_ = lax.map(lambda x: p_eval_step(state, x), rays_render) gen_video(preds["rgb"], "rgb", r_hwf, step) gen_video(preds["disp"] / jnp.max(preds["disp"]), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: preds = lax.map( lambda x: p_eval_step(state, x)[0]["rgb"], rays_render_vdirs ) gen_video(preds, "rgb_still", r_hwf, step) logging.info("Video rendering done in %ds", time.time() - t) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: logging.info("Rendering test set at step %d", step) preds = lax.map(lambda x: p_eval_step(state, x)[0]["rgb"], test_rays) save_test_imgs(preds, r_hwf, step) if FLAGS.config.render_factor == 0: loss = np.mean((preds.reshape(test_imgs.shape) - test_imgs) ** 2.0) summary_writer.scalar(f"test/loss", loss, step) summary_writer.scalar(f"test/psnr", psnr_fn(loss), step) ### Save ckpt if step % FLAGS.config.i_weights == 0 and step > 0: if jax.host_id() == 0: checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(state), step, keep=5, ) t = time.time()
def train(config, model_def, device_batch_size, eval_ds, num_steps, steps_per_epoch, steps_per_eval, train_ds, image_size, data_source, workdir): """Train model.""" make_lr_fn = schedulers.get_make_lr_fn(config) make_temp_fn = schedulers.get_make_temp_fn(config) make_step_size_fn = schedulers.get_make_step_size_fn(config) if jax.host_count() > 1: raise ValueError('CIFAR10 example should not be run on ' 'more than 1 host due to preconditioner updating.') initial_step = 0 # TODO(basv): load from checkpoint. writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs in TensorBoard easier. # with writer.summary_writer.as_default(): writer.write_hparams(dict(config)) rng = random.PRNGKey(config.seed) rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4) base_learning_rate = config.learning_rate # Create the model. model, state = create_model(rng, device_batch_size, image_size, model_def) parameter_overview.log_parameter_overview(model.params) state = jax_utils.replicate(state) train_size = data_source.TRAIN_IMAGES with flax.deprecated.nn.stochastic(init_key): optimizer = create_optimizer(config, model, base_learning_rate, train_size, sampler_rng) del model # Don't keep a copy of the initial model. # Learning rate schedule learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch) temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch) step_size_fn = make_step_size_fn(steps_per_epoch) p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions( config, config.l2_reg, learning_rate_fn, train_size, temperature_fn, step_size_fn) # Create dataset batch iterators. train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Gather metrics. train_metrics = [] epoch = 0 # Ensemble. ensemble = [] ensemble_logits = [] ensemble_labels = [] ensemble_probs = [] def ensemble_add_step(step): if config.lr_schedule == 'cosine': # Add if learning rate jumps up again in the next step. increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8 _, temp_end = ast.literal_eval(config.temp_ramp) past_burn_in = step >= steps_per_epoch * temp_end return increase and past_burn_in elif config.lr_schedule == 'constant': if (step + 1) % steps_per_epoch == 0: return True return False logging.info('Starting training loop at step %d.', initial_step) for step in range(initial_step, num_steps): if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0: optimizer, rng = update_preconditioner(config, optimizer, p_update_grad_vars, rng, state, train_iter) # Generate a PRNG key that will be rolled into the batch step_key = jax.random.fold_in(rng, step) opt_step_rng = jax.random.fold_in(opt_rng, step) # Load and shard the TF batch batch = next(train_iter) batch = input_pipeline.load_and_shard_tf_batch(config, batch) if not config.debug_run: # Shard the step PRNG key # Don't shard the optimizer rng, as it should be equal among all machines. sharded_keys = common_utils.shard_prng_key(step_key) else: sharded_keys = step_key # Train step optimizer, state, metrics = p_train_step(optimizer, state, batch, sharded_keys, opt_step_rng) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if step == initial_step: initial_train_metrics = get_metrics(config, train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} logging.log(logging.INFO, 'initial metrics = %s', str(train_summary.items())) if (step + 1) % steps_per_epoch == 0: # We've finished an epoch # Save model params/state. train_metrics = get_metrics(config, train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} writer.write_scalars(epoch, train_summary) # Reset train metrics train_metrics = [] # Evaluation if config.do_eval: eval_metrics = [] eval_logits = [] eval_labels = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = input_pipeline.load_and_shard_tf_batch( config, eval_batch) # Step logits, labels, metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_logits.append(logits) eval_labels.append(labels) eval_metrics = get_metrics(config, eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) eval_summary = {'eval_' + k: v for k, v in eval_summary.items()} writer.write_scalars(epoch, eval_summary) if config.algorithm == 'sgmcmc' and ensemble_add_step(step): ensemble.append((serialization.to_state_dict(optimizer.target), state)) if config.algorithm == 'sgmcmc' and ensemble_add_step( step) and len(ensemble) >= 1: # Gather predictions for this ensemble sample. eval_logits = jnp.concatenate(eval_logits, axis=0) eval_probs = jax.nn.softmax(eval_logits, axis=-1) eval_labels = jnp.concatenate(eval_labels, axis=0) # Ensure that labels are consistent between predict runs. if ensemble_labels: assert jnp.allclose( eval_labels, ensemble_labels[0]), 'Labels unordered between eval runs.' ensemble_logits.append(eval_logits) ensemble_probs.append(eval_probs) ensemble_labels.append(eval_labels) # Compute ensemble predictions over last config.ensemble_size samples. ensemble_last_probs = jnp.mean( jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0) ensemble_metrics = train_functions.compute_metrics_probs( ensemble_last_probs, ensemble_labels[0]) ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics) ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()} ensemble_summary['ensemble_size'] = min(config.ensemble_size, len(ensemble_probs)) writer.write_scalars(epoch, ensemble_summary) epoch += 1 return ensemble, optimizer
def train_and_evaluate(config: ml_collections.ConfigDict, resume: str): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. resume: Resume from checkpoints at specified dir if set (TDDO: support specific checkpoint file/step) """ rng = random.PRNGKey(42) if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.host_count() config.eval_batch_size = config.eval_batch_size or config.batch_size if config.eval_batch_size % jax.device_count() > 0: raise ValueError( 'Validation batch size must be divisible by the number of devices') local_eval_batch_size = config.eval_batch_size // jax.host_count() platform = jax.local_devices()[0].platform half_prec = config.half_precision if half_prec: if platform == 'tpu': model_dtype = jnp.bfloat16 else: model_dtype = jnp.float16 else: model_dtype = jnp.float32 rng, model_create_rng = random.split(rng) model, variables = create_model(config.model, dtype=model_dtype, drop_rate=config.drop_rate, drop_path_rate=config.drop_path_rate, rng=model_create_rng) image_size = config.image_size or model.default_cfg['input_size'][-1] dataset_builder = tfds.builder(config.dataset, data_dir=config.data_dir) train_iter = create_input_iter( dataset_builder, local_batch_size, train=True, image_size=image_size, augment_name=config.autoaugment, randaug_magnitude=config.randaug_magnitude, randaug_num_layers=config.randaug_num_layers, half_precision=half_prec, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_eval_batch_size, train=False, image_size=image_size, half_precision=half_prec, cache=config.cache) steps_per_epoch = dataset_builder.info.splits[ 'train'].num_examples // config.batch_size if config.num_train_steps == -1: num_steps = steps_per_epoch * config.num_epochs else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.eval_batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 1 base_lr = config.lr * config.batch_size / 256. lr_fn = create_lr_schedule_epochs(base_lr, config.lr_schedule, steps_per_epoch=steps_per_epoch, total_epochs=config.num_epochs, decay_rate=config.lr_decay_rate, decay_epochs=config.lr_decay_epochs, warmup_epochs=config.lr_warmup_epochs, min_lr=config.lr_minimum) state = create_train_state(config, variables, lr_fn) if resume: state = restore_checkpoint(state, resume) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = flax.jax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model.apply, lr_fn=lr_fn, label_smoothing=config.label_smoothing, weight_decay=config.weight_decay), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, model.apply), axis_name='batch') p_eval_step_ema = None if config.ema_decay != 0.: p_eval_step_ema = jax.pmap(functools.partial(eval_step_ema, model.apply), axis_name='batch') if jax.host_id() == 0: if resume and step_offset > 0: output_dir = resume else: output_base = config.output_base_dir if config.output_base_dir else './output' exp_name = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), config.model]) output_dir = get_outdir(output_base, exp_name) summary_writer = tensorboard.SummaryWriter(output_dir) summary_writer.hparams(dict(config)) epoch_metrics = [] t_loop_start = time.time() num_samples = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): step_p1 = step + 1 rng, step_rng = random.split(rng) sharded_rng = common_utils.shard_prng_key(step_rng) num_samples += config.batch_size state, metrics = p_train_step(state, batch, dropout_rng=sharded_rng) epoch_metrics.append(metrics) if step_p1 % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) samples_per_sec = num_samples / (time.time() - t_loop_start) logging.info( 'train epoch: %d, loss: %.4f, img/sec %.2f, top1: %.2f, top5: %.3f', epoch, summary['loss'], samples_per_sec, summary['top1'], summary['top5']) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step_p1 - len(vals) + i) summary_writer.scalar('samples per second', samples_per_sec, step) epoch_metrics = [] state = sync_batch_stats( state) # sync batch statistics across replicas eval_metrics = [] for step_eval in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, top1: %.2f, top5: %.3f', epoch, summary['loss'], summary['top1'], summary['top5']) if p_eval_step_ema is not None: # NOTE running both ema and non-ema eval while improving this script eval_metrics = [] for step_eval in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step_ema(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info( 'eval epoch ema: %d, loss: %.4f, top1: %.2f, top5: %.3f', epoch, summary['loss'], summary['top1'], summary['top5']) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() t_loop_start = time.time() num_samples = 0 elif step_p1 % 100 == 0: summary = jax.tree_map(lambda x: x.mean(), common_utils.get_metrics(epoch_metrics)) samples_per_sec = num_samples / (time.time() - t_loop_start) logging.info('train steps: %d, loss: %.4f, img/sec: %.2f', step_p1, summary['loss'], samples_per_sec) if step_p1 % steps_per_checkpoint == 0 or step_p1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, output_dir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) batch_size = config.batch_size n_devices = jax.device_count() if jax.host_count() > 1: raise ValueError( 'PixelCNN++ example should not be run on more than 1 host' ' (for now)') if batch_size % n_devices > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_summary_writer, eval_summary_writer = get_summary_writers(workdir) # Load dataset data_source = input_pipeline.DataSource(config) train_ds = data_source.train_ds eval_ds = data_source.eval_ds steps_per_epoch = data_source.ds_info.splits[ 'train'].num_examples // config.batch_size # Create dataset batch iterators train_iter = iter(train_ds) num_train_steps = train_ds.cardinality().numpy() steps_per_checkpoint = 1000 # Create the model using data-dependent initialization. Don't shard the init # batch. assert config.init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size] rng = jax.random.PRNGKey(config.seed) rng, init_rng, dropout_rng = jax.random.split(rng, 3) initial_variables = model(config).init( { 'params': init_rng, 'dropout': dropout_rng }, init_batch)['params'] optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables) ema = initial_variables step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step # pmap the train and eval functions p_train_step = jax.pmap(functools.partial(train_step, config, learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, config=config), axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_train_steps), train_iter): # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Generate a PRNG key that will be rolled into the batch. rng, step_rng = jax.random.split(rng) sharded_rngs = common_utils.shard_prng_key(step_rng) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for eval_batch in eval_ds: # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps: save_checkpoint(workdir, optimizer, ema, step)
def replicate(self): return jax_utils.replicate(self).replace( dropout_rng=shard_prng_key(self.dropout_rng))
def main(_): if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching: raise ValueError( "'precrop_iters has no effect when 'batching' the dataset") assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") os.makedirs(FLAGS.model_dir, exist_ok=True) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5) rngs = common_utils.shard_prng_key(step_rng) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, rng=data_rng, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets *_, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets iter_render_ds = zip(range(num_poses), render_ds) iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds) iter_test_ds = zip(range(test_items), test_ds) img_h, img_w, _ = hwf logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) # cycle already seen examples if resuming from checkpoint # (only useful for ensuring deterministic dataset, slow for large start_step) # if start_step != 0: # for _ in range(start_step): # _ = next(train_ds) # parameter_overview.log_parameter_overview(state.optimizer_coarse.target) # if FLAGS.config.num_importance > 0: # parameter_overview.log_parameter_overview(state.optimizer_fine.target) state = jax.device_put_replicated(state, jax.local_devices()) ### Build "pmapped" functions for distributed training train_fn = functools.partial(train_step, near, far, FLAGS.config, schedule_fn) p_train_step = jax.pmap( train_fn, axis_name="batch", in_axes=(0, 0, None, 0), # donate_argnums=(0, 1, 2), ) def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) # TODO: add hparams writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) logging.info("Starting training loop.") hooks = [] profiler = periodic_actions.Profile(num_profile_steps=5, logdir=FLAGS.model_dir) report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.config.num_steps, writer=writer) if jax.process_index() == 0: hooks += [profiler, report_progress] train_metrics = [] gen_video_ = functools.partial(gen_video, FLAGS.model_dir) for step in range(start_step, FLAGS.config.num_steps + 1): is_last_step = step == FLAGS.config.num_steps batch = next(train_ds) coords = None if not FLAGS.config.batching: coords = jnp.meshgrid(jnp.arange(img_h), jnp.arange(img_w), indexing="ij") if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jnp.stack(coords, axis=-1).reshape([-1, 2]) with jax.profiler.StepTraceAnnotation("train", step_num=step): state, metrics = p_train_step(batch, state, coords, rngs) train_metrics.append(metrics) logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) _ = [h(step) for h in hooks] ### Write train summaries to TB if step % FLAGS.config.i_print == 0 or is_last_step: with report_progress.timed("training_metrics"): train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) summary = {f"train/{k}": v for k, v in train_summary.items()} writer.write_scalars(step, summary) train_metrics = [] ### Eval a random validation image and plot it to TB if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step: with report_progress.timed("validation"): inputs = next(val_ds) rays, padding = prepare_render_data(inputs["rays"]._numpy()) outputs = p_eval_step(state, rays) preds, preds_c, z_std = jax.tree_map( lambda x: to_np(x, hwf, padding), outputs) loss = np.mean((preds["rgb"] - inputs["image"])**2) summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) summary = { "val/rgb": to_rgb(preds["rgb"]), "val/target": to_np(inputs["image"], hwf, padding), "val/disp": disp_post(preds["disp"], FLAGS.config), "val/acc": preds["acc"], } if FLAGS.config.num_importance > 0: summary["val/rgb_c"] = to_rgb(preds_c["rgb"]) summary["val/disp_c"] = disp_post(preds_c["disp"], FLAGS.config) summary["val/z_std"] = z_std writer.write_images(step, summary) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: with report_progress.timed("video_render"): logging.info("Rendering video at step %d", step) rgb_list = [] disp_list = [] for idx, inputs in tqdm(iter_render_ds, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) gen_video_(np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video_(disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: rgb_list = [] for idx, inputs in tqdm(iter_vdirs_ds, desc="Viewdirs render"): rays, padding = prepare_render_data( inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) rgb_list.append(to_np(preds["rgb"], r_hwf, padding)) gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: with report_progress.timed("test_render"): logging.info("Rendering test set at step %d", step) test_losses = [] for idx, inputs in tqdm(iter_test_ds, desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 0: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 0: loss = np.mean(test_losses) summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) writer.flush() ### Save ckpt if step % FLAGS.config.i_weights == 0 or is_last_step: with report_progress.timed("checkpoint"): save_checkpoint(state, FLAGS.model_dir)