def schedule_fn(count): progress = count / num_train_steps if cosine_decay_schedule: logging.info("Uses cosine decay learning rate with linear warmup on %f " "first steps.", warmup_ratio) def cosine_decay_fn(_): return (jnp.cos(jnp.pi * (progress - warmup_ratio) / (1 - warmup_ratio)) + 1) / 2 return jax.lax.cond( progress < warmup_ratio, lambda _: progress / warmup_ratio, # linear warmup cosine_decay_fn, None) if decay_at_steps: logging.info( "Learning rate will be multiplied by %f at steps %s [total steps %d]", decay, decay_at_steps, num_train_steps) logging.info("Learning rate has a linear warmup on %f first steps.", warmup_ratio) decay_at_steps_dict = {s: decay for s in decay_at_steps} fn = optax.piecewise_constant_schedule(1., decay_at_steps_dict) step_size = fn(count) if warmup_ratio > 0.0: step_size *= jnp.minimum(progress / warmup_ratio, 1.0) return step_size if warmup_ratio > 0.0: return min(progress / warmup_ratio, 1.0) else: return 1.0
def get_learning_rate_schedule( total_batch_size, steps_per_epoch, total_steps, optimizer_config): """Build the learning rate schedule function.""" base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr, optimizer_config.scale_by_batch) schedule_type = optimizer_config.schedule_type if schedule_type == 'steps': boundaries = optimizer_config.step_decay_kwargs.decay_boundaries boundaries.sort() decay_rate = optimizer_config.step_decay_kwargs.decay_rate boundaries_and_scales = { int(boundary * total_steps): decay_rate for boundary in boundaries} schedule_fn = optax.piecewise_constant_schedule( init_value=base_lr, boundaries_and_scales=boundaries_and_scales) elif schedule_type == 'cosine': warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs * steps_per_epoch) # Batch scale the other lr values as well: init_value = _get_batch_scaled_lr( total_batch_size, optimizer_config.cosine_decay_kwargs.init_value, optimizer_config.scale_by_batch) end_value = _get_batch_scaled_lr( total_batch_size, optimizer_config.cosine_decay_kwargs.end_value, optimizer_config.scale_by_batch) schedule_fn = optax.warmup_cosine_decay_schedule( init_value=init_value, peak_value=base_lr, warmup_steps=warmup_steps, decay_steps=total_steps, end_value=end_value) elif schedule_type == 'constant_cosine': # Convert end_value to alpha, used by cosine_decay_schedule. alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr # Number of steps spent in constant phase. constant_steps = int( optimizer_config.constant_cosine_decay_kwargs.constant_fraction * total_steps) decay_steps = total_steps - constant_steps constant_phase = optax.constant_schedule(value=base_lr) decay_phase = optax.cosine_decay_schedule( init_value=base_lr, decay_steps=decay_steps, alpha=alpha) schedule_fn = optax.join_schedules( schedules=[constant_phase, decay_phase], boundaries=[constant_steps]) else: raise ValueError(f'Unknown learning rate schedule: {schedule_type}') return schedule_fn
def get_step_schedule( max_learning_rate: float, total_steps: int, warmup_steps: int = 0) -> optax.Schedule: """Builds a step schedule with initial warm-up.""" if total_steps < warmup_steps: return optax.linear_schedule(init_value=0., end_value=max_learning_rate, transition_steps=warmup_steps) return optax.join_schedules([ optax.linear_schedule(init_value=0., end_value=max_learning_rate, transition_steps=warmup_steps), optax.piecewise_constant_schedule( init_value=max_learning_rate, boundaries_and_scales={total_steps * 2 // 3: .1}), ], [warmup_steps])
def run_experiment( initial_lr=1e-1, lr_boundaries=[150, 250], seed=0, augmentation=False, epochs=1, batch_size=128, net=ResNet18, l2=True, momentum=True, adversarial_dataset=False, R=1, zero_out_ratio=0, testing=False, run_weights_name=None, ): """Runs the experiment with given parameters and auto logs to wandb. In case of errors, make sure you've runned 'wandb login' :param testing - Dummy testing variable to quickly filter out runs in wandb :param run_weights_name - Name of the experiment RUN PATH in Wandb to load weights from. If None, default weight initialization is used. Note that the experiment name is different than it's run path """ # locals is a htly hacky way of getting function arguments - see https://stackoverflow.com/a/582097 # Must be used before setting any local variables wandb.init(project='bad-global-minima', entity='data-icmc', config=locals()) if run_weights_name is not None: weights_file = download_weights_file(wandb.config.run_weights_name) else: weights_file = None if wandb.config.adversarial_dataset: dataloader = datasets.get_adversarial_cifar( data_root='.', download_data=True, batch_size=wandb.config.batch_size, R=wandb.config.R, zero_out_ratio=wandb.config.zero_out_ratio) dataloader_test = None else: dataloader = datasets.get_cifar(data_root='.', download_data=True, split='train', batch_size=wandb.config.batch_size, augmentation=wandb.config.augmentation) dataloader_test = datasets.get_cifar( data_root='.', split='test', batch_size=wandb.config.batch_size) boundaries_and_scales = { ep * len(dataloader): 1 / 10 for ep in wandb.config.lr_boundaries } schedule_fn = optax.piecewise_constant_schedule(-wandb.config.initial_lr, boundaries_and_scales) train(net, wandb.config.epochs, dataloader, dataloader_test, schedule_fn, wandb.config.l2, wandb.config.momentum, wandb.config.seed, weights_file, wandb.config.run_weights_name)
def main(unused_argv): print(f'Loading "{_CKPT.value}"') print(f'Using a WideResNet with depth {_DEPTH.value} and width ' f'{_WIDTH.value}.') # Create dataset. if _DATASET.value == 'mnist': _, data_test = tf.keras.datasets.mnist.load_data() normalize_fn = datasets.mnist_normalize elif _DATASET.value == 'cifar10': _, data_test = tf.keras.datasets.cifar10.load_data() normalize_fn = datasets.cifar10_normalize else: assert _DATASET.value == 'cifar100' _, data_test = tf.keras.datasets.cifar100.load_data() normalize_fn = datasets.cifar100_normalize # Create model. @hk.transform_with_state def model_fn(x, is_training=False): model = model_zoo.WideResNet(num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, activation='swish') return model(normalize_fn(x), is_training=is_training) # Build dataset. images, labels = data_test samples = (images.astype(np.float32) / 255., np.squeeze(labels, axis=-1).astype(np.int64)) data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value) test_loader = tfds.as_numpy(data) # Load model parameters. rng_seq = hk.PRNGSequence(0) if _CKPT.value == 'dummy': for images, _ in test_loader: break params, state = model_fn.init(next(rng_seq), images, is_training=True) # Reset iterator. test_loader = tfds.as_numpy(data) else: params, state = np.load(_CKPT.value, allow_pickle=True) # Create adversarial attack. We run a PGD-40 attack with margin loss. epsilon = 8 / 255 eval_attack = attacks.UntargetedAttack(attacks.PGD( attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule( init_value=.1, boundaries_and_scales={ 20: .1, 30: .01 })), num_steps=40, initialize_fn=attacks.linf_initialize_fn(epsilon), project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), loss_fn=attacks.untargeted_margin) def logits_fn(x, rng): return model_fn.apply(params, state, rng, x)[0] # Evaluation. correct = 0 adv_correct = 0 total = 0 batch_count = 0 total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value) for images, labels in tqdm.tqdm(test_loader, total=total_batches): rng = next(rng_seq) loop_logits_fn = functools.partial(logits_fn, rng=rng) # Clean examples. outputs = loop_logits_fn(images) correct += (np.argmax(outputs, 1) == labels).sum().item() # Adversarial examples. adv_images = eval_attack(loop_logits_fn, next(rng_seq), images, labels) outputs = loop_logits_fn(adv_images) predicted = np.argmax(outputs, 1) adv_correct += (predicted == labels).sum().item() total += labels.shape[0] batch_count += 1 if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value: break print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%') print(f'Robust accuracy: {100 * adv_correct / total:.2f}%')
def get_config(): """Return config object for training.""" config = base_config.get_base_config() # Batch size, training steps and data. num_classes = 10 num_epochs = 400 # Gowal et al. (2020) and Rebuffi et al. (2021) use 1024 as batch size. # Reducing this batch size may require further adjustments to the batch # normalization decay or the learning rate. If you have to use a batch size # of 256, reduce the number of emulated workers to 1 (it should match the # results of using a batch size of 1024 with 4 workers). train_batch_size = 1024 def steps_from_epochs(n): return max(int(n * 50_000 / train_batch_size), 1) num_steps = steps_from_epochs(num_epochs) test_batch_size = train_batch_size # Specify the path to the downloaded data. You can download data from # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness. # If the path is set to "cifar10_ddpm.npz" and is not found in the current # directory, the corresponding data will be downloaded. extra_npz = 'cifar10_ddpm.npz' # Learning rate. learning_rate = .1 * max(train_batch_size / 256, 1.) learning_rate_warmup = steps_from_epochs(10) learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, learning_rate_warmup) # Model definition. model_ctor = model_zoo.WideResNet model_kwargs = dict(num_classes=num_classes, depth=28, width=10, activation='swish') # Attack used during training (can be None). epsilon = 8 / 255 train_attack = attacks.UntargetedAttack( attacks.PGD(attacks.Adam( optax.piecewise_constant_schedule(init_value=.1, boundaries_and_scales={5: .1})), num_steps=10, initialize_fn=attacks.linf_initialize_fn(epsilon), project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), loss_fn=attacks.untargeted_kl_divergence) # Attack used during evaluation (can be None). eval_attack = attacks.UntargetedAttack(attacks.PGD( attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule( init_value=.1, boundaries_and_scales={ 20: .1, 30: .01 })), num_steps=40, initialize_fn=attacks.linf_initialize_fn(epsilon), project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), loss_fn=attacks.untargeted_margin) config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( epsilon=epsilon, num_classes=num_classes, # Results from various publications use 4 worker machines, which results # in slight differences when using less worker machines. To compensate for # such discrepancies, we emulate these additional workers. Set to zero, # when using more than 4 workers. emulated_workers=4, dry_run=False, save_final_checkpoint_as_npy=True, model=dict(constructor=model_ctor, kwargs=model_kwargs), training=dict(batch_size=train_batch_size, learning_rate=learning_rate_fn, weight_decay=5e-4, swa_decay=.995, use_cutmix=False, supervised_batch_ratio=.3, extra_data_path=extra_npz, extra_label_smoothing=.1, attack=train_attack), evaluation=dict( # If `interval` is positive, synchronously evaluate at regular # intervals. Setting it to zero will not evaluate while training, # unless `--jaxline_mode` is set to `train_eval_multithreaded`, which # asynchronously evaluates checkpoints. interval=steps_from_epochs(40), batch_size=test_batch_size, attack=eval_attack), ))) config.checkpoint_dir = '/tmp/jaxline/robust' config.train_checkpoint_all_hosts = False config.training_steps = num_steps config.interval_type = 'steps' config.log_train_data_interval = steps_from_epochs(.5) config.log_tensors_interval = steps_from_epochs(.5) config.save_checkpoint_interval = steps_from_epochs(40) config.eval_specific_checkpoint_dir = '' return config
def _create_jax_schedule(self): import optax return optax.piecewise_constant_schedule( init_value=self.initial_rate, boundaries_and_scales=self.boundaries_and_scales)