Example #1
0
def main(argv, experiment_class: experiment.AbstractExperiment):

  # Maybe restore a model.
  restore_path = FLAGS.config.restore_path
  if restore_path:
    _restore_state_to_in_memory_checkpointer(restore_path)

  # Maybe save a model.
  save_dir = os.path.join(FLAGS.config.checkpoint_dir, 'models')
  if FLAGS.config.one_off_evaluate:
    save_model_fn = lambda: None  # No need to save checkpoint in this case.
  else:
    save_model_fn = functools.partial(
        _save_state_from_in_memory_checkpointer, save_dir, experiment_class)
  _setup_signals(save_model_fn)  # Save on Ctrl+C (continue) or Ctrl+\ (exit).

  try:
    platform.main(experiment_class, argv)
  finally:
    save_model_fn()  # Save at the end of training or in case of exception.
      return (name in ['scale', 'offset', 'b']
              or 'gain' in name or 'bias' in name)
    gains_biases, weights = hk.data_structures.partition(pred_gb, self._params)
    def pred_fc(mod, name, val):
      del name, val
      return 'linear' in mod and 'squeeze_excite' not in mod
    fc_weights, weights = hk.data_structures.partition(pred_fc, weights)
    # Lr schedule with batch-based LR scaling
    if self.config.lr_scale_by_bs:
      max_lr = (self.config.lr * self.config.train_batch_size) / 256
    else:
      max_lr = self.config.lr
    lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
    lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
    # Optimizer; no need to broadcast!
    opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
    opt_kwargs['lr'] = lr_schedule
    opt_module = getattr(optim, self.config.optimizer.name)
    self.opt = opt_module([{'params': gains_biases, 'weight_decay': None,},
                           {'params': fc_weights, 'clipping': None},
                           {'params': weights}], **opt_kwargs)
    if self._opt_state is None:
      self._opt_state = self.opt.states()
    else:
      self.opt.plugin(self._opt_state)


if __name__ == '__main__':
  flags.mark_flag_as_required('config')
  platform.main(Experiment, sys.argv[1:])