def create_optimizer(model, learning_rate, weight_decay, layers=None):
    """Instantiates Adam multi-optimizer."""

    if layers is None:
        assert (
            type(learning_rate) == type(weight_decay) == float
        ), 'Specify float values for moded learning rate and weight decay!'
        optimizer_def = optim.Adam(learning_rate=learning_rate,
                                   weight_decay=weight_decay)
        optimizer = optimizer_def.create(model)

    else:
        assert (
            len(learning_rate) == len(weight_decay) == len(layers)
        ), 'Number of specified learning rates, weight decays, and layers must be equal!'
        optimizers = []
        for lr, wd, layer in zip(learning_rate, weight_decay, layers):
            if lr > 0:
                opt = optim.Adam(learning_rate=lr, weight_decay=wd)
                filter_fn = functools.partial(path_inclusion_filter_fn,
                                              layer=layer)
                traversal = optim.ModelParamTraversal(filter_fn)
                traversal_opt = (traversal, opt)
                optimizers.append(traversal_opt)
        optimizer_def = optim.MultiOptimizer(*optimizers)
        optimizer = optimizer_def.create(model)

    return optimizer
Exemplo n.º 2
0
def create_optimizer(model, learning_rate=1e-4):
  """Create optimizer used for training model.

  MultiOpt is used to apply Adam Optimizer with weight decay to all parameters
  except layer_norm and bias and Adam Optimizer without weight decay for
  layer_norm and bias params.

  Args:
    model: JAX model to add optimizer to
    learning_rate: base learning rate used for initializing optimizer

  Returns:
    optimizer: model with Adam Optimizer to be used for training
  """
  weight_decay_def = optim.Adam(
      learning_rate=learning_rate, eps=1e-6, weight_decay=0.01)
  no_decay_def = optim.Adam(
      learning_rate=learning_rate, eps=1e-6, weight_decay=0.0)

  def filter_weight_decay(key, _):
    return 'layer_norm' not in key and 'bias' not in key
  def filter_other(key, _):
    return 'layer_norm' in key or 'bias' in key

  weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay)
  no_decay_traversal = optim.ModelParamTraversal(filter_other)
  optimizer_def = optim.MultiOptimizer(
      (weight_decay_traversal, weight_decay_def),
      (no_decay_traversal, no_decay_def))

  optimizer = optimizer_def.create(model)
  optimizer = optimizer.replicate()
  del model
  return optimizer
Exemplo n.º 3
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Exemplo n.º 4
0
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        lr=3e-4,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        expl_noise=0.1,
        noise_clip=0.5,
        policy_freq=2,
        seed=0,
    ):

        self.rng = PRNGSequence(seed)

        actor_input_dim = [((1, state_dim), jnp.float32)]

        init_rng = next(self.rng)

        actor = build_td3_actor_model(actor_input_dim, action_dim, max_action,
                                      init_rng)
        self.actor_target = build_td3_actor_model(actor_input_dim, action_dim,
                                                  max_action, init_rng)
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [
            ((1, state_dim), jnp.float32),
            ((1, action_dim), jnp.float32),
        ]

        critic = build_td3_critic_model(critic_input_dim, init_rng)
        self.critic_target = build_td3_critic_model(critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.expl_noise = expl_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0
Exemplo n.º 5
0
def create_optimizer(model, model_kwargs, learning_rate=1e-4):
  """Create optimizer used for training model.

  MultiOpt is used to apply Adam/LAMB Optimizer with weight decay to all
  parameters except layer_norm and bias and Adam/LAMB Optimizer without weight
  decay for layer_norm and bias params.

  Args:
    model: JAX model to add optimizer to
    model_kwargs: Bert model config parameter dictionary.
    learning_rate: base learning rate used for initializing optimizer

  Returns:
    optimizer: model with Adam/LAMB Optimizer to be used for training
  """
  if FLAGS.use_lamb:
    weight_decay_def = bert_lamb.BertLAMB(
        learning_rate=learning_rate,
        beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2,
        eps=10**FLAGS.log_epsilon,
        weight_decay=FLAGS.lamb_weight_decay,
        num_layers=model_kwargs['num_layers'])
    no_decay_def = bert_lamb.BertLAMB(
        learning_rate=learning_rate,
        beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2,
        eps=10**FLAGS.log_epsilon, weight_decay=0.0,
        num_layers=model_kwargs['num_layers'])
  else:
    weight_decay_def = optim.Adam(
        learning_rate=learning_rate, eps=1e-6, weight_decay=FLAGS.lamb_weight_decay)
    no_decay_def = optim.Adam(
        learning_rate=learning_rate, eps=1e-6, weight_decay=0.0)

  def filter_weight_decay(key, _):
    return 'layer_norm' not in key and 'bias' not in key and 'layernorm' not in key

  def filter_other(key, _):
    return 'layer_norm' in key or 'bias' in key or 'layernorm' in key

  weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay)
  no_decay_traversal = optim.ModelParamTraversal(filter_other)
  optimizer_def = optim.MultiOptimizer(
      (weight_decay_traversal, weight_decay_def),
      (no_decay_traversal, no_decay_def))

  optimizer = optimizer_def.create(model)
  optimizer = jax_utils.replicate(optimizer)
  del model
  return optimizer
Exemplo n.º 6
0
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        lr: float = 3e-4,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_noise: float = 0.2,
        expl_noise: float = 0.1,
        noise_clip: float = 0.5,
        policy_freq: int = 2,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        actor_input_dim = (1, state_dim)

        init_rng = next(self.rng)

        actor_params = build_td3_actor_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        self.actor_target_params = build_td3_actor_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_td3_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_td3_critic_model(critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.expl_noise = expl_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.action_dim = action_dim

        self.total_it = 0
Exemplo n.º 7
0
def create_optimizer(config, model):
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-6,
    )
    optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs)
    optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs)
    decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    optimizer = optimizer_def.create(model)
    return optimizer
Exemplo n.º 8
0
def create_optimizer(name='adam', learning_rate=6.25e-5, beta1=0.9, beta2=0.999,
                     eps=1.5e-4):
  if name == 'adam':
    return optim.Adam(
        learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps)
  else:
    raise ValueError(f'Unknown optimizer {name}')
Exemplo n.º 9
0
def main(argv):
    key = random.PRNGKey(0)
    train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
    train_ds = train_ds.cache().shuffle(1000).batch(FLAGS.batch_size)
    test_ds = tfds.as_numpy(
        tfds.load('mnist', split=tfds.Split.TEST, batch_size=-1))

    _, params = VAE.init_by_shape(key, [((1, 784), jnp.float32)])
    vae = nn.Model(VAE, params)

    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)

    for epoch in range(FLAGS.num_epochs):
        for batch in tfds.as_numpy(train_ds):
            batch['image'] = batch['image'].reshape(-1, 784) / 255.0
            optimizer = train_step(optimizer, batch)

        z = np.random.normal(size=(64, 20))
        metrics, comparison, sample = eval(optimizer.target, test_ds, z)
        save_image(comparison,
                   'results/reconstruction_' + str(epoch) + '.png',
                   nrow=8)
        save_image(sample, 'results/sample_' + str(epoch) + '.png', nrow=8)

        print("eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}".format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
Exemplo n.º 10
0
    def train(rho_g, nn_params):
        optimizer = optim.Adam(learning_rate=lr,
                               weight_decay=w_decay).create(nn_params)
        optimizer = jax.device_put(optimizer)

        train_loss = []
        loss0 = 1E16
        loss0_tot = 1E16
        itercount = itertools.count()
        f_params = init_params
        for epoch in range(n_epochs):
            for _ in range(n_batches):
                optimizer, loss_and_grad = train_step(optimizer, rho_g,
                                                      next(batches))
                loss, grad = loss_and_grad

#             f = open(f_out,'a+')
#             print(i,loss,file=f)
#             f.close()

            train_loss.append(loss)
#             params = optimizer.target
#             loss_tot = f_validation(params)

        nn_params = optimizer.target

        return nn_params, loss_and_grad, train_loss
Exemplo n.º 11
0
def train():
    """Run main training loop."""
    rng = random.PRNGKey(0)

    # Get Zachary's karate club graph dataset.
    node_feats, node_labels, sources, targets = get_karate_club_data()

    # Create model and optimizer.
    _, initial_params = GNN.init(rng,
                                 node_x=node_feats,
                                 edge_x=None,
                                 sources=sources,
                                 targets=targets)
    model = nn.Model(GNN, initial_params)
    optimizer = optim.Adam(learning_rate=0.01).create(model)

    # Train for 20 iterations.
    for iteration in range(20):
        optimizer, loss = train_step(optimizer, node_feats, sources, targets)

        accuracy = eval_step(  # Model is stored in `optimizer.target`.
            optimizer.target, node_feats, sources, targets, node_labels)

        print('iteration: %d, loss: %.4f, accuracy: %.2f' %
              (iteration + 1, loss, accuracy * 100))
Exemplo n.º 12
0
def create_optimizer(name='adam',
                     learning_rate=6.25e-5,
                     beta1=0.9,
                     beta2=0.999,
                     eps=1.5e-4):
    """Create an optimizer for training.

  Currently, only the Adam optimizer is supported.

  Args:
    name: str, name of the optimizer to create.
    learning_rate: float, learning rate to use in the optimizer.
    beta1: float, beta1 parameter for the optimizer.
    beta2: float, beta2 parameter for the optimizer.
    eps: float, epsilon parameter for the optimizer.

  Returns:
    A flax optimizer.
  """
    if name == 'adam':
        logging.info(
            'Creating Adam optimizer with settings lr=%f, beta1=%f, '
            'beta2=%f, eps=%f', learning_rate, beta1, beta2, eps)
        return optim.Adam(learning_rate=learning_rate,
                          beta1=beta1,
                          beta2=beta2,
                          eps=eps)
    elif name == 'rmsprop':
        logging.info(
            'Creating RMSProp optimizer with settings lr=%f, beta2=%f, '
            'eps=%f', learning_rate, beta2, eps)
        return optim.RMSProp(learning_rate=learning_rate, beta2=beta2, eps=eps)
    else:
        raise ValueError('Unsupported optimizer {}'.format(name))
Exemplo n.º 13
0
def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparamters,
                         rng: spec.RandomState) -> spec.OptimizerState:
  del model_state
  del rng
  del workload

  optimizer_def = optim.Adam(
      learning_rate=hyperparameters.learning_rate,
      beta1=1.0 - hyperparameters.one_minus_beta_1,
      beta2=0.98,
      eps=hyperparameters.epsilon)
  optimizer = optimizer_def.create(model_params)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000)

  # compile multidevice versions of train.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=models.TransformerConfig(
              dropout_rate=hyperparameters.dropout_rate,
              attention_dropout_rate=hyperparameters.attention_dropout_rate),
          learning_rate_fn=learning_rate_fn),
      axis_name="batch",
      donate_argnums=(0,))

  return optimizer, p_train_step
def load_model(dataset_name, attention_mask_type, use_relative_attention,
               bos_special_attention, predict_config):
    """Loads a checkpoint."""
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)

    m = models.DecomposeAttentionTransformer(predict_config)
    initial_variables = jax.jit(m.init)({
        'params': init_rng,
        'dropout': init_rng
    }, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(1e-3,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=1e-1)
    optimizer = optimizer_def.create(initial_variables['params'])

    checkpoint_fname = os.path.join(
        FLAGS.train_directory, 'train-{}/checkpoints/'
        'amt={},bsa={},ed=256,hd=512,l=0.001,nh=4,nl=3,s=0,ura={}/'.format(
            dataset_name, attention_mask_type, bos_special_attention,
            use_relative_attention))
    logging.info('Loading checkpoint: %s', checkpoint_fname)

    optimizer = checkpoints.restore_checkpoint(checkpoint_fname, optimizer)
    checkpoint_num_trained_steps = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %s.',
                 checkpoint_num_trained_steps)
    optimizer = jax_utils.replicate(optimizer)

    return optimizer
Exemplo n.º 15
0
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_freq: int = 2,
        lr: float = 3e-4,
        entropy_tune: bool = True,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        actor_input_dim = (1, state_dim)

        actor_params = build_gaussian_policy_model(actor_input_dim, action_dim,
                                                   max_action, next(self.rng))
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_double_critic_model(
            critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.entropy_tune = entropy_tune

        log_alpha_params = build_constant_model(-3.5, next(self.rng))
        log_alpha_optimizer = optim.Adam(
            learning_rate=lr).create(log_alpha_params)
        self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer)
        self.target_entropy = -action_dim

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_freq = policy_freq

        self.action_dim = action_dim

        self.total_it = 0
Exemplo n.º 16
0
def create_optimizer(model, learning_rate, weight_decay):
    optimizer_def = optim.Adam(learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=weight_decay)
    optimizer = optimizer_def.create(model)
    return optimizer
Exemplo n.º 17
0
def create_optimizer(config, model):
    optimizer_def = optim.Adam(learning_rate=config.learning_rate,
                               beta1=0.9,
                               beta2=0.999,
                               eps=1e-6,
                               weight_decay=0.0)
    optimizer = optimizer_def.create(model)
    return optimizer
def create_optimizer(model, learning_rate, weight_decay):
    """Instantiates Adam optimizer."""

    optimizer_def = optim.Adam(learning_rate=learning_rate,
                               weight_decay=weight_decay)
    optimizer = optimizer_def.create(model)

    return optimizer
Exemplo n.º 19
0
def create_optimizer(model, learning_rate):
    optimizer_def = optim.Adam(learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(model)
    optimizer = jax_utils.replicate(optimizer)
    return optimizer
Exemplo n.º 20
0
 def init_fn(seed):
     rng = random.PRNGKey(seed)
     classifier = MLPClassifier.partial(hidden_layers=2,
                                        hidden_dim=512,
                                        n_classes=n_classes)
     _, initial_params = classifier.init_by_shape(rng,
                                                  [(128, *input_shape)])
     initial_model = nn.Model(classifier, initial_params)
     optimizer = optim.Adam(1e-4).create(initial_model)
     return optimizer
Exemplo n.º 21
0
def load_parameters(logdir, init_params):
  if has_checkpoint(logdir):
    print("Loading checkpoint from %s" % logdir)
    optimizer_def = optim.Adam()
    optimizer = optimizer_def.create(init_params)
    optimizer = checkpoints.restore_checkpoint(logdir, optimizer)
    print("Checkpoint loaded from step %d" % optimizer.state.step)
    return optimizer.target
  else:
    print("No checkpoint found in %s" % logdir)
    return None
Exemplo n.º 22
0
 def create_model_optimizer(n_bins):
     ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3],
                   block_cls=ResNetBlock)
     module = ResNet50.partial(n_bins=n_bins, dtype=jnp.float32)
     input_shape = (1, training_data.shape[1], 1)
     with nn.stateful() as init_state:
         _, initial_params = module.init_by_shape(
             jax.random.PRNGKey(0), [(input_shape, jnp.float32)]
       )
         model = nn.Model(module, initial_params)
     optimizer = optim.Adam(learning_rate=learning_rate).create(model)
     return model, optimizer
Exemplo n.º 23
0
def create_optimizer(config, model, initial_params):
    """Create a model, starting with a pre-trained checkpoint."""
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-6,
    )
    optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs)
    optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs)
    decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    # TODO(marcvanzee): MultiOptimizer triggers double XLA compilation on TPU so
    # we use Adam here, but we should investigate why this happens.
    optimizer_def = optim.Adam(learning_rate=config.learning_rate)
    optimizer = optimizer_def.create(model)
    optimizer = optimizer.replicate()
    del model  # don't keep a copy of the initial model
    return optimizer
Exemplo n.º 24
0
def create_train_fn(model, model_dir, duration, batch, train_steps,
                    learning_rate):
    optimizer = optim.Adam()
    opt = optimizer.create(model)
    state = TrainState(optimizer=opt, step=0)  # pytype:disable=wrong-keyword-args
    state = checkpoints.restore_checkpoint(model_dir, state)
    state = jax_utils.replicate(state)
    iterator = None

    @functools.partial(jax.pmap, axis_name="batch")
    def train_step(obs, state):
        actions = obs["action"]
        rewards = obs["reward"]
        step = state.step
        optimizer = state.optimizer

        def loss(model):
            predictions = model(actions)
            l = (rewards - predictions)**2
            l = jnp.mean(l)
            return l

        grad_fn = jax.value_and_grad(loss)
        l, grads = grad_fn(state.optimizer.target)
        grads = lax.pmean(grads, axis_name="batch")
        new_optimizer = optimizer.apply_gradient(grads,
                                                 learning_rate=learning_rate)
        new_state = state.replace(step=step + 1, optimizer=new_optimizer)
        return new_state, l

    def train(data_path):
        nonlocal iterator
        nonlocal state

        if iterator is None:
            dataset = npz.load_dataset_from_directory(data_path, duration,
                                                      batch)
            iterator = dataset.make_one_shot_iterator()
            iterator = map(
                lambda x: jax.tree_map(
                    lambda x: np.reshape(x, (jax.local_device_count(), -1) + x.
                                         numpy().shape[1:]), x), iterator)
            iterator = jax_utils.prefetch_to_device(iterator, 2)

        for _ in range(train_steps):
            obs = next(iterator)
            state, l = train_step(obs, state)
        local_state = get_first_device(state)
        l = get_first_device(l)
        checkpoints.save_checkpoint(model_dir, local_state, local_state.step)

    return train
Exemplo n.º 25
0
def make_optimizer(optimizer: str,
                   learning_rate: float,
                   weight_decay: float = 5.0e-4):
    if optimizer == 'SGD':
        return optim.Optimizer(lerning_rate=learning_rate)
    elif optimizer == 'Momentum':
        return optim.Momentum(learning_rate=learning_rate,
                              weight_decay=weight_decay)
    elif optimizer == 'Adam':
        return optim.Adam(learning_rate=learning_rate,
                          weight_decay=weight_decay)
    else:
        raise ValueError('Unknown optimizer spec.')
Exemplo n.º 26
0
  def test_init_state(self):
    params = onp.zeros((1,))
    optimizer_def = optim.Adam(learning_rate=0.1,
                               beta1=0.2,
                               beta2=0.9,
                               eps=0.01,
                               weight_decay=0.0)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = optim.OptimizerState(
        0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,))))
    self.assertEqual(state, expected_state)
Exemplo n.º 27
0
def main(argv):
    del argv

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    rng = random.PRNGKey(0)
    rng, key = random.split(rng)

    ds_builder = tfds.builder('binarized_mnist')
    ds_builder.download_and_prepare()
    train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
    train_ds = train_ds.map(prepare_image)
    train_ds = train_ds.cache()
    train_ds = train_ds.repeat()
    train_ds = train_ds.shuffle(50000)
    train_ds = train_ds.batch(FLAGS.batch_size)
    train_ds = iter(tfds.as_numpy(train_ds))

    test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
    test_ds = test_ds.map(prepare_image).batch(10000)
    test_ds = np.array(list(test_ds)[0])
    test_ds = jax.device_put(test_ds)

    module = VAE.partial(latents=FLAGS.latents)
    _, params = module.init_by_shape(key, [(FLAGS.batch_size, 784)],
                                     z_rng=random.PRNGKey(0))
    vae = nn.Model(module, params)

    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)
    optimizer = jax.device_put(optimizer)

    rng, z_key, eval_rng = random.split(rng, 3)
    z = random.normal(z_key, (64, FLAGS.latents))

    steps_per_epoch = 50000 // FLAGS.batch_size

    for epoch in range(FLAGS.num_epochs):
        for _ in range(steps_per_epoch):
            batch = next(train_ds)
            rng, key = random.split(rng)
            optimizer = train_step(optimizer, batch, key)

        metrics, comparison, sample = eval(optimizer.target, test_ds, z,
                                           eval_rng)
        save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8)
        save_image(sample, f'results/sample_{epoch}.png', nrow=8)

        print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
Exemplo n.º 28
0
def train_model():
    """Train for a fixed number of steps and decode during training."""
    param = get_param(jax.random.PRNGKey(0))
    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(param)
    key = jax.random.PRNGKey(0)
    for step in range(FLAGS.num_train_steps):
        key, lstm_key = jax.random.split(key)
        batch = get_batch(FLAGS.batch_size)
        optimizer, metrics = train_step(optimizer, batch, lstm_key)
        if step % FLAGS.decode_frequency == 0:
            key, decode_key = jax.random.split(key)
            logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step,
                         metrics['loss'], metrics['accuracy'] * 100)
            decode_batch(optimizer.target, 5, decode_key)
    return optimizer.target
Exemplo n.º 29
0
def create_optimizer(model, learning_rate):

    # def adam_optimizer(weight_decay):
    #     return optim.Adam(learning_rate=learning_rate, beta1=0.9,
    #         beta2=0.999, eps=1e-6, weight_decay=weight_decay)

    # optimizer_decay_def = adam_optimizer(weight_decay=0.01)
    # optimizer_no_decay_def = adam_optimizer(weight_decay=0.0)
    # decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    # no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    # optimizer_def = optim.MultiOptimizer(
    #   (decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def))
    optimizer_def = optim.Adam(learning_rate=learning_rate)
    optimizer = optimizer_def.create(model)
    return optimizer
Exemplo n.º 30
0
 def test_apply_gradient(self):
   optimizer_def = optim.Adam(learning_rate=0.1,
                              beta1=0.2,
                              beta2=0.9,
                              eps=0.01,
                              weight_decay=0.0)
   params = onp.array([1.])
   state = optim.OptimizerState(
       1, _AdamParamState(onp.array([0.1]), onp.array([0.9])))
   grads = onp.array([4.])
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_new_state = optim.OptimizerState(
       2, _AdamParamState(onp.array([3.22]), onp.array([2.41])))
   expected_new_params = onp.array([0.906085])
   onp.testing.assert_allclose(new_params, expected_new_params)
   self.assertEqual(new_state, expected_new_state)