Exemplo n.º 1
0
def save_samples_to_json(features: List[Dict[str, Any]],
                         config: ml_collections.ConfigDict, step: int):
    """Save samples to a json file."""
    save_samples_for_this_step = (
        config.get('save_samples_every_steps')
        and (step % config.get('save_samples_every_steps') == 0))

    process_index = jax.process_index()
    accepted_processes = config.get('save_samples_process_ids', 0)
    if isinstance(accepted_processes, list):
        save_samples_for_this_process = (process_index in accepted_processes)
    elif accepted_processes == -1:
        save_samples_for_this_process = True
    else:
        save_samples_for_this_process = (process_index == accepted_processes)

    if save_samples_for_this_step and save_samples_for_this_process:
        logging.info('Saving samples at step %d, process %d', step,
                     process_index)
        path = os.path.join(config.model_dir, 'samples',
                            'step_%d.process_%d.json' % (step, process_index))
        tf.io.gfile.makedirs(os.path.dirname(path))
        with tf.io.gfile.GFile(path, 'ab') as fp:
            for batch in features:
                json.dump(batch, fp)
                fp.write('\n')
Exemplo n.º 2
0
def config_to_opt_args(config: ml_collections.ConfigDict):
    opt_kwargs = dict(eps=config.get('opt_eps'),
                      decay=config.get('opt_decay'),
                      momentum=config.get('opt_momentum'),
                      beta1=config.get('opt_beta1'),
                      beta2=config.get('opt_beta2'),
                      weight_decay=config.get('opt_weight_decay', 0))
    opt_kwargs = {k: v for k, v in opt_kwargs.items() if v is not None}
    return opt_kwargs
Exemplo n.º 3
0
def get_optimizer(
    config: ml_collections.ConfigDict) -> tf.keras.optimizers.Optimizer:
  """Returns an optimizer based on the given configuration.

  Supports the Adam optimizer. Default values for optional optimizer parameters
  come from TensorFlow Core v2.4.1:
  https://www.tensorflow.org/api_docs/python/tf/keras/optimizers.

  Args:
    config: A ConfigDict containing a `config.opt` sub-config.

  Returns:
    A tf.keras optimizer.

  Raises:
    ValueError: `config` missing the `opt` sub-config.
    ValueError: `config.opt` contains an invalid optimizer class.
    ValueError: `config.schedule` contains an invalid schedule class.
  """
  opt_config = config.get('opt', None)
  if opt_config is None:
    raise ValueError(f'Provided `config` missing `opt` sub-config: {config}')

  initial_learning_rate = opt_config.get('learning_rate', 0.001)
  steps_per_epoch = config.dataset.num_train_examples / config.dataset.batch_size

  schedule_config = config.get('schedule', {})
  schedule_type = schedule_config.get('schedule', None)
  if schedule_type == 'exponential':
    decay_steps = int(schedule_config.epochs_per_decay * steps_per_epoch)
    learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=decay_steps,
        decay_rate=schedule_config.decay_rate,
        staircase=schedule_config.staircase)
  elif schedule_type is None:
    learning_rate = initial_learning_rate
    print(f'No LR schedule provided. Using a fixed LR of "{learning_rate}".')
  else:
    raise ValueError(f'Unknown scheduler name: {schedule_type}')

  opt_type = opt_config.get('optimizer', None)
  if opt_type == 'adam':
    opt = tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        beta_1=opt_config.get('beta_1', 0.9),
        beta_2=opt_config.get('beta_2', 0.999),
        epsilon=opt_config.get('epsilon', 1e-07),
        amsgrad=opt_config.get('amsgrad', False))
    start_step = int(steps_per_epoch * config.train.initial_epoch)
    if opt_config.get('use_model_averaging', True):
      opt = tfa.optimizers.MovingAverage(
          opt, average_decay=0.9999, start_step=start_step)
    return opt

  raise ValueError(f'Unknown optimizer name: {opt_type}')
Exemplo n.º 4
0
def get_defaults():
    config = ConfigDict()
    # Covariance matrix stabilising jitter
    config.jitter = 1e-6
    # Parameter transformations
    config.transformations = transformations = ConfigDict()
    transformations.positive_transform = tfb.Softplus
    transformations.identity_transform = tfb.Identity

    transformations.lengthscale = "positive_transform"
    transformations.variance = "positive_transform"
    transformations.obs_noise = "positive_transform"
    transformations.latent = "identity_transform"
    transformations.basis_fns = "identity_transform"
    return config
Exemplo n.º 5
0
def build_model_graph(
    config: ml_collections.ConfigDict,
    outcomes: List[ml_collections.ConfigDict]) -> tf.keras.Model:
  """Returns a tf.keras.Model configured with the given ConfigDict.

  Args:
    config: A ConfigDict containing a `config.model` sub-config.
    outcomes: A list of outcome ConfigDict instances.

  Returns:
    A tf.keras.Model.

  Raises:
    ValueError: `config` missing the `model` sub-config.
    ValueError: `config.model` contains an invalid model backbone.
  """
  model_config = config.get('model', None)
  if model_config is None:
    raise ValueError(f'Provided `config` missing `model` sub-config: {config}')

  model_backbone = model_config.get('backbone', None)

  # Config specifies model choice.
  if model_backbone == 'inceptionv3':
    return inceptionv3(model_config, outcomes)

  raise ValueError(f'Unknown model backbone: {model_backbone}')
Exemplo n.º 6
0
def main(cfg, make_env, experiment_name):
    exp_dir = os.path.join(cfg.save_dir, experiment_name)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
        with open(os.path.join(exp_dir, "config.yaml"), "w") as fp:
            yaml.dump(ConfigDict.to_dict(cfg), fp)
    else:
        raise ValueError("Experiment already exists.")
    logger = Logger(
        exp_dir, save_tb=True, log_frequency=cfg.log_frequency, agent="sac",
    )
    utils.set_seed_everywhere(cfg.seed)
    device = torch.device(cfg.device)
    env = make_env()
    cfg.sac.obs_dim = env.observation_space.shape[0]
    cfg.sac.action_dim = env.action_space.shape[0]
    cfg.sac.action_range = [
        float(env.action_space.low.min()),
        float(env.action_space.high.max()),
    ]
    agent = SACAgent(cfg.sac, device)
    replay_buffer = ReplayBuffer(
        env.observation_space.shape,
        env.action_space.shape,
        int(cfg.replay_buffer_capacity),
        device,
    )
    video_recorder = VideoRecorder(exp_dir if cfg.save_video else None)
    train(env, logger, video_recorder, cfg, agent, replay_buffer)
Exemplo n.º 7
0
    def dummy_input(config: ml_collections.ConfigDict) -> Dict[Text, Any]:
        """Produces model-specific dummy input batch. See BaseTask."""

        if config.get('max_length_with_entity_tokens') is not None:
            max_length = config.max_length_with_entity_tokens
        else:
            max_length = config.model_config.encoder_config.max_length

        bsz = config.per_device_batch_size
        text_shape = (bsz, max_length)
        mention_shape = (config.max_mentions)
        int_type = jnp.int32

        position_ids = np.arange(max_length)
        position_ids = np.tile(position_ids, (bsz, 1))

        dummy_input = {
            'text_ids': jnp.ones(text_shape, int_type),
            'text_mask': jnp.ones(text_shape, int_type),
            'position_ids': jnp.asarray(position_ids, int_type),
            'segment_ids': jnp.zeros(text_shape, int_type),
            'classifier_target': jnp.ones(bsz, int_type),
            'mention_start_positions': jnp.zeros(mention_shape, int_type),
            'mention_end_positions': jnp.zeros(mention_shape, int_type),
            'mention_mask': jnp.ones(mention_shape, int_type),
            'mention_batch_positions': jnp.ones(mention_shape, int_type),
        }

        return dummy_input
Exemplo n.º 8
0
 def get_default_config():
     config = ConfigDict()
     config.online = False
     config.prefix = 'SimpleSAC'
     config.project = 'sac'
     config.output_dir = '/tmp/SimpleSAC'
     config.random_delay = 0.0
     config.experiment_id = ''
     return config
Exemplo n.º 9
0
def load_config_from_dir(exp_dir):
    """Load experiment config."""
    try:
        with open(os.path.join(exp_dir, "config.yaml"), "r") as fp:
            cfg = yaml.load(fp, Loader=yaml.FullLoader)
        return ConfigDict(cfg)
    except FileNotFoundError as e:
        raise e
Exemplo n.º 10
0
def get_config():
    config = ConfigDict()
    config.constructor = "xgboost.XGBRegressor"
    config.hparams = ConfigDict({
        'tree_method': "gpu_hist",
        'booster': "gbtree",
        'n_estimators': 250,
        'learning_rate': 1e-2,
        'max_depth': 5,
        'reg_alpha': 1.0,
        'reg_lambda': 1.0,
        'min_child_weight': 0.0,
        'subsample': 0.8,
        'colsample_bytree': 0.8,
        'num_parallel_tree': 1,
    })
    return config
Exemplo n.º 11
0
 def get_config(self):
   config = ConfigDict()
   config.hidden_size = 128
   config.ff_size = 256
   config.num_heads = 2
   config.num_encoder_layers = 2
   config.num_symbols = 8
   return config
Exemplo n.º 12
0
def load_data(idx, mode='test'):
    run = api.run(os.path.join(USER, PROJECT, idx))
    data_module, _ = load_dataset(ConfigDict(run.config))
    data_module.prepare_data()
    if mode == 'test':
        data_module.setup('test')
        return data_module.test_dataloader()
    elif mode == 'train':
        data_module.setup('fit')
        return data_module.train_dataloader()
Exemplo n.º 13
0
def build_outcome_head(config: ml_collections.ConfigDict,
                       inputs: tf.Tensor,
                       l2: float = 0.0) -> tf.Tensor:
  """Returns an output head tensor configured for the given outcome.

  Supports regression, binary classification, and multinomial classification
  outcomes.

  Note: binary classification labels are assumed to be of shape (2,). Binary
  heads consist of a `tf.keras.layers.Dense(2)` with a softmax activation.

  Args:
    config: An outcome ConfigDict.
    inputs: The backbone output tensor; used as the input to the head.
    l2: The l2 regularization factor used in `tf.keras.layers.Dense` layers.

  Returns:
    A tensor representing the output of the given head.

  Raises:
    ValueError: `config` missing a valid `type`.
    ValueError: A binary classification config uses num_classes=1 rather than
    num_classes=2.
  """
  outcome_type = config.get('type', None)
  if outcome_type is None:
    raise ValueError(f'Provided `config` missing `type`: {config}')

  l2_regularizer = tf.keras.regularizers.L2(l2) if l2 else None

  if outcome_type == 'regression':
    head = tf.keras.layers.Dense(
        1,
        dtype=tf.float32,
        name=config.name,
        kernel_regularizer=l2_regularizer)
    return head(inputs)

  if outcome_type == 'classification':
    if config.num_classes < 2:
      raise ValueError('Binary heads should specify `config.num_classes=2`.'
                       'Binary labels are assumed to be one-hot vectors.')

    head = tf.keras.layers.Dense(
        config.num_classes,
        activation='softmax',
        dtype=tf.float32,
        name=config.name,
        kernel_regularizer=l2_regularizer)
    return head(inputs)

  raise ValueError(f'Unknown outcome type: {outcome_type}')
Exemplo n.º 14
0
def inceptionv3(model_config: ml_collections.ConfigDict,
                outcomes: List[ml_collections.ConfigDict]) -> tf.keras.Model:
  """Returns an InceptionV3 architecture as defined by the configuration.

  See https://tensorflow.org/api_docs/python/tf/keras/applications/InceptionV3.

  Args:
    model_config: A ConfigDict containing model hyperparamters.
    outcomes: A list of outcome ConfigDict instances.

  Returns:
    An InceptionV3-based model.
  """
  input_shape = model_config.get('input_shape', DEFAULT_IMAGE_SHAPE)

  backbone = tf.keras.applications.InceptionV3(
      include_top=False,
      weights=model_config.get('weights', 'imagenet'),
      input_shape=input_shape,
      pooling=model_config.get('pooling', 'avg'))
  weight_decay = model_config.get('weight_decay', 0.0)
  if weight_decay:
    backbone = add_l2_regularizers(
        backbone, tf.keras.layers.Conv2D, l2=weight_decay)
  backbone_drop_rate = model_config.get('backbone_drop_rate', 0.2)

  inputs_image = tf.keras.Input(shape=input_shape, name='image')
  hid = backbone(inputs_image)
  hid = tf.keras.layers.Dropout(backbone_drop_rate)(hid)

  outputs = []
  for outcome in outcomes:
    outputs.append(build_outcome_head(outcome, hid, l2=weight_decay))

  model = tf.keras.Model(
      inputs=[inputs_image], outputs=outputs, name=model_config.backbone)
  model.summary()
  print(f'Number of l2 regularizers: {len(model.losses)}.')
  return model
Exemplo n.º 15
0
def setup_experiment(exp_dir):
    """Initializes a training experiment."""
    if os.path.exists(exp_dir):
        if not FLAGS.resume:
            raise ValueError(
                "Experiment already exists. Run with --resume to continue.")
        with open(os.path.join(exp_dir, "config.yaml"), "r") as fp:
            cfg = yaml.load(fp, Loader=yaml.FullLoader)
        FLAGS.config.update(cfg)
    else:
        os.makedirs(exp_dir)
        with open(os.path.join(exp_dir, "config.yaml"), "w") as fp:
            yaml.dump(ConfigDict.to_dict(FLAGS.config), fp)
def create_train_state(config: ml_collections.ConfigDict, params, model_state):
    """Create initial training state."""
    dynamic_scale = None
    platform = jax.local_devices()[0].platform
    if config.half_precision and platform == 'gpu':
        dynamic_scale = flax.optim.DynamicScale()

    opt_kwargs = dict(eps=config.get('opt_eps'),
                      beta1=config.get('opt_beta1'),
                      beta2=config.get('opt_beta2'),
                      weight_decay=config.get('opt_weight_decay', 0))
    opt_kwargs = {k: v
                  for k, v in opt_kwargs.items()
                  if v is not None}  # remove unset
    optimizer = create_optim(config.opt, params, **opt_kwargs)
    ema = EmaState.create(config.ema_decay, optimizer.target, model_state)

    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=model_state,
                       dynamic_scale=dynamic_scale,
                       ema=ema)
    return state
Exemplo n.º 17
0
 def test_factorized_attention(self):
     config = ConfigDict()
     config.hidden_size = 256
     config.ff_size = 256
     config.num_encoder_layers = 2
     config.num_heads = 2
     fact = layers.FactorizedAttention(config)
     inputs = tf.random.uniform(shape=(8, 8, 8, 256))
     output = fact(inputs)
     self.assertEqual(output.shape, (8, 8, 8, 256))
  def make_loss_fn(
      cls, config: ml_collections.ConfigDict
  ) -> Callable[..., Tuple[float, MetricGroups, Dict[str, Any]]]:
    """Creates task loss function.

    See BaseTask.

    Model is trained using entity linking loss.

    Args:
      config: contains experiment hyperparameters.

    Returns:
      Loss function.
    """
    el_score_mode = config.get('el_score_mode', 'dot')

    def loss_fn(
        model_config: ml_collections.FrozenConfigDict,
        model_params: Dict[str, Any],
        model_vars: Dict[str, Any],  # pylint: disable=unused-argument
        batch: Dict[str, Any],
        deterministic: bool,
        dropout_rng: Optional[Dict[str, Array]] = None,
    ) -> Tuple[float, MetricGroups, Dict[str, Any]]:
      """Task-specific loss function. See BaseTask."""

      loss_helpers, logging_helpers = cls.build_model(model_config).apply(  # pylint: disable=unused-variable
          {'params': model_params},
          batch,
          deterministic=deterministic,
          rngs=dropout_rng)

      mention_target_ids = batch['mention_target_ids']
      mention_target_ids = mention_target_ids * batch['mention_target_weights']

      (loss, el_final_metrics, _) = mention_losses.entity_linking_loss(
          loss_helpers['target_mention_encodings'],
          loss_helpers['entity_embeddings'], mention_target_ids,
          batch['mention_target_weights'], el_score_mode)

      metrics = {'agg': el_final_metrics}

      return loss, metrics, {}

    return loss_fn
Exemplo n.º 19
0
    def __init__(self, config: ml_collections.ConfigDict):
        self.memory_reduction = config.memory_reduction
        self.memory_entity_id_pattern = config.memory_entity_id_pattern
        self.memory_text_pattern = config.memory_text_pattern
        self.memory_positions_pattern = config.memory_positions_pattern
        self.save_k_retrieval = config.get('save_k_retrieval', 10)
        if self.save_k_retrieval is not None and config.model_config.encoder_config.get(
                'k_top_post_selection') is None:
            raise Exception(
                'save_k_retrieval only allowed with k_top_post_selection')

        # Lazy load for memory
        self.memory_entity_id = None
        self.memory_text = None
        self.memory_positions = None

        # TODO(urikz): Move `memory_prop` to `data_utils.load_sharded_array`,
        # e.g. array = array[:int(memory_prop * array.shape[0])]
        assert config.memory_prop is None
Exemplo n.º 20
0
def get_loss(config: ml_collections.ConfigDict) -> tf.losses.Loss:
  """Returns a loss for use in training and evaluation.

  Args:
    config: A ConfigDict containing a `config.loss` name.

  Returns:
    A loss function.

  Raises:
    ValueError: `config.loss` is missing or an unknown loss name.
  """
  loss_name = config.get('loss', None)

  if loss_name == 'ce':
    return tf.keras.losses.CategoricalCrossentropy(from_logits=False)

  if loss_name == 'bce':
    return tf.keras.losses.BinaryCrossentropy(from_logits=False)

  if loss_name == 'mse':
    return tf.keras.losses.MeanSquaredError()

  raise ValueError(f'Unknown loss name: {loss_name}')
Exemplo n.º 21
0
def get_config():
    """Experiment configuration."""
    config = ConfigDict()

    # Data.
    config.dataset = 'imagenet'
    config.downsample = True
    config.downsample_res = 64
    config.resolution = [256, 256]
    config.random_channel = True

    # Training.
    config.batch_size = 1
    config.max_train_steps = 15000
    config.save_checkpoint_secs = 900
    config.num_epochs = -1
    config.polyak_decay = 0.999
    config.eval_num_examples = 20000
    config.eval_batch_size = 16
    config.eval_checkpoint_wait_secs = -1

    config.optimizer = ConfigDict()
    config.optimizer.type = 'rmsprop'
    config.optimizer.learning_rate = 3e-4

    # Model.
    config.model = ConfigDict()
    config.model.hidden_size = 32
    config.model.ff_size = 32
    config.model.num_heads = 1
    config.model.num_encoder_layers = 1
    config.model.resolution = [64, 64]
    config.model.name = 'color_upsampler'

    config.sample = ConfigDict()
    config.sample.gen_data_dir = ''
    config.sample.log_dir = 'samples'
    config.sample.batch_size = 1
    config.sample.mode = 'argmax'
    config.sample.num_samples = 1
    config.sample.num_outputs = 1
    config.sample.skip_batches = 0
    config.sample.gen_file = 'gen0'

    return config
Exemplo n.º 22
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    Final TrainState.
  """

    writer = metric_writers.create_default_writer(
        logdir=workdir, just_logging=jax.host_id() != 0)

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.process_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder(config.dataset)
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = int(steps_per_epoch * config.num_epochs)
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    learning_rate_fn = create_learning_rate_fn(config, base_learning_rate,
                                               steps_per_epoch)

    state = create_train_state(rng, config, model, image_size,
                               learning_rate_fn)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    train_metrics = []
    hooks = []
    if jax.process_index() == 0:
        hooks += [
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics_last_t = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        if step == step_offset:
            logging.info('Initial compilation completed.')

        if config.get('log_every_steps'):
            train_metrics.append(metrics)
            if (step + 1) % config.log_every_steps == 0:
                train_metrics = common_utils.get_metrics(train_metrics)
                summary = {
                    f'train_{k}': v
                    for k, v in jax.tree_map(lambda x: x.mean(),
                                             train_metrics).items()
                }
                summary['steps_per_second'] = config.log_every_steps / (
                    time.time() - train_metrics_last_t)
                writer.write_scalars(step + 1, summary)
                train_metrics = []
                train_metrics_last_t = time.time()

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            writer.write_scalars(
                step + 1, {f'eval_{key}': val
                           for key, val in summary.items()})
            writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

    return state
Exemplo n.º 23
0
def get_config():
    """Experiment configuration."""
    config = ConfigDict()

    # Data.
    config.dataset = 'imagenet'
    config.downsample = True
    config.downsample_res = 64
    config.resolution = [256, 256]

    # Training.
    config.batch_size = 7
    config.max_train_steps = 450000
    config.save_checkpoint_secs = 900
    config.num_epochs = -1
    config.polyak_decay = 0.999
    config.eval_num_examples = 20000
    config.eval_batch_size = 16
    config.eval_checkpoint_wait_secs = -1

    # loss hparams.
    config.loss_factor = 0.99
    config.encoder_loss_factor = 0.01

    config.optimizer = ConfigDict()
    config.optimizer.type = 'rmsprop'
    config.optimizer.learning_rate = 3e-4

    # Model.
    config.model = ConfigDict()
    config.model.hidden_size = 512
    config.model.stage = 'encoder_decoder'
    config.model.resolution = [64, 64]
    config.model.name = 'coltran_core'

    # encoder
    config.model.encoder = ConfigDict()
    config.model.encoder.ff_size = 512
    config.model.encoder.hidden_size = 512
    config.model.encoder.num_heads = 4
    config.model.encoder.num_encoder_layers = 4
    config.model.encoder.dropout = 0.0

    # decoder
    config.model.decoder = ConfigDict()
    config.model.decoder.ff_size = 512
    config.model.decoder.hidden_size = 512
    config.model.decoder.resolution = [64, 64]
    config.model.decoder.num_heads = 4
    config.model.decoder.num_inner_layers = 2
    config.model.decoder.num_outer_layers = 2
    config.model.decoder.dropout = 0.0
    config.model.decoder.skip = True

    config.model.decoder.cond_mlp = 'affine'
    config.model.decoder.cond_mlp_act = 'identity'

    config.model.decoder.cond_ln_act = 'identity'
    config.model.decoder.cond_ln = True
    config.model.decoder.cond_ln_seq = 'sc'
    config.model.decoder.cond_ln_sp_ave = 'learnable'
    config.model.decoder.cond_ln_init = 'glorot_uniform'

    config.model.decoder.cond_att_init = 'glorot_uniform'
    config.model.decoder.cond_att_v = True
    config.model.decoder.cond_att_q = True
    config.model.decoder.cond_att_k = True
    config.model.decoder.cond_att_scale = True
    config.model.decoder.cond_att_act = 'identity'

    config.sample = ConfigDict()
    config.sample.log_dir = ''
    config.sample.batch_size = 1
    config.sample.mode = 'sample'
    config.sample.num_samples = 1
    config.sample.num_outputs = 1
    config.sample.skip_batches = 0
    config.sample.gen_file = 'gen0'
    return config
Exemplo n.º 24
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      reverse_translation=config.reverse_translation,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)
  target_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("translate_and_bleu"):
          exemplars, bleu_score = translate_and_calculate_bleu(
              p_pred_step=p_pred_step,
              p_init_cache=p_init_cache,
              target=optimizer.target,
              predict_ds=predict_ds,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_scalars(step, {"bleu": bleu_score})
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
Exemplo n.º 25
0
def inference_time(config: ml_collections.ConfigDict, workdir: str):
  """Runs a number of steps and measures inference time."""

  assert config.batch, f'Expected --config.batch={config.batch} > 0'
  assert config.num_classes, (
      f'Expected --config.num_classes={config.num_classes} > 0')
  assert config.image_size, (
      f'Expected --config.image_size={config.image_size} > 0')

  # Build VisionTransformer architecture
  model_config = config_lib.MODEL_CONFIGS[config.model_name]
  model = models.VisionTransformer(
      num_classes=config.num_classes, **model_config)

  # Make sure initial model parameters (before replication) are on CPU only.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    return model.init(
        rng,
        # Discard the "num_local_devices" dimension for initialization.
        inputs=jnp.ones([1, config.image_size, config.image_size, 3],
                        jnp.float32),
        train=False)

  variables = init(jax.random.PRNGKey(0))

  params_repl = flax_utils.replicate(variables['params'])

  # pmap replicates the models over all TPUs/GPUs
  vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False))
  images = jnp.ones([
      jax.local_device_count(), config.batch // jax.local_device_count(),
      config.image_size, config.image_size, 3
  ], jnp.float32)

  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())

  logging.info('Starting training loop; initial compile can take a while...')
  logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
  logits.block_until_ready()
  logging.info('Done.')

  logging.info('Going to run %d inferences WITHOUT measuring...',
               config.initial_steps)
  for _ in range(config.initial_steps):
    logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
    logits.block_until_ready()

  logging.info('Going to run %d inferences measuring...', config.steps)
  times = []
  for _ in range(config.initial_steps):
    t0 = time.time()
    logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
    logits.block_until_ready()
    times.append(time.time() - t0)
  logging.info('times=%s', times)
  imgs_sec_core = config.batch / jax.local_device_count() / np.array(times)
  logging.info('imgs_sec_core_min=%f', imgs_sec_core.min())
  logging.info('imgs_sec_core_max=%f', imgs_sec_core.max())
  logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean())
  logging.info('imgs_sec_core_std=%f', imgs_sec_core.std())
  writer.write_scalars(
      0,
      dict(
          imgs_sec_core_min=imgs_sec_core.min(),
          imgs_sec_core_max=imgs_sec_core.max(),
          imgs_sec_core_mean=imgs_sec_core.mean(),
          imgs_sec_core_std=imgs_sec_core.std(),
      ))
Exemplo n.º 26
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  Returns:
    The final train state that includes the trained parameters.
  """
    # Prepare datasets.
    train_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2',
                                               split='train')
    eval_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2',
                                              split='validation')
    train_batches = train_dataset.get_bucketed_batches(
        config.batch_size,
        config.bucket_size,
        max_input_length=config.max_input_length,
        drop_remainder=True,
        shuffle=True,
        shuffle_seed=config.seed)
    eval_batches = eval_dataset.get_batches(batch_size=config.batch_size)

    # Keep track of vocab size in the config so that the embedder knows it.
    config.vocab_size = len(train_dataset.vocab)

    # Compile step functions.
    train_step_fn = jax.jit(train_step)
    eval_step_fn = jax.jit(eval_step)

    # Create model and a state that contains the parameters.
    rng = jax.random.PRNGKey(config.seed)
    model = model_from_config(config)
    state = create_train_state(rng, config, model)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    # Main training loop.
    logging.info('Starting training...')
    for epoch in range(1, config.num_epochs + 1):

        # Train for one epoch.
        rng, epoch_rng = jax.random.split(rng)
        rngs = {'dropout': epoch_rng}
        state, train_metrics = train_epoch(train_step_fn, state, train_batches,
                                           epoch, rngs)

        # Evaluate current model on the validation data.
        eval_metrics = evaluate_model(eval_step_fn, state, eval_batches, epoch)

        # Write metrics to TensorBoard.
        summary_writer.scalar('train_loss', train_metrics.loss, epoch)
        summary_writer.scalar('train_accuracy', train_metrics.accuracy * 100,
                              epoch)
        summary_writer.scalar('eval_loss', eval_metrics.loss, epoch)
        summary_writer.scalar('eval_accuracy', eval_metrics.accuracy * 100,
                              epoch)

    summary_writer.flush()
    return state
Exemplo n.º 27
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    # https://arxiv.org/abs/2106.10270
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceContext('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
Exemplo n.º 28
0
def generate(config: ml_collections.ConfigDict):
    """Generates memories."""
    # Establish host information
    local_device_count = jax.local_device_count()
    device_count = jax.device_count()
    process_count = jax.process_count()
    process_index = jax.process_index()

    task = memory_generation_task.MemoryGenerationTask
    model_config = ml_collections.FrozenConfigDict(config.model_config)
    model = task.build_model(model_config)
    p_predict_step = jax.pmap(functools.partial(
        task.make_prediction_fn(config),
        model_config,
    ),
                              axis_name='batch')
    rng = jax.random.PRNGKey(config.seed)

    # Initialization needs to be pmapped because models use collective ops.
    # Create dummy input
    dummy_input = {
        key: jnp.tile(value, (local_device_count, ) + (1, ) * value.ndim)
        for key, value in task.dummy_input(config).items()
    }

    rng, init_rng = jax.random.split(rng)
    init_rng = jax.random.split(init_rng, local_device_count)

    logging.info('Initializing model.')
    initial_variables = jax.pmap(model.init,
                                 'batch',
                                 static_broadcasted_argnums=2)(init_rng,
                                                               dummy_input,
                                                               True)
    logging.info('Finished initializing model.')
    initial_variables = initial_variables.unfreeze()

    if config.load_weights is not None:
        logging.info('Loading model weights from file')
        loaded_variables = task.load_weights(config)
        unexpected, missing = checkpoint_utils.merge_nested_dicts(
            initial_variables, loaded_variables)
        logging.info('*** Unexpected features: ***')
        for feature_name in unexpected:
            logging.info('\t%s', feature_name)
        # In the prediction mode we don't allow any features to be missing
        # pylint: disable=g-explicit-length-test
        if len(missing) > 0:
            raise ValueError('Missing features: %s' % ','.join(missing))

    # model_params = jax_utils.unreplicate(initial_variables['params'])
    model_params = initial_variables['params']
    model_vars = {
        key: value
        for key, value in initial_variables.items() if key != 'params'
    }
    # We access model params only from train state.
    del initial_variables

    writer = metric_writers.create_default_writer(
        config.output_dir, just_logging=process_index > 0)

    max_length = config.get('max_length_with_entity_tokens',
                            model_config.encoder_config.max_length)

    num_total_memories = math.ceil(config.num_total_memories / process_count)
    memory_saver = memory_generation_task.MemorySaver(
        num_total_memories=num_total_memories,
        memory_dim=config.memory_dim,
        max_length=max_length,
        max_mentions_per_sample=config.max_mentions_per_sample,
        memory_key_dim=config.get('memory_key_dim'))
    n_samples = 0
    data_iter = get_data_iterator(config)

    logging.info('Start memory generation.')
    with metric_writers.ensure_flushes(writer):
        for step, batch in enumerate(data_iter):
            batch = jax.tree_map(jnp.asarray, batch)
            predictions = p_predict_step(
                model_params,
                model_vars,
                batch,
            )
            predictions = jax.device_get(predictions)
            memory_saver.add_memories(batch, predictions)
            n_devices, batch_size, _ = batch['text_ids'].shape
            logging.log_first_n(
                logging.INFO, 'Process %d / %d: '
                'Finished generating step %d, local devices %d, batch size %d',
                5, process_index, process_count, step, n_devices, batch_size)

            n_samples += device_count * config.per_device_batch_size
            if (step % config.log_every_steps == 0
                    or memory_saver.get_num_memories() >= num_total_memories):
                writer.write_scalars(
                    step,
                    dict(n_memories=memory_saver.get_num_memories(),
                         n_samples=n_samples))

            if memory_saver.get_num_memories() >= num_total_memories:
                break

    logging.info('Process %d / %d: Finished generating memories: %d out of %d',
                 process_index, process_count, memory_saver.get_num_memories(),
                 num_total_memories)

    start_time = time.time()
    logging.info('Process %d / %d: Start saving generated memories to files.',
                 process_index, process_count)
    memory_saver.save(config.output_dir,
                      num_shards=config.num_shards,
                      stride=process_count,
                      offset=process_index,
                      shard_size_divisible=config.shard_size_divisible)

    logging.info(
        'Process %d / %d: Finished saving generated memories to files in %.2f seconds',
        process_index, process_count,
        time.time() - start_time)
Exemplo n.º 29
0
def _get_augment_element_fn(
    dataset_config: ml_collections.ConfigDict
) -> Callable[[TensorDict, TensorDict, TensorDict], TensorDictTriple]:
    """"Returns a function that augments the `IMAGE_KEY` input tensor.

  The following transformations are applied if specified in `dataset_config`:

    - tf.image.random_flip_left_right
    - tf.image.random_flip_up_down
    - tf.image.random_brightness
    - tf.image.random_hue
    - tf.image.random_saturation
    - tf.image.random_contrast

  Important: The returned function requires that image tensors have dtype
  tf.float32 and have values in range [0, 1]. Augmented images are then clipped
  back to the [0, 1] range.

  Args:
    dataset_config: A ConfigDict used to build the set of applied augmentations.

  Returns:
    A function that applies the set of augmentations to an input TensorDict's
    `IMAGE_KEY` image.
  """

    horizontal_flip = dataset_config.get('random_horizontal_flip', False)
    vertical_flip = dataset_config.get('random_vertical_flip', False)
    brightness_max_delta = dataset_config.get('random_brightness_max_delta',
                                              None)
    hue_max_delta = dataset_config.get('random_hue_max_delta', None)

    saturation_lower = dataset_config.get('random_saturation_lower', None)
    saturation_upper = dataset_config.get('random_saturation_upper', None)
    apply_saturation = saturation_lower and saturation_upper
    if apply_saturation and (saturation_upper <= saturation_lower):
        raise ValueError(
            f'Invalid saturation range: ({saturation_lower}, {saturation_upper})'
        )

    contrast_lower = dataset_config.get('random_contrast_lower', None)
    contrast_upper = dataset_config.get('random_contrast_upper', None)
    apply_contrast = contrast_lower and contrast_upper
    if apply_contrast and (contrast_upper <= contrast_lower):
        raise ValueError(
            f'Invalid contrast range: ({contrast_lower}, {contrast_upper})')

    def _augment_element_fn(inputs: TensorDict, labels: TensorDict,
                            weights: TensorDict) -> TensorDictTriple:

        image = inputs[IMAGE_KEY]

        # Ensure images are in the expected format. Image augmentations assume that
        # the image tensor is a tf.float32 and contains pixels in range [0, 1].
        tf.debugging.assert_type(image, tf.float32)
        tf.debugging.assert_less_equal(tf.math.reduce_max(image), 1.0)
        tf.debugging.assert_greater_equal(tf.math.reduce_min(image), 0.0)

        if horizontal_flip:
            image = tf.image.random_flip_left_right(image)
        if vertical_flip:
            image = tf.image.random_flip_up_down(image)
        if brightness_max_delta:
            image = tf.image.random_brightness(image,
                                               max_delta=brightness_max_delta)
        if hue_max_delta:
            image = tf.image.random_hue(image, max_delta=hue_max_delta)
        if apply_saturation:
            image = tf.image.random_saturation(image,
                                               lower=saturation_lower,
                                               upper=saturation_upper)
        if apply_contrast:
            image = tf.image.random_contrast(image,
                                             lower=contrast_lower,
                                             upper=contrast_upper)

        # Clip image back to [0.0, 1.0] prior to architecture-specific centering.
        image = tf.clip_by_value(image, 0.0, 1.0)

        inputs[IMAGE_KEY] = image

        return inputs, labels, weights

    return _augment_element_fn
Exemplo n.º 30
0
    )
    parser.add_argument("--print_batch_metrics",
                        action='store_true',
                        default=False,
                        help="Set to print metrics for every batch.")
    args = parser.parse_args()

    assert (
        args.checkpoint is not None
    ), "A checkpoint needs to be specified via commandline argument (--checkpoint)"
    assert (
        args.config is not None
    ), "A config needs to be specified via commandline argument (--config)"

    with open(args.config) as f:
        cfg = ConfigDict(yaml.load(f, Loader=yaml.Loader))

    cfg.checkpoint = args.checkpoint

    if args.annotations is not None:
        cfg.test_annotations = args.annotations
    if args.imagedir is not None:
        cfg.test_imagedir = args.imagedir
    if args.uncertainty_threshold is not None:
        cfg.uncertainty_threshold = args.uncertainty_threshold
    if args.uncertainty_gate_type is not None:
        cfg.uncertainty_gate_type = args.uncertainty_gate_type
    if args.weighted_prediction is not None:
        cfg.weighted_prediction = args.weighted_prediction

    assert (