コード例 #1
0
ファイル: ppo_main.py プロジェクト: zizai/tensor2tensor
def main(argv):
    del argv
    logging.info("Starting PPO Main.")

    if FLAGS.jax_debug_nans:
        config.update("jax_debug_nans", True)

    if FLAGS.use_tpu:
        config.update("jax_platform_name", "tpu")
    else:
        config.update("jax_platform_name", "gpu")

    gin_configs = FLAGS.config or []
    gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs)

    # TODO(pkozakowski): Find a better way to determine this.
    env_kwargs = {}
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if "OnlineTuneEnv" in FLAGS.env_problem_name:
        # TODO(pkozakowski): Separate env output dirs by train/eval and epoch.
        train_env_kwargs = {}
        train_env_kwargs.update(env_kwargs)
        train_env_kwargs["output_dir"] = os.path.join(FLAGS.output_dir,
                                                      "envs/train")

        eval_env_kwargs = {}
        eval_env_kwargs.update(env_kwargs)
        eval_env_kwargs["output_dir"] = os.path.join(FLAGS.output_dir,
                                                     "envs/eval")

    if "ClientEnv" in FLAGS.env_problem_name:
        train_env_kwargs["per_env_kwargs"] = [{
            "remote_env_address":
            os.path.join(FLAGS.train_server_bns, str(replica))
        } for replica in range(FLAGS.batch_size)]

        eval_env_kwargs["per_env_kwargs"] = [{
            "remote_env_address":
            os.path.join(FLAGS.eval_server_bns, str(replica))
        } for replica in range(FLAGS.eval_batch_size)]

    # Make an env here.
    env = make_env(batch_size=FLAGS.batch_size, **train_env_kwargs)
    assert env

    eval_env = make_env(batch_size=FLAGS.eval_batch_size, **eval_env_kwargs)
    assert eval_env

    def run_training_loop():
        """Runs the training loop."""
        logging.info("Starting the training loop.")

        policy_and_value_net_fn = functools.partial(
            ppo.policy_and_value_net,
            bottom_layers_fn=common_layers,
            two_towers=FLAGS.two_towers)
        policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate)

        ppo.training_loop(
            output_dir=FLAGS.output_dir,
            env=env,
            eval_env=eval_env,
            env_name=str(FLAGS.env_problem_name),
            policy_and_value_net_fn=policy_and_value_net_fn,
            policy_and_value_optimizer_fn=policy_and_value_optimizer_fn,
        )

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
コード例 #2
0
    @parameterized.parameters(
        ('double_q', ),
        ('double_q_v', ),
        ('double_q_p', ),
        ('double_q_pv', ),
        ('q_regression', ),
    )
    def test_passive_loss_gradients(self, loss_type):
        loss_fn = losses.make_loss_fn(loss_type, active=False)

        def fn(q_tm1, q_t, q_t_target, transition, rng_key):
            return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key)

        grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2)))

        dldq_tm1, dldq_t, dldq_t_target = grad_fn(self.qs, self.qs, self.qs,
                                                  self.transition,
                                                  self.rng_key)
        # Assert that only passive net gets nonzero gradients.
        self.assertGreater(np.sum(np.abs(dldq_tm1.passive.q_values)), 0.)
        self.assertTrue(np.all(dldq_t.passive.q_values == 0.))
        self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.))
        self.assertTrue(np.all(dldq_t.active.q_values == 0.))
        self.assertTrue(np.all(dldq_tm1.active.q_values == 0.))
        self.assertTrue(np.all(dldq_t_target.active.q_values == 0.))


if __name__ == '__main__':
    config.update('jax_numpy_rank_promotion', 'raise')
    absltest.main()
コード例 #3
0
ファイル: debug_nans_test.py プロジェクト: gnecula/jax
 def setUp(self):
     self.cfg = config.read("jax_debug_infs")
     config.update("jax_debug_infs", True)
コード例 #4
0
ファイル: run_atari.py プロジェクト: WendyShang/dqn_zoo
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()


if __name__ == '__main__':
  config.update('jax_platform_name', 'gpu')  # Default to GPU.
  config.update('jax_numpy_rank_promotion', 'raise')
  config.config_with_absl()
  app.run(main)
コード例 #5
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensornetwork
import pytest
import numpy as np
import tensorflow as tf
from jax.config import config
import tensornetwork.config as config_file

config.update("jax_enable_x64", True)
tf.compat.v1.enable_v2_behavior()


def test_split_node_qr_disable(backend):
    net = tensornetwork.TensorNetwork(backend=backend)
    a = net.add_node(np.zeros((2, 3, 4, 5, 6)))
    left_edges = []
    for i in range(3):
        left_edges.append(a[i])
    right_edges = []
    for i in range(3, 5):
        right_edges.append(a[i])
    _, _ = net.split_node_qr(a, left_edges, right_edges)
    with pytest.raises(ValueError):
        a.edges[0]
コード例 #6
0
ファイル: nn_test.py プロジェクト: Jakob-Unfried/jax
 def setUp(self):
     super().setUp()
     config.update("jax_numpy_rank_promotion", "raise")
コード例 #7
0
def main(argv):
    del argv

    if FLAGS.jax_debug_nans:
        config.update("jax_debug_nans", True)
    if FLAGS.use_tpu:
        config.update("jax_platform_name", "tpu")

    # TODO(afrozm): Refactor.
    if "NoFrameskip" in FLAGS.env_problem_name and FLAGS.xm:
        FLAGS.atari_roms_path = "local_ram_fs_tmp"
        atari_utils.copy_roms()

    # Make an env here.
    env = make_env(batch_size=FLAGS.batch_size)
    assert env

    eval_env = make_env(batch_size=FLAGS.eval_batch_size)
    assert eval_env

    def run_training_loop():
        """Runs the training loop."""
        policy_net_fun = None
        value_net_fun = None
        policy_and_value_net_fun = None
        policy_optimizer_fun = None
        value_optimizer_fun = None
        policy_and_value_optimizer_fun = None

        if FLAGS.combined_network:
            policy_and_value_net_fun = functools.partial(
                ppo.policy_and_value_net,
                bottom_layers_fn=common_layers,
                two_towers=FLAGS.two_towers)
            policy_and_value_optimizer_fun = get_optimizer_fun(
                FLAGS.learning_rate)
        else:
            policy_net_fun = functools.partial(ppo.policy_net,
                                               bottom_layers=common_layers())
            value_net_fun = functools.partial(ppo.value_net,
                                              bottom_layers=common_layers())
            policy_optimizer_fun = get_optimizer_fun(
                FLAGS.policy_only_learning_rate)
            value_optimizer_fun = get_optimizer_fun(
                FLAGS.value_only_learning_rate)

        random_seed = None
        try:
            random_seed = int(FLAGS.random_seed)
        except Exception:  # pylint: disable=broad-except
            pass

        ppo.training_loop(
            env=env,
            epochs=FLAGS.epochs,
            policy_net_fun=policy_net_fun,
            value_net_fun=value_net_fun,
            policy_and_value_net_fun=policy_and_value_net_fun,
            policy_optimizer_fun=policy_optimizer_fun,
            value_optimizer_fun=value_optimizer_fun,
            policy_and_value_optimizer_fun=policy_and_value_optimizer_fun,
            num_optimizer_steps=FLAGS.num_optimizer_steps,
            policy_only_num_optimizer_steps=FLAGS.
            policy_only_num_optimizer_steps,
            value_only_num_optimizer_steps=FLAGS.
            value_only_num_optimizer_steps,
            print_every_optimizer_steps=FLAGS.print_every_optimizer_steps,
            batch_size=FLAGS.batch_size,
            target_kl=FLAGS.target_kl,
            boundary=FLAGS.boundary,
            max_timestep=FLAGS.truncation_timestep,
            max_timestep_eval=FLAGS.truncation_timestep_eval,
            random_seed=random_seed,
            c1=FLAGS.value_coef,
            c2=FLAGS.entropy_coef,
            gamma=FLAGS.gamma,
            lambda_=FLAGS.lambda_,
            epsilon=FLAGS.epsilon,
            enable_early_stopping=FLAGS.enable_early_stopping,
            output_dir=FLAGS.output_dir,
            eval_every_n=FLAGS.eval_every_n,
            eval_env=eval_env)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
コード例 #8
0
def main(_):
    if FLAGS.debug_nans:
        jax_config.update("jax_debug_nans", True)
    if FLAGS.spoof_multi_device:
        os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

    gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)

    # Optionally disable jit and pmap.
    if FLAGS.disable_jit:
        chex.fake_jit().__enter__()
        chex.fake_pmap().__enter__()

    # We cleanup a root_dir so that on preemptions we start from scratch.
    if gfile.Exists(FLAGS.root_dir):
        logging.info('Logdir (%s) already exists, removing it...',
                     FLAGS.root_dir)
        gfile.DeleteRecursively(FLAGS.root_dir)

    create_env_fn = lambda: envs.create_environment(
        FLAGS.task_class, FLAGS.task_name, FLAGS.single_precision_env)
    create_data_iter_fn = (lambda: data.create_data_iter(
        FLAGS.task_class, FLAGS.task_name, FLAGS.batch_size))

    # for testing the data loader
    # batch = next(create_data_iter_fn())
    # import pdb; pdb.set_trace()
    # import time
    # data_iter = create_data_iter_fn()
    # for _ in range(50):
    #   begin_time = time.time()
    #   for _ in range(20):
    #     next(data_iter)
    #   print(f'\n\n{(time.time() - begin_time)/20.}\n\n')

    environment = create_env_fn()
    spec = specs.make_environment_spec(environment)
    # import time
    # import numpy as np
    # for _ in range(10):
    #   print('EPISODE')
    #   ep_begin = time.time()
    #   environment.reset()
    #   for _ in range(10):
    #     environment.step(np.zeros(shape=(12,)))
    #   ep_time = time.time() - ep_begin
    #   print('\n\n', ep_time)

    train_logger_factory = build_create_learner_logger(FLAGS.root_dir)
    rl_components = agents.create_agent(FLAGS.algorithm, spec,
                                        create_data_iter_fn,
                                        train_logger_factory)
    builder = rl_components.make_builder()

    counter = counting.Counter(time_delta=0.)
    is_offline_agent = not builder.make_replay_tables(spec)
    networks = rl_components.make_networks()
    seed = FLAGS.seed
    random_key = jax.random.PRNGKey(seed)
    learner_counter = counting.Counter(counter, 'learner', time_delta=0.)
    if is_offline_agent:
        random_key, sub_key = jax.random.split(random_key)
        learner = builder.make_learner(
            sub_key,
            networks,
            dataset=iter(()),  # dummy iterator
            counter=learner_counter)
        variable_source = learner
        train_loop = learner
    else:
        actor_counter_name = 'actor'
        steps_label = f'{actor_counter_name}_steps'

        agent = rl_components.make_agent(  # pytype: disable=attribute-error
            networks,
            learner_counter,
            FLAGS.seed,
        )
        variable_source = agent._learner

        actor_counter = counting.Counter(counter,
                                         actor_counter_name,
                                         time_delta=0.)
        train_loop_cls = acme.EnvironmentLoop
        train_loop = train_loop_cls(environment,
                                    agent,
                                    counter=actor_counter,
                                    logger=build_create_env_loop_logger(
                                        FLAGS.root_dir,
                                        label='train_env_loop',
                                        steps_key=steps_label)())

    random_key, sub_key = jax.random.split(random_key)
    all_eval_loops = []

    eval_actor = builder.make_actor(
        random_key=sub_key,
        policy_network=rl_components.make_eval_behavior_policy(networks),
        variable_source=variable_source)
    eval_env = create_env_fn()
    eval_counter = counting.Counter(counter, 'eval_loop', time_delta=0.)
    # eval_loop = acme.EnvironmentLoop(
    #     eval_env,
    #     eval_actor,
    #     counter=eval_counter,
    #     label='eval_loop',
    #     logger=create_eval_loop_logger())
    eval_loop = evaluation.EvaluatorStandardWithFinalRewardLogging(
        eval_actor=eval_actor,
        environment=eval_env,
        num_episodes=FLAGS.episodes_per_eval,
        counter=eval_counter,
        logger=build_create_env_loop_logger(FLAGS.root_dir,
                                            label='eval_loop')(),
        eval_sync=None,
        progress_counter_name='eval_actor_steps',
        min_steps_between_evals=None,
        self_cleanup=False)
    all_eval_loops.append(eval_loop)

    if FLAGS.eval_with_q_filter:
        random_key, sub_key = jax.random.split(random_key)

        # pytype: disable=attribute-error
        old_value = builder._config.eval_with_q_filter
        builder._config.eval_with_q_filter = True
        q_filter_eval_actor = builder.make_actor(
            random_key=sub_key,
            policy_network=rl_components.make_eval_behavior_policy(
                networks,
                force_eval_with_q_filter=True,
                q_filter_with_unif=True),
            variable_source=variable_source,
        )
        builder._config.eval_with_q_filter = old_value
        # pytype: enable=attribute-error
        q_filter_eval_env = create_env_fn()
        q_filter_eval_counter = counting.Counter(counter,
                                                 'q_filter_eval_loop',
                                                 time_delta=0.)
        q_filter_eval_loop = evaluation.EvaluatorStandardWithFinalRewardLogging(
            eval_actor=q_filter_eval_actor,
            environment=q_filter_eval_env,
            num_episodes=FLAGS.episodes_per_eval,
            counter=q_filter_eval_counter,
            logger=build_create_env_loop_logger(FLAGS.root_dir,
                                                label='q_filter_eval_loop')(),
            eval_sync=None,
            progress_counter_name='q_filter_eval_actor_steps',
            min_steps_between_evals=None,
            self_cleanup=False)
        all_eval_loops.append(q_filter_eval_loop)

        # pytype: disable=attribute-error
        old_value = builder._config.eval_with_q_filter
        builder._config.eval_with_q_filter = True
        q_filter_eval_actor = builder.make_actor(
            random_key=sub_key,
            policy_network=rl_components.make_eval_behavior_policy(
                networks,
                force_eval_with_q_filter=True,
                q_filter_with_unif=False),
            variable_source=variable_source,
        )
        builder._config.eval_with_q_filter = old_value
        # pytype: enable=attribute-error
        q_filter_eval_env = create_env_fn()
        q_filter_eval_counter = counting.Counter(counter,
                                                 'q_filter_eval_loop_no_unif',
                                                 time_delta=0.)
        q_filter_eval_loop = evaluation.EvaluatorStandardWithFinalRewardLogging(
            eval_actor=q_filter_eval_actor,
            environment=q_filter_eval_env,
            num_episodes=FLAGS.episodes_per_eval,
            counter=q_filter_eval_counter,
            logger=build_create_env_loop_logger(
                FLAGS.root_dir, label='q_filter_eval_loop_no_unif')(),
            eval_sync=None,
            progress_counter_name='q_filter_no_unif_eval_actor_steps',
            min_steps_between_evals=None,
            self_cleanup=False)
        all_eval_loops.append(q_filter_eval_loop)

    # Run the training loop interleaved with some evaluations.
    assert FLAGS.num_steps % FLAGS.eval_every_steps == 0
    num_iterations = FLAGS.num_steps // FLAGS.eval_every_steps

    for _ in range(num_iterations):
        train_loop.run(num_steps=FLAGS.eval_every_steps)
        for el in all_eval_loops:
            el.run_once()

    # Final eval at the end of training
    for el in all_eval_loops:
        el.run_once()

    # saved model policy
    if FLAGS.create_saved_model_actor:
        _select_action = rl_components.make_eval_behavior_policy(networks)

        def select_action(params, x):
            obs, rng_seed = x[0], x[1]
            rng = jax.random.PRNGKey(rng_seed)
            return _select_action(params, rng, obs)

        input_spec = (
            tf.TensorSpec(eval_env.reset().observation.shape,
                          dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
        )
        saved_model_lib.convert_and_save_model(
            select_action,
            params=eval_actor._variable_client.params,  # pytype: disable=attribute-error
            model_dir=os.path.join(FLAGS.root_dir, 'saved_model'),
            input_signatures=[input_spec],
        )

    # Make sure to properly tear down the evaluators.
    # For e.g. to flush the files recording the episodes.
    for el in all_eval_loops:
        if hasattr(el, 'tear_down'):
            el.tear_down()
コード例 #9
0
def setUpModule():
    jax_config.update('jax_enable_x64', True)
コード例 #10
0
def debugging():
    config.update("jax_enable_x64", True)
    config.update("jax_debug_nans", True)
コード例 #11
0
def startup():
    # Jax config (needs to be executed right at startup)
    config.update("jax_enable_x64", True)
コード例 #12
0
def pytest_runtest_setup(item):
    config.update('jax_platform_name', 'cpu')
    if 'JAX_ENABLE_x64' in os.environ:
        config.update('jax_enable_x64', True)
    set_rng_seed(0)
コード例 #13
0
def main():
    parser = argparse.ArgumentParser(description="Trains VAE-GPLVM model")
    parser.add_argument(
        "-p",
        "--predict",
        help=
        "run predictions (0 - only train, 1 - train and predict, 2 - only predict)",
        action="count",
        default=0)
    parser.add_argument("-d",
                        "--debug",
                        help="start TF debug",
                        action="store_true")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-r",
                       "--restore",
                       help="restore checkpoint",
                       action="store_true")
    group.add_argument("-e",
                       "--erase",
                       help="erase summary dirs before training",
                       action="store_true")
    args = parser.parse_args()

    if args.debug:
        cfig.update("jax_debug_nans", True)

    # Time-series datasets
    # datasets = ["toy_dataset_linear_combination", "toy_dataset_linear_combination_different_rotations",
    #             "toy_dataset_linear_combination_with_softplus", "toy_dataset_expanding_circle",
    #             "toy_dataset_rotating_rectangle", "toy_dataset_sine_draw",
    #             "toy_dataset_1d_chirp",
    #             "eeg_eye_state", "exchange_rate",
    #             "missile_to_air"]
    datasets = ["toy_dataset_1d_chirp"]

    model_name = "gp_sde"

    constant_diffusion = False
    constant_diffusion_legend = {
        False: "varying_diffusion",
        True: "constant_diffusion"
    }

    # time_dependent_drifts = [False, True]
    time_dependent_drifts = [False]
    time_dependent_drift_legend = {
        False: "time_independent_drift",
        True: "time_dependent_drift"
    }

    for time_dependent_drift in time_dependent_drifts:
        for dataset in datasets:
            parent_folder = f"results/{model_name}/"
            parent_folder += f"{time_dependent_drift_legend[time_dependent_drift]}/"
            parent_folder += f"{constant_diffusion_legend[constant_diffusion]}/{dataset}"

            model = None
            trainer = None
            predictor = None

            if model_name == "gp_sde":
                model = models.GPSDE
                trainer = trainers.GPSDETrainer
                predictor = predictors.GPSDEPredictor

            config = cd.config_dict(parent_folder, dataset)
            config["time_dependent_gp"] = time_dependent_drift

            config["constant_diffusion"] = constant_diffusion

            if not constant_diffusion:
                config["solver"] = "EulerMaruyamaSolver"

            # Create the experiment dirs
            if args.erase:
                if os.path.exists(config["summary_dir"]):
                    shutil.rmtree(config["summary_dir"], ignore_errors=True)
                    shutil.rmtree(config["checkpoint_dir"], ignore_errors=True)

            config[
                "results_dir"] += f"{config['delta_t']}deltaT_{config['num_steps']}tsteps"
            bm.create_dirs([
                config["summary_dir"], config["checkpoint_dir"],
                config["results_dir"]
            ])

            with open(os.path.join(config["summary_dir"], "config_dict.txt"),
                      'w') as f:
                f.write(json.dumps(config))

            data = None
            if dataset == "toy_dataset_sine_draw":
                data = dg.DataGeneratorToyDatasetSineDraw(config)
            elif dataset == "toy_dataset_1d_chirp":
                data = dg.DataGeneratorToyDataset1DChirp(config)
            elif dataset == "toy_dataset_expanding_circle":
                data = dg.DataGeneratorToyDatasetExpandingCircle(config)
            elif dataset == "toy_dataset_rotating_rectangle":
                data = dg.DataGeneratorToyDatasetRotatingRectangle(config)
            elif dataset == "toy_dataset_linear_combination":
                data = dg.DataGeneratorToyDatasetLinearCombination(config)
            elif dataset == "toy_dataset_linear_combination_different_rotations":
                data = dg.DataGeneratorToyDatasetLinearCombinationDifferentRotations(
                    config)
            elif dataset == "toy_dataset_linear_combination_with_softplus":
                data = dg.DataGeneratorToyDatasetLinearCombinationWithSoftplus(
                    config)
            elif dataset == "eeg_eye_state":
                data = dg.DataGeneratorEEGEyeState(config)
            elif dataset == "exchange_rate":
                data = dg.DataGeneratorExchangeRate(config)
            elif dataset == "missile_to_air":
                data = dg.DataGeneratorMissile2Air(config)

            run(config, model, trainer, predictor, data, args)
コード例 #14
0
# -*- coding: utf-8 -*-
import jax
from jax.config import config; config.update("jax_enable_x64", True)    # Use 64 bit floating point numbers
import jax.numpy as npx


def create_tensor(x):
    return npx.array(x, dtype=npx.float64)

def affine(x, w, b):
    return npx.dot(w, x) + b

def softmax(x):
    return npx.exp(x) / npx.sum(npx.exp(x), axis=0)

def normal_distribution(x, mean, variance):
    return npx.exp(-.5 * npx.square(x - mean) / variance) / npx.sqrt(2. * npx.pi * variance)

def log_normal_distribution(x, mean, variance):
    return -.5 * npx.square(x - mean) / variance - 0.5 * (2. + npx.pi + variance)
import os
import argparse
import time
import datetime
import itertools
import numpy.random as onpr

import jax
import jax.numpy as jnp
from jax import jit, vmap, random
from jax import value_and_grad, grad, jacfwd, jacrev
# from jax.experimental import optimizers

from jax.config import config
config.update("jax_debug_nans", True)
jax.config.update('jax_enable_x64', True)

from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
from flax import serialization, jax_utils
from flax import linen as nn
from flax import optim

#ravh
from data import *
from flax_mlp import *

# mase to jax
# from monomials import f_monomials as f_mono
# from polynomials import f_polynomials as f_poly

Ha2cm = 220000
コード例 #16
0
ファイル: rl_trainer.py プロジェクト: zongdaofu/trax
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=rl_trainers.PPO,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
  """

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', 'gpu')

    # TODO(pkozakowski): Find a better way to determine this.
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if 'OnlineTuneEnv' in env_name:
        envs_output_dir = FLAGS.envs_output_dir or os.path.join(
            output_dir, 'envs')
        train_env_output_dir = os.path.join(envs_output_dir, 'train')
        eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
        train_env_kwargs = {'output_dir': train_env_output_dir}
        eval_env_kwargs = {'output_dir': eval_env_output_dir}

    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

    logging.info('Num discretized actions %s', num_actions)
    logging.info('Resize %d', resize)

    train_env = env_problem_utils.make_env(batch_size=train_batch_size,
                                           env_problem_name=env_name,
                                           rendered_env=rendered_env,
                                           resize=resize,
                                           resize_dims=resize_dims,
                                           max_timestep=max_timestep,
                                           clip_rewards=clip_rewards,
                                           parallelism=parallelism,
                                           use_tpu=FLAGS.use_tpu,
                                           num_actions=num_actions,
                                           **train_env_kwargs)
    assert train_env

    eval_env = env_problem_utils.make_env(batch_size=eval_batch_size,
                                          env_problem_name=env_name,
                                          rendered_env=rendered_env,
                                          resize=resize,
                                          resize_dims=resize_dims,
                                          max_timestep=max_timestep,
                                          clip_rewards=clip_rewards,
                                          parallelism=parallelism,
                                          use_tpu=FLAGS.use_tpu,
                                          num_actions=num_actions,
                                          **eval_env_kwargs)
    assert eval_env

    def run_training_loop():
        """Runs the training loop."""
        logging.info('Starting the training loop.')

        trainer = trainer_class(
            output_dir=output_dir,
            train_env=train_env,
            eval_env=eval_env,
            trajectory_dump_dir=trajectory_dump_dir,
            async_mode=FLAGS.async_mode,
        )
        trainer.training_loop(n_epochs=n_epochs)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
コード例 #17
0
    "num_warmup": 3000,  #def: 3000
    "num_data": 100,  #def: 100
    "num_hidden": 10,  #def: 10
    "device": 'cpu',  #def: cpu
    "save_directory": "./results",
}

# PREPARE TO SAVE RESULTS
try:
    os.stat(args["save_directory"])
except:
    os.mkdir(args["save_directory"])

sigmas = [1 / 5, 1, 5]  # def: [1/5, 1, 5]

jax_config.update('jax_platform_name', args["device"])
N, D_X, D_H = args["num_data"], 1, args["num_hidden"]

# GENERATE ARTIFICIAL DATA
X, Y, X_test = get_data(functions, ranges, num_samples=500)
mean = X.mean()
X = X / mean
X, Y = shuffle(X, Y)

X_test = onp.arange(0, 2, 0.01).reshape(-1, 1)

# PLOTTING
plt.cla()  # Clear axis
plt.clf()  # Clear figure
plt.close()  # Close a figure window
# make plots
コード例 #18
0
ファイル: __init__.py プロジェクト: romanodev/deltapv
import os
from jax.config import config
config.update("jax_enable_x64", True)
if os.environ.get("DEBUGNANS") == "TRUE":
    config.update("jax_debug_nans", True)
if os.environ.get("NOJIT") == "TRUE":
    config.update('jax_disable_jit', True)

import logging
logging.basicConfig(format="")
logger = logging.getLogger("deltapv")
logger.setLevel("INFO")

from deltapv import simulator, materials, plotting, objects, spline, physics, util
from deltapv.simulator import make_design, incident_light, equilibrium, simulate, eff_at_bias, empty_design, add_material, doping, contacts
from deltapv.materials import create_material, load_material
from deltapv.plotting import plot_band_diagram, plot_bars, plot_charge, plot_iv_curve

util.print_ascii()
コード例 #19
0
ファイル: nn_test.py プロジェクト: Jakob-Unfried/jax
 def tearDown(self):
     super().tearDown()
     config.update("jax_numpy_rank_promotion", "allow")
コード例 #20
0
 def setUp(self):
     super().setUp()
     jax_config.update('jax_enable_x64', False)
コード例 #21
0
ファイル: debug_nans_test.py プロジェクト: yashk2810/jax
 def setUp(self):
   self.cfg = config._read("jax_debug_nans")
   config.update("jax_debug_nans", True)
コード例 #22
0
from helpers import col_vec, suppress_stdout_stderr
from pathlib import Path
import pickle
from tqdm import tqdm

# jax related imports
import jax.numpy as jnp
from jax import grad, jit, jacfwd, jacrev
from jax.lax import scan
from jax.ops import index, index_update
from jax.config import config

# optimisation module imports (needs to be done before the jax confix update)
from optimisation import solve_chance_logbarrier, log_barrier_cosine_cost

config.update("jax_enable_x64", True)  # run jax in 64 bit mode for accuracy

# Control parameters
z_star = np.array([[0], [np.pi], [0.0], [0.0]],
                  dtype=float)  # desired set point in z1
Ns = 200  # number of samples we will use for MC MPC
Nh = 25  # horizonline of MPC algorithm
sqc_v = np.array([1, 30., 1e-5, 1e-5], dtype=float)  # cost on state error
sqc = np.diag(sqc_v)
src = np.array([[0.001]])

# define the state constraints, (these need to be tuples)
state_bound = 0.75 * np.pi
input_bound = 18.0
state_constraints = (lambda z: state_bound - z[[0], :, :],
                     lambda z: z[[0], :, :] + state_bound)
コード例 #23
0
from functools import partial

import numpy as np

from jax import random, vmap
from jax.config import config
config.update("jax_platform_name", "cpu")

import numpyro.handlers as handler
from numpyro.mcmc import MCMC, NUTS

from brmp.backend import Backend, Model, apply_default_hmc_args
from brmp.fit import Samples
from brmp.numpyro_codegen import gen

# The types described in the comments in pyro_backend.py as follows
# in this back end:
#
# bs: dict from parameter names to JAX numpy arrays
# ps: JAX numpy array


def get_param(samples, name):
    assert type(samples) == dict
    # Reminder to use correct interface.
    assert not name == 'mu', 'Use `location` to fetch `mu`.'
    return samples[name]


# Extract the underlying numpy array (rather than using JAX numpy) to
# match the interface exactly.
コード例 #24
0
ファイル: test_handlers.py プロジェクト: schmolly/timemachine
from jax.config import config; config.update("jax_enable_x64", True)
import jax

import pytest
import numpy as np

from rdkit import Chem
from rdkit.Chem import AllChem
from ff.handlers import nonbonded, bonded
from ff.handlers.deserialize import deserialize_handlers

import functools


def test_harmonic_bond():

    patterns = [
        ['[#6X4:1]-[#6X4:2]', 0.1, 0.2],
        ['[#6X4:1]-[#6X3:2]', 99., 99.],
        ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.],
        ['[#6X3:1]-[#6X3:2]', 99., 99.],
        ['[#6X3:1]:[#6X3:2]', 99., 99.],
        ['[#6X3:1]=[#6X3:2]', 99., 99.],
        ['[#6:1]-[#7:2]',0.1, 0.2],
        ['[#6X3:1]-[#7X3:2]', 99., 99.],
        ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.],
        ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.],
        ['[#6X3:1]-[#7X2:2]', 99., 99.],
        ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.],
        ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.],
コード例 #25
0
ファイル: jax_backend.py プロジェクト: gradhep/pyhf
from jax.config import config

config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.scipy.special import gammaln
from jax.scipy import special
from jax.scipy.stats import norm
import numpy as np
import scipy.stats as osp_stats
import logging

log = logging.getLogger(__name__)


class _BasicPoisson:
    def __init__(self, rate):
        self.rate = rate

    def sample(self, sample_shape):
        # TODO: Support other dtypes
        return jnp.asarray(
            osp_stats.poisson(self.rate).rvs(size=sample_shape +
                                             self.rate.shape),
            dtype=jnp.float64,
        )

    def log_prob(self, value):
        tensorlib = jax_backend()
        return tensorlib.poisson_logpdf(value, self.rate)
コード例 #26
0
ファイル: rl_trainer.py プロジェクト: wintersurvival/trax
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=None,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=True,
    light_rl_trainer=light_trainers.PolicyGradient,
):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: whether to use the light RL setting (experimental).
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)

    if light_rl:
        task = rl_task.RLTask()
        env_name = task.env_name
    else:
        # TODO(lukaszkaiser): remove the name light and all references.
        # It was kept for now to make sure all regression tests pass first,
        # so that if we need to revert we save some work.
        raise ValueError('Non-light RL is deprecated.')

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    if light_rl:
        trainer = light_rl_trainer(task=task, output_dir=output_dir)

        def light_training_loop():
            """Run the trainer for n_epochs and call close on it."""
            try:
                logging.info('Starting RL training for %d epochs.', n_epochs)
                trainer.run(n_epochs, n_epochs_is_total_epochs=True)
                logging.info('Completed RL training for %d epochs.', n_epochs)
                trainer.close()
                logging.info('Trainer is now closed.')
            except Exception as e:
                raise e
            finally:
                logging.info(
                    'Encountered an exception, still calling trainer.close()')
                trainer.close()
                logging.info('Trainer is now closed.')

        if FLAGS.jax_debug_nans or FLAGS.disable_jit:
            fastmath.disable_jit()
            with jax.disable_jit():
                light_training_loop()
        else:
            light_training_loop()
        return

    # TODO(pkozakowski): Find a better way to determine this.
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if 'OnlineTuneEnv' in env_name:
        envs_output_dir = FLAGS.envs_output_dir or os.path.join(
            output_dir, 'envs')
        train_env_output_dir = os.path.join(envs_output_dir, 'train')
        eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
        train_env_kwargs = {'output_dir': train_env_output_dir}
        eval_env_kwargs = {'output_dir': eval_env_output_dir}

    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

    logging.info('Num discretized actions %s', num_actions)
    logging.info('Resize %d', resize)

    train_env = env_problem_utils.make_env(batch_size=train_batch_size,
                                           env_problem_name=env_name,
                                           rendered_env=rendered_env,
                                           resize=resize,
                                           resize_dims=resize_dims,
                                           max_timestep=max_timestep,
                                           clip_rewards=clip_rewards,
                                           parallelism=parallelism,
                                           use_tpu=FLAGS.use_tpu,
                                           num_actions=num_actions,
                                           **train_env_kwargs)
    assert train_env

    eval_env = env_problem_utils.make_env(batch_size=eval_batch_size,
                                          env_problem_name=env_name,
                                          rendered_env=rendered_env,
                                          resize=resize,
                                          resize_dims=resize_dims,
                                          max_timestep=max_timestep,
                                          clip_rewards=clip_rewards,
                                          parallelism=parallelism,
                                          use_tpu=FLAGS.use_tpu,
                                          num_actions=num_actions,
                                          **eval_env_kwargs)
    assert eval_env

    def run_training_loop():
        """Runs the training loop."""
        logging.info('Starting the training loop.')

        trainer = trainer_class(
            output_dir=output_dir,
            train_env=train_env,
            eval_env=eval_env,
            trajectory_dump_dir=trajectory_dump_dir,
            async_mode=FLAGS.async_mode,
        )
        trainer.training_loop(n_epochs=n_epochs)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
コード例 #27
0
ファイル: 2d_rbig_demo.py プロジェクト: IPL-UV/rbig_jax
# library functions
from rbig_jax.data import generate_2d_grid, get_classic
from rbig_jax.plots import plot_info_loss, plot_joint, plot_joint_prob
from rbig_jax.transforms.block import get_default_rbig_block

# spyder up to find the root
root = here(project_files=[".here"])

# append to path
sys.path.append(str(here()))



# import chex
config.update("jax_enable_x64", False)






sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

# ==========================
# PARAMETERS
# ==========================

parser = ArgumentParser(
    description="2D Data Demo with Iterative Gaussianization method"
コード例 #28
0
ファイル: autodiffjax.py プロジェクト: ronan-keane/havsim
import numpy as np  #note using np and scipy functions shouldn't work; you have to use the jax versions

import jax.numpy as jnp

import jax

from jax.ops import *

from jax import jit
import pickle
from jax import lax
import copy

from jax.config import config

config.update("jax_enable_x64",
              True)  #if you want float 64 in jax this is the command


#%% examples
def intdiv(x, n):
    #integer divide x by n
    out = x / n
    out = jnp.floor(out)
    return out


def eg1(x, *args):
    #    x[0] = x[0] // 2
    #    x = index_update(x,0,intdiv(x[0],2))
    x[0] = x[0] // 2
    return jnp.tanh(x[0]**2)
コード例 #29
0
ファイル: debug_nans_test.py プロジェクト: gnecula/jax
 def tearDown(self):
     config.update("jax_debug_infs", self.cfg)
コード例 #30
0
def main(_):
    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax_config.update("jax_xla_backend", "tpu_driver")
        jax_config.update("jax_backend_target", FLAGS.jax_backend_target)

    logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)
        # summary_writer.hparams(dict(FLAGS.config))

    rng = random.PRNGKey(FLAGS.seed)
    rng, init_rng_coarse, init_rng_fine = random.split(rng, 3)
    n_devices = jax.device_count()

    ### Load dataset and data values
    if FLAGS.config.dataset_type == "blender":
        images, poses, render_poses, hwf, counts = load_blender.load_data(
            FLAGS.data_dir,
            half_res=FLAGS.config.half_res,
            testskip=FLAGS.config.testskip,
        )
        logging.info("Loaded blender, total images: %d", images.shape[0])

        near = 2.0
        far = 6.0

        if FLAGS.config.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
        else:
            images = images[..., :3]

    elif FLAGS.config.dataset_type == "deepvoxels":
        images, poses, render_poses, hwf, counts = load_deepvoxels.load_dv_data(
            FLAGS.data_dir,
            scene=FLAGS.config.shape,
            testskip=FLAGS.config.testskip,
        )
        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.0
        far = hemi_R + 1.0
        logging.info(
            "Loaded deepvoxels (%s), total images: %d",
            FLAGS.config.shape,
            images.shape[0],
        )
    else:
        raise ValueError(f"Dataset '{FLAGS.config.dataset_type}' is not available.")

    img_h, img_w, focal = hwf
    logging.info("Images splits: %s", counts)
    logging.info("Render poses: %s", render_poses.shape)
    logging.info("Image height: %d, image width: %d, focal: %.5f", img_h, img_w, focal)

    train_imgs, val_imgs, test_imgs, *_ = np.split(images, np.cumsum(counts))
    train_poses, val_poses, test_poses, *_ = np.split(poses, np.cumsum(counts))

    if FLAGS.config.render_factor > 0:
        # render downsampled for speed
        r_img_h = img_h // FLAGS.config.render_factor
        r_img_w = img_w // FLAGS.config.render_factor
        r_focal = focal / FLAGS.config.render_factor
        r_hwf = r_img_h, r_img_w, r_focal
    else:
        r_hwf = hwf

    to_np = lambda x, h=img_h, w=img_w: np.reshape(x, [h, w, -1]).astype(np.float32)
    psnr_fn = lambda x: -10.0 * np.log(x) / np.log(10.0)

    ### Pre-compute rays
    @functools.partial(jax.jit, static_argnums=(0,))
    def prep_rays(hwf, c2w, c2w_sc=None):
        if c2w_sc is not None:
            c2w_sc = c2w_sc[:3, :4]
        return prepare_rays(None, hwf, FLAGS.config, near, far, c2w[:3, :4], c2w_sc)

    rays_render = lax.map(lambda x: prep_rays(r_hwf, x), render_poses)
    render_shape = [-1, n_devices, r_hwf[1], rays_render.shape[-1]]
    rays_render = jnp.reshape(rays_render, render_shape)
    logging.info("Render rays shape: %s", rays_render.shape)

    if FLAGS.config.use_viewdirs:
        rays_render_vdirs = lax.map(
            lambda x: prep_rays(r_hwf, x, render_poses[0]), render_poses
        ).reshape(render_shape)

    if FLAGS.config.batching:
        train_rays = lax.map(lambda pose: prep_rays(hwf, pose), train_poses)
        train_rays = jnp.reshape(train_rays, [-1, train_rays.shape[-1]])
        train_imgs = jnp.reshape(train_imgs, [-1, 3])
        logging.info("Batched rays shape: %s", train_rays.shape)
        val_rays = lax.map(lambda pose: prep_rays(hwf, pose), val_poses)

    test_rays = lax.map(lambda pose: prep_rays(r_hwf, pose), test_poses)
    test_rays = jnp.reshape(test_rays, render_shape)

    ### Init model parameters and optimizer
    input_pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    input_views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized(
        init_rng_coarse, input_pts_shape, input_views_shape, FLAGS.config.model
    )

    optimizer = optim.Adam()
    state = TrainState(
        step=0, optimizer_coarse=optimizer.create(params_coarse), optimizer_fine=None
    )
    model_fn = (model_coarse.apply, None)

    if FLAGS.config.num_importance > 0:
        input_pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized(
            init_rng_fine, input_pts_shape, input_views_shape, FLAGS.config.model_fine
        )
        state = state.replace(optimizer_fine=optimizer.create(params_fine))
        model_fn = (model_coarse.apply, model_fine.apply)

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)
    state = jax_utils.replicate(state)

    ### Build 'pmapped' functions for distributed training
    learning_rate_fn = create_learning_rate_scheduler(
        factors=FLAGS.config.lr_schedule,
        base_learning_rate=FLAGS.config.learning_rate,
        decay_factor=FLAGS.config.decay_factor,
        steps_per_decay=FLAGS.config.lr_decay * 1000,
    )
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model_fn,
            FLAGS.config,
            learning_rate_fn,
            (hwf, near, far),
        ),
        axis_name="batch",
        donate_argnums=(0,),
    )
    p_eval_step = jax.pmap(
        functools.partial(eval_step, model_fn, FLAGS.config),
        axis_name="batch",
    )

    t = time.time()
    train_metrics = []

    for step in range(start_step, FLAGS.config.num_steps + 1):
        rng, sample_rng, step_rng, test_rng = random.split(rng, 4)
        sharded_rngs = common_utils.shard_prng_key(step_rng)
        coords = None

        if FLAGS.config.batching:
            select_idx = random.randint(
                sample_rng,
                [n_devices * FLAGS.config.num_rand],
                minval=0,
                maxval=train_rays.shape[0],
            )
            inputs = train_rays[select_idx, ...]
            inputs = jnp.reshape(inputs, [n_devices, FLAGS.config.num_rand, -1])
            target = train_imgs[select_idx, ...]
            target = jnp.reshape(target, [n_devices, FLAGS.config.num_rand, 3])
        else:
            img_idx = random.randint(
                sample_rng, [n_devices], minval=0, maxval=counts[0]
            )
            inputs = train_poses[img_idx, ...]  # [n_devices, 4, 4]
            target = train_imgs[img_idx, ...]  # [n_devices, img_h, img_w, 3]

            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
                coords = jax_utils.replicate(
                    jnp.stack(coords, axis=-1).reshape([-1, 2])
                )

        state, metrics, coarse_res, fine_res = p_train_step(
            state, (inputs, target), coords, rng=sharded_rngs
        )
        train_metrics.append(metrics)

        ### Write summaries to TB
        if step % FLAGS.config.i_print == 0 and step > 0:
            steps_per_sec = time.time() - t
            train_metrics = common_utils.get_metrics(train_metrics)
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            if jax.host_id() == 0:
                logging.info(
                    "Step: %6d, %.3f s/step, loss %.5f, psnr %6.3f",
                    step,
                    steps_per_sec,
                    train_summary["loss"],
                    train_summary["psnr"],
                )
                for key, val in train_summary.items():
                    summary_writer.scalar(f"train/{key}", val, step)

                summary_writer.scalar("steps per second", steps_per_sec, step)
                summary_writer.histogram("raw_c", np.array(coarse_res["raw"]), step)
                if FLAGS.config.num_importance > 0:
                    summary_writer.histogram("raw_f", np.array(fine_res["raw"]), step)
            train_metrics = []

            ### Eval a random validation image and plot it in TB
            if step % FLAGS.config.i_img == 0:
                val_idx = random.randint(test_rng, [1], minval=0, maxval=counts[1])
                if FLAGS.config.batching:
                    inputs = val_rays[tuple(val_idx)].reshape(render_shape)
                else:
                    inputs = prep_rays(hwf, val_poses[tuple(val_idx)])
                    inputs = jnp.reshape(inputs, render_shape)
                target = val_imgs[tuple(val_idx)]
                preds, preds_c, z_std = lax.map(lambda x: p_eval_step(state, x), inputs)
                rgb = to_np(preds["rgb"])
                loss = np.mean((rgb - target) ** 2)

                summary_writer.scalar(f"val/loss", loss, step)
                summary_writer.scalar(f"val/psnr", psnr_fn(loss), step)

                rgb = 255 * np.clip(rgb, 0, 1)
                summary_writer.image("val/rgb", rgb.astype(np.uint8), step)
                summary_writer.image("val/target", target, step)
                summary_writer.image("val/disp", to_np(preds["disp"]), step)
                summary_writer.image("val/acc", to_np(preds["acc"]), step)

                if FLAGS.config.num_importance > 0:
                    rgb = 255 * np.clip(to_np(preds_c["rgb"]), 0, 1)
                    summary_writer.image("val/rgb_c", rgb.astype(np.uint8), step)
                    summary_writer.image("val/disp_c", to_np(preds_c["disp"]), step)
                    summary_writer.image("val/z_std", to_np(z_std), step)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            logging.info("Rendering video at step %d", step)
            t = time.time()
            preds, *_ = lax.map(lambda x: p_eval_step(state, x), rays_render)
            gen_video(preds["rgb"], "rgb", r_hwf, step)
            gen_video(preds["disp"] / jnp.max(preds["disp"]), "disp", r_hwf, step, ch=1)

            if FLAGS.config.use_viewdirs:
                preds = lax.map(
                    lambda x: p_eval_step(state, x)[0]["rgb"], rays_render_vdirs
                )
                gen_video(preds, "rgb_still", r_hwf, step)
            logging.info("Video rendering done in %ds", time.time() - t)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            logging.info("Rendering test set at step %d", step)
            preds = lax.map(lambda x: p_eval_step(state, x)[0]["rgb"], test_rays)
            save_test_imgs(preds, r_hwf, step)

            if FLAGS.config.render_factor == 0:
                loss = np.mean((preds.reshape(test_imgs.shape) - test_imgs) ** 2.0)
                summary_writer.scalar(f"test/loss", loss, step)
                summary_writer.scalar(f"test/psnr", psnr_fn(loss), step)

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 and step > 0:
            if jax.host_id() == 0:
                checkpoints.save_checkpoint(
                    FLAGS.model_dir,
                    jax_utils.unreplicate(state),
                    step,
                    keep=5,
                )
        t = time.time()