def run_test(folder_path, override_dict, test_path, snapshot_iter, is_large,
             save_img_data):

    print("Folder path: %s" % folder_path)

    with open(os.path.join(folder_path, "PARAM.p"), 'rb') as f:
        opt0 = pickle.load(f)

    # opt = {**opt0, **override_dict}
    opt = recursive_merge_dicts(opt0, override_dict)

    vp = Pipeline(None,
                  opt,
                  model_dir=folder_path,
                  auto_save_hyperparameters=False,
                  use_logging=False)

    print(vp.opt)
    with vp.graph.as_default():
        sess = vp.create_session()
        vp.run_full_test_from_checkpoint(sess,
                                         test_path=test_path,
                                         snapshot_iter=snapshot_iter,
                                         is_large=is_large,
                                         save_img_data=save_img_data)
示例#2
0
文件: prepare.py 项目: liek51/civet
def insert_tasks(PL, task_file):

    pipeline = Pipeline(PL.name, PL.log_dir)
    logging.debug("Pipeline is: {}".format(pipeline))

    task_list = PL.prepare_managed_tasks()
    logging.debug("Task list is: {}".format([x['name'] for x in task_list]))

    # we need to be able to translate the dependencies as stored in the task
    # list (list of other task names that a particular task depends on)
    # into a list of Job object references that have already been added to the
    # session. We will build up a dictionary of task['name'] : Job as we
    # insert them
    deps_to_job = {}
    print("  Inserting tasks into {}".format(task_file))
    logging.info("Inserting tasks into {}".format(task_file))
    try:
        for task in task_list:
            print("    -> {}".format(task['name']))
            try:
                dependencies = [deps_to_job[d] for d in task['dependencies']]
            except KeyError as e:
                logging.exception("Key error processing dependencies")
                msg = "Task {} depends on a task that hasn't been been " \
                      "processed ({}). Check your Pipeline XML".format(
                          task['name'], e.args[0])
                raise Exception(msg)
            job = Job(pipeline, task['name'], task['threads'],
                      task['stdout_path'], task['stderr_path'],
                      task['script_path'], task['epilogue_path'], task['mem'],
                      task['email_list'], task['mail_options'],
                      task['batch_env'], dependencies, task['queue'],
                      task['walltime'])

            deps_to_job[task['name']] = job
            logging.debug("Adding job {} (log dir: {}) to session".format(
                job.job_name, job.pipeline.log_directory))
            Session.add(job)
    except Exception as e:
        logging.exception("Error inserting tasks into database")
        print("Error inserting tasks into database: {}".format(e),
              file=sys.stderr)
        sys.exit(6)

    # only commit the session if we were able to add all the jobs to the session
    # without catching an Exception
    Session.commit()

    logging.info("  {} tasks have been inserted into task file {}; "
                 "(log dir: {})".format(len(task_list), task_file, PL.log_dir))

    return len(task_list)
 def __init__(self):
     """
     Class constructor
     Instance vars:
         self.algdir -- a direcory path for algorithms
         self.methdir -- a directory path for methods
         self.all_algs -- a list of algorithm paths
         self.all_meths -- a list of method paths
         self.pipeline -- Pipeline instance
     Private vars:
         _default_pipe -- xml.etree object of config.xml
         _found_methods -- a sorted list of imported methods
         _container -- a list with Method instances
     """
     self.algdir = os.path.join('model', 'algorithms')
     self.methdir = os.path.join('model', 'methods')
     self.scan_dirs()
     _default_pipe = read_config('config.xml')
     _found_methods = self.scan_meths(_default_pipe)
     _container = self._get_meth_container(_found_methods)
     self.pipeline = Pipeline(_container, _default_pipe)
示例#4
0
from model.pipeline import Pipeline

from tensorflow.python import debug as tf_debug


if __name__ == "__main__":

    num_keypoints = 25
    patch_feature_dim = 8
    decoding_levels = 5
    kp_transform_loss = 5000

    base_recon_weight = 0.0001
    recon_weight = Pipeline.ValueScheduler(
        "piecewise_constant",
        [100000, 200000],
        [base_recon_weight, base_recon_weight * 10, base_recon_weight * 100],
    )

    base_learning_rate = 0.001
    learning_rate = Pipeline.ValueScheduler(
        "piecewise_constant",
        [100000, 200000],
        [base_learning_rate, base_learning_rate * 0.1, base_learning_rate * 0.01],
    )

    keypoint_separation_bandwidth = 0.08
    keypoint_separation_loss_weight = 20.0

    opt = {
        "optimizer": "Adam",
class ExtensionLoader:
    """
    A class that initializes the pipeline and fills it with available method
    instances and algorithms. It scans for new files in /model/algorithms and
    /model/methods. It checks these files for compliance with the interface,
    analyses which algorithm belongs to which method and creates a
    corresponding mapping that is used by Method class upon instantiating.
    """
    def __init__(self):
        """
        Class constructor
        Instance vars:
            self.algdir -- a direcory path for algorithms
            self.methdir -- a directory path for methods
            self.all_algs -- a list of algorithm paths
            self.all_meths -- a list of method paths
            self.pipeline -- Pipeline instance
        Private vars:
            _default_pipe -- xml.etree object of config.xml
            _found_methods -- a sorted list of imported methods
            _container -- a list with Method instances
        """
        self.algdir = os.path.join('model', 'algorithms')
        self.methdir = os.path.join('model', 'methods')
        self.scan_dirs()
        _default_pipe = read_config('config.xml')
        _found_methods = self.scan_meths(_default_pipe)
        _container = self._get_meth_container(_found_methods)
        self.pipeline = Pipeline(_container, _default_pipe)

    def new_pipeline(self, pipe_path):
        """
        Replace currently active pipeline with a new one.
        Args:
            pipe_path -- a file path to a new pipeline
        """
        _new_pipe = read_config(pipe_path)  # xml.etree object
        new_meths = [elem.attrib['method'] for elem in _new_pipe]
        _found_methods = self.scan_meths(_new_pipe)
        self.pipeline.container = []  # erasing current method container
        _new_container = self._get_meth_container(_found_methods)
        self.pipeline.load_new_pipeline(_new_container, _new_pipe)

    def scan_dirs(self):
        """
        Search for new files in model directory and check their
        interface compliance. Incompatible files shall be not included into
        the pipeline and registered with the UI!
        """
        _alg_files = os.listdir(self.algdir)
        _meth_files = os.listdir(self.methdir)
        _ignored = re.compile(r'.*.pyc|__init__|_category.py|HOWTO.txt')
        self.all_algs = filter(lambda x: not _ignored.match(x), _alg_files)
        self.all_meths = filter(lambda x: not _ignored.match(x), _meth_files)
        self._check_compliance()

    def scan_meths(self, pipe_config):
        """
        Scan methods dir for new methods, check interface compliance,
        import methods according to provided pipeline config, sort them and
        return a list of imported method instances.
        Args:
            pipe_config -- xml.etree object of config.xml
        Returns:
            imported_meths -- a sorted list of imported methods
        """
        print '> ExtLoader: Importing methods...'
        meth_list = []
        imported_meths = []
        for settings in pipe_config:
            meth_name = settings.attrib['method']
            for met in self.all_meths:
                imported = __import__(met.split('.')[0])
                imp_name = getattr(imported, 'get_name')()
                if imp_name == meth_name:
                    imported_meths.append(imported)
                    meth_list.append(imp_name)

        meth_order = [m.attrib['method'] for m in pipe_config.iter('settings')]
        missing_meth = [i for i in meth_list if i not in meth_order]
        missing_conf = [i for i in meth_order if i not in meth_list]
        if missing_meth:
            print 'MethodNotPresentError: {0} not found in config.xml'.format(missing_meth)
            return 1
        elif missing_conf:
            print 'MethodNotPresentError: {0} not found in methods dir'.format(missing_conf)
            return 1
        # sorting methods according to config.xml order
        print meth_order
        imported_meths.sort(key=lambda x: meth_order.index(x.get_name()))
        return imported_meths

    def _check_compliance(self):
        """
        Check extension's code compliance. If a file does not comply with the
        interface, skip it (it won't be displayed in UI).
        """
        _alg_required = ('apply', 'belong', 'main', 'get_name')
        for ex in self.all_algs:
            fpath = os.path.join(self.algdir, ex)
            with open(fpath, 'r') as extfile:
                fdata = extfile.read()
            # WARNING! findall does not do full word match
            found = re.findall('|'.join(_alg_required), fdata)  # FIX: weak matching
            if len(found) < 4:
                print found
                print 'AlgorithmSyntaxError: {0} does not comply with code ' \
                      'requirements, skipping.'.format(fpath)
                self.all_algs.remove(ex)

        _meth_required = ('new', 'get_name')
        for me in self.all_meths:
            fpath = os.path.join(self.methdir, me)
            with open(fpath, 'r') as methfile:
                fdata = methfile.read()
            # WARNING! findall does not do full word match
            found = re.findall('|'.join(_meth_required), fdata)  # FIX: weak matching
            if len(found) < 3:
                print 'MethodSyntaxError: {0} does not comply with code ' \
                      'requirements, skipping.'.format(fpath)
                self.all_meths.remove(me)

    def _get_meth_container(self, found_meths):
        """
        Build alg -> method mapping.
        Create a container with methods that represent a pipeline with
        selected algorithms and predefined settings.
        Args:
            found_meths -- a list of found and imported methods
        Returns:
            meth_container -- a list with Method instances
        """
        alg_meth_map = {}
        for ext in self.all_algs:
            imported = __import__(ext.split('.')[0])
            alg_meth_map[imported] = getattr(imported, 'belongs')()

        meth_container = []
        for meth in found_meths:
            inst = getattr(meth, 'new')(alg_meth_map)  # creating Method objects
            meth_container.append(inst)
        return meth_container


    def _sniff(self, alg_fun):
        """
        Inspect the algorithm function arguments and their respective types
        in order to inform the UI about the widgets that are required to
        correctly display the algorithm settings.

        <There should be two widgets available. Basically two of them
        actually make sense: checkbox and slider. I think other widgets are
        redundant and only add more unneeded complexity. If we detect that
        algorithm function requires bool argument we create a checkbox,
            else a slider.>
        """
        return inspect.getargspec(alg_fun)

    def get_pipeline(self):
        """
        Return a pipeline instance filled with methods and algorithms.
        """
        return self.pipeline
示例#6
0
    }

    opt["encoder_options"] = {
        "keypoint_num": num_keypoints,
        "patch_feature_dim": patch_feature_dim,
        "ae_recon_type": opt["recon_name"],
        "keypoint_concentration_loss_weight": 100.,
        "keypoint_axis_balancing_loss_weight": 200.,
        "keypoint_separation_loss_weight": keypoint_separation_loss_weight,
        "keypoint_separation_bandwidth": keypoint_separation_bandwidth,
        "keypoint_transform_loss_weight": kp_transform_loss,
        "keypoint_decoding_heatmap_levels": decoding_levels,
        "keypoint_decoding_heatmap_level_base": 0.5**(1 / 2),
        "image_channels": 3,
    }
    opt["decoder_options"] = copy(opt["encoder_options"])

    # -------------------------------------
    model_dir = os.path.join("results/aflw_30")
    checkpoint_dir = 'pretrained_results'
    checkpoint_filename = 'celeba_30/model/snapshot_step_205317'
    vp = Pipeline(None, opt, model_dir=model_dir)
    print(vp.opt)
    with vp.graph.as_default():
        sess = vp.create_session()
        vp.run_full_train_from_checkpoint(
            sess,
            checkpoint_dir=checkpoint_dir,
            checkpoint_filename=checkpoint_filename)
        vp.run_full_test(sess)
def batch_mode(args):
    """
    Process images in console mode

    Args:
        | *args* (dict) : argument dict returned by ArgumentParser

    """
    ui_mode = False
    extloader = ExtensionLoader()
    pipeline = Pipeline(extloader.cats_container, ui_mode)
    # processing args values
    if args.pipeline:
        # load the specified pipeline file
        # default url
        pipeline.load_pipeline_json(args.pipeline)
    if args.dir:
        # load the images from the specified source dir
        pipeline.set_input(args.dir)
    elif args.file:
        # load a single image
        pipeline.set_input(args.file)
    if args.out:
        pipeline.set_output_dir(args.out)
    pipeline.process()