Пример #1
0
Файл: conf.py Проект: zyg11/VKD
    def add_default_args(parser: ArgumentParser):
        parser.add_argument('dataset_name',
                            choices=list(DATASETS.keys()),
                            type=str,
                            help='dataset name')
        # Network
        parser.add_argument('--backbone',
                            type=str,
                            choices=BACKBONES,
                            default=Backbone.RESNET_50,
                            help='Backbone network type.')
        parser.add_argument('--pretrained',
                            type=str2bool,
                            default=True,
                            help='No pretraining.')

        # Others
        parser.add_argument('--set_determinism', type=str2bool, default=False)
        parser.add_argument('--test_batch', default=32, type=int)
        parser.add_argument('--img_test_batch', default=512, type=int)
        parser.add_argument('--verbose',
                            type=str2bool,
                            default=True,
                            help='Debug mode')
        parser.add_argument('-j', '--workers', default=4, type=int)
        parser.add_argument('--p', type=int, default=18, help='')
        parser.add_argument('--k', type=int, default=4, help='')

        parser.add_argument('--num_test_images', type=int, default=8)

        return parser
Пример #2
0
def main(_argv):
    datasets = DATASETS.keys()
    if len(FLAGS.dataset) > 0:
        datasets = FLAGS.dataset

    items = SPLIT_OPTIONS
    if len(FLAGS.train_skip) > 0:
        items = {
            train_skip: SPLIT_OPTIONS[train_skip] for train_skip in FLAGS.train_skip
        }

    methods = METHODS.values()
    if len(FLAGS.method) > 0:
        methods = [METHODS[method] for method in FLAGS.method]

    for model_version in FLAGS.model_version:
        for dataset in datasets:
            for label, skip in items.items():
                weights_dir = os.path.join(
                    model_folder,
                    f"{model_version}-{dataset}-{label}.pth",
                )
                if (
                    FLAGS.weights is not None
                    and len(FLAGS.train_skip) == 0
                    and len(FLAGS.dataset) == 0
                ):
                    weights_dir = FLAGS.weights

                for method in methods:
                    Path(
                        os.path.join(
                            out_folder,
                            dataset,
                            f"{model_version}-{label}",
                            method,
                        )
                    ).mkdir(parents=True, exist_ok=True)
                    step = 1
                    if dataset == DATASETS["food101"]:
                        step = 5
                    if dataset == DATASETS["stanford-dogs"]:
                        step = 2
                    if dataset == DATASETS["plant-data"]:
                        step = 2
                    measure_model(
                        model_version,
                        dataset,
                        os.path.join(
                            out_folder,
                            dataset,
                            f"{model_version}-{label}",
                            method,
                        ),
                        weights_dir,
                        device,
                        method=method,
                        step=step,
                    )
Пример #3
0
def main(_argv):
    for model_version in FLAGS.model_version:
        for dataset in DATASETS.keys():
            for label, skip in SPLIT_OPTIONS.items():
                print(
                    f"Training {model_version} model with {label} of data ({dataset})"
                )
                if model_version in AVAILABLE_MODELS[:1]:
                    trained_model = train_resnet(dataset,
                                                 model_version,
                                                 data_dir=data_dir,
                                                 skip=skip)

                if model_version == AVAILABLE_MODELS[2]:
                    trained_model = train_efficientnet(dataset,
                                                       model_version,
                                                       data_dir=data_dir,
                                                       skip=skip)

                if model_version == AVAILABLE_MODELS[3]:
                    trained_model = train_densenet(dataset,
                                                   model_version,
                                                   data_dir=data_dir,
                                                   skip=skip)

                print(
                    f"Saving model to '{os.path.join(model_folder, f'{model_version}-{dataset}-{label}.pth')}'"
                )
                torch.save(
                    trained_model.state_dict(),
                    os.path.join(model_folder,
                                 f"{model_version}-{dataset}-{label}.pth"),
                )

                print(
                    f"Testing {model_version} model trained on {label} of data ({dataset})"
                )
                weights_dir = os.path.join(
                    model_folder, f"{model_version}-{dataset}-{label}.pth")

                test_model(model_version, dataset, out_folder, weights_dir,
                           device, label)
Пример #4
0
def main(_argv):
    datasets = DATASETS.keys()
    if len(FLAGS.dataset) > 0:
        datasets = FLAGS.dataset

    items = SPLIT_OPTIONS
    if len(FLAGS.train_skip) > 0:
        items = {
            train_skip: SPLIT_OPTIONS[train_skip]
            for train_skip in FLAGS.train_skip
        }

    methods = METHODS.values()
    if len(FLAGS.method) > 0:
        methods = [METHODS[method] for method in FLAGS.method]

    ids = None
    if FLAGS.image_ids is not None:
        ids = [int(x) - 1 for x in FLAGS.image_ids.split(',')]

    for model_version in FLAGS.model_version:
        for dataset in datasets:
            for label, skip in items.items():
                weights_dir = os.path.join(
                    model_folder,
                    f"{model_version}-{dataset}-{label}.pth",
                )
                if (FLAGS.weights is not None and len(FLAGS.train_skip) == 0
                        and len(FLAGS.dataset) == 0):
                    weights_dir = FLAGS.weights

                for method in methods:
                    Path(
                        os.path.join(
                            out_folder,
                            dataset,
                            f"{model_version}-{label}",
                            method,
                        )).mkdir(parents=True, exist_ok=True)
                    step = 20
                    if dataset == DATASETS["food101"]:
                        step = 1300
                    if dataset == DATASETS["stanford-dogs"]:
                        step = 250
                    if dataset == DATASETS["plant-data"]:
                        step = 250
                    measure_filter_model(
                        model_version,
                        dataset,
                        os.path.join(
                            out_folder,
                            dataset,
                            f"{model_version}-{label}",
                            method,
                        ),
                        weights_dir,
                        device,
                        method=method,
                        step=step,
                        use_infidelity=FLAGS.use_infidelity,
                        use_sensitivity=FLAGS.use_sensitivity,
                        render=FLAGS.render_results,
                        ids=ids,
                    )
Пример #5
0
import warnings

warnings.filterwarnings("ignore")

FLAGS = flags.FLAGS

flags.DEFINE_multi_enum(
    "model_version",
    ["resnet18"],
    AVAILABLE_MODELS,
    f"Model version {AVAILABLE_MODELS}",
)
flags.DEFINE_multi_enum(
    "dataset",
    [],
    list(DATASETS.keys()),
    f"(optional) Dataset name, one of available datasets: {list(DATASETS.keys())}",
)
flags.DEFINE_multi_enum(
    "train_skip",
    [],
    list(SPLIT_OPTIONS.keys()),
    f"(optional) version of the train dataset size: {list(SPLIT_OPTIONS.keys())}",
)
flags.DEFINE_multi_enum(
    "method",
    [],
    list(METHODS.keys()),
    f"(optional) select on of available methods: {list(METHODS.keys())}",
)
flags.DEFINE_string(