def get_sweep(hyper): """Sweeps over datasets.""" cifar10_sweep = sweep_utils.cifar10(hyper, size=384, steps=10_000, warmup=500) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper, size=384, steps=10_000, warmup=500) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper, size=384, steps=20_000, warmup=500, include_ood_maha=False) imagenet_sweep = hyper.product(imagenet_sweep) return hyper.chainit([ cifar10_sweep, cifar100_sweep, imagenet_sweep, ])
def get_sweep(hyper): """Sweeps over datasets.""" checkpoints = ['/path/to/pretrained_model_ckpt.npz'] # Apply a learning rate sweep following Table 4 of Vision Transformer paper. cifar10_sweep = sweep_utils.cifar10(hyper, val_split='train[98%:]') cifar10_sweep.extend([ hyper.sweep('config.model.num_factors', [3]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper, val_split='train[98%:]') cifar100_sweep.extend([ hyper.sweep('config.model.num_factors', [10]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper, val_split='train[99%:]') imagenet_sweep.extend([ hyper.sweep('config.model.num_factors', [15]), hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]), ]) imagenet_sweep = hyper.product(imagenet_sweep) return hyper.product([ hyper.chainit([ cifar10_sweep, cifar100_sweep, imagenet_sweep, ]), hyper.sweep('config.model.temperature', [0.35, 1.0, 2.0]), hyper.sweep('config.model_init', checkpoints), ])
def get_sweep(hyper): """Sweep over datasets and relevant hyperparameters.""" checkpoints = ['/path/to/pretrained_model_ckpt.npz'] cifar10_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.cifar10(hyper, steps=int(10_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) cifar100_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.cifar100(hyper, steps=int(10_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) imagenet_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.imagenet(hyper, steps=int(20_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]), ]) return hyper.product([ hyper.chainit([ cifar10_sweep, cifar100_sweep, imagenet_sweep, ]), hyper.product([ hyper.sweep('config.fast_weight_lr_multiplier', [0.5, 1.0, 2.0]), hyper.sweep('config.model.transformer.random_sign_init', [-0.5, 0.5]), hyper.sweep('config.model_init', checkpoints), ]) ])
def get_sweep(hyper): """Sweep over datasets and relevant hyperparameters.""" checkpoints = ['/path/to/pretrained_model_ckpt.npz'] cifar10_sweep = hyper.product([ hyper.product(sweep_utils.cifar10(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.01]), ]) cifar100_sweep = hyper.product([ hyper.product(sweep_utils.cifar100(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.03]), ]) imagenet_sweep = hyper.product([ hyper.product(sweep_utils.imagenet(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.06, 0.03]), ]) places365_sweep = hyper.product([ hyper.product(sweep_utils.places365_small(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.06, 0.03]), ]) return hyper.product([ hyper.sweep('config.acquisition_method', acquisition_methods), hyper.chainit([ cifar10_sweep, cifar100_sweep, imagenet_sweep, places365_sweep, ]), hyper.product([ hyper.sweep('config.fast_weight_lr_multiplier', [0.5, 1.0, 2.0]), hyper.sweep('config.model.transformer.random_sign_init', [-0.5, 0.5]), hyper.sweep('config.model_init', checkpoints), ]) ])
def get_sweep(hyper): """Sweeps over datasets.""" checkpoints = ['/path/to/pretrained_model_ckpt.npz'] use_jft = True # whether to use JFT-300M or ImageNet-21K settings sweep_lr = False # whether to sweep over learning rates if use_jft: cifar10_sweep = sweep_utils.cifar10(hyper) cifar10_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper) cifar100_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1)) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper) imagenet_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1)) imagenet_sweep = hyper.product(imagenet_sweep) imagenet_1shot_sweep = sweep_utils.imagenet_fewshot(hyper, fewshot='1shot', steps=200, warmup=40, log_eval_steps=20) imagenet_1shot_sweep.append( hyper.fixed('config.lr.base', 0.01, length=1)) imagenet_1shot_sweep = hyper.product(imagenet_1shot_sweep) imagenet_5shot_sweep = sweep_utils.imagenet_fewshot(hyper, fewshot='5shot', steps=1000, warmup=40, log_eval_steps=100) imagenet_5shot_sweep.append( hyper.fixed('config.lr.base', 0.02, length=1)) imagenet_5shot_sweep = hyper.product(imagenet_5shot_sweep) imagenet_10shot_sweep = sweep_utils.imagenet_fewshot( hyper, fewshot='10shot', steps=2000, warmup=50, log_eval_steps=200) imagenet_10shot_sweep.append( hyper.fixed('config.lr.base', 0.02, length=1)) imagenet_10shot_sweep = hyper.product(imagenet_10shot_sweep) else: cifar10_sweep = sweep_utils.cifar10(hyper) cifar10_sweep.append(hyper.fixed('config.lr.base', 0.003, length=1)) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper) cifar100_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper) imagenet_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) imagenet_sweep = hyper.product(imagenet_sweep) imagenet_1shot_sweep = sweep_utils.imagenet_fewshot(hyper, fewshot='1shot', steps=200, warmup=10, log_eval_steps=20) imagenet_1shot_sweep.append( hyper.fixed('config.lr.base', 0.01, length=1)) imagenet_1shot_sweep = hyper.product(imagenet_1shot_sweep) imagenet_5shot_sweep = sweep_utils.imagenet_fewshot(hyper, fewshot='5shot', steps=1000, warmup=30, log_eval_steps=100) imagenet_5shot_sweep.append( hyper.fixed('config.lr.base', 0.03, length=1)) imagenet_5shot_sweep = hyper.product(imagenet_5shot_sweep) imagenet_10shot_sweep = sweep_utils.imagenet_fewshot( hyper, fewshot='10shot', steps=2000, warmup=10, log_eval_steps=200) imagenet_10shot_sweep.append( hyper.fixed('config.lr.base', 0.01, length=1)) imagenet_10shot_sweep = hyper.product(imagenet_10shot_sweep) if sweep_lr: # Sweep over learning rates following Table 4 of Vision Transformer paper # and training steps following E^3 paper. checkpoints = [checkpoints[0]] cifar10_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.cifar10(hyper, steps=int(10_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) cifar100_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.cifar100(hyper, steps=int(10_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) imagenet_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.imagenet(hyper, steps=int(20_000 * s), warmup=int(500 * s))) for s in [0.5, 1.0, 1.5, 2.0] ]), hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]), ]) imagenet_1shot_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.imagenet_fewshot(hyper, fewshot='1shot', steps=200, warmup=s)) for s in [1, 5, 10, 20, 30, 40, 50] ]), hyper.sweep('config.lr.base', [0.06, 0.05, 0.04, 0.03, 0.02, 0.01]) ]) imagenet_5shot_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.imagenet_fewshot(hyper, fewshot='5shot', steps=1000, warmup=s)) for s in [1, 5, 10, 20, 30, 40, 50] ]), hyper.sweep('config.lr.base', [0.06, 0.05, 0.04, 0.03, 0.02, 0.01]), ]) imagenet_10shot_sweep = hyper.product([ hyper.chainit([ hyper.product( sweep_utils.imagenet_fewshot(hyper, fewshot='10shot', steps=2000, warmup=s)) for s in [1, 5, 10, 20, 30, 40, 50] ]), hyper.sweep('config.lr.base', [0.06, 0.05, 0.04, 0.03, 0.02, 0.01]) ]) return hyper.product([ hyper.chainit([ cifar10_sweep, cifar100_sweep, imagenet_sweep, imagenet_1shot_sweep, imagenet_5shot_sweep, imagenet_10shot_sweep ]), hyper.sweep('config.model_init', checkpoints), ])
def sweep_checkpoints(use_jft, sweep_lr=sweep_lr): """whether to use JFT-300M or ImageNet-21K settings.""" checkpoints = ['/path/to/pretrained_model_ckpt.npz'] if use_jft: cifar10_sweep = sweep_utils.cifar10(hyper, steps=1000) cifar10_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper, steps=1000) cifar100_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1)) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper, steps=1000) imagenet_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1)) imagenet_sweep = hyper.product(imagenet_sweep) places365_sweep = hyper.product([ hyper.product(sweep_utils.places365_small(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.03]), ]) else: cifar10_sweep = sweep_utils.cifar10(hyper, steps=1000) cifar10_sweep.append(hyper.fixed('config.lr.base', 0.003, length=1)) cifar10_sweep = hyper.product(cifar10_sweep) cifar100_sweep = sweep_utils.cifar100(hyper, steps=1000) cifar100_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) cifar100_sweep = hyper.product(cifar100_sweep) imagenet_sweep = sweep_utils.imagenet(hyper, steps=1000) imagenet_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1)) imagenet_sweep = hyper.product(imagenet_sweep) places365_sweep = hyper.product([ hyper.product(sweep_utils.places365_small(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.01]), ]) if sweep_lr: # Apply a learning rate sweep following Table 4 of Vision Transformer # paper. checkpoints = [checkpoints[0]] cifar10_sweep = hyper.product([ hyper.product(sweep_utils.cifar10(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) cifar100_sweep = hyper.product([ hyper.product(sweep_utils.cifar100(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) imagenet_sweep = hyper.product([ hyper.product(sweep_utils.imagenet(hyper)), hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]), ]) places365_sweep = hyper.product([ hyper.product(sweep_utils.places365_small(hyper, steps=1000)), hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]), ]) return hyper.product([ hyper.chainit([ cifar10_sweep, cifar100_sweep, places365_sweep, imagenet_sweep, ]), hyper.sweep('config.model_init', checkpoints), hyper.sweep('config.acquisition_method', acquisition_methods), ])