Esempio n. 1
0
    def __init__(
        self,
        model,
        output_dir,
        trainer_class=trax.Trainer,
        loss_fn=trax.loss,
        optimizer=trax_opt.Adafactor,
        inputs=trax_inputs.inputs,
        action_multipliers=None,
        observation_metrics=(
            ("train", "metrics/accuracy"),
            ("train", "metrics/loss"),
            ("eval", "metrics/accuracy"),
            ("eval", "metrics/loss"),
        ),
        include_lr_in_observation=False,
        reward_metric=("eval", "metrics/accuracy"),
        train_steps=100,
        eval_steps=10,
        env_steps=100,
        start_lr=0.001,
        max_lr=10.0,
        observation_range=(0.0, 5.0),
        # Don't save checkpoints by default, as they tend to use a lot of
        # space.
        should_save_checkpoints=False):
        if action_multipliers is None:
            action_multipliers = self.DEFAULT_ACTION_MULTIPLIERS
        self._model = model
        self._trainer = trainer_class(
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            lr_schedule=(lambda history: lambda step: self._current_lr),
            inputs=inputs,
            should_save=should_save_checkpoints,
        )
        self._action_multipliers = action_multipliers
        self._observation_metrics = observation_metrics
        self._include_lr_in_observation = include_lr_in_observation
        self._reward_metric = reward_metric
        self._train_steps = train_steps
        self._eval_steps = eval_steps
        self._env_steps = env_steps
        self._start_lr = start_lr
        self._max_lr = max_lr

        self._output_dir = output_dir
        gfile.makedirs(self._output_dir)
        # Action is an index in self._action_multipliers.
        self.action_space = gym.spaces.Discrete(len(self._action_multipliers))
        # Observation is a vector with the values of the metrics specified in
        # observation_metrics plus optionally the learning rate.
        observation_dim = (len(self._observation_metrics) +
                           int(self._include_lr_in_observation))
        self._observation_range = observation_range
        (low, high) = self._observation_range
        self.observation_space = gym.spaces.Box(low=low,
                                                high=high,
                                                shape=(observation_dim, ))
Esempio n. 2
0
 def mkdir(self, path: str) -> bool:
     try:
         gfile.makedirs(path)
         return True
     except Exception as e:  # pylint: disable=broad-except
         logging.error('Error during create %s', e)
     return False
Esempio n. 3
0
def main(args):
    # The config below is the config for the area dataset used in the
    # robustness study paper. The configs for the rotation and location datasets
    # are included above.

    areas = list(np.arange(0.1, 0.9, 0.1))
    areas = [round(x, 2) for x in areas]

    config = {
        'coord': [(0.5, 0.5)],
        'area': areas,
        'rotation': [0],
        'bg_resolution': [(512, 512)],
    }

    new_dataset_dir = path.join(args.new_dataset_parent_dir, args.dataset_name,
                                '')
    if not gfile.exists(new_dataset_dir):
        gfile.makedirs(new_dataset_dir)

    dataset = synthetic.Dataset(config=config,
                                foregrounds_dir='foreground_samples/',
                                backgrounds_dir='background_samples/',
                                new_dataset_dir=new_dataset_dir,
                                num_bgs_per_fg_instance=2,
                                min_pct_inside_image=0.95)

    dataset.generate_dataset()
 def __init__(self, path, compress=True):  # pylint: disable=redefined-outer-name
     self.path = path
     self.compress = compress
     if gfile.exists(self.path):
         gfile.Remove(self.path)
     elif not gfile.exists(os.path.dirname(self.path)):
         gfile.makedirs(os.path.dirname(self.path))
Esempio n. 5
0
def main(_):
    master = jax.host_id() == 0
    # make sure TF does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    # load configs from a config json string
    hparams = FLAGS.config
    logging.info('=========== Hyperparameters ============')
    logging.info(hparams)

    if hparams.get('debug'):
        logging.warning('DEBUG MODE IS ENABLED!')

    # set tensorflow random seed
    tf.random.set_seed(jax.host_id() + hparams.rng_seed)
    experiment_dir = FLAGS.experiment_dir
    logging.info('Experiment directory: %s', experiment_dir)
    summary_writer = None

    if master and hparams.write_summary:
        tensorboard_dir = os.path.join(experiment_dir, 'tb_summaries')
        gfile.makedirs(tensorboard_dir)
        summary_writer = tensorboard.SummaryWriter(tensorboard_dir)

    run(hparams, experiment_dir, summary_writer)

    pool.close()
    pool.join()
Esempio n. 6
0
def train(dataset,
          agent,
          ckpt_dir,
          optimizer_type,
          learning_rate,
          batch_size,
          num_epochs,
          loss_fn,
          l2_weight,
          summary_dir=None,
          epochs_to_eval=(),
          shuffle_seed=None):
    """Train agent on dataset."""
    # print(agent.network.summary())
    if optimizer_type == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    else:
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    ckpt_path = None
    if ckpt_dir is not None:
        ckpt_path = os.path.join(ckpt_dir, 'ckpt')
        if not gfile.exists(os.path.dirname(ckpt_path)):
            gfile.makedirs(os.path.dirname(ckpt_path))
        # agent.save(ckpt_path + '_init')

    best_epoch = (train_on_dataset(dataset, agent, optimizer, batch_size,
                                   num_epochs, loss_fn, l2_weight, ckpt_path,
                                   summary_dir, epochs_to_eval, shuffle_seed))

    # Restore to best epoch.
    if ckpt_dir is not None:
        agent.load(ckpt_path)
    return best_epoch
Esempio n. 7
0
    def __init__(
        self,
        train_env,
        eval_env,
        output_dir,
        trajectory_dump_dir=None,
        trajectory_dump_min_count_per_shard=16,
    ):
        """Base class constructor.

    Args:
      train_env: EnvProblem to use for training. Settable.
      eval_env: EnvProblem to use for evaluation. Settable.
      output_dir: Directory to save checkpoints and metrics to.
      trajectory_dump_dir: Directory to dump trajectories to. Trajectories
        are saved in shards of name <epoch>.pkl under this directory. Settable.
      trajectory_dump_min_count_per_shard: Minimum number of trajectories to
        collect before dumping in a new shard. Sharding is for efficient
        shuffling for model training in SimPLe.
    """
        self.train_env = train_env
        self.eval_env = eval_env
        self._output_dir = output_dir
        gfile.makedirs(self._output_dir)
        self.trajectory_dump_dir = trajectory_dump_dir
        self._trajectory_dump_min_count_per_shard = (
            trajectory_dump_min_count_per_shard)
        self._trajectory_buffer = []
Esempio n. 8
0
    def test_preview_dataset_and_feature_metrics(self):
        # write data
        gfile.makedirs(self.default_dataset2.path)
        meta_path = os.path.join(self.default_dataset2.path, '_META')
        meta_data = {
            'dtypes': {
                'f01': 'bigint'
            },
            'samples': [
                [1],
                [0],
            ],
        }
        with gfile.GFile(meta_path, 'w') as f:
            f.write(json.dumps(meta_data))

        features_path = os.path.join(self.default_dataset2.path, '_FEATURES')
        features_data = {
            'f01': {
                'count': '2',
                'mean': '0.0015716767309123998',
                'stddev': '0.03961485047808605',
                'min': '0',
                'max': '1',
                'missing_count': '0'
            }
        }
        with gfile.GFile(features_path, 'w') as f:
            f.write(json.dumps(features_data))

        hist_path = os.path.join(self.default_dataset2.path, '_HIST')
        hist_data = {
            "f01": {
                "x": [
                    0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5,
                    0.6000000000000001, 0.7000000000000001, 0.8, 0.9, 1
                ],
                "y": [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
            }
        }
        with gfile.GFile(hist_path, 'w') as f:
            f.write(json.dumps(hist_data))

        response = self.client.get('/api/v2/datasets/2/preview')
        self.assertEqual(response.status_code, 200)
        preview_data = self.get_response_data(response)
        meta_data['metrics'] = features_data
        self.assertEqual(preview_data, meta_data, 'should has preview data')

        feat_name = 'f01'
        feature_response = self.client.get(
            f'/api/v2/datasets/2/feature_metrics?name={feat_name}')
        self.assertEqual(response.status_code, 200)
        feature_data = self.get_response_data(feature_response)
        self.assertEqual(
            feature_data, {
                'name': feat_name,
                'metrics': features_data.get(feat_name, {}),
                'hist': hist_data.get(feat_name, {})
            }, 'should has feature data')
Esempio n. 9
0
def main(argv):
    """Main function."""
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    # TODO(mohitreddy): Change to flags.mark_flag_as_required('model_dir').
    assert FLAGS.model_dir is not None, 'Please provide model_dir.'
    if not gfile.exists(FLAGS.model_dir):
        gfile.makedirs(FLAGS.model_dir)

    train_and_evaluate(seed=FLAGS.seed,
                       model_dir=FLAGS.model_dir,
                       num_epochs=FLAGS.num_epochs,
                       batch_size=FLAGS.batch_size,
                       embedding_size=FLAGS.embedding_size,
                       hidden_size=FLAGS.hidden_size,
                       min_freq=FLAGS.min_freq,
                       max_seq_len=FLAGS.max_seq_len,
                       dropout=FLAGS.dropout,
                       emb_dropout=FLAGS.emb_dropout,
                       word_dropout_rate=FLAGS.word_dropout_rate,
                       learning_rate=FLAGS.learning_rate,
                       checkpoints_to_keep=FLAGS.checkpoints_to_keep,
                       l2_reg=FLAGS.l2_reg)
Esempio n. 10
0
def main(_):
    with gfile.GFile(
            os.path.join(FLAGS.expert_policy_dir, f'{FLAGS.task}.pickle'),
            'rb') as f:
        agent = pickle.load(f)
    env = gym.make(f'visual-{FLAGS.task}-v0')
    env.seed(FLAGS.seed)
    im_size = FLAGS.image_size
    if im_size is not None:
        env.env.im_size = im_size

    if FLAGS.logdir is None:
        log_path = None
    else:
        logdir = os.path.join(FLAGS.logdir, f'{FLAGS.task}')
        run_id = '' if FLAGS.run_id is None else '_' + FLAGS.run_id
        if FLAGS.record_failed:
            run_id += '_all'
        if im_size is not None and im_size != adroit_ext.camera_kwargs[
                'im_size']:
            run_id += f'_{im_size}px'
        increment_str = 'i' if FLAGS.increment_seed else ''
        log_path = os.path.join(
            logdir,
            f's{FLAGS.seed}{increment_str}_e{FLAGS.num_episodes}{run_id}')
        gfile.makedirs(os.path.dirname(log_path))
        print('Writing to', log_path)
    env_loop(env, agent, FLAGS.num_episodes, log_path, FLAGS.record_failed,
             FLAGS.seed, FLAGS.increment_seed)
Esempio n. 11
0
def save_file(path, **data):
    """Saves numpy data arrays."""

    gfile.makedirs(path)
    for arr_name, arr in data.items():
        with gfile.GFile(os.path.join(path, arr_name + '.npy'), 'wb') as f:
            np.save(f, arr)
Esempio n. 12
0
def make_randoms(source_dir, number_of_images_per_folder,
                 number_of_random_folders):

    logging.basicConfig(filename=source_dir + '/logger.log',
                        level=logging.INFO)

    # Run script to download data to source_dir
    if not gfile.exists(source_dir):
        gfile.makedirs(source_dir)
    if not gfile.exists(os.path.join(
            source_dir, 'broden1_224/')) or not gfile.exists(
                os.path.join(source_dir, 'inception5h')):
        subprocess.call(['bash', 'FetchDataAndModels.sh', source_dir])

    # make targets from imagenet
    imagenet_dataframe = fetcher.make_imagenet_dataframe(
        "/home/tomohiro/code/tcav/tcav/tcav_examples/image_models/imagenet/imagenet_url_map.csv"
    )

    # Make random folders. If we want to run N random experiments with tcav, we need N+1 folders.
    fetcher.generate_random_folders(
        working_directory=source_dir,
        random_folder_prefix="random50",
        number_of_random_folders=number_of_random_folders + 1,
        number_of_examples_per_folder=number_of_images_per_folder,
        imagenet_dataframe=imagenet_dataframe)
Esempio n. 13
0
def main(_):
    config = flags.FLAGS

    gfile.makedirs(config.checkpoint_dir)
    if config.mode == "train":
        train(config)
    elif config.mode == "evaluate_pair":
        while True:
            checkpoint_path = utils.maybe_pick_models_to_evaluate(
                checkpoint_dir=config.checkpoint_dir)
            if checkpoint_path:
                evaluate_pair(
                    config=config,
                    batch_size=config.batch_size,
                    checkpoint_path=checkpoint_path,
                    data_dir=config.data_dir,
                    dataset=config.dataset,
                    num_examples_for_eval=config.num_examples_for_eval)
            else:
                logging.info(
                    "No models to evaluate found, sleeping for %d seconds",
                    EVALUATOR_SLEEP_PERIOD)
                time.sleep(EVALUATOR_SLEEP_PERIOD)
    else:
        raise Exception(
            "Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\""
            % (config.mode))
 def append_to_pickle(self, path):
     if not gfile.exists(os.path.dirname(path)):
         gfile.makedirs(os.path.dirname(path))
     with gfile.GFile(path, 'ab') as f:
         pickle.dump(
             (self.actions, self.base_actions, self.residual_actions), f)
     self._reset()
Esempio n. 15
0
    def __init__(
        self,
        train_env,
        eval_env,
        output_dir,
        trajectory_dump_dir=None,
        trajectory_dump_min_count_per_shard=16,
        async_mode=False,
    ):
        """Base class constructor.

    Args:
      train_env: EnvProblem to use for training. Settable.
      eval_env: EnvProblem to use for evaluation. Settable.
      output_dir: Directory to save checkpoints and metrics to.
      trajectory_dump_dir: Directory to dump trajectories to. Trajectories
        are saved in shards of name <epoch>.pkl under this directory. Settable.
      trajectory_dump_min_count_per_shard: Minimum number of trajectories to
        collect before dumping in a new shard. Sharding is for efficient
        shuffling for model training in SimPLe.
      async_mode: (bool) If True, this means we are in async mode and we read
        trajectories from a location rather than interact with the environment.
    """
        self.train_env = train_env
        self.eval_env = eval_env
        self._output_dir = output_dir
        gfile.makedirs(self._output_dir)
        self.trajectory_dump_dir = trajectory_dump_dir
        self._trajectory_dump_min_count_per_shard = (
            trajectory_dump_min_count_per_shard)
        self._trajectory_buffer = []
        self._async_mode = async_mode
Esempio n. 16
0
def main(unused_argv):
  del unused_argv  # Unused
  corpus = get_lm_corpus(FLAGS.data_dir)

  save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
  if not exists(save_dir):
    makedirs(save_dir)

  # test mode
  if FLAGS.per_host_test_bsz > 0:
    corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
                                FLAGS.tgt_len, FLAGS.num_core_per_host,
                                FLAGS=FLAGS)
    return

  for split, batch_size in zip(
      ["train", "valid"],
      [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

    if batch_size <= 0:
      continue

    print("Converting {} set...".format(split))
    corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
                                FLAGS.num_core_per_host, FLAGS=FLAGS)
Esempio n. 17
0
    def make_concept_folder(dataframe, concept):
        # Create the folder and save the dataframe as a csv file there
        path = os.path.join(source_dir, concept)
        if not gfile.exists(path):
            gfile.makedirs(path)

        concept_file_name = os.path.join(path, concept + ".csv")
        dataframe.to_csv(concept_file_name, index=False)
Esempio n. 18
0
    def begin(self):
        self._global_step_tensor = tf.train.get_global_step()
        if self._global_step_tensor is None:
            raise RuntimeError(
                'Global step should be created to use PlottingHook.')

        if not gfile.exists(self._logdir):
            gfile.makedirs(self._logdir)
Esempio n. 19
0
def make_concepts_targets_and_randoms(source_dir, number_of_images_per_folder,
                                      number_of_random_folders):

    logging.basicConfig(filename=source_dir + '/logger.log',
                        level=logging.INFO)

    # Run script to download data to source_dir
    if not gfile.exists(source_dir):
        gfile.makedirs(source_dir)
    if not gfile.exists(os.path.join(
            source_dir, 'broden1_224/')) or not gfile.exists(
                os.path.join(source_dir, 'inception5h')):
        subprocess.call(['bash', 'FetchDataAndModels.sh', source_dir])

    # make targets from imagenet
    imagenet_dataframe = fetcher.make_imagenet_dataframe(
        "/home/tomohiro/code/tcav/tcav/tcav_examples/image_models/imagenet/imagenet_url_map.csv"
    )
    all_class = imagenet_dataframe["class_name"].values.tolist()

    # Determine classes that we will fetch
    imagenet_classes = ['fire engine']
    broden_concepts = ['striped', 'dotted', 'zigzagged']
    random_except_concepts = ['zebra', 'fire engine']
    except_words = [
        'cat', 'shark', 'apron', 'dogsled', 'dumbbell', 'ball', 'bus'
    ]
    for e_word in except_words:
        random_except_concepts.extend([
            element for element in all_class
            if e_word == str(element)[-len(e_word):]
        ])

    tf.logging.info('imagenet_classe %s' % imagenet_classes)
    tf.logging.info('concepts %s' % broden_concepts)
    tf.logging.info('random_except_concepts %s' % random_except_concepts)

    for image in imagenet_classes:
        fetcher.fetch_imagenet_class(source_dir, image,
                                     number_of_images_per_folder,
                                     imagenet_dataframe)
    # Make concepts from broden
    for concept in broden_concepts:
        fetcher.download_texture_to_working_folder(
            broden_path=os.path.join(source_dir, 'broden1_224'),
            saving_path=source_dir,
            texture_name=concept,
            number_of_images=number_of_images_per_folder)

    # Make random folders. If we want to run N random experiments with tcav, we need N+1 folders.
    # (変更) 除外するクラスを指定
    fetcher.generate_random_folders(
        working_directory=source_dir,
        random_folder_prefix="random500",
        number_of_random_folders=number_of_random_folders + 1,
        number_of_examples_per_folder=number_of_images_per_folder,
        imagenet_dataframe=imagenet_dataframe,
        random_except_concepts=random_except_concepts)
Esempio n. 20
0
    def _make_root_class_dirs(self):
        """Make dataset root dir and subdir for each class."""
        if not gfile.exists(self.new_dataset_dir):
            gfile.makedirs(self.new_dataset_dir)

        for class_name in self.fg_classes:
            class_dir = path.join(self.new_dataset_dir, class_name, '')
            if not gfile.exists(class_dir):
                gfile.mkdir(class_dir)
Esempio n. 21
0
def create_fake_data(root_dir, data):
    fake_examples_dir = os.path.join(root_dir, 'testing', 'test_data',
                                     'fake_examples', '{dataset_name}')
    fake_examples_dir = fake_examples_dir.format(**data)
    gfile.makedirs(fake_examples_dir)

    fake_path = os.path.join(fake_examples_dir,
                             'TODO-add_fake_data_in_this_directory.txt')
    with gfile.GFile(fake_path, 'w') as f:
        f.write('{TODO}: Add fake data in this directory'.format(**data))
Esempio n. 22
0
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False
Esempio n. 23
0
def save_config_file(config_file, dest_dir):
    if not gfile.exists(dest_dir):
        gfile.makedirs(dest_dir)

    config_file_dest = os.path.join(dest_dir, 'blueoil_config.yaml')

    # HACK: This is for tensorflow bug workaround.
    # We can remove following 2 lines once it's been resolved in tensorflow
    # issue link: https://github.com/tensorflow/tensorflow/issues/28508
    if gfile.exists(config_file_dest):
        gfile.remove(config_file_dest)

    return gfile.copy(config_file, config_file_dest)
Esempio n. 24
0
    def _training(self, sess, step, fetches, profiler, collective_graph_key):

        should_profile = profiler and 0 <= step < 20
        need_options_and_metadata = (should_profile or collective_graph_key > 0
                                     or (self.trace_filename and step == 0))
        if need_options_and_metadata:
            run_options = tf.RunOptions()
            if (self.trace_filename and step == 0) or should_profile:
                run_options.trace_level = tf.RunOptions.FULL_TRACE
            if collective_graph_key > 0:
                run_options.experimental.collective_graph_key = collective_graph_key
            run_metadata = tf.RunMetadata()
        else:
            run_options = None
            run_metadata = None

        batch_start_time = time.time()
        results = sess.run(fetches,
                           options=run_options,
                           run_metadata=run_metadata)
        seconds_per_batch = time.time() - batch_start_time
        examples_per_second = self.batch_size / seconds_per_batch
        step = results['global_step']

        to_print = step % self.params.frequency_log_steps == 0
        if (self.is_master and to_print) or step == 1:
            epoch = ((step * self.batch_size) / self.reader.n_train_files)
            self.message.add("epoch", epoch, format="4.2f")
            self.message.add("step", step, width=5, format=".0f")
            self.message.add("lr", results['learning_rate'], format=".6f")
            self.message.add("loss", results['loss'], format=".4f")
            self.message.add("imgs/sec",
                             examples_per_second,
                             width=5,
                             format=".0f")
            logging.info(self.message.get_message())

        if need_options_and_metadata:
            if should_profile:
                profiler.add_step(step, run_metadata)
            if trace_filename and step == -2:
                logging.info('Dumping trace to {}'.filename(trace_filename))
                trace_dir = os.path.dirname(trace_filename)
                if not gfile.exists(trace_dir):
                    gfile.makedirs(trace_dir)
                with gfile.open(trace_filename, 'w') as trace_file:
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    trace_file.write(
                        trace.generate_chrome_trace_format(show_memory=True))
        return step
Esempio n. 25
0
def save_yaml(output_dir, config):
    """Save two yaml files.

    1. 'config.yaml' is duplication of python config file as yaml.
    2. 'meta.yaml' for application. The yaml's keys defined by `PARAMS_FOR_EXPORT`.
    """

    if not gfile.exists(output_dir):
        gfile.makedirs(output_dir)

    config_yaml_path = _save_config_yaml(output_dir, config)
    meta_yaml_path = _save_meta_yaml(output_dir, config)

    return config_yaml_path, meta_yaml_path
Esempio n. 26
0
def save_weights(model, path, overwrite=True):
  """Customized version of Keras model.save_weights().

  - Works with both local and remote storage.
  - Creates intermediate directories if missing.
  """
  tmp_dir = tempfile.mkdtemp()
  dirname, basename = os.path.split(path)
  tmp_path = os.path.join(tmp_dir, basename)

  model.save_weights(tmp_path)
  gfile.makedirs(dirname)
  gfile.copy(tmp_path, path, overwrite=overwrite)
  gfile.remove(tmp_path)
Esempio n. 27
0
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self.writer = tf.summary.FileWriter(log_dir, graph=None)
        self.end_summaries = []
        self.step = 0
        self.closed = False
Esempio n. 28
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        if self._should_write_summaries:
            self._train_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'train'),
                                                    enable=self.is_chief)
            self._eval_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'eval'),
                                                   enable=self.is_chief)

        # Reset the train and eval streams.
        self._train_stream = self._inputs.train_stream()
        # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
        #   set by adding a padding and stopping the stream when too large.
        self._eval_stream = _repeat_stream(self._inputs.eval_stream)
        self._train_eval_stream = _repeat_stream(
            self._inputs.train_eval_stream)

        # Restore the training state.
        state = load_trainer_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._new_opt_state_and_model_state()
            model_state = layers.nested_map(self._maybe_replicate, model_state)
        self._opt_state = OptState(
            *layers.nested_map(self._maybe_replicate, opt_state))
        self._model_state = model_state
        if not state.opt_state and self.is_chief:
            self._maybe_save_state(keep=False)

        self.update_nontrainable_params()
Esempio n. 29
0
    def __init__(self, remote_dir: str, local_dir: str):
        self._remote_dir = remote_dir
        self._local_dir = local_dir
        self._mu = threading.Lock()
        self._cond = threading.Condition(lock=self._mu)
        self._stopping = False
        self._epoch = 0
        gfile.makedirs(local_dir)

        remote_ents = _list_dir(remote_dir)
        for name, ent in remote_ents.items():
            _copy_file(remote_dir, local_dir, name)

        self._thread = threading.Thread(target=self._loop)
        self._thread.start()
Esempio n. 30
0
def dump_object(object_to_dump, output_path):
    """Pickle the object and save to the output_path.

    Args:
      object_to_dump: Python object to be pickled
      output_path: (string) output path which can be Google Cloud Storage

    Returns:
      None
    """
    path = f"gs://{output_path}"
    if not gfile.exists(path):
        gfile.makedirs(os.path.dirname(path))
    with gfile.GFile(path, "w") as wf:
        joblib.dump(object_to_dump, wf)