コード例 #1
0
def test_arguments_invalid(datafiles):
    filenames = [str(f) for f in datafiles.listdir()]
    parser = ConfigArgumentParser(filename=filenames[0])
    with pytest.raises(SystemExit) as e:
        _ = parser.parse_args(args=['-c', filenames[0], '--baz', 'test'])
    # pytest: error: unrecognized arguments: --baz test

    Config.clear()

    parser = ConfigArgumentParser(filename=filenames[0])
    with pytest.raises(SystemExit) as e:
        _ = parser.parse_args(args=['-c', filenames[0], '--bar', 'test'])
    # pytest: error: argument --bar: invalid int value: 'test'

    Config.clear()

    parser = ConfigArgumentParser(filename=filenames[0])
    with pytest.raises(Exception) as e:
        parser.add_argument('--foo', type=int, default=1)
    assert str(e.value) in [
        'argument --foo: conflicting option string: --foo',
        'argument --foo: conflicting option string(s): --foo'
    ]

    Config.clear()
コード例 #2
0
def test_arguments_simple(datafiles):
    filenames = [str(f) for f in datafiles.listdir()]
    parser = ConfigArgumentParser(filename=filenames[0])
    args = parser.parse_args(args=['-c', filenames[0]])

    assert args.foo == 'test'
    assert args.bar == 1234
    assert Config.get_instance()['foo'] == 'test'
    assert Config.get_instance()['bar'] == 1234

    Config.clear()

    parser = ConfigArgumentParser(filename=filenames[0])
    args = parser.parse_args(args=['-c', filenames[0], '--foo', 'value'])

    assert args.foo == 'value'
    assert args.bar == 1234
    assert Config.get_instance()['foo'] == 'value'
    assert Config.get_instance()['bar'] == 1234

    Config.clear()

    parser = ConfigArgumentParser(filename=filenames[0])
    args = parser.parse_args(
        args=['-c', filenames[0], '--foo', 'value', '--bar', '4321'])

    assert args.foo == 'value'
    assert args.bar == 4321
    assert Config.get_instance()['foo'] == 'value'
    assert Config.get_instance()['bar'] == 4321

    Config.clear()

    parser = ConfigArgumentParser(filename=filenames[0])
    parser.add_argument('--baz', type=float, default=0.1)
    args = parser.parse_args(
        args=['-c', filenames[0], '--foo', 'value', '--bar', '4321'])

    assert args.foo == 'value'
    assert args.bar == 4321
    assert args.baz == 0.1
    assert Config.get_instance()['foo'] == 'value'
    assert Config.get_instance()['bar'] == 4321
    assert Config.get_instance()['baz'] == 0.1

    Config.clear()
コード例 #3
0
ファイル: search.py プロジェクト: sjjdd/fast-autoaugment-1
    tune.track.log(minus_loss=metrics['minus_loss'],
                   top1_valid=metrics['correct'],
                   elapsed_time=gpu_secs,
                   done=True)
    return metrics['correct']


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch

    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/home/noam/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num-op', type=int, default=2)
    parser.add_argument('--num-policy', type=int, default=5)
    parser.add_argument('--num-search', type=int, default=200)
    parser.add_argument('--cv-ratio', type=float, default=0.4)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--redis',
                        type=str,
                        default='gpu-cloud-vnode30.dakao.io:23655')
    parser.add_argument('--per-class', action='store_true')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--smoke-test', action='store_true', default=True)
    args = parser.parse_args()
    else:
        stats = []

    return {
        'loss': np.mean(losses),
        'prediction': preds,
        'feature': feats,
        'labels': ys,
        'f1_scores': f1s,
        'stats': stats
    }


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--load', type=bool, default=False)
    parser.add_argument('--name', type=str, default='test')
    parser.add_argument('--eval', type=bool, default=False)
    parser.add_argument('--extdata', type=int, default=1)
    # parser.add_argument('--dump', type=str, default=None, help='config dump filepath')
    parsed_args = parser.parse_args()
    print(C.get_instance())

    writer = SummaryWriter('asset/log2/%s' % C.get()['name'])

    _, _, _, ids_test = get_dataset()
    d_train, d_valid, d_cvalid, d_tests = get_dataloaders(
        tests_aug=C.get()['eval'])

    models = {
        'resnet34': Resnet34,
コード例 #5
0
                            'model': model.state_dict()
                        },
                        save_path.replace(
                            '.pth', '_e%d_top1_%.3f_%.3f' %
                            (epoch, rs['train']['top1'], rs['test']['top1']) +
                            '.pth'))

    del model

    result['top1_test'] = best_top1
    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('-d',
                        '--dataroot',
                        type=str,
                        default='cifar-10-batches-py',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--cv-ratio', type=float, default=0.15)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('-n',
                        '--name',
                        help='Add optional string to save files.',
                        type=str)
    args = parser.parse_args()
コード例 #6
0
def main():
    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler="resolve")
    parser.add_argument(
        "--dataroot",
        type=str,
        default="/data/private/pretrainedmodels",
        help="torchvision data folder",
    )
    parser.add_argument("--until", type=int, default=5)
    parser.add_argument("--num-op", type=int, default=2)
    parser.add_argument("--num-policy", type=int, default=5)
    parser.add_argument("--num-search", type=int, default=200)
    parser.add_argument("--cv-ratio", type=float, default=0.4)
    parser.add_argument("--decay", type=float, default=-1)
    parser.add_argument("--redis", type=str, default="gpu-cloud-vnode30.dakao.io:23655")
    parser.add_argument("--per-class", action="store_true")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--smoke-test", action="store_true")
    args = parser.parse_args()

    if args.decay > 0:
        logger.info("decay=%.4f" % args.decay)
        C.get()["optimizer"]["decay"] = args.decay

    add_filehandler(
        logger,
        os.path.join(
            "models",
            "%s_%s_cv%.1f.log"
            % (C.get()["dataset"], C.get()["model"]["type"], args.cv_ratio),
        ),
    )
    logger.info("configuration...")
    logger.info(json.dumps(C.get().conf, sort_keys=True, indent=4))
    logger.info("initialize ray...")
    ray.init(address=args.redis)

    num_result_per_cv = 10
    cv_num = 5
    copied_c = copy.deepcopy(C.get().conf)

    logger.info(
        "search augmentation policies, dataset=%s model=%s"
        % (C.get()["dataset"], C.get()["model"]["type"])
    )
    logger.info(
        "----- Train without Augmentations cv=%d ratio(test)=%.1f -----"
        % (cv_num, args.cv_ratio)
    )
    w.start(tag="train_no_aug")
    paths = [
        _get_path(
            C.get()["dataset"],
            C.get()["model"]["type"],
            "ratio%.1f_fold%d" % (args.cv_ratio, i),
        )
        for i in range(cv_num)
    ]
    print(paths)
    reqs = [
        train_model.remote(
            copy.deepcopy(copied_c),
            args.dataroot,
            C.get()["aug"],
            args.cv_ratio,
            i,
            save_path=paths[i],
            skip_exist=True,
        )
        for i in range(cv_num)
    ]

    tqdm_epoch = tqdm(range(C.get()["epoch"]))
    is_done = False
    for epoch in tqdm_epoch:
        while True:
            epochs_per_cv = OrderedDict()
            for cv_idx in range(cv_num):
                try:
                    latest_ckpt = torch.load(paths[cv_idx])
                    if "epoch" not in latest_ckpt:
                        epochs_per_cv["cv%d" % (cv_idx + 1)] = C.get()["epoch"]
                        continue
                    epochs_per_cv["cv%d" % (cv_idx + 1)] = latest_ckpt["epoch"]
                except Exception as e:
                    continue
            tqdm_epoch.set_postfix(epochs_per_cv)
            if (
                len(epochs_per_cv) == cv_num
                and min(epochs_per_cv.values()) >= C.get()["epoch"]
            ):
                is_done = True
            if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= epoch:
                break
            time.sleep(10)
        if is_done:
            break

    logger.info("getting results...")
    pretrain_results = ray.get(reqs)
    for r_model, r_cv, r_dict in pretrain_results:
        logger.info(
            "model=%s cv=%d top1_train=%.4f top1_valid=%.4f"
            % (r_model, r_cv + 1, r_dict["top1_train"], r_dict["top1_valid"])
        )
    logger.info("processed in %.4f secs" % w.pause("train_no_aug"))

    if args.until == 1:
        sys.exit(0)

    logger.info("----- Search Test-Time Augmentation Policies -----")
    w.start(tag="search")

    ops = augment_list(False)
    space = {}
    for i in range(args.num_policy):
        for j in range(args.num_op):
            space["policy_%d_%d" % (i, j)] = hp.choice(
                "policy_%d_%d" % (i, j), list(range(0, len(ops)))
            )
            space["prob_%d_%d" % (i, j)] = hp.uniform("prob_%d_ %d" % (i, j), 0.0, 1.0)
            space["level_%d_%d" % (i, j)] = hp.uniform(
                "level_%d_ %d" % (i, j), 0.0, 1.0
            )

    final_policy_set = []
    total_computation = 0
    reward_attr = "top1_valid"  # top1_valid or minus_loss
    for _ in range(1):  # run multiple times.
        for cv_fold in range(cv_num):
            name = "search_%s_%s_fold%d_ratio%.1f" % (
                C.get()["dataset"],
                C.get()["model"]["type"],
                cv_fold,
                args.cv_ratio,
            )
            print(name)

            # def train(augs, rpt):
            def train(config, reporter):
                return eval_tta(
                    copy.deepcopy(copied_c), config, reporter, num_class, get_model, get_dataloaders
                )

            register_trainable(name, train)
            algo = HyperOptSearch(
                space, max_concurrent=4 * 20, metric=reward_attr, mode="max"
            )

            results = run(
                train,
                name=name,
                config={
                    "dataroot": args.dataroot,
                    "save_path": paths[cv_fold],
                    "cv_ratio_test": args.cv_ratio,
                    "cv_fold": cv_fold,
                    "num_op": args.num_op,
                    "num_policy": args.num_policy,
                },
                num_samples=4 if args.smoke_test else args.num_search,
                resources_per_trial={"gpu": 1},
                stop={"training_iteration": args.num_policy},
                search_alg=algo,
                scheduler=None,
                verbose=0,
                queue_trials=True,
                resume=args.resume,
                raise_on_failed_trial=False,
            )
            print()
            df = results.results_df

            import pickle

            with open("results.pickle", "wb") as fp:
                pickle.dump(results, fp)
            df.to_csv("df.csv")

            results = df.sort_values(by=reward_attr, ascending=False)
            # results = [x for x in results if x.last_result is not None]
            # results = sorted(results, key=lambda x: x.last_result[reward_attr], reverse=True)

            # calculate computation usage
            for _, result in results.iterrows():
                total_computation += result["elapsed_time"]

            for _, result in results.iloc[:num_result_per_cv].iterrows():
                final_policy = policy_decoder(
                    result, args.num_policy, args.num_op, prefix="config."
                )
                logger.info(
                    "loss=%.12f top1_valid=%.4f %s"
                    % (result["minus_loss"], result["top1_valid"], final_policy)
                )

                final_policy = remove_deplicates(final_policy)
                final_policy_set.extend(final_policy)

    logger.info(json.dumps(final_policy_set))
    logger.info("final_policy=%d" % len(final_policy_set))
    logger.info(
        "processed in %.4f secs, gpu hours=%.4f"
        % (w.pause("search"), total_computation / 3600.0)
    )
    logger.info(
        "----- Train with Augmentations model=%s dataset=%s aug=%s ratio(test)=%.1f -----"
        % (C.get()["model"]["type"], C.get()["dataset"], C.get()["aug"], args.cv_ratio)
    )
    w.start(tag="train_aug")

    num_experiments = 5
    default_path = [
        _get_path(
            C.get()["dataset"],
            C.get()["model"]["type"],
            "ratio%.1f_default%d" % (args.cv_ratio, _),
        )
        for _ in range(num_experiments)
    ]
    augment_path = [
        _get_path(
            C.get()["dataset"],
            C.get()["model"]["type"],
            "ratio%.1f_augment%d" % (args.cv_ratio, _),
        )
        for _ in range(num_experiments)
    ]
    reqs = [
        train_model.remote(
            copy.deepcopy(copied_c),
            args.dataroot,
            C.get()["aug"],
            0.0,
            0,
            save_path=default_path[_],
            skip_exist=True,
        )
        for _ in range(num_experiments)
    ] + [
        train_model.remote(
            copy.deepcopy(copied_c),
            args.dataroot,
            final_policy_set,
            0.0,
            0,
            save_path=augment_path[_],
        )
        for _ in range(num_experiments)
    ]

    tqdm_epoch = tqdm(range(C.get()["epoch"]))
    is_done = False
    for epoch in tqdm_epoch:
        while True:
            epochs = OrderedDict()
            for exp_idx in range(num_experiments):
                try:
                    if os.path.exists(default_path[exp_idx]):
                        latest_ckpt = torch.load(default_path[exp_idx])
                        epochs["default_exp%d" % (exp_idx + 1)] = latest_ckpt["epoch"]
                except:
                    pass
                try:
                    if os.path.exists(augment_path[exp_idx]):
                        latest_ckpt = torch.load(augment_path[exp_idx])
                        epochs["augment_exp%d" % (exp_idx + 1)] = latest_ckpt["epoch"]
                except:
                    pass

            tqdm_epoch.set_postfix(epochs)
            if (
                len(epochs) == num_experiments * 2
                and min(epochs.values()) >= C.get()["epoch"]
            ):
                is_done = True
            if len(epochs) == num_experiments * 2 and min(epochs.values()) >= epoch:
                break
            time.sleep(10)
        if is_done:
            break

    logger.info("getting results...")
    final_results = ray.get(reqs)

    for train_mode in ["default", "augment"]:
        avg = 0.0
        for _ in range(num_experiments):
            r_model, r_cv, r_dict = final_results.pop(0)
            logger.info(
                "[%s] top1_train=%.4f top1_test=%.4f"
                % (train_mode, r_dict["top1_train"], r_dict["top1_test"])
            )
            avg += r_dict["top1_test"]
        avg /= num_experiments
        logger.info(
            "[%s] top1_test average=%.4f (#experiments=%d)"
            % (train_mode, avg, num_experiments)
        )
    logger.info("processed in %.4f secs" % w.pause("train_aug"))

    logger.info(w)
コード例 #7
0
ファイル: sample.py プロジェクト: sublee/theconf
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import os
import sys

base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_dir)
from theconf import Config as C
from theconf import ConfigArgumentParser

if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--added',
                        type=str,
                        default='NOT_EXIST_CONFIG',
                        help='ADDED_FROM_ARGPARSER')
    parser.add_argument('--dump',
                        type=str,
                        default=None,
                        help='config dump filepath')
    parsed_args = parser.parse_args()
    print(parsed_args)
    print(C.get().dump())

    if parsed_args.dump:
        C.get().dump(parsed_args.dump)
        print('dumped at', parsed_args.dump)
コード例 #8
0
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict(),
                            'ema':
                            ema.state_dict() if ema is not None else None,
                        }, save_path)

    del model

    result['top1_test'] = best_top1
    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='test.pth')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--evaluation-interval', type=int, default=5)
    parser.add_argument('--only-eval', action='store_true')
    args = parser.parse_args()

    assert (
        args.only_eval and args.save
    ) or not args.only_eval, 'checkpoint path not provided in evaluation mode.'
コード例 #9
0
    return result


def reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='../data/',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='weights/test.pth')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--horovod', action='store_true')
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    reproducibility(args.seed)

    assert not (args.horovod and args.only_eval
コード例 #10
0
                        'train': rs['train'].get_dict(),
                        'valid': rs['valid'].get_dict(),
                        'test': rs['test'].get_dict(),
                    },
                    'optimizer': optimizer.state_dict(),
                    'model': model.state_dict()
                }, filename)

    del model

    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot', type=str, default='/usr/share/bind_mount/data/cifar_100')
    parser.add_argument('--save', type=str, default='/app/results/checkpoints')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--horovod', action='store_true')
    parser.add_argument('--only-eval', action='store_true')
    args = parser.parse_args()
    print(args)

    assert not (args.horovod and args.only_eval), 'can not use horovod when evaluation mode is enabled.'
    assert (args.only_eval and not args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.'

    if args.decay > 0:
        logger.info('decay reset=%.8f' % args.decay)
コード例 #11
0
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict()
                        }, save_path)

    del model

    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--pretrained', type=str,
                        default='')  #loading pretrained model path
    parser.add_argument('--cv_ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--only_eval', action='store_true')
    args = parser.parse_args()

    if args.decay > 0:
        logger.info('decay reset=%.8f' % args.decay)
コード例 #12
0
                            cv_ratio_test,
                            cv_fold,
                            save_path=save_path,
                            only_eval=skip_exist,
                            reduced=reduced)
    return C.get()['model']['type'], cv_fold, result


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch
    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/mnt/ssd/data/',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num-op', type=int, default=2)
    parser.add_argument('--num-policy', type=int, default=5)
    parser.add_argument('--num-search', type=int, default=200)
    parser.add_argument('--cv-ratio', type=float, default=0.40)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--redis', type=str)
    parser.add_argument('--per-class', action='store_true')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--smoke-test', action='store_true')
    parser.add_argument('--cv-num', type=int, default=1)
    parser.add_argument('--exp_name', type=str)

    parser.add_argument('--lstm-size', type=int, default=100)
コード例 #13
0
def prepare() -> argparse.Namespace:
    parser = ConfigArgumentParser(conflict_handler='resolve')
    # parser.add_argument('--dataroot', type=str, default='~/datasets', help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num_fold', type=int, default=5)
    parser.add_argument('--num_result_per_fold', type=int, default=10)
    parser.add_argument('--num_op', type=int, default=2)
    parser.add_argument('--num_policy', type=int, default=5)
    parser.add_argument('--num_search', type=int, default=200)
    parser.add_argument('--retrain_times', type=int, default=5)
    parser.add_argument('--cv_ratio', type=float, default=0.4)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--redis', type=str, default='')
    # parser.add_argument('--per_class', action='store_true')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--smoke_test', action='store_true')
    args: argparse.Namespace = parser.parse_args()

    add_filehandler(
        logger,
        '%s_%s_cv%.1f.log' % (Config.get()['dataset'],
                              Config.get()['model']['type'], args.cv_ratio))

    logger.info('args type: %s' % str(type(args)))

    global EXEC_ROOT, MODEL_ROOT, MODEL_PATHS, DATASET_ROOT

    EXEC_ROOT = os.getcwd()  # fast-autoaugment/experiments/xxx
    logger.info('EXEC_ROOT: %s' % EXEC_ROOT)
    MODEL_ROOT = os.path.join(
        EXEC_ROOT, 'models')  # fast-autoaugment/experiments/xxx/models
    logger.info('MODEL_ROOT: %s' % MODEL_ROOT)

    DATASET_ROOT = os.path.abspath(
        os.path.join(os.path.expanduser('~'), 'datasets',
                     Config.get()['dataset'].lower()))  # ~/datasets/cifar10
    logger.info('DATASET_ROOT: %s' % DATASET_ROOT)

    _check_directory(MODEL_ROOT)
    _check_directory(DATASET_ROOT)

    MODEL_PATHS = [
        _get_model_path(
            dataset=Config.get()['dataset'],
            model=Config.get()['model']['type'],
            config='ratio%.1f_fold%d' % (args.cv_ratio, i)  # without_aug
        ) for i in range(args.num_fold)
    ]
    print('MODEL_PATHS:', MODEL_PATHS)
    logger.info('MODEL_PATHS: %s' % MODEL_PATHS)

    if args.decay > 0:
        logger.info('decay=%.4f' % args.decay)
        Config.get()['optimizer']['decay'] = args.decay

    logger.info('configuration...')
    logger.info(json.dumps(Config.get().conf, sort_keys=True, indent=4))
    logger.info('initialize ray...')
    # ray.init(redis_address=args.redis)
    address_info = ray.init(include_webui=True)
    logger.info('ray initialization: address information:')
    logger.info(str(address_info))
    logger.info('start searching augmentation policies, dataset=%s model=%s' %
                (Config.get()['dataset'], Config.get()['model']['type']))

    return args
コード例 #14
0
    ts = moment.now().format("YYYY_MMDD_HHmm_ss")
    mname = C.get()['model']['type']
    os.makedirs('models', exist_ok=True)
    torch.save(
        model.state_dict(),
        f'models/{mname}__{ts}.pth'
    )
    del model

    result['top1_test'] = best_top1
    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot', type=str, default='/home/tanimu/data/cifar10', help='torchvision data folder')
    parser.add_argument('--save', type=str, default='test.pth')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--evaluation-interval', type=int, default=5)
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--gpu_id', type=int, default=0)
    args = parser.parse_args()

    assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.'

    if not args.only_eval:
        if args.save:
            logger.info('checkpoint will be saved at %s' % args.save)
コード例 #15
0
ファイル: phase2.py プロジェクト: zwzhu-d/cores
                        'log': {
                            'train': rs['train'].get_dict(),
                            'test': rs['test'].get_dict(),
                        },
                        'optimizer': optimizer.state_dict(),
                        'state_dict': model.state_dict()
                    }, save_path)

    del model

    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='../data/',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--unsupervised', action='store_true')
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--gpu', type=int, nargs='+', default=None)
    parser.add_argument('--resume', action='store_true')

    args = parser.parse_args()
    print('Unsupervised', args.unsupervised)
    assert (
        args.only_eval and not args.save
コード例 #16
0
    reporter(minus_loss=metrics['minus_loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['correct']


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch
    w = PyStopwatch()  # 初始化一个秒表

    # ? 命令里面的 -c xxx.yaml不知道在哪里定义的
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)  # ?
    parser.add_argument('--num-op', type=int, default=2)  # 每个子策略里面的op数量
    parser.add_argument('--num-policy', type=int,
                        default=5)  # 每个policy里面包含5个子策略
    parser.add_argument('--num-search', type=int,
                        default=200)  # ?还不确定,论文里写是每次贝叶斯优化的策略集合B的大小
    parser.add_argument('--cv-ratio', type=float, default=0.4)  # ?交叉验证的比例
    parser.add_argument('--decay', type=float, default=-1)  # ?可能是学习率衰减
    parser.add_argument('--redis',
                        type=str,
                        default='gpu-cloud-vnode30.dakao.io:23655')  # 分布式相关的
    parser.add_argument('--per-class', action='store_true')  # ?
    parser.add_argument('--resume', action='store_true')  # ?应该是是否复用模型的参数吧
    parser.add_argument('--smoke-test', action='store_true')  # ?
コード例 #17
0
                            "optimizer": optimizer.state_dict(),
                            "model": model.state_dict(),
                            "ema": ema.state_dict() if ema is not None else None,
                        },
                        save_path,
                    )

    del model

    result["top1_test"] = best_top1
    return result


if __name__ == "__main__":
    parser = ConfigArgumentParser(conflict_handler="resolve")
    parser.add_argument("--tag", type=str, default="")
    parser.add_argument(
        "--dataroot",
        type=str,
        default="/data/private/pretrainedmodels",
        help="torchvision data folder",
    )
    parser.add_argument("--save", type=str, default="test.pth")
    parser.add_argument("--cv-ratio", type=float, default=0.0)
    parser.add_argument("--cv", type=int, default=0)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--evaluation-interval", type=int, default=5)
    parser.add_argument("--only-eval", action="store_true")
    args = parser.parse_args()

    assert (
コード例 #18
0
                        'log': {
                            'train': rs['train'].get_dict(),
                            'test': rs['test'].get_dict(),
                        },
                        'optimizer': optimizer.state_dict(),
                        'model': model.state_dict()
                    }, save_path)

    del model

    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--unsupervised', action='store_true')
    parser.add_argument('--only-eval', action='store_true')
    args = parser.parse_args()

    assert (
        args.only_eval and not args.save
    ) or not args.only_eval, 'checkpoint path not provided in evaluation mode.'

    if args.decay > 0:
コード例 #19
0
        for aug in policy:  # aug: [n_op, 3]
            for key_aug in bin:
                if is_same_aug(aug, key_aug):
                    bin[key_aug] += 1
                    in_bin = True
                    cnt += 1
                    break
            if not in_bin:
                bin[rec_tuple(aug)] = 1
                cnt += 1
            in_bin = False
        bins.append(bin)
        print(len(bin))
        # print(bin)
        maxkey = max(bin, key=bin.get)
        print(maxkey)
        print(bin[maxkey])
    np.savez(f"{args.save_path}.npz",
             inputs=inputs.cpu(),
             pols=pols,
             bins=bins)
    print(cnt)


if __name__ == "__main__":
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--load_path', type=str)
    parser.add_argument('--save_path', type=str)
    args = parser.parse_args()
    main(args)
コード例 #20
0
                            "test": rs["test"].get_dict(),
                        },
                        "optimizer": optimizer.state_dict(),
                        "model": model.state_dict(),
                    },
                    save_path,
                )

    del model

    return result


if __name__ == "__main__":
    parser = ConfigArgumentParser(conflict_handler="resolve")
    parser.add_argument("--tag", type=str, default="")
    parser.add_argument(
        "--dataroot",
        type=str,
        default="./.data",
        help="torchvision data folder",
    )
    parser.add_argument("--save", type=str, default="")
    parser.add_argument("--method", type=str, choices=["UDA", "IIC"], default="UDA")
    parser.add_argument("--decay", type=float, default=-1)
    parser.add_argument("--unsupervised", action="store_true")
    parser.add_argument("--only-eval", action="store_true")
    parser.add_argument("--alpha", type=float, default=5.0)
    parser.add_argument("--resume", default=False, action="store_true")
    parser.add_argument("--labeled_sample_num", default=4000, type=int)
    args = parser.parse_args()
コード例 #21
0
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(minus_loss=metrics['minus_loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['correct']


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch
    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/mnt/hdd1/data/',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num-op', type=int, default=2)
    parser.add_argument('--num-policy', type=int, default=5)
    parser.add_argument('--num-search', type=int, default=200)
    parser.add_argument('--cv-ratio', type=float, default=0.4)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--redis', type=str)
    parser.add_argument('--per-class', action='store_true')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--smoke-test', action='store_true')
    parser.add_argument('--cv-num', type=int, default=1)
    parser.add_argument('--exp_name', type=str)
    parser.add_argument('--rpc', type=int, default=10)
    parser.add_argument('--repeat', type=int, default=1)
コード例 #22
0
ファイル: bart_train.py プロジェクト: softsys4ai/athena
    trans_list.append(TRANSFORMATION.noise_poisson)
    trans_list.append(TRANSFORMATION.geo_swirl)
    trans_list.append(TRANSFORMATION.filter_rank)
    trans_list.append(TRANSFORMATION.filter_median)

    return trans_list


if __name__ == '__main__':
    # command:
    # python train.py -c confs/<config_file> --aug <augmentation> --dataroot=<folder stores dataset> --dataset <dataset> --save <model_file>
    # if evaluate an existing model, using --save and --only-eval
    # e.g.,
    # python train.py -c confs/wresnet28x10_cifar10_b128.yaml --aug fa_reduced_cifar10 --dataroot=data --dataset cifar100 --save cifar100_wres28x10.pth --only-eval
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='test.pth')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--selected-combo',
                        type=str,
                        default='revisionES4_ens1')
    args = parser.parse_args()

    assert (
        args.only_eval and args.save
コード例 #23
0
    reporter(minus_loss=metrics['minus_loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['minus_loss']


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch

    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='../data',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num-op', type=int, default=2)
    parser.add_argument('--num_cv', type=int, default=5)
    parser.add_argument('--num-policy', type=int, default=5)
    parser.add_argument('--num-search', type=int, default=100)
    parser.add_argument('--cv-ratio', type=float, default=0.4)
    parser.add_argument('--dc_model',
                        type=str,
                        default='pointnetv7',
                        choices=['pointnet', 'pointnetv5', 'pointnetv7'])
    parser.add_argument('--topk', type=int, default=8)
    parser.add_argument('--emd_coeff', type=int, default=10)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--random_range', type=float, default=0.3)
コード例 #24
0
def parse_args():
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/private/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--cv-ratio', type=float, default=0.0)
    parser.add_argument('--cv', type=int, default=0)
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--local_rank', default=None, type=int)
    return parser.parse_args()
コード例 #25
0
                        'log': {
                            'train': rs['train'].get_dict(),
                            'test': rs['test'].get_dict(),
                        },
                        'optimizer': optimizer.state_dict(),
                        'model': model.state_dict()
                    }, save_path)

    del model

    return result


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--dataroot',
                        type=str,
                        default='.data',
                        help='torchvision data folder')
    parser.add_argument('--save', type=str, default='')
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--unsupervised', action='store_true')
    parser.add_argument('--only-eval', action='store_true')
    parser.add_argument('--sample',
                        default='None',
                        type=str,
                        help='sampling strategy')
    parser.add_argument('--train_mode',
                        default='ssl',
                        type=str,
コード例 #26
0
ファイル: search.py プロジェクト: VCBE123/fast-autoaugment
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(minus_loss=metrics['minus_loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['correct']


if __name__ == '__main__':
    import json
    from pystopwatch2 import PyStopwatch
    w = PyStopwatch()

    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--dataroot',
                        type=str,
                        default='/data/lirui/pretrainedmodels',
                        help='torchvision data folder')
    parser.add_argument('--until', type=int, default=5)
    parser.add_argument('--num-op', type=int, default=2)
    parser.add_argument('--num-policy', type=int, default=5)
    parser.add_argument('--num-search', type=int, default=200)
    parser.add_argument('--cv-ratio', type=float, default=0.4)
    parser.add_argument('--decay', type=float, default=-1)
    parser.add_argument('--redis', type=str, default='192.168.0.112:23655')
    parser.add_argument('--per-class', action='store_true')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--smoke-test', action='store_true')
    args = parser.parse_args()

    if args.decay > 0:
        logger.info('decay=%.4f' % args.decay)