Beispiel #1
0
  def schedule_fn(count):
    progress = count / num_train_steps

    if cosine_decay_schedule:
      logging.info("Uses cosine decay learning rate with linear warmup on %f "
                   "first steps.", warmup_ratio)
      def cosine_decay_fn(_):
        return (jnp.cos(jnp.pi * (progress - warmup_ratio) /
                        (1 - warmup_ratio)) + 1) / 2

      return jax.lax.cond(
          progress < warmup_ratio,
          lambda _: progress / warmup_ratio,  # linear warmup
          cosine_decay_fn,
          None)

    if decay_at_steps:
      logging.info(
          "Learning rate will be multiplied by %f at steps %s [total steps %d]",
          decay, decay_at_steps, num_train_steps)
      logging.info("Learning rate has a linear warmup on %f first steps.",
                   warmup_ratio)
      decay_at_steps_dict = {s: decay for s in decay_at_steps}
      fn = optax.piecewise_constant_schedule(1., decay_at_steps_dict)
      step_size = fn(count)
      if warmup_ratio > 0.0:
        step_size *= jnp.minimum(progress / warmup_ratio, 1.0)
      return step_size

    if warmup_ratio > 0.0:
      return min(progress / warmup_ratio, 1.0)
    else:
      return 1.0
Beispiel #2
0
def get_learning_rate_schedule(
    total_batch_size, steps_per_epoch, total_steps, optimizer_config):
  """Build the learning rate schedule function."""
  base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr,
                                 optimizer_config.scale_by_batch)

  schedule_type = optimizer_config.schedule_type
  if schedule_type == 'steps':
    boundaries = optimizer_config.step_decay_kwargs.decay_boundaries
    boundaries.sort()

    decay_rate = optimizer_config.step_decay_kwargs.decay_rate
    boundaries_and_scales = {
        int(boundary * total_steps): decay_rate for boundary in boundaries}
    schedule_fn = optax.piecewise_constant_schedule(
        init_value=base_lr, boundaries_and_scales=boundaries_and_scales)
  elif schedule_type == 'cosine':
    warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs
                    * steps_per_epoch)
    # Batch scale the other lr values as well:
    init_value = _get_batch_scaled_lr(
        total_batch_size,
        optimizer_config.cosine_decay_kwargs.init_value,
        optimizer_config.scale_by_batch)
    end_value = _get_batch_scaled_lr(
        total_batch_size,
        optimizer_config.cosine_decay_kwargs.end_value,
        optimizer_config.scale_by_batch)

    schedule_fn = optax.warmup_cosine_decay_schedule(
        init_value=init_value,
        peak_value=base_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps,
        end_value=end_value)
  elif schedule_type == 'constant_cosine':
    # Convert end_value to alpha, used by cosine_decay_schedule.
    alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr

    # Number of steps spent in constant phase.
    constant_steps = int(
        optimizer_config.constant_cosine_decay_kwargs.constant_fraction
        * total_steps)
    decay_steps = total_steps - constant_steps

    constant_phase = optax.constant_schedule(value=base_lr)
    decay_phase = optax.cosine_decay_schedule(
        init_value=base_lr,
        decay_steps=decay_steps,
        alpha=alpha)
    schedule_fn = optax.join_schedules(
        schedules=[constant_phase, decay_phase],
        boundaries=[constant_steps])
  else:
    raise ValueError(f'Unknown learning rate schedule: {schedule_type}')

  return schedule_fn
Beispiel #3
0
def get_step_schedule(
    max_learning_rate: float,
    total_steps: int,
    warmup_steps: int = 0) -> optax.Schedule:
  """Builds a step schedule with initial warm-up."""
  if total_steps < warmup_steps:
    return optax.linear_schedule(init_value=0., end_value=max_learning_rate,
                                 transition_steps=warmup_steps)
  return optax.join_schedules([
      optax.linear_schedule(init_value=0., end_value=max_learning_rate,
                            transition_steps=warmup_steps),
      optax.piecewise_constant_schedule(
          init_value=max_learning_rate,
          boundaries_and_scales={total_steps * 2 // 3: .1}),
  ], [warmup_steps])
Beispiel #4
0
def run_experiment(
    initial_lr=1e-1,
    lr_boundaries=[150, 250],
    seed=0,
    augmentation=False,
    epochs=1,
    batch_size=128,
    net=ResNet18,
    l2=True,
    momentum=True,
    adversarial_dataset=False,
    R=1,
    zero_out_ratio=0,
    testing=False,
    run_weights_name=None,
):
    """Runs the experiment with given parameters and auto logs to wandb.
    In case of errors, make sure you've runned 'wandb login'
    
    :param testing - Dummy testing variable to quickly filter out runs in wandb
    :param run_weights_name - Name of the experiment RUN PATH in Wandb to load weights from.
        If None, default weight initialization is used.
        Note that the experiment name is different than it's run path
    """
    # locals is a htly hacky way of getting function arguments - see https://stackoverflow.com/a/582097
    # Must be used before setting any local variables
    wandb.init(project='bad-global-minima',
               entity='data-icmc',
               config=locals())

    if run_weights_name is not None:
        weights_file = download_weights_file(wandb.config.run_weights_name)
    else:
        weights_file = None

    if wandb.config.adversarial_dataset:
        dataloader = datasets.get_adversarial_cifar(
            data_root='.',
            download_data=True,
            batch_size=wandb.config.batch_size,
            R=wandb.config.R,
            zero_out_ratio=wandb.config.zero_out_ratio)
        dataloader_test = None
    else:
        dataloader = datasets.get_cifar(data_root='.',
                                        download_data=True,
                                        split='train',
                                        batch_size=wandb.config.batch_size,
                                        augmentation=wandb.config.augmentation)
        dataloader_test = datasets.get_cifar(
            data_root='.', split='test', batch_size=wandb.config.batch_size)

    boundaries_and_scales = {
        ep * len(dataloader): 1 / 10
        for ep in wandb.config.lr_boundaries
    }
    schedule_fn = optax.piecewise_constant_schedule(-wandb.config.initial_lr,
                                                    boundaries_and_scales)
    train(net, wandb.config.epochs, dataloader, dataloader_test, schedule_fn,
          wandb.config.l2, wandb.config.momentum, wandb.config.seed,
          weights_file, wandb.config.run_weights_name)
Beispiel #5
0
def main(unused_argv):
    print(f'Loading "{_CKPT.value}"')
    print(f'Using a WideResNet with depth {_DEPTH.value} and width '
          f'{_WIDTH.value}.')

    # Create dataset.
    if _DATASET.value == 'mnist':
        _, data_test = tf.keras.datasets.mnist.load_data()
        normalize_fn = datasets.mnist_normalize
    elif _DATASET.value == 'cifar10':
        _, data_test = tf.keras.datasets.cifar10.load_data()
        normalize_fn = datasets.cifar10_normalize
    else:
        assert _DATASET.value == 'cifar100'
        _, data_test = tf.keras.datasets.cifar100.load_data()
        normalize_fn = datasets.cifar100_normalize

    # Create model.
    @hk.transform_with_state
    def model_fn(x, is_training=False):
        model = model_zoo.WideResNet(num_classes=10,
                                     depth=_DEPTH.value,
                                     width=_WIDTH.value,
                                     activation='swish')
        return model(normalize_fn(x), is_training=is_training)

    # Build dataset.
    images, labels = data_test
    samples = (images.astype(np.float32) / 255.,
               np.squeeze(labels, axis=-1).astype(np.int64))
    data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value)
    test_loader = tfds.as_numpy(data)

    # Load model parameters.
    rng_seq = hk.PRNGSequence(0)
    if _CKPT.value == 'dummy':
        for images, _ in test_loader:
            break
        params, state = model_fn.init(next(rng_seq), images, is_training=True)
        # Reset iterator.
        test_loader = tfds.as_numpy(data)
    else:
        params, state = np.load(_CKPT.value, allow_pickle=True)

    # Create adversarial attack. We run a PGD-40 attack with margin loss.
    epsilon = 8 / 255
    eval_attack = attacks.UntargetedAttack(attacks.PGD(
        attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule(
            init_value=.1, boundaries_and_scales={
                20: .1,
                30: .01
            })),
        num_steps=40,
        initialize_fn=attacks.linf_initialize_fn(epsilon),
        project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
                                           loss_fn=attacks.untargeted_margin)

    def logits_fn(x, rng):
        return model_fn.apply(params, state, rng, x)[0]

    # Evaluation.
    correct = 0
    adv_correct = 0
    total = 0
    batch_count = 0
    total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1,
                        _NUM_BATCHES.value)
    for images, labels in tqdm.tqdm(test_loader, total=total_batches):
        rng = next(rng_seq)
        loop_logits_fn = functools.partial(logits_fn, rng=rng)

        # Clean examples.
        outputs = loop_logits_fn(images)
        correct += (np.argmax(outputs, 1) == labels).sum().item()

        # Adversarial examples.
        adv_images = eval_attack(loop_logits_fn, next(rng_seq), images, labels)
        outputs = loop_logits_fn(adv_images)
        predicted = np.argmax(outputs, 1)
        adv_correct += (predicted == labels).sum().item()

        total += labels.shape[0]
        batch_count += 1
        if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
            break
    print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')
    print(f'Robust accuracy: {100 * adv_correct / total:.2f}%')
Beispiel #6
0
def get_config():
    """Return config object for training."""
    config = base_config.get_base_config()

    # Batch size, training steps and data.
    num_classes = 10
    num_epochs = 400
    # Gowal et al. (2020) and Rebuffi et al. (2021) use 1024 as batch size.
    # Reducing this batch size may require further adjustments to the batch
    # normalization decay or the learning rate. If you have to use a batch size
    # of 256, reduce the number of emulated workers to 1 (it should match the
    # results of using a batch size of 1024 with 4 workers).
    train_batch_size = 1024

    def steps_from_epochs(n):
        return max(int(n * 50_000 / train_batch_size), 1)

    num_steps = steps_from_epochs(num_epochs)
    test_batch_size = train_batch_size
    # Specify the path to the downloaded data. You can download data from
    # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
    # If the path is set to "cifar10_ddpm.npz" and is not found in the current
    # directory, the corresponding data will be downloaded.
    extra_npz = 'cifar10_ddpm.npz'

    # Learning rate.
    learning_rate = .1 * max(train_batch_size / 256, 1.)
    learning_rate_warmup = steps_from_epochs(10)
    learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
                                                 learning_rate_warmup)

    # Model definition.
    model_ctor = model_zoo.WideResNet
    model_kwargs = dict(num_classes=num_classes,
                        depth=28,
                        width=10,
                        activation='swish')

    # Attack used during training (can be None).
    epsilon = 8 / 255
    train_attack = attacks.UntargetedAttack(
        attacks.PGD(attacks.Adam(
            optax.piecewise_constant_schedule(init_value=.1,
                                              boundaries_and_scales={5: .1})),
                    num_steps=10,
                    initialize_fn=attacks.linf_initialize_fn(epsilon),
                    project_fn=attacks.linf_project_fn(epsilon,
                                                       bounds=(0., 1.))),
        loss_fn=attacks.untargeted_kl_divergence)

    # Attack used during evaluation (can be None).
    eval_attack = attacks.UntargetedAttack(attacks.PGD(
        attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule(
            init_value=.1, boundaries_and_scales={
                20: .1,
                30: .01
            })),
        num_steps=40,
        initialize_fn=attacks.linf_initialize_fn(epsilon),
        project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
                                           loss_fn=attacks.untargeted_margin)

    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            epsilon=epsilon,
            num_classes=num_classes,
            # Results from various publications use 4 worker machines, which results
            # in slight differences when using less worker machines. To compensate for
            # such discrepancies, we emulate these additional workers. Set to zero,
            # when using more than 4 workers.
            emulated_workers=4,
            dry_run=False,
            save_final_checkpoint_as_npy=True,
            model=dict(constructor=model_ctor, kwargs=model_kwargs),
            training=dict(batch_size=train_batch_size,
                          learning_rate=learning_rate_fn,
                          weight_decay=5e-4,
                          swa_decay=.995,
                          use_cutmix=False,
                          supervised_batch_ratio=.3,
                          extra_data_path=extra_npz,
                          extra_label_smoothing=.1,
                          attack=train_attack),
            evaluation=dict(
                # If `interval` is positive, synchronously evaluate at regular
                # intervals. Setting it to zero will not evaluate while training,
                # unless `--jaxline_mode` is set to `train_eval_multithreaded`, which
                # asynchronously evaluates checkpoints.
                interval=steps_from_epochs(40),
                batch_size=test_batch_size,
                attack=eval_attack),
        )))

    config.checkpoint_dir = '/tmp/jaxline/robust'
    config.train_checkpoint_all_hosts = False
    config.training_steps = num_steps
    config.interval_type = 'steps'
    config.log_train_data_interval = steps_from_epochs(.5)
    config.log_tensors_interval = steps_from_epochs(.5)
    config.save_checkpoint_interval = steps_from_epochs(40)
    config.eval_specific_checkpoint_dir = ''
    return config
Beispiel #7
0
 def _create_jax_schedule(self):
     import optax
     return optax.piecewise_constant_schedule(
         init_value=self.initial_rate,
         boundaries_and_scales=self.boundaries_and_scales)