示例#1
0
    def build_optimizer(self,
                        clip=15.0,
                        lr=5e-4,
                        warmup=2000,
                        cosine_decay_steps=None,
                        optimizer_name="adabelief") -> GradientTransformation:
        chain = []
        if optimizer_name == "adabelief":
            chain.append(util.scale_by_belief())
        elif optimizer_name == "adam":
            chain.append(optax.scale_by_adam())
        else:
            assert 0

        # Make sure to use the negative learning rate so that we minimize
        if warmup and warmup > 0:
            warmup_schedule = partial(util.linear_warmup_lr_schedule,
                                      warmup=warmup,
                                      lr_decay=1.0,
                                      lr=-lr)
            chain.append(optax.scale_by_schedule(warmup_schedule))
        else:
            chain.append(optax.scale(-lr))

        if cosine_decay_steps and cosine_decay_steps > 0:
            cosine_lr = optax.cosine_decay_schedule(
                init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1)
            chain.append(optax.scale_by_schedule(cosine_lr))

        if clip and clip > 0:
            chain.append(optax.clip(clip))

        return optax.chain(*chain)
示例#2
0
def create_learning_rate_fn(config,):
  """Create learning rate schedule."""
  # Linear warmup
  warmup_fn = optax.linear_schedule(
      init_value=0.,
      end_value=config.train.lr_init,
      transition_steps=config.train.warmup_steps)

  if config.train.scheduler == "linear":
    decay_fn = optax.linear_schedule(
        init_value=config.train.lr_init,
        end_value=0.,
        transition_steps=config.train.max_steps - config.train.warmup_steps)
  elif config.train.scheduler == "cosine":
    cosine_steps = max(config.train.max_steps - config.train.warmup_steps, 1)
    decay_fn = optax.cosine_decay_schedule(
        init_value=config.train.lr_init, decay_steps=cosine_steps)
  elif config.train.scheduler == "step":
    step_steps = max(config.train.max_steps - config.train.warmup_steps, 1)  # pylint: disable=unused-variable

    def schedule(count):
      return config.train.lr_init * (0.5**(count // 50000))

    decay_fn = schedule

  else:
    raise NotImplementedError

  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, decay_fn], boundaries=[config.train.warmup_steps])
  return schedule_fn
示例#3
0
def get_learning_rate_schedule(
    total_batch_size, steps_per_epoch, total_steps, optimizer_config):
  """Build the learning rate schedule function."""
  base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr,
                                 optimizer_config.scale_by_batch)

  schedule_type = optimizer_config.schedule_type
  if schedule_type == 'steps':
    boundaries = optimizer_config.step_decay_kwargs.decay_boundaries
    boundaries.sort()

    decay_rate = optimizer_config.step_decay_kwargs.decay_rate
    boundaries_and_scales = {
        int(boundary * total_steps): decay_rate for boundary in boundaries}
    schedule_fn = optax.piecewise_constant_schedule(
        init_value=base_lr, boundaries_and_scales=boundaries_and_scales)
  elif schedule_type == 'cosine':
    warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs
                    * steps_per_epoch)
    # Batch scale the other lr values as well:
    init_value = _get_batch_scaled_lr(
        total_batch_size,
        optimizer_config.cosine_decay_kwargs.init_value,
        optimizer_config.scale_by_batch)
    end_value = _get_batch_scaled_lr(
        total_batch_size,
        optimizer_config.cosine_decay_kwargs.end_value,
        optimizer_config.scale_by_batch)

    schedule_fn = optax.warmup_cosine_decay_schedule(
        init_value=init_value,
        peak_value=base_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps,
        end_value=end_value)
  elif schedule_type == 'constant_cosine':
    # Convert end_value to alpha, used by cosine_decay_schedule.
    alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr

    # Number of steps spent in constant phase.
    constant_steps = int(
        optimizer_config.constant_cosine_decay_kwargs.constant_fraction
        * total_steps)
    decay_steps = total_steps - constant_steps

    constant_phase = optax.constant_schedule(value=base_lr)
    decay_phase = optax.cosine_decay_schedule(
        init_value=base_lr,
        decay_steps=decay_steps,
        alpha=alpha)
    schedule_fn = optax.join_schedules(
        schedules=[constant_phase, decay_phase],
        boundaries=[constant_steps])
  else:
    raise ValueError(f'Unknown learning rate schedule: {schedule_type}')

  return schedule_fn
示例#4
0
def create_learning_rate_fn(workload: spec.Workload,
                            hparams: spec.Hyperparameters):
    """Create learning rate schedule."""
    warmup_fn = optax.linear_schedule(init_value=0.,
                                      end_value=hparams.learning_rate,
                                      transition_steps=hparams.warmup_steps)
    cosine_fn = optax.cosine_decay_schedule(init_value=hparams.learning_rate,
                                            decay_steps=(workload.step_hint -
                                                         hparams.warmup_steps))
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, cosine_fn],
                                       boundaries=[hparams.warmup_steps])
    return schedule_fn
示例#5
0
def get_cosine_schedule(
    max_learning_rate: float,
    total_steps: int,
    warmup_steps: int = 0) -> optax.Schedule:
  """Builds a cosine decay schedule with initial warm-up."""
  if total_steps < warmup_steps:
    return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
                                 transition_steps=warmup_steps)
  return optax.join_schedules([
      optax.linear_schedule(init_value=0., end_value=max_learning_rate,
                            transition_steps=warmup_steps),
      optax.cosine_decay_schedule(init_value=max_learning_rate,
                                  decay_steps=total_steps - warmup_steps),
  ], [warmup_steps])
示例#6
0
def create_learning_rate_fn(hparams: spec.Hyperparamters, steps_per_epoch: int):
  """Create learning rate schedule."""
  base_learning_rate = hparams.learning_rate * get_batch_size('imagenet') / 256.
  warmup_fn = optax.linear_schedule(
      init_value=0.,
      end_value=base_learning_rate,
      transition_steps=hparams.warmup_epochs * steps_per_epoch)
  cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1)
  cosine_fn = optax.cosine_decay_schedule(
      init_value=base_learning_rate,
      decay_steps=cosine_epochs * steps_per_epoch)
  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, cosine_fn],
      boundaries=[hparams.warmup_epochs * steps_per_epoch])
  return schedule_fn
示例#7
0
def create_learning_rate_fn(config: ml_collections.ConfigDict,
                            base_learning_rate: float, steps_per_epoch: int):
    """Create learning rate schedule."""
    warmup_fn = optax.linear_schedule(init_value=0.,
                                      end_value=base_learning_rate,
                                      transition_steps=config.warmup_epochs *
                                      steps_per_epoch)
    cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)
    cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,
                                            decay_steps=cosine_epochs *
                                            steps_per_epoch)
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[config.warmup_epochs * steps_per_epoch])
    return schedule_fn
示例#8
0
 def build_optimizer(lr,
                     momentum,
                     steps_per_epoch,
                     n_epochs,
                     nesterov,
                     warmup_epochs=5):
     cosine_schedule = optax.cosine_decay_schedule(1,
                                                   decay_steps=n_epochs *
                                                   steps_per_epoch,
                                                   alpha=1e-10)
     warmup_schedule = optax.polynomial_schedule(
         init_value=0.0,
         end_value=1.0,
         power=1,
         transition_steps=warmup_epochs * steps_per_epoch,
     )
     schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x)
                                      )
     optimizer = optax.sgd(lr, momentum, nesterov=nesterov)
     optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule))
     return optimizer
示例#9
0
    def test_loop_over_loader(self, variant):
        x_train, _ = self.dataset.data['train']
        batch_size = 2
        epochs = 1

        # set up optimizer and lr scheduler
        scheduler = 'Cosine'
        optimizer_params = {'learning_rate': 1e-3, 'weight_decay': 1e-4}
        lr_mult = 1.0
        steps_per_epoch = float(x_train.shape[0] / batch_size)
        decay_steps = int((epochs + 1) * steps_per_epoch)
        print('steps_per_epoch: %s', str(steps_per_epoch))
        print('decay_steps: %s', str(decay_steps))
        cosine_scheduler_fn = optax.cosine_decay_schedule(
            init_value=optimizer_params['learning_rate'],
            decay_steps=decay_steps)
        optimizer_params['learning_rate'] = cosine_scheduler_fn
        print('optimizer_params: %s', str(optimizer_params))
        optim = optax.adamw(**optimizer_params)
        optim_state = optim.init(self.sim)

        loop_over_loader_partial = functools.partial(loop_over_loader,
                                                     optim=optim,
                                                     rollout_fn=rollout,
                                                     scheduler=scheduler)

        variant_loop_over_loader = variant(loop_over_loader_partial)

        prng_key = jax.random.PRNGKey(0)
        x, y, prng_key = get_shuffled_and_batched_data(self.dataset,
                                                       batch_size, 'train',
                                                       prng_key)

        (self.sim, optim_state, lr_mult,
         loss), _ = jax.lax.scan(variant_loop_over_loader,
                                 (self.sim, optim_state, lr_mult, 0.), (x, y))

        print('test_loop_over_loader loss:' + str(loss))
        self.assertTrue(jnp.allclose(float(loss), 0.89504164))
示例#10
0
def train_controller(
    controller,
    sim,
    pip_feed="parallel",  # or "sequential"
    mode="multipip",  # or "singular"
    duration=0.87,
    dt=0.03,
    epochs=100,
    use_noise=False,
    optimizer=optax.adamw,
    optimizer_params={
        "learning_rate": 1e-3,
        "weight_decay": 1e-4
    },
    loss_fn=lambda x, y: (jnp.abs(x - y)).mean(),
    scheduler="Cosine",
    tensorboard_dir=None,
    model_parameters={},  # used for tensorboard
    print_loss=1,
):
    """train controller."""
    peep = 5
    if mode == "multipip":
        pips = [10, 15, 20, 25, 30, 35]
    elif mode == "singular":
        pips = [35]

    # setup optimizer
    optim_params = copy.deepcopy(optimizer_params)
    if scheduler == "Cosine":
        if pip_feed == "parallel":
            steps_per_epoch = 1
        elif pip_feed == "sequential":
            steps_per_epoch = len(pips)
        decay_steps = int(epochs * steps_per_epoch)
        print("steps_per_epoch:" + str(steps_per_epoch))
        print("decay_steps:" + str(decay_steps))
        cosine_scheduler_fn = optax.cosine_decay_schedule(
            init_value=optim_params["learning_rate"], decay_steps=decay_steps)
        optim_params["learning_rate"] = cosine_scheduler_fn
        print("optim_params:" + str(optim_params))
        optim = optimizer(**optim_params)
    optim_state = optim.init(controller)

    # setup Tensorboard writer
    if tensorboard_dir is not None:
        trial_name = str(model_parameters)
        write_path = tensorboard_dir + trial_name
        summary_writer = metric_writers.create_default_writer(
            logdir=write_path, just_logging=jax.process_index() != 0)
        # summary_writer = tensorboard.SummaryWriter(write_path)
        summary_writer.write_hparams(model_parameters)

    tt = jnp.linspace(0, duration, int(duration / dt))
    losses = []
    for epoch in range(epochs):
        if pip_feed == "parallel":
            value, grad = jax.value_and_grad(rollout_parallel)(controller, sim,
                                                               tt, use_noise,
                                                               peep,
                                                               jnp.array(pips),
                                                               loss_fn)
            updates, optim_state = optim.update(grad, optim_state, controller)
            controller = optax.apply_updates(controller, updates)
            per_step_loss = value / len(tt)
            losses.append(per_step_loss)
            if epoch % print_loss == 0:
                # make new controller with trained parameters and normal clamp
                score = test_controller(controller, sim, pips, peep)
                print(f"Epoch: {epoch}\tLoss: {score:.2f}")
                if tensorboard_dir is not None:
                    summary_writer.write_scalars(epoch, {"score": score})
        if pip_feed == "sequential":
            for pip in pips:
                value, grad = jax.value_and_grad(rollout)(controller, sim, tt,
                                                          use_noise, peep,
                                                          pip, loss_fn,
                                                          jnp.array(0.))
                updates, optim_state = optim.update(grad, optim_state,
                                                    controller)
                controller = optax.apply_updates(controller, updates)
                per_step_loss = value / len(tt)
                losses.append(per_step_loss)
                if epoch % print_loss == 0:
                    # make new controller with trained parameters and normal clamp
                    score = test_controller(controller, sim, pips, peep)
                    print(f"Epoch: {epoch}, pip: {pip}\tLoss: {score:.2f}")
                    if tensorboard_dir is not None:
                        summary_writer.write_scalars(epoch,
                                                     {"per_step_loss": score})
    return controller, per_step_loss, score
示例#11
0
def train_simulator(
    dataset,
    model,
    num_boundary_models,
    activation_fn_name,
    R,
    C,
    # idx 0 to num_boundary_models-1 are boundary models,
    # idx num_boundary_models is default_model
    train_key="train",
    test_key="test",
    batch_size=512,
    epochs=500,
    optimizer=optax.adamw,
    optimizer_params={
        "learning_rate": 1e-3,
        "weight_decay": 1e-4
    },
    patience=10,
    lr_decay_factor=0.1,
    scheduler="ReduceLROnPlateau",  # or "Cosine"
    loss_fn=lambda x, y: (jnp.abs(x - y)).mean(),
    print_loss=10,
    use_tensorboard=False,
    mode="train",
    user_name="alexjyu-brain",
    tb_dir=None,
):
  """train simulator."""
  # evaluate on these at end of epoch
  for key in ["train", "test"]:
    dataset.data[key] = (jnp.array(dataset.data[key][0]),
                         jnp.array(dataset.data[key][1]))
  X_train, y_train = dataset.data[train_key]
  X_test, y_test = dataset.data[test_key]

  # set up optimizer and lr scheduler
  lr_mult = 1.0
  if scheduler == "ReduceLROnPlateau":
    optim = optimizer(**optimizer_params)
    patience_cnt = 0
    prev_loss = float("inf")
  elif scheduler == "Cosine":
    steps_per_epoch = float(X_train.shape[0] / batch_size)
    decay_steps = int((epochs + 1) * steps_per_epoch)
    logging.info("steps_per_epoch: %s", str(steps_per_epoch))
    logging.info("decay_steps: %s", str(decay_steps))
    cosine_scheduler_fn = optax.cosine_decay_schedule(
        init_value=optimizer_params["learning_rate"], decay_steps=decay_steps)
    optimizer_params["learning_rate"] = cosine_scheduler_fn
    logging.info("optimizer_params: %s", str(optimizer_params))
    optim = optimizer(**optimizer_params)
  optim_state = optim.init(model)

  loop_over_loader_partial = functools.partial(
      loop_over_loader, optim=optim, rollout_fn=rollout, scheduler=scheduler)

  # Tensorboard writer
  if use_tensorboard:
    config = copy.deepcopy(model.default_model_parameters)
    del config["activation_fn"]
    config["activation_fn_name"] = activation_fn_name

    if mode == "train":
      file_name = str(config)
    write_path = tb_dir + file_name
    summary_writer = metric_writers.create_default_writer(
        logdir=write_path, just_logging=jax.process_index() != 0)
    summary_writer = tensorboard.SummaryWriter(write_path)
    summary_writer.write_hparams(dict(config))

  # Main Training Loop
  prng_key = jax.random.PRNGKey(0)
  for epoch in range(epochs + 1):
    if epoch % 10 == 0:
      logging.info("epoch: %s", str(epoch))
    X, y, prng_key = get_shuffled_and_batched_data(dataset, batch_size,
                                                   train_key, prng_key)
    if epoch == 0:
      logging.info("X.shape: %s", str(X.shape))
      logging.info("y.shape: %s", str(y.shape))

    (model, optim_state, lr_mult,
     loss), _ = jax.lax.scan(loop_over_loader_partial,
                             (model, optim_state, lr_mult, 0.), (X, y))
    """for i in range(X.shape[0]):

      carry = (model, optim_state, lr_mult, 0.)
      carry, _ = loop_over_loader_partial(carry, (X[i], y[i]))
    model, optim_state, lr_mult, loss = carry
    """
    if scheduler == "ReduceLROnPlateau":
      if loss > prev_loss:
        patience_cnt = patience_cnt + 1
      else:
        patience_cnt = 0
      if patience_cnt == patience:
        lr_mult = lr_mult * lr_decay_factor
        patience_cnt = 0
      prev_loss = loss

    if epoch % print_loss == 0:
      if scheduler == "ReduceLROnPlateau":
        logging.info("loss: %s", str(loss))
        logging.info("prev_loss: %s", str(prev_loss))
        logging.info("patience_cnt: %s", str(patience_cnt))
        logging.info("lr_mult: %s", str(lr_mult))
      # expensive end-of-epoch eval, just for intuition
      train_loss = map_rollout_over_batch(model, (X_train, y_train), rollout)
      # cross-validation
      test_loss = map_rollout_over_batch(model, (X_test, y_test), rollout)

      if epoch % print_loss == 0:
        logging.info(
            f"Epoch {epoch:2d}: train={train_loss.item():.5f}, test_loss={test_loss.item():.5f}"
        )
        logging.info("-----------------------------------")
      if use_tensorboard:
        summary_writer.write_scalars(epoch, {"train_loss": train_loss})
        summary_writer.write_scalars(epoch, {"test_loss": test_loss})
  if use_tensorboard:
    summary_writer.flush()
  logging.info("finished looping over epochs")
  return model, test_loss
示例#12
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
示例#13
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})
示例#14
0
 def _create_jax_schedule(self):
     import optax
     return optax.cosine_decay_schedule(init_value=self.initial_rate,
                                        decay_steps=self.decay_steps,
                                        alpha=self.alpha)