Exemple #1
0
from ml_collections import config_flags

import ppo_lib
import models
import env_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'workdir',
    default='/tmp/ppo_training',
    help=('Directory to save checkpoints and logging info.'))

config_flags.DEFINE_config_file(
    'config',
    "configs/default.py",
    'File path to the default configuration file.',
    lock_config=True)


def main(argv):
  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')
  config = FLAGS.config
  game = config.game + 'NoFrameskip-v4'
  num_actions = env_utils.get_num_actions(game)
  print(f'Playing {game} with {num_actions} actions')
  model = models.ActorCritic(num_outputs=num_actions)
  ppo_lib.train(model, config, FLAGS.workdir)

if __name__ == '__main__':
Exemple #2
0
from absl import app
from absl import flags
from absl import logging

from clu import platform
import train
import jax
from ml_collections import config_flags
import tensorflow as tf

FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
    'config',
    'configs/default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=True)
flags.mark_flags_as_required(['config', 'workdir'])


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], 'GPU')

  logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
  logging.info('JAX local devices: %r', jax.local_devices())
## Data loaders

## Resnet model definition built on flax
from stochastic_polyak import models
from stochastic_polyak.get_solver import get_solver
from stochastic_polyak.utils import create_dumpfile
from stochastic_polyak.utils import get_datasets


FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", "/tmp/stochastic_polyak/",
                    "Parent directory to store model data.")
config_flags.DEFINE_config_file(
    "config",
    "stochastic_polyak/configs/sps.py",
    "File path to the training hyperparameter configuration.",
    lock_config=True)
flags.DEFINE_integer("max_steps_per_epoch", -1,
                     "Maximum number of steps in an epoch.")
flags.DEFINE_float(
    "slack_lmbda", -1,
    "The lmbda regularization parameter of the slack formulation.")
flags.DEFINE_float("slack_delta", -1,
                   "The delta dampening parameter of the slack formulation.")
flags.DEFINE_float("momentum", 0.0, "The momentum parameter.")
flags.DEFINE_integer(
    "choose_update", 1,
    "What solver to use in the SSPS methods. Can take values [1, 2, 3, 4, 5] SSPS."
)
Exemple #4
0
import sys

from absl import app
from absl import flags
import matplotlib.pyplot as plt
from ml_collections import config_flags
import torchvision
from xirl.common import get_pretraining_dataloaders

FLAGS = flags.FLAGS

flags.DEFINE_boolean("debug", False, "Debug mode.")

config_flags.DEFINE_config_file(
    "config",
    "xirl/config.py",
    "File path to the training hyperparameter configuration.",
    lock_config=True,
)


def main(_):
    num_ctx_frames = FLAGS.config.FRAME_SAMPLER.NUM_CONTEXT_FRAMES
    num_frames = FLAGS.config.FRAME_SAMPLER.NUM_FRAMES_PER_SEQUENCE
    pretrain_loaders = get_pretraining_dataloaders(FLAGS.config, FLAGS.debug)
    try:
        loader = pretrain_loaders["train"]
        print("Total videos: ", loader.dataset.total_vids)
        for batch_idx, batch in enumerate(loader):
            print(f"Batch #{batch_idx}")
            frames = batch["frames"]
            b, _, c, h, w = frames.shape
from jax.config import config
from ml_collections import config_flags
import tensorflow as tf
from tensorflow.io import gfile

from gift.models import all_models
from gift.pipelines import all_pipelines
from gift.tasks import all_tasks

# Enable flax xprof trace labelling.
os.environ['FLAX_PROFILE'] = 'true'

FLAGS = flags.FLAGS
flags.DEFINE_string('experiment_dir', None, 'Experiment directory.')
config_flags.DEFINE_config_file('config',
                                None,
                                'Path to the experiment configuration.',
                                lock_config=True)


def run(hparams, experiment_dir, summary_writer=None):
    """Prepares model, and dataset for training.

  This creates summary directories, summary writers, model definition, and
  builds datasets to be sent to the main training script.

  Args:
    hparams:  ConfigDict; Hyper parameters.
    experiment_dir: string; Root directory for the experiment.
    summary_writer: Summary writer object.

  Returns:
from torchkit.utils.py_utils import Stopwatch
from utils import setup_experiment
from xirl import common

# pylint: disable=logging-fstring-interpolation

FLAGS = flags.FLAGS

flags.DEFINE_string("experiment_name", None, "Experiment name.")
flags.DEFINE_boolean("resume", False, "Whether to resume training.")
flags.DEFINE_string("device", "cuda:0", "The compute device.")
flags.DEFINE_boolean("raw_imagenet", False, "")

config_flags.DEFINE_config_file(
    "config",
    "base_configs/pretrain.py",
    "File path to the training hyperparameter configuration.",
)


@experiment.pdb_fallback
def main(_):
    # Make sure we have a valid config that inherits all the keys defined in the
    # base config.
    validate_config(FLAGS.config, mode="pretrain")

    config = FLAGS.config
    exp_dir = osp.join(config.root_dir, FLAGS.experiment_name)
    setup_experiment(exp_dir, config, FLAGS.resume)

    # No need to do any pretraining if we're loading the raw pretrained
Exemple #7
0
from utils import (
    disp_post,
    eval_step,
    gen_video,
    prepare_render_data,
    save_test_imgs,
    to_np,
)

psnr_fn = lambda x: -10.0 * np.log(x) / np.log(10.0)

tf.config.experimental.set_visible_devices([], "GPU")

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    "config", None, "File path to the hyperparameter configuration.")
flags.DEFINE_integer("seed", default=0, help="Initialization seed.")
flags.DEFINE_string("data_dir",
                    default=None,
                    help="Directory containing data files.")
flags.DEFINE_string("model_dir",
                    default=None,
                    help="Directory to save model data.")
flags.DEFINE_string("save_dir",
                    default=None,
                    help="Directory to save outputs.")
flags.DEFINE_string(
    "render_video_set",
    default="render",
    help="Subset of data to use to render the video.",
)
Exemple #8
0
from ml_collections import config_flags

import tensorflow as tf

# Local imports.
import train

FLAGS = flags.FLAGS

flags.DEFINE_string('workdir',
                    default=None,
                    help=('Directory to store model data.'))

config_flags.DEFINE_config_file(
    'config', os.path.join(os.path.dirname(__file__), 'configs/default.py'),
    'File path to the Training hyperparameter configuration.')


def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
    # Require JAX omnistaging mode.
    jax.config.enable_omnistaging()

    train.train_and_evaluate(workdir=FLAGS.workdir, config=FLAGS.config)

Exemple #9
0
from jax.config import config as jax_config
from ml_collections import config_flags
from typing import Any, Optional

from datasets import load_blender, load_deepvoxels
from model import NeRF
from rays_utils import prepare_rays
from utils import eval_step, psnr_fn

jax_config.enable_omnistaging()

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    "config",
    os.path.join(os.path.dirname(__file__), "configs/default.py"),
    "File path to the hyperparameter configuration.",
)

flags.DEFINE_integer("seed", default=0, help=("Initialization seed."))
flags.DEFINE_string("data_dir", default=None, help=("Directory containing data files."))
flags.DEFINE_string("model_dir", default=None, help=("Directory to store model data."))
flags.DEFINE_string("save_dir", default=None, help=("Directory to store outputs."))
flags.DEFINE_string(
    "render_video_set",
    default="render",
    help=("Subset of data to use to render the video."),
)
flags.DEFINE_bool("render_video", default=True, help=("Whether to render video."))
flags.DEFINE_bool("render_testset", default=True, help=("Whether to render testset."))
flags.DEFINE_string(
Exemple #10
0
from absl import logging
import jax.numpy as jnp
from jax_verify.extensions.functional_lagrangian import attacks
from jax_verify.extensions.functional_lagrangian import bounding
from jax_verify.extensions.functional_lagrangian import data
from jax_verify.extensions.functional_lagrangian import dual_solve
from jax_verify.extensions.functional_lagrangian import model
from jax_verify.extensions.functional_lagrangian import verify_utils
from jax_verify.extensions.sdp_verify import utils as sdp_utils
import ml_collections
from ml_collections import config_flags

PROJECT_PATH = os.getcwd()

config_flags.DEFINE_config_file(
    'config', f'{PROJECT_PATH}/configs/config_ood_stochastic_model.py',
    'ConfigDict for the experiment.')

FLAGS = flags.FLAGS


def make_logger(log_message: str) -> Callable[[int, Mapping[str, Any]], None]:
  """Creates a logger.

  Args:
    log_message: description message for the logs.

  Returns:
    Function that accepts a step counter and measurements, and logs them.
  """
Exemple #11
0
r"""Example of basic DEFINE_flag_dict usage.

To run this example with basic config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/config.py
  \
  --my_config.field1=8 --my_config.nested.field=2.1 \
  --my_config.tuple='(1, 2, (1, 2))'

To run this example with parameterised config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear
  \
  --my_config.model_config.output_size=256'
"""
# pylint: enable=line-too-long

from absl import app

from ml_collections import config_flags

_CONFIG = config_flags.DEFINE_config_file('my_config')


def main(_):
    print(_CONFIG.value)


if __name__ == '__main__':
    app.run(main)
Exemple #12
0
from flax import struct
import flax.linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from ml_collections import config_dict
from ml_collections import config_flags
import optax
import tqdm

from aux_tasks.grid import dataset
from aux_tasks.grid import loss_utils
from aux_tasks.grid import utils

_BASE_DIR = flags.DEFINE_string('base_dir', None, 'Base directory')
_CONFIG = config_flags.DEFINE_config_file('config', lock_config=True)


@struct.dataclass
class TrainMetrics(clu_metrics.Collection):
    loss: clu_metrics.Average.from_output('loss')
    rank: clu_metrics.Average.from_output('rank')


# @struct.dataclass
# class EvalMetrics(clu_metrics.Collection):
#   grassman_distance: clu_metrics.LastValue.from_output('grassman_distance')
#   dot_product: clu_metrics.LastValue.from_output('dot_product')
#   top_singular_value: clu_metrics.LastValue.from_output('top_singular_value')

# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Example of ConfigDict usage.

This example includes loading a ConfigDict in FLAGS, locking it, type
safety, iteration over fields, checking for a particular field, unpacking with
`**`, and loading dictionary from string representation.
"""

from absl import app
from ml_collections import config_flags
import yaml

_CONFIG = config_flags.DEFINE_config_file(
    'my_config', default='ml_collections/config_dict/examples/config.py')


def dummy_function(string, **unused_kwargs):
    return 'Hello {}'.format(string)


def print_section(name):
    print()
    print()
    print('-' * len(name))
    print(name.upper())
    print('-' * len(name))
    print()

Exemple #14
0
import jax.random
import tensorflow as tf
from ml_collections import config_flags

import ppo_lib
import models
import env_utils

FLAGS = flags.FLAGS

flags.DEFINE_string('workdir',
                    default='/tmp/ppo_training',
                    help=('Directory to save checkpoints and logging info.'))

config_flags.DEFINE_config_file('config',
                                None,
                                'File path to the default configuration file.',
                                lock_config=True)

flags.mark_flags_as_required(['config'])


def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
Exemple #15
0
than can be easily tested and imported in Colab.
"""

from absl import app
from absl import flags
from absl import logging
from clu import platform
import train
import jax
from ml_collections import config_flags
import tensorflow as tf

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file("config",
                                "configs/default.py",
                                "Training configuration.",
                                lock_config=True)
flags.DEFINE_string("workdir", None, "Work unit directory.")
flags.DEFINE_string("jax_backend_target", None,
                    "JAX backend target to use. Can be used with UPTC.")
flags.mark_flags_as_required(["config", "workdir"])


def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")
Exemple #16
0
flags.DEFINE_integer('max_episode_steps', 100, 'Environment step limit.')

# For goal image
flags.DEFINE_string(
    'robot_data_path', None,
    'Path to gzipped pickle file with robot interaction data.')

# Distance function
flags.DEFINE_string(
    'distance_ckpt_to_load', None,
    'Path to flax checkpoint of trained distance model used to define reward.')
flags.DEFINE_enum('distance_ckpt_format', 'flax', ['flax', 'scenic'],
                  'Format of distance_ckpt_to_load.')
config_flags.DEFINE_config_file(
    'scenic_config',
    None,
    'Path to scenic config (if loading a scenic trained distance model).',
    lock_config=True)
flags.DEFINE_bool(
    'use_true_distance', False,
    'If True, set reward to true distance between end-effector '
    'and puck + puck and target.')
flags.DEFINE_list(
    'encoder_conv_filters', [16, 16, 32],
    'Number and sizes of convolutional filters in the embedding network.')  # pytype: disable=wrong-arg-types
flags.DEFINE_integer('encoder_conv_size', 5,
                     'Convolution kernel size in the embedding network.')
flags.DEFINE_float('learning_rate', 3e-4, 'Learning rate for training.')

# Reward function
flags.DEFINE_float(
Exemple #17
0
import chex
import jax

from jaxline import base_config
from jaxline import train
from jaxline import utils
from ml_collections import config_dict
from ml_collections import config_flags
import numpy as np

import tensorflow as tf

FLAGS = flags.FLAGS

# TODO(tomhennigan) Add support for ipdb and pudb.
config_flags.DEFINE_config_file("config",
                                help_string="Training configuration file.")
# This flag is expected to be used only internally by jaxline.
# It is prefixed by "jaxline" to prevent a conflict with a "mode" flag defined
# by Monarch.
_JAXLINE_MODE = flags.DEFINE_string(
    "jaxline_mode", "train",
    ("Execution mode: `train` will run training, `eval` will run evaluation."))
_JAXLINE_TPU_DRIVER = flags.DEFINE_string("jaxline_tpu_driver", "",
                                          "Whether to use tpu_driver.")
_JAXLINE_ENSURE_TPU = flags.DEFINE_bool(
    "jaxline_ensure_tpu", False, "Whether to ensure we have a TPU connected.")


def create_checkpointer(config: config_dict.ConfigDict, mode: str) -> Any:
    """Creates an object to be used as a checkpointer."""
    return utils.InMemoryCheckpointer(config, mode)
    'upsampler.')
flags.DEFINE_string('store_dir', None, 'Path to store generated images.')
flags.DEFINE_string('master', 'local',
                    'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('tpu_worker_name', 'tpu_worker', 'Name of the TPU worker.')
flags.DEFINE_enum('accelerator_type', 'GPU', ['CPU', 'GPU', 'TPU'],
                  'Hardware type.')
flags.DEFINE_enum('mode', 'colorize', ['colorize', 'recolorize'],
                  'Whether to colorizer or recolorize images.')
flags.DEFINE_integer('steps_per_summaries', 100, 'Steps per summaries.')
flags.DEFINE_integer(
    'batch_size', None,
    'Batch size. If not provided, use the optimal batch-size '
    'for each model.')
config_flags.DEFINE_config_file('config',
                                default='test_configs/colorizer.py',
                                help_string='Training configuration file.')
FLAGS = flags.FLAGS


def create_grayscale_dataset_from_images(image_dir, batch_size):
    """Creates a dataset of grayscale images from the input image directory."""
    def load_and_preprocess_image(path, child_path):
        image_str = tf.io.read_file(path)
        num_channels = 1 if FLAGS.mode == 'colorize' else 3
        image = tf.image.decode_image(image_str, channels=num_channels)

        # Central crop to square and resize to 256x256.
        image = datasets.resize_to_square(image, resolution=256, train=False)

        # Resize to a low resolution image.
Exemple #19
0
FLAGS = flags.FLAGS
CFG = None
PyTreeDef = type(jax.tree_structure(None))
TransformerConfig = models.TransformerConfig
jax.config.parse_flags_with_absl()

flags.DEFINE_string('model_dir',
                    default=None,
                    help='Directory to store model data.')

flags.DEFINE_string('data_dir',
                    default=None,
                    help='Tensorflow datasets directory.')

config_flags.DEFINE_config_file(name='config',
                                default='configs/t5_small_glue.py',
                                help_string='training config file.')

ConfigDict = ml_collections.ConfigDict


def get_configs(config):
    """Get train, eval, and predict model configs.

  Args:
    config: The config dict for the experiment.

  Returns:
    A triple (train_config, eval_config, predict_config).
  """
    train_config = TransformerConfig(
Exemple #20
0
from ml_collections import config_flags

import tensorflow as tf

import ppo_lib
import models
import env_utils

FLAGS = flags.FLAGS

flags.DEFINE_string('logdir',
                    default='/tmp/ppo_training',
                    help=('Directory to save checkpoints and logging info.'))

config_flags.DEFINE_config_file(
    'config', os.path.join(os.path.dirname(__file__), 'default_config.py'),
    'File path to the default configuration file.')


def main(argv):
    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    config = FLAGS.config
    game = config.game + 'NoFrameskip-v4'
    num_actions = env_utils.get_num_actions(game)
    print(f'Playing {game} with {num_actions} actions')
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    model = models.create_model(subkey, num_outputs=num_actions)
    optimizer = models.create_optimizer(model,
from absl import app
from absl import flags
from absl import logging

from ml_collections import config_flags
import numpy as np
import tensorflow.compat.v1 as tf

from symbolic_functionals.syfes import loss
from symbolic_functionals.syfes.evolution import common
from symbolic_functionals.syfes.evolution import regularized_evolution
from symbolic_functionals.syfes.symbolic import optimizers
from symbolic_functionals.syfes.symbolic import search_utils
from symbolic_functionals.syfes.symbolic import xc_functionals

config_flags.DEFINE_config_file('config', 'config.py')

flags.DEFINE_integer('xid', 0, 'The experiment id.')
flags.DEFINE_integer('wid', 0, 'The work unit id.')

FLAGS = flags.FLAGS


class XCFunctionalPopulation(regularized_evolution.Population):
    """Population of functional forms."""
    def create_initial_population(self):
        """Creates the initial population with mutation base."""
        if 'json' in self._cfg.xc.mutation_base:
            with tf.io.gfile.GFile(self._cfg.xc.mutation_base, 'r') as f:
                functional_base = xc_functionals.XCFunctional.from_dict(
                    json.load(f))
Exemple #22
0
import jax
import jax.numpy as jnp
from lib import data
from lib import models
from lib import utils
import lib.classification_utils as classification_lib
from lib.layers import sample_patches
import ml_collections
import ml_collections.config_flags as config_flags
import optax
import tensorflow as tf

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file("config",
                                None,
                                "Training configuration.",
                                lock_config=True)
flags.DEFINE_string("workdir", None, "Work unit directory.")


class ClassificationModule(nn.Module):
    """A module that does classification."""
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""

        if config.get("append_position_to_input", False):
            b, h, w, _ = x.shape
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
Exemple #23
0
from aqt.jax.imagenet.configs.paper.resnet50_w8_a8_auto import get_config as w8a8auto_paper_config
from aqt.utils import hparams_utils as os_hparams_utils
from aqt.utils import report_utils
from aqt.utils import summary_utils


FLAGS = flags.FLAGS


flags.DEFINE_string(
    'model_dir', default=None, help=('Directory to store model data.'))

flags.DEFINE_string(
    'data_dir', default=None, help=('Directory where imagenet data is stored.'))

config_flags.DEFINE_config_file('hparams_config_dict', None,
                                'Path to file defining a config dict.')

flags.DEFINE_integer(
    'config_idx',
    default=None,
    help=(
        'Identifies which config within the sweep this training run should use.'
    ))

flags.DEFINE_integer(
    'batch_size', default=128, help=('Batch size for training.'))

flags.DEFINE_bool('cache', default=False, help=('If True, cache the dataset.'))


flags.DEFINE_bool(
Exemple #24
0
from flax.training import checkpoints
from flax.training import common_utils
import jax
from jax import random
import jax.nn
import jax.numpy as jnp
from lra_benchmarks.image import task_registry
from lra_benchmarks.models.transformer import transformer
from lra_benchmarks.utils import train_utils
from ml_collections import config_flags
import tensorflow.compat.v2 as tf

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file('config',
                                None,
                                'Training configuration.',
                                lock_config=True)
flags.DEFINE_string('model_dir',
                    default=None,
                    help='Directory to store model data.')
flags.DEFINE_string('task_name', default='mnist', help='Name of the task')
flags.DEFINE_bool('eval_only',
                  default=False,
                  help='Run the evaluation on the test data.')


def create_model(key, flax_module, input_shape, model_kwargs):
    """Creates and initializes the model."""
    @functools.partial(jax.jit, backend='cpu')
    def _create_model(key):
        module = flax_module.partial(**model_kwargs)
Exemple #25
0
import pathlib

from absl import app
from absl import flags
import input_pipeline
import ml_collections
from ml_collections import config_flags
import model_utils
import predict_utils
import tensorflow as tf

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    'config',
    None,
    'A resource path to the ConfigDict used during training.',
    lock_config=True)
flags.DEFINE_string('workdir', None, 'The `workdir` used during training')
flags.DEFINE_string('output_filepath', 'predictions.csv',
                    'The filepath at which the output CSV is written.')


def predict(
    workdir: pathlib.Path,
    config: ml_collections.ConfigDict,
    output_filepath: str,
) -> None:
    """Generates model predictions using the best available checkpoint."""

    # Set seed for reproducibility.
Exemple #26
0
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf

from vit_jax import inference_time
from vit_jax import train
from vit_jax import utils

FLAGS = flags.FLAGS

_WORKDIR = flags.DEFINE_string('workdir', None,
                               'Directory to store logs and model data.')
config_flags.DEFINE_config_file(
    'config',
    None,
    'File path to the training hyperparameter configuration.',
    lock_config=True)
flags.mark_flags_as_required(['config', 'workdir'])
# Flags --jax_backend_target and --jax_xla_backend are available through JAX.


def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    utils.add_gfile_logger(_WORKDIR.value)

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')
Exemple #27
0
import chex
import jax

from jaxline import base_config
from jaxline import train
from jaxline import utils
from ml_collections import config_dict
from ml_collections import config_flags
import numpy as np

import tensorflow as tf

# TODO(tomhennigan) Add support for ipdb and pudb.
_CONFIG = config_flags.DEFINE_config_file(
    name="config",
    help_string="Training configuration file.",
)
# This flag is expected to be used only internally by jaxline.
# It is prefixed by "jaxline" to prevent a conflict with a "mode" flag defined
# by Monarch.
_JAXLINE_MODE = flags.DEFINE_string(
    name="jaxline_mode",
    default="train",
    help=("Execution mode. "
          " `train` will run training, `eval` will run evaluation."),
)
_JAXLINE_TPU_DRIVER = flags.DEFINE_string(
    name="jaxline_tpu_driver",
    default="",
    help="Whether to use tpu_driver.",
)