コード例 #1
0
def get_sweep(hyper):
    """Sweeps over datasets."""
    cifar10_sweep = sweep_utils.cifar10(hyper,
                                        size=384,
                                        steps=10_000,
                                        warmup=500)
    cifar10_sweep = hyper.product(cifar10_sweep)

    cifar100_sweep = sweep_utils.cifar100(hyper,
                                          size=384,
                                          steps=10_000,
                                          warmup=500)
    cifar100_sweep = hyper.product(cifar100_sweep)

    imagenet_sweep = sweep_utils.imagenet(hyper,
                                          size=384,
                                          steps=20_000,
                                          warmup=500,
                                          include_ood_maha=False)
    imagenet_sweep = hyper.product(imagenet_sweep)

    return hyper.chainit([
        cifar10_sweep,
        cifar100_sweep,
        imagenet_sweep,
    ])
コード例 #2
0
def get_sweep(hyper):
  """Sweeps over datasets."""
  checkpoints = ['/path/to/pretrained_model_ckpt.npz']
  # Apply a learning rate sweep following Table 4 of Vision Transformer paper.
  cifar10_sweep = sweep_utils.cifar10(hyper, val_split='train[98%:]')
  cifar10_sweep.extend([
      hyper.sweep('config.model.num_factors', [3]),
      hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
  ])
  cifar10_sweep = hyper.product(cifar10_sweep)

  cifar100_sweep = sweep_utils.cifar100(hyper, val_split='train[98%:]')
  cifar100_sweep.extend([
      hyper.sweep('config.model.num_factors', [10]),
      hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
  ])
  cifar100_sweep = hyper.product(cifar100_sweep)

  imagenet_sweep = sweep_utils.imagenet(hyper, val_split='train[99%:]')
  imagenet_sweep.extend([
      hyper.sweep('config.model.num_factors', [15]),
      hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]),
  ])
  imagenet_sweep = hyper.product(imagenet_sweep)

  return hyper.product([
      hyper.chainit([
          cifar10_sweep,
          cifar100_sweep,
          imagenet_sweep,
      ]),
      hyper.sweep('config.model.temperature', [0.35, 1.0, 2.0]),
      hyper.sweep('config.model_init', checkpoints),
  ])
コード例 #3
0
def get_sweep(hyper):
    """Sweep over datasets and relevant hyperparameters."""
    checkpoints = ['/path/to/pretrained_model_ckpt.npz']

    cifar10_sweep = hyper.product([
        hyper.chainit([
            hyper.product(
                sweep_utils.cifar10(hyper,
                                    steps=int(10_000 * s),
                                    warmup=int(500 * s)))
            for s in [0.5, 1.0, 1.5, 2.0]
        ]),
        hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
    ])
    cifar100_sweep = hyper.product([
        hyper.chainit([
            hyper.product(
                sweep_utils.cifar100(hyper,
                                     steps=int(10_000 * s),
                                     warmup=int(500 * s)))
            for s in [0.5, 1.0, 1.5, 2.0]
        ]),
        hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
    ])
    imagenet_sweep = hyper.product([
        hyper.chainit([
            hyper.product(
                sweep_utils.imagenet(hyper,
                                     steps=int(20_000 * s),
                                     warmup=int(500 * s)))
            for s in [0.5, 1.0, 1.5, 2.0]
        ]),
        hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]),
    ])
    return hyper.product([
        hyper.chainit([
            cifar10_sweep,
            cifar100_sweep,
            imagenet_sweep,
        ]),
        hyper.product([
            hyper.sweep('config.fast_weight_lr_multiplier', [0.5, 1.0, 2.0]),
            hyper.sweep('config.model.transformer.random_sign_init',
                        [-0.5, 0.5]),
            hyper.sweep('config.model_init', checkpoints),
        ])
    ])
コード例 #4
0
def get_sweep(hyper):
    """Sweep over datasets and relevant hyperparameters."""
    checkpoints = ['/path/to/pretrained_model_ckpt.npz']
    cifar10_sweep = hyper.product([
        hyper.product(sweep_utils.cifar10(hyper, steps=1000)),
        hyper.sweep('config.lr.base', [0.01]),
    ])

    cifar100_sweep = hyper.product([
        hyper.product(sweep_utils.cifar100(hyper, steps=1000)),
        hyper.sweep('config.lr.base', [0.03]),
    ])

    imagenet_sweep = hyper.product([
        hyper.product(sweep_utils.imagenet(hyper, steps=1000)),
        hyper.sweep('config.lr.base', [0.06, 0.03]),
    ])

    places365_sweep = hyper.product([
        hyper.product(sweep_utils.places365_small(hyper, steps=1000)),
        hyper.sweep('config.lr.base', [0.06, 0.03]),
    ])

    return hyper.product([
        hyper.sweep('config.acquisition_method', acquisition_methods),
        hyper.chainit([
            cifar10_sweep,
            cifar100_sweep,
            imagenet_sweep,
            places365_sweep,
        ]),
        hyper.product([
            hyper.sweep('config.fast_weight_lr_multiplier', [0.5, 1.0, 2.0]),
            hyper.sweep('config.model.transformer.random_sign_init',
                        [-0.5, 0.5]),
            hyper.sweep('config.model_init', checkpoints),
        ])
    ])
コード例 #5
0
def get_sweep(hyper):
    """Sweeps over datasets."""
    checkpoints = ['/path/to/pretrained_model_ckpt.npz']
    use_jft = True  # whether to use JFT-300M or ImageNet-21K settings
    sweep_lr = False  # whether to sweep over learning rates
    if use_jft:
        cifar10_sweep = sweep_utils.cifar10(hyper)
        cifar10_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1))
        cifar10_sweep = hyper.product(cifar10_sweep)

        cifar100_sweep = sweep_utils.cifar100(hyper)
        cifar100_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1))
        cifar100_sweep = hyper.product(cifar100_sweep)

        imagenet_sweep = sweep_utils.imagenet(hyper)
        imagenet_sweep.append(hyper.fixed('config.lr.base', 0.03, length=1))
        imagenet_sweep = hyper.product(imagenet_sweep)

        imagenet_1shot_sweep = sweep_utils.imagenet_fewshot(hyper,
                                                            fewshot='1shot',
                                                            steps=200,
                                                            warmup=40,
                                                            log_eval_steps=20)
        imagenet_1shot_sweep.append(
            hyper.fixed('config.lr.base', 0.01, length=1))
        imagenet_1shot_sweep = hyper.product(imagenet_1shot_sweep)

        imagenet_5shot_sweep = sweep_utils.imagenet_fewshot(hyper,
                                                            fewshot='5shot',
                                                            steps=1000,
                                                            warmup=40,
                                                            log_eval_steps=100)
        imagenet_5shot_sweep.append(
            hyper.fixed('config.lr.base', 0.02, length=1))
        imagenet_5shot_sweep = hyper.product(imagenet_5shot_sweep)

        imagenet_10shot_sweep = sweep_utils.imagenet_fewshot(
            hyper, fewshot='10shot', steps=2000, warmup=50, log_eval_steps=200)
        imagenet_10shot_sweep.append(
            hyper.fixed('config.lr.base', 0.02, length=1))
        imagenet_10shot_sweep = hyper.product(imagenet_10shot_sweep)
    else:
        cifar10_sweep = sweep_utils.cifar10(hyper)
        cifar10_sweep.append(hyper.fixed('config.lr.base', 0.003, length=1))
        cifar10_sweep = hyper.product(cifar10_sweep)

        cifar100_sweep = sweep_utils.cifar100(hyper)
        cifar100_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1))
        cifar100_sweep = hyper.product(cifar100_sweep)

        imagenet_sweep = sweep_utils.imagenet(hyper)
        imagenet_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1))
        imagenet_sweep = hyper.product(imagenet_sweep)

        imagenet_1shot_sweep = sweep_utils.imagenet_fewshot(hyper,
                                                            fewshot='1shot',
                                                            steps=200,
                                                            warmup=10,
                                                            log_eval_steps=20)
        imagenet_1shot_sweep.append(
            hyper.fixed('config.lr.base', 0.01, length=1))
        imagenet_1shot_sweep = hyper.product(imagenet_1shot_sweep)

        imagenet_5shot_sweep = sweep_utils.imagenet_fewshot(hyper,
                                                            fewshot='5shot',
                                                            steps=1000,
                                                            warmup=30,
                                                            log_eval_steps=100)
        imagenet_5shot_sweep.append(
            hyper.fixed('config.lr.base', 0.03, length=1))
        imagenet_5shot_sweep = hyper.product(imagenet_5shot_sweep)

        imagenet_10shot_sweep = sweep_utils.imagenet_fewshot(
            hyper, fewshot='10shot', steps=2000, warmup=10, log_eval_steps=200)
        imagenet_10shot_sweep.append(
            hyper.fixed('config.lr.base', 0.01, length=1))
        imagenet_10shot_sweep = hyper.product(imagenet_10shot_sweep)
    if sweep_lr:
        # Sweep over learning rates following Table 4 of Vision Transformer paper
        # and training steps following E^3 paper.
        checkpoints = [checkpoints[0]]

        cifar10_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.cifar10(hyper,
                                        steps=int(10_000 * s),
                                        warmup=int(500 * s)))
                for s in [0.5, 1.0, 1.5, 2.0]
            ]),
            hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
        ])

        cifar100_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.cifar100(hyper,
                                         steps=int(10_000 * s),
                                         warmup=int(500 * s)))
                for s in [0.5, 1.0, 1.5, 2.0]
            ]),
            hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
        ])

        imagenet_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.imagenet(hyper,
                                         steps=int(20_000 * s),
                                         warmup=int(500 * s)))
                for s in [0.5, 1.0, 1.5, 2.0]
            ]),
            hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]),
        ])

        imagenet_1shot_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.imagenet_fewshot(hyper,
                                                 fewshot='1shot',
                                                 steps=200,
                                                 warmup=s))
                for s in [1, 5, 10, 20, 30, 40, 50]
            ]),
            hyper.sweep('config.lr.base', [0.06, 0.05, 0.04, 0.03, 0.02, 0.01])
        ])

        imagenet_5shot_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.imagenet_fewshot(hyper,
                                                 fewshot='5shot',
                                                 steps=1000,
                                                 warmup=s))
                for s in [1, 5, 10, 20, 30, 40, 50]
            ]),
            hyper.sweep('config.lr.base',
                        [0.06, 0.05, 0.04, 0.03, 0.02, 0.01]),
        ])

        imagenet_10shot_sweep = hyper.product([
            hyper.chainit([
                hyper.product(
                    sweep_utils.imagenet_fewshot(hyper,
                                                 fewshot='10shot',
                                                 steps=2000,
                                                 warmup=s))
                for s in [1, 5, 10, 20, 30, 40, 50]
            ]),
            hyper.sweep('config.lr.base', [0.06, 0.05, 0.04, 0.03, 0.02, 0.01])
        ])

    return hyper.product([
        hyper.chainit([
            cifar10_sweep, cifar100_sweep, imagenet_sweep,
            imagenet_1shot_sweep, imagenet_5shot_sweep, imagenet_10shot_sweep
        ]),
        hyper.sweep('config.model_init', checkpoints),
    ])
コード例 #6
0
    def sweep_checkpoints(use_jft, sweep_lr=sweep_lr):
        """whether to use JFT-300M or ImageNet-21K settings."""
        checkpoints = ['/path/to/pretrained_model_ckpt.npz']
        if use_jft:
            cifar10_sweep = sweep_utils.cifar10(hyper, steps=1000)
            cifar10_sweep.append(hyper.fixed('config.lr.base', 0.01, length=1))
            cifar10_sweep = hyper.product(cifar10_sweep)

            cifar100_sweep = sweep_utils.cifar100(hyper, steps=1000)
            cifar100_sweep.append(hyper.fixed('config.lr.base', 0.03,
                                              length=1))
            cifar100_sweep = hyper.product(cifar100_sweep)

            imagenet_sweep = sweep_utils.imagenet(hyper, steps=1000)
            imagenet_sweep.append(hyper.fixed('config.lr.base', 0.03,
                                              length=1))
            imagenet_sweep = hyper.product(imagenet_sweep)

            places365_sweep = hyper.product([
                hyper.product(sweep_utils.places365_small(hyper, steps=1000)),
                hyper.sweep('config.lr.base', [0.03]),
            ])
        else:
            cifar10_sweep = sweep_utils.cifar10(hyper, steps=1000)
            cifar10_sweep.append(hyper.fixed('config.lr.base', 0.003,
                                             length=1))
            cifar10_sweep = hyper.product(cifar10_sweep)

            cifar100_sweep = sweep_utils.cifar100(hyper, steps=1000)
            cifar100_sweep.append(hyper.fixed('config.lr.base', 0.01,
                                              length=1))
            cifar100_sweep = hyper.product(cifar100_sweep)

            imagenet_sweep = sweep_utils.imagenet(hyper, steps=1000)
            imagenet_sweep.append(hyper.fixed('config.lr.base', 0.01,
                                              length=1))
            imagenet_sweep = hyper.product(imagenet_sweep)

            places365_sweep = hyper.product([
                hyper.product(sweep_utils.places365_small(hyper, steps=1000)),
                hyper.sweep('config.lr.base', [0.01]),
            ])
        if sweep_lr:
            # Apply a learning rate sweep following Table 4 of Vision Transformer
            # paper.
            checkpoints = [checkpoints[0]]

            cifar10_sweep = hyper.product([
                hyper.product(sweep_utils.cifar10(hyper, steps=1000)),
                hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
            ])

            cifar100_sweep = hyper.product([
                hyper.product(sweep_utils.cifar100(hyper, steps=1000)),
                hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
            ])

            imagenet_sweep = hyper.product([
                hyper.product(sweep_utils.imagenet(hyper)),
                hyper.sweep('config.lr.base', [0.06, 0.03, 0.01, 0.003]),
            ])

            places365_sweep = hyper.product([
                hyper.product(sweep_utils.places365_small(hyper, steps=1000)),
                hyper.sweep('config.lr.base', [0.03, 0.01, 0.003, 0.001]),
            ])
        return hyper.product([
            hyper.chainit([
                cifar10_sweep,
                cifar100_sweep,
                places365_sweep,
                imagenet_sweep,
            ]),
            hyper.sweep('config.model_init', checkpoints),
            hyper.sweep('config.acquisition_method', acquisition_methods),
        ])