continue best_top1_acc = max(best_top1_acc, trial.last_result["top1_valid"]) print("iter", self._iteration, "top1_acc=%.3f" % best_top1_acc, cnts, end="\r") return original(self) patch = gorilla.Patch( ray.tune.trial_runner.TrialRunner, "step", step_w_log, settings=gorilla.Settings(allow_hit=True), ) gorilla.apply(patch) logger = get_logger("Fast AutoAugment") def _get_path(dataset, model, tag): return os.path.join( os.path.dirname(os.path.realpath(__file__)), "models/%s_%s_%s.model" % (dataset, model, tag), ) # TODO @ray.remote(num_gpus=1, max_calls=1) def train_model( config, dataroot, augment, cv_ratio_test, cv_fold, save_path=None, skip_exist=False ): C.get() C.get().conf = config
from torchvision.transforms import transforms from sklearn.model_selection import StratifiedShuffleSplit from theconf import Config as C from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, random_search2048, \ fa_reduced_imagenet, fa_reduced_cifar10 # fa_wresnet40x2, fa_wresnet28x10, fa_pyramid_e300, fa_pyramid_c100, fa_wresnet40x2_c100, \ # fa_resnet50_imagenet_minusloss, fa_wresnet40x2_cifar100, fa_wresnet40x2_cifar100_r5, fa_wresnet40x2_cifar10, \ # fa_wresnet28x10_cifar100, fa_wresnet28x10_cifar10, fa_pyramid_cifar10, fa_pyramid_cifar100, \ # fa_resnet50_rimagenet, fa_shake26_2x96d_cifar100, fa_shake26_2x96d_cifar10, fa_reduced_cifar10_progressive from FastAutoAugment.augmentations import * from FastAutoAugment.common import get_logger from FastAutoAugment.samplers.distributed_sampler import DistributedStratifiedSampler from FastAutoAugment.samplers.stratified_sampler import StratifiedSampler logger = get_logger('Fast AutoAugment') logger.setLevel(logging.INFO) _IMAGENET_PCA = { 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 'eigvec': torch.Tensor([ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], ]) } def get_dataloaders(dataset, batch,