def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
               pretrained_checkpoint_step=-1, split="train"):
    """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1 (default), use the latest checkpoint from the pretrained model
        directory.
      split: str, the mixture/task split to finetune on.
    """
    if pretrained_checkpoint_step == 0:
      with gin.unlock_config():
        gin.parse_config_file(_operative_config_path(pretrained_model_dir))
      self.train(mixture_or_task_name, finetune_steps,
                init_checkpoint=None,
                split=split)
    else:
      if pretrained_checkpoint_step == -1:
        checkpoint_step = utils.get_latest_checkpoint_from_dir(
            pretrained_model_dir)
      else:
        checkpoint_step = pretrained_checkpoint_step
      with gin.unlock_config():
        gin.parse_config_file(_operative_config_path(pretrained_model_dir))

      model_ckpt = "model.ckpt-" + str(checkpoint_step)
      self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
                init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
                split=split)
  def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False,
                score_in_predict_mode=False):

    if not self._tpu or disable_tpu:
      with gin.unlock_config():
        gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
        gin.bind_parameter(
            "utils.get_variable_dtype.activation_dtype", "float32")
    with gin.unlock_config():
      gin.parse_config(self._gin_bindings)

    return utils.get_estimator(
        model_type=self._model_type,
        vocabulary=vocabulary,
        layout_rules=self._layout_rules,
        mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape,
        mesh_devices=None if disable_tpu else self._mesh_devices,
        model_dir=self._model_dir,
        batch_size=self.batch_size,
        sequence_length=self._sequence_length,
        autostack=self._autostack,
        learning_rate_schedule=self._learning_rate_schedule,
        keep_checkpoint_max=self._keep_checkpoint_max,
        save_checkpoints_steps=self._save_checkpoints_steps,
        optimizer=self._optimizer,
        predict_fn=self._predict_fn,
        variable_filter=self._variable_filter,
        ensemble_inputs=self._ensemble_inputs,
        use_tpu=None if disable_tpu else self._tpu,
        tpu_job_name=self._tpu_job_name,
        iterations_per_loop=self._iterations_per_loop,
        cluster=self._cluster,
        init_checkpoint=init_checkpoint,
        score_in_predict_mode=score_in_predict_mode)
示例#3
0
def load_model(instrument_model, audio_length):
    # Build checkpoint path
    # Assumes only one checkpoint in the folder, 'model.ckpt-[iter]`.
    model_dir = os.path.join(CKPT_DIR,
                             "solo_%s_ckpt" % instrument_model.lower())
    ckpt_files = [
        f for f in tf.gfile.ListDirectory(model_dir) if "model.ckpt" in f
    ]
    ckpt_name = ".".join(ckpt_files[0].split(".")[:2])
    ckpt = os.path.join(model_dir, ckpt_name)

    # Parse gin config
    with gin.unlock_config():
        gin_file = os.path.join(model_dir, "operative_config-0.gin")
        gin.parse_config_file(gin_file, skip_unknown=True)

    # Ensure dimensions sampling rates are equal
    time_steps_train = gin.query_parameter("DefaultPreprocessor.time_steps")
    n_samples_train = gin.query_parameter("Additive.n_samples")
    hop_size = int(n_samples_train / time_steps_train)

    time_steps = int(audio_length / hop_size)
    n_samples = time_steps * hop_size

    gin_params = [
        "Additive.n_samples = {}".format(n_samples),
        "FilteredNoise.n_samples = {}".format(n_samples),
        "DefaultPreprocessor.time_steps = {}".format(time_steps),
    ]

    with gin.unlock_config():
        gin.parse_config(gin_params)

    return ckpt, time_steps, n_samples
示例#4
0
    def load_model(self, model):
        #model = 'Flute2' #@param ['Violin', 'Flute', 'Flute2', 'Trumpet', 'Tenor_Saxophone','Upload your own (checkpoint folder as .zip)']
        MODEL = model
        self.model_name = model
        if model in ('Violin', 'Flute', 'Flute2', 'Trumpet',
                     'Tenor_Saxophone'):
            # Pretrained models.
            PRETRAINED_DIR = 'pretrained'
            model_dir = PRETRAINED_DIR
            gin_file = os.path.join(PRETRAINED_DIR, 'operative_config-0.gin')

        # Parse gin config,
        with gin.unlock_config():
            gin.parse_config_file(gin_file, skip_unknown=True)

        # Assumes only one checkpoint in the folder, 'ckpt-[iter]`.
        ckpt_files = [f for f in tf.io.gfile.listdir(model_dir) if 'ckpt' in f]
        ckpt_name = ckpt_files[0].split('.')[0]
        ckpt = os.path.join(model_dir, ckpt_name)

        # Ensure dimensions and sampling rates are equal
        time_steps_train = gin.query_parameter(
            'DefaultPreprocessor.time_steps')
        n_samples_train = gin.query_parameter('Additive.n_samples')
        hop_size = int(n_samples_train / time_steps_train)
        print(self.audio.shape[1])
        time_steps = int(self.audio.shape[1] / hop_size)
        print(time_steps)
        n_samples = time_steps * hop_size
        print(n_samples)
        gin_params = [
            'Additive.n_samples = {}'.format(n_samples),
            'FilteredNoise.n_samples = {}'.format(n_samples),
            'DefaultPreprocessor.time_steps = {}'.format(time_steps),
        ]

        with gin.unlock_config():
            gin.parse_config(gin_params)

        # Trim all input vectors to correct lengths
        for key in ['f0_hz', 'f0_confidence', 'loudness_db']:
            print(type(self.audio_features[key]))
            self.audio_features[key] = self.audio_features[key][:time_steps]

        print(self.audio_features['audio'].shape)
        print(n_samples)
        self.audio_features['audio'] = self.audio_features[
            'audio'][:, :n_samples]

        # Set up the model just to predict audio given new conditioning
        self.model = ddsp.training.models.Autoencoder()
        self.model.restore(ckpt)

        # Build model by running a batch through it.
        start_time = time.time()
        _ = self.model(self.audio_features, training=False)
        print('Restoring model took %.1f seconds' % (time.time() - start_time))
示例#5
0
    def _worker(self, root_dir, parameters, device_queue):
        # sleep for random seconds to avoid crowded launching
        try:
            time.sleep(random.uniform(0, 3))

            device = device_queue.get()
            if self._conf.use_gpu:
                os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            else:
                os.environ["CUDA_VISIBLE_DEVICES"] = ""  # run on cpu

            from alf.utils.common import set_per_process_memory_growth
            set_per_process_memory_growth()

            logging.set_verbosity(logging.INFO)

            logging.info("parameters %s" % parameters)
            with gin.unlock_config():
                gin.parse_config(
                    ['%s=%s' % (k, v) for k, v in parameters.items()])
            train_eval(root_dir)

            device_queue.put(device)
        except Exception as e:
            logging.info(e)
            raise e
def latent1d(ctx, rows, cols, plot, filename, **kwargs):
    """Latent space traversal in 1D."""
    add_gin(ctx, "config", ["evaluate/visual/latent1d.gin"])
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.dimensions",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.steps",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.traversal1d(
        ctx.obj["model"],
        dataset,
        dimensions=gin.REQUIRED,
        offset=gin.REQUIRED,
        skip_batches=gin.REQUIRED,
        steps=gin.REQUIRED,
    )
  def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      # The following config setting ensures we do scoring instead of inference.
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file)
示例#8
0
  def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
             temperature=1.0,
             sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH):
    """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
    """
    if checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)
      gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
      gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32")

    vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path)
    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    export_dir = export_dir or self._model_dir
    utils.export_model(
        self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary,
        self._sequence_length, batch_size=self.batch_size,
        checkpoint_path=os.path.join(self._model_dir, model_ckpt))
示例#9
0
    def configure_gin(self, ckpt):
        """Parse the model operative config with special streaming parameters."""
        parse_operative_config(ckpt)

        # Set streaming specific params.
        preprocessor_ref = gin.query_parameter('Autoencoder.preprocessor')
        preprocessor_str = preprocessor_ref.scoped_selector
        time_steps = gin.query_parameter(f'{preprocessor_str}.time_steps')
        n_samples = gin.query_parameter('Harmonic.n_samples')
        if not isinstance(n_samples, int):
            n_samples = gin.query_parameter('%n_samples')
        samples_per_frame = int(n_samples / time_steps)

        config = [
            'Autoencoder.preprocessor = @F0PowerPreprocessor()',
            'F0PowerPreprocessor.time_steps = 1',
            f'Harmonic.n_samples = {samples_per_frame}',
            f'FilteredNoise.n_samples = {samples_per_frame}',
        ]

        # Remove reverb and crop processors.
        processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
        config.append(processor_group_string)

        with gin.unlock_config():
            gin.parse_config(config)
示例#10
0
 def testSingleTrainingStepArchitectures(self,
                                         use_predictor,
                                         project_y=True,
                                         self_supervision="none"):
     parameters = {
         "architecture": c.RESNET_BIGGAN_ARCH,
         "lambda": 1,
         "z_dim": 120,
     }
     with gin.unlock_config():
         gin.bind_parameter("ModularGAN.conditional", True)
         gin.bind_parameter("loss.fn", loss_lib.hinge)
         gin.bind_parameter("S3GAN.use_predictor", use_predictor)
         gin.bind_parameter("S3GAN.project_y", project_y)
         gin.bind_parameter("S3GAN.self_supervision", self_supervision)
     # Fake ImageNet dataset by overriding the properties.
     dataset = datasets.get_dataset("imagenet_128")
     model_dir = self._get_empty_model_dir()
     run_config = tf.contrib.tpu.RunConfig(
         model_dir=model_dir,
         tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
     gan = S3GAN(dataset=dataset,
                 parameters=parameters,
                 model_dir=model_dir,
                 g_optimizer_fn=tf.train.AdamOptimizer,
                 g_lr=0.0002,
                 rotated_batch_fraction=2)
     estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
def descartes_builder(name='out', params=[]):
    print("Building grid search with parameters: ", params)

    directory = os.path.join('grids', name)
    if not os.path.exists(directory):
        os.makedirs(directory)

    all_values = []
    all_params = []
    for param in params:
        values = gin.query_parameter(param)
        all_params.append(param)
        all_values.append(values)
    descartes = itertools.product(*all_values)

    i = 0
    for one in descartes:

        exp_directory = os.path.join(directory, str(i))
        if not os.path.exists(exp_directory):
            os.makedirs(exp_directory)

        with gin.unlock_config():
            for param_idx in range(len(all_params)):
                gin.bind_parameter(all_params[param_idx], one[param_idx])

        config_str = gin.config_str()
        with open(os.path.join(exp_directory, 'config.gin'), 'w+') as f:
            f.write(config_str)
        i += 1
    pass
示例#12
0
    def finetune(self, mixture_or_task_name, finetune_steps,
                 pretrained_model_dir, checkpoint_step):
        """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      checkpoint_step: int, checkpoint to initialize weights from. If -1, use
        the latest checkpoint from the pretrained model directory.
    """
        if checkpoint_step == -1:
            checkpoint_step = get_latest_checkpoint_from_dir(
                pretrained_model_dir)

        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        with gin.unlock_config():
            gin.parse_config_file(
                os.path.join(pretrained_model_dir, "operative_config.gin"))
        self.train(mixture_or_task_name,
                   checkpoint_step + finetune_steps,
                   init_checkpoint=os.path.join(pretrained_model_dir,
                                                model_ckpt))
示例#13
0
  def configure_gin(self, ckpt):
    """Parse the model operative config with special streaming parameters."""
    parse_operative_config(ckpt)

    # Set streaming specific params.
    time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
    n_samples = gin.query_parameter('Harmonic.n_samples')
    samples_per_frame = int(n_samples / time_steps)
    config = [
        'F0PowerPreprocessor.time_steps = 1',
        f'Harmonic.n_samples = {samples_per_frame}',
        f'FilteredNoise.n_samples = {samples_per_frame}',
    ]

    # Remove reverb processor.
    processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
    config.append(processor_group_string)

    with gin.unlock_config():
      gin.parse_config(config)
示例#14
0
  def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
      gin.bind_parameter("G.batch_norm_fn", evonorm_s0)

    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")

    gan = CLGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
    )
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
示例#15
0
  def test_build_layer(self, kwarg_modules):
    """Tests if layer builds properly and produces outputs of correct shape."""
    gin_config = (self.gin_config_kwarg_modules if kwarg_modules else
                  self.gin_config_dag_modules)
    with gin.unlock_config():
      gin.clear_config()
      gin.parse_config(gin_config)

    dag_layer = ConfigurableDAGLayer()
    outputs = dag_layer(self.inputs)
    self.assertIsInstance(outputs, dict)

    z = outputs['bottleneck']['z_bottleneck']
    x_rec = outputs['decoder']['reconstruction']
    x_rec2 = outputs['out']['reconstruction']

    # Confirm that layer generates correctly sized tensors.
    self.assertEqual(outputs['test_data'].shape, self.x.shape)
    self.assertEqual(outputs['inputs']['test_data'].shape, self.x.shape)
    self.assertEqual(x_rec.shape, self.x.shape)
    self.assertEqual(z.shape[-1], self.z_dims)
    self.assertAllClose(x_rec, x_rec2)

    # Confirm that variables are inherited by DAGLayer.
    self.assertLen(dag_layer.trainable_variables, 6)  # 3 weights, 3 biases.
示例#16
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            # Copy the config file from gstorage
            helper_functions.copy_config_file_from_gstorage(
                operative_config, LAST_OPERATIVE_CONFIG_PATH)
            logging.info('Using operative config: %s', operative_config)
            gin.parse_config_file(LAST_OPERATIVE_CONFIG_PATH,
                                  skip_unknown=True)
            # gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
示例#17
0
def parse_gin(model_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        operative_config = os.path.join(model_dir, 'operative_config-0.gin')
        if tf.io.gfile.exists(operative_config):
            gin.parse_config_file(operative_config, skip_unknown=True)

        # Only use the custom cumsum for TPUs.
        gin.parse_config('ddsp.core.cumsum.use_tpu={}'.format(use_tpu))

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
示例#18
0
    def reset_config(self, new_config):
        """Resets trainer config for fast PBT implementation."""
        with gin.unlock_config():
            _parse_config_item(None, new_config['mutable_bindings'])

        if new_config.get('reset_session'):
            self.reset()
        else:
            if new_config.get('reset_optimizer'):
                self._optimizer = None
                self._model = None
            if new_config.get('reset_problem'):
                self._problem = None
                self._generators = None
            if new_config.get('reset_generators'):
                self._generators = None
            if new_config.get('reset_model'):
                with tempfile.TemporaryDirectory() as tmp_dir:
                    checkpoint = self._save(tmp_dir)
                    self._model = None
                    self._restore(checkpoint)
        self.config = new_config
        self._reset_callbacks()
        self._save_operative_config()
        return True
示例#19
0
文件: inference.py 项目: zeeps31/ddsp
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with special streaming parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set streaming specific params.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
         n_samples = gin.query_parameter('Harmonic.n_samples')
         samples_per_frame = int(n_samples / time_steps)
         gin.parse_config([
             'F0PowerPreprocessor.time_steps=1',
             f'Harmonic.n_samples={samples_per_frame}',
             f'FilteredNoise.n_samples={samples_per_frame}',
             pg_string,
         ])
 def create_env(self):
     """Create the RL environment."""
     """Create an environment according to config."""
     if 'env_config_file' in self.config:
         with gin.unlock_config():
             gin.parse_config_file(self.config['env_config_file'])
     return load_env()
示例#21
0
    def _setup(self, config):
        util.tf_init(gpus=None,
                     allow_growth=config.get('allow_growth', True),
                     eager=False)
        logging.set_verbosity(config.get('verbosity', logging.INFO))

        logging.info("calling setup")
        config_files = config['config_files']
        if isinstance(config_files, six.string_types):
            config_files = [config_files]

        with gin.unlock_config():
            config_dir = config.get('config_dir')
            if config_dir is None:
                config_dir = util.get_config_dir()
            util.parse_config(config_dir,
                              config_files,
                              config['bindings'],
                              finalize_config=False)
            _parse_config_item(None, config['mutable_bindings'])
            gin.finalize()

        self._generators = None
        self._problem = None
        self._optimizer = None
        self._model = None
        self._reset_callbacks()

        wp = config.get('initial_weights_path', None)
        if wp is not None:
            self.model.load_weights(wp)

        self._save_operative_config()
示例#22
0
def parse_gin_defaults_and_flags() -> None:
    """Parses all default gin files and those provided via flags."""
    args = _parse_args()  # TODO: should we require that args be passed in?

    for path in [
            *(args.gin_location_prefix or []),
            Path(__file__).parent.parent.joinpath("config"),
            pkg_resources.resource_filename("t5.models", "gin"),
            pkg_resources.resource_filename("mesh_tensorflow.transformer",
                                            "gin"),
    ]:
        gin.add_config_file_search_path(path)

    try:
        # attempt to parse these first so they can be overridden later
        gin.parse_config_file("defaults.gin")
        gin.parse_config_file("operative_config.gin")
    except IOError:
        pass

    gin.parse_config_files_and_bindings(args.gin_file, args.gin_param)

    # make it so we don't have to specify a unique model_dir each time
    model_dir = gin.query_parameter("utils.run.model_dir").format(
        hostname=platform.node(),
        timestamp=RUN_TIMESTAMP,
    )
    with gin.unlock_config():
        gin.bind_parameter("utils.run.model_dir", model_dir)

    tf_logging()
示例#23
0
def evaluate_tune(log_dir,
                  split='validation',
                  verbose=True,
                  eval_batch_size=None):
    from pointnet import train
    log_dir = os.path.realpath(os.path.expanduser(os.path.expandvars(log_dir)))
    paths = [os.path.join(log_dir, d) for d in tf.io.gfile.listdir(log_dir)]
    op_configs = [d for d in paths if d.endswith('.gin')]
    op_config = max(op_configs,
                    key=lambda x: parse_operative_config_path(x)[1])
    dirs = [d for d in paths if tf.io.gfile.isdir(d)]
    final_dir = max(dirs, key=lambda x: int(x.split('_')[-1]))
    logging.info('Using operative_config {}, checkpoint {}'.format(
        op_config, final_dir))

    with gin.unlock_config():
        gin.parse_config_file(op_config)
    chkpt_callback = cb.ModelCheckpoint(final_dir)
    if eval_batch_size is None:
        eval_batch_size = blocks.batch_size()

    train.evaluate(blocks.problem(),
                   blocks.model_fn(),
                   blocks.optimizer(),
                   eval_batch_size,
                   chkpt_callback,
                   split=split,
                   verbose=verbose)
示例#24
0
    def _worker(self, root_dir, parameters, device_queue):
        # sleep for random seconds to avoid crowded launching
        try:
            time.sleep(random.uniform(0, 3))

            device = device_queue.get()
            if self._conf.use_gpu:
                os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            else:
                os.environ["CUDA_VISIBLE_DEVICES"] = ""  # run on cpu

            if torch.cuda.is_available():
                alf.set_default_device("cuda")
            logging.set_verbosity(logging.INFO)

            logging.info("Search parameters %s" % parameters)
            with gin.unlock_config():
                gin.parse_config(
                    ['%s=%s' % (k, v) for k, v in parameters.items()])
                gin.parse_config(
                    "TrainerConfig.confirm_checkpoint_upon_crash=False")
            train_eval(FLAGS.ml_type, root_dir)

            device_queue.put(device)
        except Exception as e:
            logging.info(traceback.format_exc())
            raise e
def fixed(ctx, batch_size, filename, rows, cols, plot, verbose, **kwargs):
    """View/save images of dataset given a fixed latent factor."""
    dataset = ctx.obj['dataset']
    add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)])
    parse(ctx)

    with gin.unlock_config():
        gin.bind_parameter('disentangled.visualize.show.output.show_plot',
                           plot)
        gin.bind_parameter('disentangled.visualize.show.output.filename',
                           filename)
        if rows is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.rows',
                               rows)
        if cols is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.cols',
                               cols)

    num_values_per_factor = disentangled.dataset.get(
        dataset).num_values_per_factor
    dataset = disentangled.dataset.get(dataset).supervised()

    fixed, _ = disentangled.metric.utils.fixed_factor_dataset(
        dataset, batch_size, num_values_per_factor)

    disentangled.visualize.fixed_factor_data(fixed,
                                             rows=gin.REQUIRED,
                                             cols=gin.REQUIRED,
                                             verbose=verbose)
示例#26
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))
        eval_default = 'eval/basic.gin'
        gin.parse_config_file(eval_default)

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            logging.info('Using operative config: %s', operative_config)
            operative_config = cloud.make_file_paths_local(
                operative_config, GIN_PATH)
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
        gin.parse_config_files_and_bindings(gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
  def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation"):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
def visual(ctx, rows, cols, plot, filename, **kwargs):
    """Qualitative evaluation of output."""
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.rows",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.cols",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.reconstructed(ctx.obj["model"],
                                         dataset,
                                         rows=gin.REQUIRED,
                                         cols=gin.REQUIRED)
  def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
             temperature=1.0, vocabulary=None):
    """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      The string path to the exported directory.
    """
    if checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    export_dir = export_dir or self._model_dir
    return utils.export_model(
        self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary,
        self._sequence_length, batch_size=self.batch_size,
        checkpoint_path=os.path.join(self._model_dir, model_ckpt))
示例#30
0
def load_operative_gin_configurations(operative_config_dir):
    """Load operative Gin configurations from the given directory."""
    gin_log_file = operative_config_path(operative_config_dir)
    with gin.unlock_config():
        gin.parse_config_file(gin_log_file)
    gin.finalize()
    logging.info('Operative Gin configurations loaded from %s.', gin_log_file)