Exemplo n.º 1
0
import shutil
import tempfile

from absl import flags
import apache_beam as beam
from apache_beam.transforms import combiners
from moonlight.pipeline import pipeline_flags
from moonlight.training.clustering import staffline_patches_dofn
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.python.lib.io import file_io
from tensorflow.python.lib.io import tf_record

FLAGS = flags.FLAGS

flags.DEFINE_multi_string('music_pattern', [],
                          'Pattern for the input music score PNGs.')
flags.DEFINE_string('output_path', None, 'Path to the output TFRecords.')
flags.DEFINE_integer('patch_height', 18,
                     'The normalized height of a staffline.')
flags.DEFINE_integer('patch_width', 15,
                     'The width of a horizontal patch of a staffline.')
flags.DEFINE_integer('num_stafflines', 19,
                     'The number of stafflines to extract.')
flags.DEFINE_integer('num_pages', 0, 'Subsample the pages to run on.')
flags.DEFINE_integer('num_outputs', 0, 'Number of output patches.')
flags.DEFINE_integer('max_patches_per_page', 10,
                     'Sample patches per page if above this amount.')
flags.DEFINE_integer('timeout_ms', 600000, 'Timeout for processing a page.')
flags.DEFINE_integer('kmeans_num_clusters', 1000,
                     'Number of k-means clusters.')
flags.DEFINE_integer('kmeans_batch_size', 10000,
Exemplo n.º 2
0
flags.DEFINE_integer('decode_top_k', 8,
                     'Maximum number of tokens to consider for begin/end.')

flags.DEFINE_integer('decode_max_size', 16,
                     'Maximum number of sentence pieces in an answer.')

flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate for hidden layers.')

flags.DEFINE_float('attention_dropout_rate', 0.3,
                   'Dropout rate for attention layers.')

flags.DEFINE_float('label_smoothing', 1e-1, 'Degree of label smoothing.')

flags.DEFINE_multi_string(
    'gin_bindings', [],
    'Gin bindings to override the values set in the config files')

FLAGS = flags.FLAGS


@contextlib.contextmanager
def worker_context():
  if FLAGS.master:
    with tf.device('/job:worker') as d:
      yield d
  else:
    yield


def read_sentencepiece_model(path):
# from mesh_tensorflow.transformer import utils
import gin
import t5

_DEFAULT_MODULE_IMPORTS = []

FLAGS = flags.FLAGS

flags.DEFINE_string("task", None, "A registered Task.")
flags.DEFINE_integer("max_examples", -1,
                     "maximum number of examples. -1 for no limit")
flags.DEFINE_string("format_string", "{inputs}\t{targets}",
                    "format for printing examples")
flags.DEFINE_multi_string(
    "module_import", _DEFAULT_MODULE_IMPORTS,
    "Modules to import. Use this when your Task or is defined outside "
    "of the T5 codebase so that it is registered.")
flags.DEFINE_string("split", "train",
                    "which split of the dataset, e.g. train or validation")

flags.DEFINE_bool("detokenize", False, "If True, then decode ids to strings.")


@gin.configurable
def sequence_length(value=512):
    """Sequence length used when tokenizing.

  Args:
    value: an integer or dictionary
  Returns:
    a dictionary
Exemplo n.º 4
0
from official.common import registry_imports  # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import export_saved_model_lib

FLAGS = flags.FLAGS

flags.DEFINE_string('experiment', None,
                    'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
    'config_file',
    default=None,
    help='YAML/JSON files which specifies overrides. The override order '
    'follows the order of args. Note that each file '
    'can be used as an override template to override the default parameters '
    'specified in Python. If the same parameter is specified in both '
    '`--config_file` and `--params_override`, `config_file` will be used '
    'first, followed by params_override.')
flags.DEFINE_string(
    'params_override', '',
    'The JSON/YAML file or string which specifies the parameter to be overriden'
    ' on top of `config_file` template.')
flags.DEFINE_integer('batch_size', None, 'The batch size.')
flags.DEFINE_string(
    'input_type', 'image_tensor',
    'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
flags.DEFINE_string(
    'input_image_size', '224,224',
    'The comma-separated string of two integers representing the height,width '
Exemplo n.º 5
0
		(va verificato però fino a che punto) se updatiamo le priority sul RB mentre stiamo iterando sul dataset
		la cosa non diventa problematica perchè a ogni iterazione viene chiamata self.get_next() che esegue con le nuove 
		(e corrette) priority. 
		N.B.
		probabilmente il metodo prefetch() applicato al dataset gioca un ruolo perchè le cose già fetchate a occhio non aggiornano
		la priorità, ma di queste sottigliezze forse possiamo fregarcene e contare che non cambino molto
	2) Implementare il cambiamento non sarebbe particolarmente fastidioso per il driver visto che potremmo usare il
		DynamicStepDriver che come output ti dà lo stato delle cose all'ultimo step che poi puoi passare al
		DynamicStepDriver stesso alla prossima iterazione perchè riprenda da lì (e questo dovrebbe mantenere consistency
		per quanto riguarda metriche importanti come AverageReturn)
"""

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string(
    'gin_files', [], 'List of paths to gin configuration files (e.g.'
    '"configs/hanabi_rainbow.gin").')
flags.DEFINE_multi_string(
    'gin_bindings', [],
    'Gin bindings to override the values set in the config files '
    '(e.g. "train_eval.num_iterations=100").')

FLAGS = flags.FLAGS


#TODO Very much unfinished function. it should run an episode stopping step by step
# and printing everything we might want to see.
def run_verbose_mode(agent_1, agent_2):
    env = rl_env.make('Hanabi-Full-CardKnowledge', num_players=2)
    tf_env = tf_py_environment.TFPyEnvironment(env)
Exemplo n.º 6
0
from typing import Any, Callable, Dict, List, Optional, Text

EXPORTER_FN = Callable[[
    model_interface.ModelInterface, abstract_export_generator.
    AbstractExportGenerator
], List[tf.estimator.Exporter]]

FLAGS = flags.FLAGS

try:
    flags.DEFINE_list(
        'gin_configs', None, 'A comma-separated list of paths to Gin '
        'configuration files.')
    flags.DEFINE_multi_string(
        'gin_bindings', [],
        'A newline separated list of Gin parameter bindings.')
except flags.DuplicateFlagError:
    pass

gin_configurable_eval_spec = gin.external_configurable(
    tf.estimator.EvalSpec, name='tf.estimator.EvalSpec')


def print_spec(tensor_spec):
    """Iterate over a spec and print its values in sorted order.

  Args:
    tensor_spec: A dict, (named)tuple, list or a hierarchy thereof filled by
      TensorSpecs(subclasses) or Tensors.
  """
Exemplo n.º 7
0
flags.DEFINE_integer('pre_train_steps', 100, help=('pretrain steps'))
flags.DEFINE_integer('finetune_steps', 100, help=('pretrain steps'))
flags.DEFINE_integer('ctrl_steps', 100, help=('pretrain steps'))

flags.DEFINE_string(
    'param_file',
    None,
    help=(
        'Base set of model parameters to use with this model. To see '
        'documentation on the parameters, see the docstring in resnet_params.'
    ))
flags.DEFINE_multi_string(
    'param_overrides',
    None,
    help=(
        'Model parameter overrides for this model. For example, if '
        'experimenting with larger numbers of train_steps, a possible value '
        'is --param_overrides=train_steps=28152. If you have a collection of '
        'parameters that make sense to use together repeatedly, consider '
        'extending resnet_params.param_sets_table.'))
flags.DEFINE_string(
    'data_dir',
    '',
    help=('The directory where the ImageNet input data is stored. Please see'
          ' the README.md for the expected data format.'))
flags.DEFINE_string(
    'model_dir',
    None,
    help=('The directory where the model and training/evaluation summaries are'
          ' stored.'))
flags.DEFINE_string('mode',
Exemplo n.º 8
0
from tapas.scripts import calc_metrics_utils
from tapas.scripts import prediction_utils as pred_utils
from tapas.utils import experiment_utils  # pylint: disable=unused-import
import tensorflow.compat.v1 as tf
FLAGS = flags.FLAGS

flags.DEFINE_string("data_format", "tfrecord", "The input data format.")

flags.DEFINE_string(
    "compression_type",
    "GZIP",
    "Compression to use when reading tfrecords. '' for no compression.",
)

flags.DEFINE_multi_string(
    "input_file_train", None,
    "Input TF example files (can be a glob or comma separated).")

flags.DEFINE_multi_string(
    "input_file_eval", None,
    "Input TF example files (can be a glob or comma separated).")

flags.DEFINE_multi_string(
    "input_file_predict", None,
    "Input TF example files (can be a glob or comma separated).")

flags.DEFINE_string(
    "prediction_output_dir", None,
    "If not none or empty writes predictions to this directory. Otherwise "
    "writes predictions to model_dir.")
Exemplo n.º 9
0
flags.DEFINE_bool(
    'deploy', True, 'Depoly the test environment. '
    'Set to false to skip the deployment phase and go straight to tests')
flags.DEFINE_bool(
    'skip_before_all', False, 'True to skip @before_all methods. '
    'Like --nodeploy, this is used to skip set up steps. '
    'Useful when developing new tests.')
flags.DEFINE_bool(
    'no_external_access', False, 'True to skip creating RDP/SSH firewall '
    'rules during deployment. Should be used in automated test runs.')
flags.DEFINE_bool('cleanup', False,
                  'Clean up the host environment after the test')
flags.DEFINE_string('error_logs_dir', None,
                    'Where to collect extra logs on test failures')
flags.DEFINE_multi_string('test_arg', None, 'Flags passed to tests')


def ConfigureLogging():
  # Filter out logs from low level loggers
  errorOnlyLoggers = [
      'googleapiclient.discovery_cache', 'google.auth', 'google_auth_httplib2'
  ]
  for logger in errorOnlyLoggers:
    logging.getLogger(logger).setLevel(logging.ERROR)
  message = 'We recommend that most server applications use service accounts.'
  warnings.filterwarnings('ignore', '.*%s' % message)

  logging.error("%s: Logging level error is visible." % __file__)
  logging.warning("%s: Logging level warning is visible." % __file__)
  logging.info("%s: Logging level info is visible." % __file__)
Exemplo n.º 10
0
import gin.tf 
import time 

import experiment_prototype 

from absl import app 
from absl import flags 
import argparse 

flags.DEFINE_string('env', 'cartpole', 'Env name') 
flags.DEFINE_string('base_dir', './results', 'Base dir') 
flags.DEFINE_string('agent_id', 'mmd', 'Agent id')
flags.DEFINE_string('agent_name', None, 'Agent name')

flags.DEFINE_integer('run_id', 0, 'run id')
flags.DEFINE_multi_string('gin_files', './configs/mmd_atari.gin', 'List of paths to gin configuration files')
flags.DEFINE_multi_string('gin_bindings', [], 'Gin bindings to override the values in the config files')
flags.DEFINE_bool('debug', 0, 'Debug')
FLAGS = flags.FLAGS 


def main(unused_argv):
    agent_name = FLAGS.agent_id if FLAGS.agent_name is None else FLAGS.agent_name 

    if FLAGS.debug:
        base_dir = os.path.join('./results/tmp',str(time.time())) 
    else:
        base_dir = os.path.join(FLAGS.base_dir, FLAGS.env, agent_name, 'run_%d'%(FLAGS.run_id))  

    tf.logging.set_verbosity(tf.logging.INFO)
    experiment_prototype.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
Exemplo n.º 11
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os

from absl import app
from absl import flags
import gin.tf

from bisimulation_aaai2020.grid_world import grid_world

flags.DEFINE_string('grid_file', None, 'Path to file defining grid world MDP.')
flags.DEFINE_string('base_dir', None, 'Base directory to store stats.')
flags.DEFINE_multi_string('gin_files', [],
                          'List of paths to gin configuration files.')
flags.DEFINE_multi_string(
    'gin_bindings', [],
    'Gin bindings to override the values set in the config files.')
flags.DEFINE_bool('exact_metric', True,
                  'Whether to compute the metric using the exact method.')
flags.DEFINE_bool('sampled_metric', True,
                  'Whether to compute the metric using sampling.')
flags.DEFINE_bool('learn_metric', True,
                  'Whether to compute the metric using learning.')
flags.DEFINE_bool('sample_distance_pairs', True,
                  'Whether to aggregate states (needs a learned metric.')
flags.DEFINE_integer('num_samples_per_cell', 100,
                     'Number of samples per cell when aggregating.')
flags.DEFINE_bool('verbose', False, 'Whether to print verbose messages.')
Exemplo n.º 12
0
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Flags for trainer.py and rl_trainer.py.

We keep these flags in sync across the trainer and the rl_trainer binaries.
"""

from absl import flags
from absl import logging

# Common flags.
flags.DEFINE_string('output_dir', None,
                    'Path to the directory to save logs and checkpoints.')
flags.DEFINE_multi_string('config_file', None,
                          'Configuration file with parameters (.gin).')
flags.DEFINE_multi_string('config', None,
                          'Configuration parameters (gin string).')

# TPU Flags
flags.DEFINE_bool('use_tpu', False, "Whether we're running on TPU.")
flags.DEFINE_string(
    'jax_xla_backend', 'xla',
    'Either "xla" for the XLA service directly, or "tpu_driver"'
    'for a TPU Driver backend.')
flags.DEFINE_string(
    'jax_backend_target', 'local',
    'Either "local" or "rpc:address" to connect to a '
    'remote service target.')

# trainer.py flags.
Exemplo n.º 13
0
from absl import flags
'''
flags.DEFINE_multi_string("gin_file", "dataset.gin", "Path to a Gin file.")
flags.DEFINE_multi_string("gin_file", "gs://t5-data/pretrained_models/small/operative_config.gin", "Path to a Gin file.")
flags.DEFINE_multi_string("gin_param", "utils.tpu_mesh_shape.model_parallelism = 1", "Gin parameter binding.")
flags.DEFINE_multi_string("gin_param", "utils.tpu_mesh_shape.tpu_topology = '2x2'", "Gin parameter binding.")
flags.DEFINE_multi_string("gin_param", "MIXTURE_NAME = 'glue_mrpc_v002'", "Gin parameter binding.")
flags.DEFINE_list("gin_location_prefix", [], "Gin file search path.")
'''
flags.DEFINE_string(
    "tpu_job_name", None,
    "Name of TPU worker binary. Only necessary if job name is changed from "
    "default tpu_worker.")

flags.DEFINE_string("model_dir", "transformer_standalone",
                    "Estimator model_dir")

flags.DEFINE_string(
    "tpu", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url."
)

flags.DEFINE_string(
    "gcp_project", None,
    "Project name for the Cloud TPU-enabled project. If not specified, we "
    "will attempt to automatically detect the GCE project from metadata.")

flags.DEFINE_string(
    "tpu_zone", None,
    "GCE zone where the Cloud TPU is located in. If not specified, we "
Exemplo n.º 14
0
FLAGS = flags.FLAGS
flags.DEFINE_string('clust_spec_csv_path', '../sandbox/clust_spec_example.csv',
                    'cluster specification csv file path.')
flags.DEFINE_string('remote_working_folder', '/root/code',
                    'unified remote working folder across all machines.')
flags.DEFINE_boolean('force_overwrite_remote', True,
                     'Should force overwriting remote working folder?')
flags.DEFINE_string(
    'remote_worker_pre_cmd', '', "pre command for each REMOTE worker. "
    "e.g., activate python virutal env")
flags.DEFINE_multi_string(
    "prepare_python_package_path", [],
    "pip installable package path in comma separated format: "
    "package_local_fullpath,remote_fullpath "
    "It can occur multiple times. E.g., \n"
    "-p /root/Arena,/root/Arena\n"
    "-p /home/me/pysc2,/home/work/pysc2\n"
    "...",
    short_name='p')
flags.DEFINE_integer('n_process', 1,
                     'number of parallel processes for connections.')


def _parse_packages(args):
    def _known_package_ext_whitelist(package_basename):
        if package_basename.lower() == 'pysc2':  # Tencent PySC2 Extension
            return ['.py', '.serialized', '.SC2Map', '.md']
        if package_basename.lower() == 'arena':
            return ['.py', '.md', '.wad', '.cfg']
        return ['.py', '.md', '.csv']  # default to these typical files
Exemplo n.º 15
0
    'For more information see train_util.get_strategy().')
flags.DEFINE_boolean(
    'allow_memory_growth', False,
    'Whether to grow the GPU memory usage as is needed by the '
    'process. Prevents crashes on GPUs with smaller memory.')
flags.DEFINE_boolean(
    'hypertune', False,
    'Enable metric reporting for hyperparameter tuning, such '
    'as on Google Cloud AI-Platform.')
flags.DEFINE_float(
    'early_stop_loss_value', None,
    'Stops training early when the `total_loss` reaches below '
    'this value during training.')

# Gin config flags.
flags.DEFINE_multi_string('gin_search_path', [],
                          'Additional gin file search paths.')
flags.DEFINE_multi_string(
    'gin_file', [], 'List of paths to the config files. If file '
    'in gstorage bucket specify whole gstorage path: '
    'gs://bucket-name/dir/in/bucket/file.gin.')
flags.DEFINE_multi_string('gin_param', [],
                          'Newline separated list of Gin parameter bindings.')

# Evaluation/sampling specific flags.
flags.DEFINE_boolean('run_once', False, 'Whether evaluation will run once.')
flags.DEFINE_integer('initial_delay_secs', None,
                     'Time to wait before evaluation starts')

GIN_PATH = pkg_resources.resource_filename(__name__, 'gin')

Exemplo n.º 16
0
flags.DEFINE_enum(
    'inference_model',
    'streaming_f0_pw',
    [
        'autoencoder',
        'streaming_f0_pw',
        'vst_extract_features',
        'vst_predict_controls',
        'vst_synthesize',
    ],
    'Specify the ddsp.training.inference model to use for '
    'converting a checkpoint to a SavedModel. Names are '
    'snake_case versions of class names.')

# Optional flags.
flags.DEFINE_multi_string('gin_param', [],
                          'Gin parameters for custom inference model kwargs.')
flags.DEFINE_boolean('debug', False, 'DEBUG: Do not save the model')

# Conversion formats.
flags.DEFINE_boolean('tfjs', True,
                     'Convert SavedModel to TFJS for deploying on the web.')
flags.DEFINE_boolean('tflite', True,
                     'Convert SavedModel to TFLite for embedded C++ apps.')
flags.DEFINE_string('metadata_file', None,
                    'Optional metadata file to pack into TFLite model.')

FLAGS = flags.FLAGS


def get_inference_model(ckpt):
  """Restore model from checkpoint using global FLAGS.
Exemplo n.º 17
0
    def test_write_help_in_xmlformat(self):
        fv = flags.FlagValues()
        # Since these flags are defined by the top module, they are all key.
        flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=fv)
        flags.DEFINE_integer('nb_iters',
                             17,
                             'An integer flag',
                             lower_bound=5,
                             upper_bound=27,
                             flag_values=fv)
        flags.DEFINE_string('file_path',
                            '/path/to/my/dir',
                            'A test string flag.',
                            flag_values=fv)
        flags.DEFINE_boolean('use_gpu',
                             False,
                             'Use gpu for performance.',
                             flag_values=fv)
        flags.DEFINE_enum('cc_version',
                          'stable', ['stable', 'experimental'],
                          'Compiler version to use.',
                          flag_values=fv)
        flags.DEFINE_list('files',
                          'a.cc,a.h,archive/old.zip',
                          'Files to process.',
                          flag_values=fv)
        flags.DEFINE_list('allow_users', ['alice', 'bob'],
                          'Users with access.',
                          flag_values=fv)
        flags.DEFINE_spaceseplist('dirs',
                                  'src libs bins',
                                  'Directories to create.',
                                  flag_values=fv)
        flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'],
                                  'Files to delete',
                                  flag_values=fv)
        flags.DEFINE_multi_integer('cols', [5, 7, 23],
                                   'Columns to select',
                                   flag_values=fv)
        flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'],
                                ['APPLE', 'BANANA', 'CHERRY'],
                                'Compilation flavour.',
                                flag_values=fv)
        # Define a few flags in a different module.
        module_bar.define_flags(flag_values=fv)
        # And declare only a few of them to be key.  This way, we have
        # different kinds of flags, defined in different modules, and not
        # all of them are key flags.
        flags.declare_key_flag('tmod_bar_z', flag_values=fv)
        flags.declare_key_flag('tmod_bar_u', flag_values=fv)

        # Generate flag help in XML format in the StringIO sio.
        sio = io.StringIO() if six.PY3 else io.BytesIO()
        fv.write_help_in_xml_format(sio)

        # Check that we got the expected result.
        expected_output_template = EXPECTED_HELP_XML_START
        main_module_name = sys.argv[0]
        module_bar_name = module_bar.__name__

        if main_module_name < module_bar_name:
            expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE
            expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
        else:
            expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
            expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE

        expected_output_template += EXPECTED_HELP_XML_END

        # XML representation of the whitespace list separators.
        whitespace_separators = _list_separators_in_xmlformat(
            string.whitespace, indent='    ')
        expected_output = (expected_output_template % {
            'basename_of_argv0': os.path.basename(sys.argv[0]),
            'usage_doc': sys.modules['__main__'].__doc__,
            'main_module_name': main_module_name,
            'module_bar_name': module_bar_name,
            'whitespace_separators': whitespace_separators
        })

        actual_output = sio.getvalue()
        self.assertMultiLineEqual(expected_output, actual_output)

        # Also check that our result is valid XML.  minidom.parseString
        # throws an xml.parsers.expat.ExpatError in case of an error.
        xml.dom.minidom.parseString(actual_output)
Exemplo n.º 18
0
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for asserting against golden files."""

import contextlib
import difflib
import io
import os.path
import re
import sys
import traceback
from typing import Dict, Optional

from absl import flags

flags.DEFINE_multi_string('golden', [], 'List of golden files available.')
flags.DEFINE_bool('update_goldens', False, 'Set true to update golden files.')
flags.DEFINE_bool('verbose', False, 'Set true to show golden diff output.')

FLAGS = flags.FLAGS


class MismatchedGoldenError(RuntimeError):
  pass


_filename_to_golden_map: Optional[Dict[str, str]] = None


def _filename_to_golden_path(filename: str) -> str:
  """Retrieve the `--golden` path flag for a golden file named `filename`."""
    ############################################################################
    # Finally, write commands, script, and results to disk
    ############################################################################
    # for reproducibility, copy command and script contents to results
    if results_dir not in ('.', ):
        cmd = 'python ' + ' '.join(sys.argv)
        with open(os.path.join(results_dir, 'command.sh'), 'w') as f:
            f.write(cmd)
        this_script = open(__file__, 'r').readlines()
        with open(os.path.join(results_dir, __file__), 'w') as f:
            f.write(''.join(this_script))

    results_filename = os.path.join(results_dir, 'results.p')
    with open(results_filename, 'wb') as f:
        _ = pickle.dump(results, f)

    # Finally, write gin config to disk
    with open(os.path.join(results_dir, 'config.gin'), 'w') as f:
        f.write(gin.operative_config_str())


if __name__ == "__main__":
    FLAGS = flags.FLAGS
    flags.DEFINE_string('gin_file', './config/multi_step_simulation.gin',
                        'Config file path.')
    flags.DEFINE_multi_string(
        'gin_param', None, 'Newline separated list of Gin parameter bindings.')

    app.run(main)
Exemplo n.º 20
0
from tf_agents.metrics import tf_metrics
from tf_agents.metrics import batched_py_metric
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.networks.utils import mlp_layers
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
from tf_agents.utils import episode_utils
from IPython import embed

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the gin config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')

flags.DEFINE_integer('num_iterations', 1000000,
                     'Total number train/eval iterations to perform.')
flags.DEFINE_integer(
    'initial_collect_steps', 1000,
    'Number of steps to collect at the beginning of training using random policy'
)
flags.DEFINE_integer(
    'collect_steps_per_iteration', 1,
    'Number of steps to collect and be added to the replay buffer after every training iteration'
)
flags.DEFINE_integer('num_parallel_environments', 1,
                     'Number of environments to run in parallel')
flags.DEFINE_integer('num_parallel_environments_eval', 1,
Exemplo n.º 21
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags

from dopamine.discrete_domains import run_experiment

import tensorflow.compat.v1 as tf

flags.DEFINE_string('base_dir', None,
                    'Base directory to host all required sub-directories.')
flags.DEFINE_integer('level', 1, 'Game Level to Start')
flags.DEFINE_multi_string(
    'gin_files', [], 'List of paths to gin configuration files (e.g.'
    '"dopamine/agents/dqn/dqn.gin").')
flags.DEFINE_multi_string(
    'gin_bindings', [],
    'Gin bindings to override the values set in the config files '
    '(e.g. "DQNAgent.epsilon_train=0.1",'
    '      "create_environment.game_name="Pong"").')

FLAGS = flags.FLAGS


def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
Exemplo n.º 22
0
import gin
import gym
import gym_minigrid  # pylint: disable=unused-import
from gym_minigrid.wrappers import RGBImgObsWrapper
import matplotlib.pylab as plt
import tensorflow as tf

from minigrid_basics.custom_wrappers import tabular_wrapper  # pylint: disable=unused-import
from minigrid_basics.envs import mon_minigrid

FLAGS = flags.FLAGS

flags.DEFINE_string('file_path', '/tmp/rw_four_directions',
                    'Path in which we will save the observations.')
flags.DEFINE_multi_string(
    'gin_bindings', [], 'Gin bindings to override default parameter values '
    '(e.g. "MonMiniGridEnv.stochasticity=0.1").')


def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings(
        [os.path.join(mon_minigrid.GIN_FILES_PREFIX, 'classic_fourrooms.gin')],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
Exemplo n.º 23
0
import numpy as np
import relabelling_replay_buffer
import tensorflow as tf
from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
import utils

flags.DEFINE_string("root_dir", None,
                    "Root directory for writing logs/summaries/checkpoints.")
flags.DEFINE_multi_string("gin_file", None,
                          "Path to the trainer config files.")
flags.DEFINE_multi_string("gin_bindings", None, "Gin binding to pass through.")

FLAGS = flags.FLAGS


@gin.configurable
def train_eval(
    root_dir,
    env_name="HalfCheetah-v2",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
Exemplo n.º 24
0
from absl import flags
from absl import logging

import gin
from tensor2tensor.jax import j2j

import tensorflow as tf

FLAGS = flags.FLAGS

flags.DEFINE_string("dataset", None, "Which dataset to use.")
flags.DEFINE_string("model", None, "Which model to train.")
flags.DEFINE_string("data_dir", None, "Path to the directory with data.")
flags.DEFINE_string("output_dir", None,
                    "Path to the directory to save logs and checkpoints.")
flags.DEFINE_multi_string("config_file", None,
                          "Configuration file with parameters (.gin).")
flags.DEFINE_multi_string("config", None,
                          "Configuration parameters (gin string).")

# For iterators over datasets so we can do "for example in dataset".
tf.enable_v2_behavior()


def j2j_train(model_name,
              dataset_name,
              data_dir=None,
              output_dir=None,
              config_file=None,
              config=None):
    """Main function to train the given model on the given dataset.
Exemplo n.º 25
0
def define_flags():
    """Defines flags."""
    flags.DEFINE_string('experiment',
                        default=None,
                        help='The experiment type registered.')

    flags.DEFINE_enum('mode',
                      default=None,
                      enum_values=[
                          'train', 'eval', 'train_and_eval', 'continuous_eval',
                          'continuous_train_and_eval'
                      ],
                      help='Mode to run: `train`, `eval`, `train_and_eval`, '
                      '`continuous_eval`, and `continuous_train_and_eval`.')

    flags.DEFINE_string(
        'model_dir',
        default=None,
        help='The directory where the model and training/evaluation summaries'
        'are stored.')

    flags.DEFINE_multi_string(
        'config_file',
        default=None,
        help='YAML/JSON files which specifies overrides. The override order '
        'follows the order of args. Note that each file '
        'can be used as an override template to override the default parameters '
        'specified in Python. If the same parameter is specified in both '
        '`--config_file` and `--params_override`, `config_file` will be used '
        'first, followed by params_override.')

    flags.DEFINE_string(
        'params_override',
        default=None,
        help='a YAML/JSON string or a YAML file which specifies additional '
        'overrides over the default parameters and those specified in '
        '`--config_file`. Note that this is supposed to be used only to override '
        'the model parameters, but not the parameters like TPU specific flags. '
        'One canonical use case of `--config_file` and `--params_override` is '
        'users first define a template config file using `--config_file`, then '
        'use `--params_override` to adjust the minimal set of tuning parameters, '
        'for example setting up different `train_batch_size`. The final override '
        'order of parameters: default_model_params --> params from config_file '
        '--> params in params_override. See also the help message of '
        '`--config_file`.')

    flags.DEFINE_multi_string('gin_file',
                              default=None,
                              help='List of paths to the config files.')

    flags.DEFINE_multi_string(
        'gin_params',
        default=None,
        help='Newline separated list of Gin parameter bindings.')

    flags.DEFINE_string(
        'tpu',
        default=None,
        help='The Cloud TPU to use for training. This should be either the name '
        'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
        'url.')

    flags.DEFINE_string('tf_data_service',
                        default=None,
                        help='The tf.data service address')
Exemplo n.º 26
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
from absl import app
from absl import flags

from tleague.model_pools.model_pool_apis import ModelPoolAPIs
from tleague.model_pools.model import Model

FLAGS = flags.FLAGS
flags.DEFINE_string("model_pool_addrs", "localhost:10003:10004",
                    "Model Pool address.")

flags.DEFINE_multi_string("model_path", [], "model file path")
flags.DEFINE_multi_string("model_key", [], "model_keys")


def main(_):
    model_pool_apis = ModelPoolAPIs(FLAGS.model_pool_addrs.split(','))
    keys = model_pool_apis.pull_keys()
    for key, model_path in zip(FLAGS.model_key, FLAGS.model_path):
        if key in keys:
            m = model_pool_apis.pull_model(key)
            with open(model_path, 'rb') as f:
                model = pickle.load(f)
            if isinstance(model, Model):
                model = model.model
            model_pool_apis.push_model(model, m.hyperparam, m.key,
                                       m.createtime, m.freezetime,
                                       m.updatetime)
Exemplo n.º 27
0
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Common flags."""

from absl import flags

flags.DEFINE_integer('trial_id', 0, 'The trial ID from 0 to num_trials-1.')
flags.DEFINE_integer('max_episode_len', 1000, 'Number of steps in an episode.')
flags.DEFINE_string('env_name', 'cartpole-swingup', 'Name of the environment.')
flags.DEFINE_string('root_dir', None,
                    'Path to output trajectories from data collection')
flags.DEFINE_multi_string('gin_files', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding parameters.')
flags.DEFINE_integer('seed', None, 'Random Seed for model_dir/data_dir')
Exemplo n.º 28
0
common_tpu_flags.define_common_tpu_flags()
common_hparams_flags.define_common_hparams_flags()

FLAGS = flags.FLAGS

FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet'

flags.DEFINE_string(
    'hparams_file',
    default=None,
    help=('Set of model parameters to override the default mparams.'))

flags.DEFINE_multi_string(
    'hparams',
    default=None,
    help=('This is used to override only the model hyperparameters. It should '
          'not be used to override the other parameters like the tpu specific '
          'flags etc. For example, if experimenting with larger numbers of '
          'train_steps, a possible value is '
          '--hparams=train_steps=28152.'))

flags.DEFINE_string(
    'default_hparams_file',
    default=os.path.join(os.path.dirname(__file__), './configs/default.yaml'),
    help=(
        'Default set of model parameters to use with this model. Look the at '
        'configs/default.yaml for this.'))

flags.DEFINE_integer(
    'resnet_depth',
    default=None,
    help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'
Exemplo n.º 29
0
from six.moves import range
import tensorflow as tf
from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding to pass through.')

FLAGS = flags.FLAGS


@gin.configurable
def bce_loss(y_true, y_pred, label_smoothing=0):
  loss_fn = tf.keras.losses.BinaryCrossentropy(
      label_smoothing=label_smoothing, reduction=tf.keras.losses.Reduction.NONE)
  return loss_fn(y_true[:, None], y_pred[:, None])


@gin.configurable
class ClassifierCriticNetwork(critic_network.CriticNetwork):
  """Creates a critic network."""
Exemplo n.º 30
0
        params = utils.get_gin_params_as_dict(gin.config._CONFIG)
        neptune.init(project_qualified_name="melindafkiss/sandbox")

        exp = neptune.create_experiment(params=params, name="exp")
        #ONLY WORKS FOR ONE GIN-CONFIG FILE
        with open(FLAGS.gin_file[0]) as ginf:
            param = ginf.readline()
            while param:
                param = param.replace('.','-').replace('=','-').replace(' ','').replace('\'','').replace('\n','').replace('@','')
                #neptune.append_tag(param)
                param = ginf.readline()
        #for tag in opts['tags'].split(','):
        #  neptune.append_tag(tag)
    else:
        neptune.init('shared/onboarding', api_token='ANONYMOUS', backend=neptune.OfflineBackend())

    er = ExperimentRunner(prefix=exp.id)
    er.train()

    params = utils.get_gin_params_as_dict(gin.config._OPERATIVE_CONFIG)
    for k, v in params.items():
        neptune.set_property(k, v)
    neptune.stop()
    print('fin')

if __name__ == '__main__':
    flags.DEFINE_multi_string('gin_file', None, 'List of paths to the config files.')
    flags.DEFINE_multi_string('gin_param', None, 'Newline separated list of Gin parameter bindings.')
    FLAGS = flags.FLAGS
    app.run(main)