Exemplo n.º 1
0
            continue
        best_top1_acc = max(best_top1_acc, trial.last_result["top1_valid"])
    print("iter", self._iteration, "top1_acc=%.3f" % best_top1_acc, cnts, end="\r")
    return original(self)


patch = gorilla.Patch(
    ray.tune.trial_runner.TrialRunner,
    "step",
    step_w_log,
    settings=gorilla.Settings(allow_hit=True),
)
gorilla.apply(patch)


logger = get_logger("Fast AutoAugment")


def _get_path(dataset, model, tag):
    return os.path.join(
        os.path.dirname(os.path.realpath(__file__)),
        "models/%s_%s_%s.model" % (dataset, model, tag),
    )  # TODO


@ray.remote(num_gpus=1, max_calls=1)
def train_model(
    config, dataroot, augment, cv_ratio_test, cv_fold, save_path=None, skip_exist=False
):
    C.get()
    C.get().conf = config
Exemplo n.º 2
0
from torchvision.transforms import transforms
from sklearn.model_selection import StratifiedShuffleSplit
from theconf import Config as C

from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, random_search2048, \
    fa_reduced_imagenet, fa_reduced_cifar10
# fa_wresnet40x2, fa_wresnet28x10, fa_pyramid_e300, fa_pyramid_c100, fa_wresnet40x2_c100, \
# fa_resnet50_imagenet_minusloss, fa_wresnet40x2_cifar100, fa_wresnet40x2_cifar100_r5, fa_wresnet40x2_cifar10, \
# fa_wresnet28x10_cifar100, fa_wresnet28x10_cifar10, fa_pyramid_cifar10, fa_pyramid_cifar100, \
# fa_resnet50_rimagenet, fa_shake26_2x96d_cifar100, fa_shake26_2x96d_cifar10, fa_reduced_cifar10_progressive
from FastAutoAugment.augmentations import *
from FastAutoAugment.common import get_logger
from FastAutoAugment.samplers.distributed_sampler import DistributedStratifiedSampler
from FastAutoAugment.samplers.stratified_sampler import StratifiedSampler

logger = get_logger('Fast AutoAugment')
logger.setLevel(logging.INFO)
_IMAGENET_PCA = {
    'eigval':
    torch.Tensor([0.2175, 0.0188, 0.0045]),
    'eigvec':
    torch.Tensor([
        [-0.5675, 0.7192, 0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948, 0.4203],
    ])
}


def get_dataloaders(dataset,
                    batch,