예제 #1
0
def get_models_to_evaluate():
    """
    Returns the models to evaluate based on what is in logdir and modeldir
    specified as command line arguments. The matching pattern is specified by
    the match argument.

    Returns: [(log_dir, model_dir, source, target, model_name, method_name), ...]
    """
    files = pathlib.Path(FLAGS.logdir).glob(FLAGS.match)
    models_to_evaluate = []

    for log_dir in files:
        items = str(log_dir.stem).split("-")
        assert len(items) >= 3 or len(items) <= 5, \
            "name should be one of source-target-model-{-method{-num,},-num,}"

        method_name = "none"

        if len(items) == 3:
            source, target, model_name = items
        elif len(items) == 4 or len(items) == 5:
            source, target, model_name, keyword = items[:4]

            if keyword in methods:
                method_name = keyword
            else:
                pass  # probably a debug number, which we don't care about

        model_dir = os.path.join(FLAGS.modeldir, log_dir.stem)
        assert os.path.exists(model_dir), "Model does not exist "+str(model_dir)

        assert model_name in models.names(), "Unknown model "+str(model_name)
        assert source in load_datasets.names(), "Unknown source "+str(source)
        assert target in [""]+load_datasets.names(), "Unknown target "+str(target)

        models_to_evaluate.append((str(log_dir), model_dir, source, target,
            model_name, method_name))

    return models_to_evaluate
예제 #2
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Network evaluation example")
    # data
    parser.add_argument('-b', '--batch-size', type=int, default=64)
    parser.add_argument('-j', '--workers', type=int, default=4)
    parser.add_argument('--height', type=int, default=224)
    parser.add_argument('--width', type=int, default=224)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--topk', '-t', type=int, default=5)
    # model
    parser.add_argument('-a',
                        '--arch',
                        type=str,
                        default='resnet50',
                        choices=models.names())
    parser.add_argument('--weights',
                        '-w',
                        type=str,
                        metavar='PATH',
                        help='path to the checkpoint')
    parser.add_argument('--ann_file',
                        type=str,
                        metavar='PATH',
                        help="path to the annotation file")
    parser.add_argument('--data_dir',
                        type=str,
                        metavar='PATH',
                        help="path to the data folder")
    parser.add_argument('--print-freq',
                        '-p',
예제 #3
0
from absl import app
from absl import flags
from absl import logging

import models
import load_datasets

from metrics import Metrics
from checkpoints import CheckpointManager
from file_utils import last_modified_number, write_finished
from gpu_memory import set_gpu_memory

FLAGS = flags.FLAGS

flags.DEFINE_enum("model", None, models.names(), "What model type to use")
flags.DEFINE_string("modeldir", "models", "Directory for saving model files")
flags.DEFINE_string("logdir", "logs", "Directory for saving log files")
flags.DEFINE_enum("method", None, ["none", "dann", "pseudo", "instance"],
                  "What method of domain adaptation to perform (or none)")
flags.DEFINE_enum("source", None, load_datasets.names(),
                  "What dataset to use as the source")
flags.DEFINE_enum("target", "", [""] + load_datasets.names(),
                  "What dataset to use as the target")
flags.DEFINE_integer("steps", 80000, "Number of training steps to run")
flags.DEFINE_float("lr", 0.001, "Learning rate for training")
flags.DEFINE_float("lr_domain_mult", 1.0,
                   "Learning rate multiplier for training domain classifier")
flags.DEFINE_float("lr_target_mult", 0.5,
                   "Learning rate multiplier for training target classifier")
flags.DEFINE_float("gpumem", 3350,