Ejemplo n.º 1
0
    def remap_var_list(self, var_list):
        """Map old vars in checkpoint to new vars in current session.

        Args:
            var_list (dict): var names mapped to variables (or some related
            quantity, such as variable shapes).

        Returns:
            dict: New var names mapped to the corresponding restored var.

        Examples:
        >>>var_list
        {'Weights': <tf.Variable>}
        >>>self.load_param_dict
        {'Weights': 'Filters'}
        >>>self.remap_var_list(var_list)
        {'Filters': <tf.Variable>}

        """
        if self.load_param_dict is None:
            log.info('No variable mapping specified.')
            return var_list
        for old_name, new_name in self.load_param_dict.items():
            for name in var_list:
                if old_name == name:
                    var_list[old_name] = var_list.pop(old_name)
                    break
        return var_list
Ejemplo n.º 2
0
def create_train_tpu_config(model_dir,
                            model_params,
                            tpu_name,
                            gcp_project,
                            steps_per_checkpoint,
                            tpu_zone=DEFAULT_TPU_ZONE,
                            num_shards=DEFAULT_NUM_SHARDS,
                            keep_checkpoint_max=5,
                            iterations_per_loop=DEFAULT_ITERATIONS_PER_LOOP):
    tpu_cluster_resolver = (tpu_cluster_resolver_lib.TPUClusterResolver(
        tpu=[tpu_name], zone=tpu_zone, project=gcp_project))

    if iterations_per_loop == -1 or (
            steps_per_checkpoint is not None
            and steps_per_checkpoint < iterations_per_loop):
        log.info(
            'Setting iterations_per_loop ({}) to be the same as steps_per_checkpoint ({}).'
            .format(iterations_per_loop, steps_per_checkpoint))
        iterations_per_loop = steps_per_checkpoint
        model_params['iterations_per_loop'] = iterations_per_loop

    config = tpu_config_lib.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=model_dir,
        save_checkpoints_steps=steps_per_checkpoint,
        save_checkpoints_secs=None,
        keep_checkpoint_max=keep_checkpoint_max,
        log_step_count_steps=iterations_per_loop,
        tpu_config=tpu_config_lib.TPUConfig(
            iterations_per_loop=iterations_per_loop, num_shards=num_shards))

    return config
Ejemplo n.º 3
0
    def get_restore_vars(self, save_file, all_vars=None):
        """Create the `var_list` init argument to tf.Saver from save_file.

        Extracts the subset of variables from tf.global_variables that match the
        name and shape of variables saved in the checkpoint file, and returns these
        as a list of variables to restore.

        To support multi-model training, a model prefix is prepended to all
        tf global_variable names, although this prefix is stripped from
        all variables before they are saved to a checkpoint. Thus,


        Args:
            save_file: path of tf.train.Saver checkpoint.

        Returns:
            dict: checkpoint variables.

        """
        reader = tf.train.NewCheckpointReader(save_file)
        var_shapes = reader.get_variable_to_shape_map()
        log.info('Saved Vars:\n' + str(var_shapes.keys()))

        var_shapes = {  # Strip the prefix off saved var names.
            strip_prefix_from_name(self.params['model_params']['prefix'], name): shape
            for name, shape in var_shapes.items()}

        # Map old vars from checkpoint to new vars via load_param_dict.
        mapped_var_shapes = self.remap_var_list(var_shapes)
        log.info('Saved shapes:\n' + str(mapped_var_shapes))

        if all_vars is None:
            all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
            all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

        # Specify which vars are to be restored vs. reinitialized.
        if self.load_param_dict is None:
            restore_vars = {name: var for name, var in all_vars.items() if name in mapped_var_shapes}
        else:
            # associate checkpoint names with actual variables
            load_var_dict = {}
            for ckpt_var_name, curr_var_name in self.load_param_dict.items():
                for curr_name, curr_var in all_vars.items():
                    if curr_name == curr_var_name:
                        load_var_dict[ckpt_var_name] = curr_var
                        break

            restore_vars = load_var_dict

        restore_vars = self.filter_var_list(restore_vars)

        # Ensure the vars to restored have the correct shape.
        var_list = {}
        for name, var in restore_vars.items():
            var_shape = var.get_shape().as_list()
            if var_shape == mapped_var_shapes[name]:
                var_list[name] = var
        return var_list
Ejemplo n.º 4
0
    def get_restore_vars(self, save_file):
        """Create the `var_list` init argument to tf.Saver from save_file.

        Extracts the subset of variables from tf.global_variables that match the
        name and shape of variables saved in the checkpoint file, and returns these
        as a list of variables to restore.

        To support multi-model training, a model prefix is prepended to all
        tf global_variable names, although this prefix is stripped from
        all variables before they are saved to a checkpoint. Thus,


        Args:
            save_file: path of tf.train.Saver checkpoint.

        Returns:
            dict: checkpoint variables.

        """
        reader = tf.train.NewCheckpointReader(save_file)
        var_shapes = reader.get_variable_to_shape_map()

        # Map old vars from checkpoint to new vars via load_param_dict.
        log.info('Saved vars and shapes:\n' + str(var_shapes))

        # Specify which vars are to be restored vs. reinitialized.
        all_vars = self.var_list
        if not self.load_param_dict:
            restore_vars = {
                    name: var for name, var in all_vars.items() \
                            if name in var_shapes}
        else:
            # associate checkpoint names with actual variables
            load_var_dict = {}
            for ckpt_var_name, curr_var_name in self.load_param_dict.items():
                if curr_var_name in all_vars:
                    load_var_dict[ckpt_var_name] = all_vars[curr_var_name]
            restore_vars = load_var_dict

        restore_vars = self.filter_var_list(restore_vars)

        # Ensure the vars to restored have the correct shape.
        var_list = {}

        for name, var in restore_vars.items():
            var_shape = var.get_shape().as_list()
            if var_shape == var_shapes[name]:
                var_list[name] = var
            else:
                log.info('Shape mismatch for %s' % name \
                      + str(var_shape) \
                      + str(var_shapes[name]))
        return var_list
Ejemplo n.º 5
0
    def initialize(self, no_scratch=False):
        """Fetch record then uses tf's saver.restore."""
        if self.do_restore:

            # First, determine which checkpoint to use.
            if self.from_ckpt is not None:
                # Use a cached checkpoint file.
                ckpt_filename = self.from_ckpt
                log.info('Restoring variables from checkpoint %s ...' % ckpt_filename)
            else:
                # Otherwise, use a database checkpoint.
                self.load_rec() if self.load_data is None else None
                if self.load_data is not None:
                    rec, ckpt_filename = self.load_data
                    log.info('Restoring variables from record %s (step %d)...' %
                             (str(rec['_id']), rec['step']))
                else:
                    # No db checkpoint to load.
                    ckpt_filename = None

            if ckpt_filename is not None:

                all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
                self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

                # Next, determine which vars should be restored from the specified checkpoint.
                restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars)
                restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values()))
                restore_names =  [name for name, var in restore_stripped.items()]
                # Actually load the vars.
                log.info('Restored Vars:\n' + str(restore_names))
                tf_saver_restore = tf.train.Saver(restore_vars)
                tf_saver_restore.restore(self.sess, ckpt_filename)
                log.info('... done restoring.')

                # Reinitialize all other, unrestored vars.
                unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names]
                unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names]
                log.info('Unrestored Vars:\n' + str(unrestored_var_names))
                self.sess.run(tf.variables_initializer(unrestored_vars))  # initialize variables not restored
                assert len(self.sess.run(tf.report_uninitialized_variables())) == 0, (
                    self.sess.run(tf.report_uninitialized_variables()))

        if not self.do_restore or (self.load_data is None and self.from_ckpt is None):
            init_op_global = tf.global_variables_initializer()
            self.sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            self.sess.run(init_op_local)
Ejemplo n.º 6
0
def version_check_and_info(module):
    """Return either git info or standard module version if not a git repo.

    Args:
        module (module): python module object to get info for.

    Returns:
        dict: dictionary of info

    """
    srcpath = inspect.getsourcefile(module)
    try:
        repo = git.Repo(srcpath, search_parent_directories=True)
    except git.InvalidGitRepositoryError:
        log.info('module %s not in a git repo, checking package version' %
                 module.__name__)
        info = version_info(module)
    else:
        info = git_info(repo)
    info['source_path'] = srcpath
    return info
Ejemplo n.º 7
0
 def do_tpu_validation():
     log.info('Starting to evaluate.')
     eval_results = eval_cls.evaluate(input_fn=valid_fn,
                                      hooks=valid_hooks,
                                      steps=valid_steps)
     log.info('Saving eval results to database.')
     trarg['dbinterface'].save(valid_res={valid_k: eval_results},
                               validation_only=True)
     log.info('Done saving eval results to database.')
     return eval_results
Ejemplo n.º 8
0
def test_estimator(cls_dict, param, ttarg):
    # load params query stores path to checkpoint
    if param['load_params']['do_restore'] and (param['load_params']['query']
                                               is not None):
        # path to specific checkpoint
        load_dir = param['load_params']['query']
    else:
        # gets latest checkpoint from model_dir
        load_dir = None

    ttarg['dbinterface'] = DBInterface(sess=None,
                                       params=param,
                                       save_params=param['save_params'],
                                       load_params=param['load_params'])

    ttarg['dbinterface'].start_time_step = time.time()

    m_predictions = {}
    for valid_k in cls_dict.keys():
        cls = cls_dict[valid_k]
        validation_data_params = param['validation_params'][valid_k][
            'data_params']
        # can use to filter particular params to save, if not there will set to None and all saved
        filter_keys = param['validation_params'][valid_k].get('keys_to_save')
        session_hooks = param['validation_params'][valid_k].get('hooks')
        valid_fn = validation_data_params['func']
        log.info('Starting to evaluate ({}).'.format(valid_k))
        eval_results = cls.predict(input_fn=valid_fn,
                                   predict_keys=filter_keys,
                                   hooks=session_hooks,
                                   checkpoint_path=load_dir)
        m_predictions[valid_k] = list(eval_results)

    log.info('Saving eval results to database.')
    # set validation only to be True to just save the results and not filters
    ttarg['dbinterface'].save(valid_res=m_predictions, validation_only=True)
    log.info('Done saving eval results to database.')

    # sync with hosts
    res = []
    ttarg['dbinterface'].sync_with_host()
    res.append(trarg['dbinterface'].outrecs)
    # returning final eval results for convenience
    return eval_results, res
Ejemplo n.º 9
0
def train_estimator(train_cls, eval_cls, param, trarg):
    if eval_cls is None:
        eval_cls = train_cls

    model_dir = param['save_params'].get('cache_dir', '')
    train_steps = param['train_params']['num_steps']
    # only single targets during eval mode
    need_val = len(param['validation_params'].keys()) > 0
    steps_per_eval = param['save_params'].get('save_valid_freq')
    if need_val:
        valid_k = param['validation_params'].keys()[0]
        validation_data_params = param['validation_params'][valid_k][
            'data_params']
        valid_steps = param['validation_params'][valid_k]['num_steps']
        valid_fn = validation_data_params['func']
        if steps_per_eval is None:
            steps_per_eval = param['save_params']['save_filters_freq']
        else:
            save_filters_freq = param['save_params'].get('save_filters_freq')
            if save_filters_freq is not None:
                # these need to be the same right now because estimator loads
                # from last checkpoint after validating
                assert (steps_per_eval == save_filters_freq)
            else:
                param['save_params']['save_filters_freq'] = steps_per_eval
    train_fn = param['train_params']['data_params']['func']

    model_params = param['model_params']
    iterations_per_loop = model_params.get('iterations_per_loop',
                                           DEFAULT_ITERATIONS_PER_LOOP)

    if (steps_per_eval is
            None) or (steps_per_eval < iterations_per_loop
                      ):  # eval steps cannot be less than TPU iterations
        log.info(
            'Setting save_valid_freq ({}) to be the same as iterations_per_loop ({}).'
            .format(steps_per_eval, iterations_per_loop))
        steps_per_eval = iterations_per_loop

    train_hooks = param['train_params'].get('hooks')
    if need_val:
        valid_hooks = param['validation_params'][valid_k].get('hooks')
    else:
        valid_hooks = None

    current_step = estimator._load_global_step_from_checkpoint_dir(model_dir)
    # initialize db here (currently no support for loading and saving to different places. May need to modify init so load_params can load from different dir, estimator interface limited
    #    when loading and saving to different paths, may need to create a new config)

    trarg['dbinterface'] = DBInterface(sess=None,
                                       params=param,
                                       global_step=current_step,
                                       save_params=param['save_params'],
                                       load_params=param['load_params'],
                                       cache_dir=model_dir)

    log.info('Training beginning ...')
    log.info('Training for %d steps. Current '
             'step %d' % (train_steps, current_step))

    trarg['dbinterface'].start_time_step = time.time()

    tpu_validate_first = param['train_params'].get('tpu_validate_first', False)

    def do_tpu_validation():
        log.info('Starting to evaluate.')
        eval_results = eval_cls.evaluate(input_fn=valid_fn,
                                         hooks=valid_hooks,
                                         steps=valid_steps)
        log.info('Saving eval results to database.')
        trarg['dbinterface'].save(valid_res={valid_k: eval_results},
                                  validation_only=True)
        log.info('Done saving eval results to database.')
        return eval_results

    if tpu_validate_first:
        eval_results = do_tpu_validation()

    while current_step < train_steps:
        next_eval = min(current_step + steps_per_eval, train_steps)
        log.info('Training until step %d' % next_eval)
        train_cls.train(input_fn=train_fn,
                        max_steps=next_eval,
                        hooks=train_hooks)
        current_step = next_eval

        if need_val:
            eval_results = do_tpu_validation()

    # sync with hosts
    res = []
    trarg['dbinterface'].sync_with_host()
    res.append(trarg['dbinterface'].outrecs)
    # returning final eval results for convenience
    return eval_results, res
Ejemplo n.º 10
0
    def model_fn(features, labels, mode, params):
        model_params = params['model_params']
        opt_params = params['opt_params']
        loss_agg_func = params['loss_agg_func']
        loss_per_case_func = params['loss_per_case_func']
        loss_func_kwargs = params['loss_func_kwargs']
        loss_agg_func_kwargs = params['loss_agg_func_kwargs']
        lr_params = params['lr_params']

        model_params['train'] = (mode == tf.estimator.ModeKeys.TRAIN)
        if use_tpu:
            model_params['batch_size'] = params[
                'batch_size']  # per shard batch_size

        model_func = model_params.pop('func')

        outputs = model_func(inputs=features, **model_params)
        if isinstance(outputs, dict):
            logit_key = model_params.get('logit_key', 'logits')
            logits = outputs[logit_key]
        else:
            logits = outputs

        loss_args = (outputs, labels)
        loss = loss_per_case_func(*loss_args, **loss_func_kwargs)
        loss = loss_agg_func(loss, **loss_agg_func_kwargs)

        if isinstance(loss, list):
            optimizer_loss = loss
            spec_loss = tf.add_n(loss)
        else:
            optimizer_loss = loss
            spec_loss = loss

        global_step = tf.train.get_global_step()

        lr_func = lr_params.pop('func')
        learning_rate = lr_func(global_step=global_step, **lr_params)

        if mode == tf.estimator.ModeKeys.TRAIN:
            opt_func = opt_params.pop('optimizer', ClipOptimizer)
            # For deprecated parameter func
            old_opt_func = opt_params.pop('func', None)
            if old_opt_func:
                log.info('func in optimizer_params is deprecated, ' + \
                        'please use optimizer')
                opt_func = old_opt_func

            log.info('Passing optimizer class to CrossShardMultiOptimizer')
            optimizer_base = CrossShardMultiOptimizer(
                opt_func(learning_rate=learning_rate, **opt_params))

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_base.minimize(optimizer_loss, global_step)
        else:
            train_op = None

        eval_metrics = None
        if mode == tf.estimator.ModeKeys.EVAL:
            num_valid_targets = len(validation_params.keys())
            metric_fn_kwargs = {'labels': labels, 'logits': logits}
            if use_tpu:
                assert (
                    num_valid_targets == 1
                )  # tpu estimators currently only support single targets :(
                first_valid = validation_params.keys()[0]
                valid_target = validation_params[first_valid]['targets']
                metric_fn = valid_target['func']
                if isinstance(outputs, dict):
                    for kw in outputs.keys():
                        if kw != logit_key:
                            kw_val = outputs[kw]
                            new_kw = kw
                            if isinstance(new_kw, int):
                                new_kw = 'i%i' % new_kw
                            metric_fn_kwargs.update({new_kw: kw_val})

                for kw in valid_target.keys():
                    v = valid_target[kw]
                    if isinstance(v, dict):
                        for kw1 in v.keys():
                            # add any additional kwargs
                            kw_val = v[kw1]
                            metric_fn_kwargs.update({kw1: kw_val})
                            #metric_fn_kwargs[kw] = kw_val
                eval_metrics = (metric_fn, metric_fn_kwargs)
            else:
                # normal estimators expect dicts and can support multiple targets (but same dataset and eval_steps etc)
                eval_dict = {}
                for k in validation_params.keys():
                    k_metric_fn_kwargs = metric_fn_kwargs
                    k_target = k['targets']
                    for kw in k_target.keys():
                        if kw != 'func':
                            # add any additional kwargs
                            kw_val = k_target[kw]
                            k_metric_fn_kwargs[kw] = kw_val
                    eval_dict[k] = (k_target['func'], k_metric_fn_kwargs)
                eval_metrics = eval_dict

        if use_tpu:
            return tpu_estimator_lib.TPUEstimatorSpec(
                mode=mode,
                loss=spec_loss,
                train_op=train_op,
                eval_metrics=eval_metrics)
        else:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=spec_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metrics)
Ejemplo n.º 11
0
def train_from_params(
        save_params,
        model_params,
        train_params,
        loss_params=None,
        learning_rate_params=None,
        optimizer_params=None,
        validation_params=None,
        load_params=None,
        log_device_placement=DEFAULT_PARAMS[
            'log_device_placement'],  # advanced
        dont_run=DEFAULT_PARAMS['dont_run'],  # advanced
        skip_check=DEFAULT_PARAMS['skip_check'],  # advanced
        use_estimator=False):
    """
    Main training interface function.

    Args:
        save_params (dict): 
            Describing the parameters used to construct the save database, and
            control saving. These include:

            - host (str)
                Hostname where database connection lives
            - port (int)
                Port where database connection lives
            - dbname (str)
                Name of database for storage
            - collname (str)
                Name of collection for storage
            - exp_id (str)
                Experiment id descriptor
                NOTE: the variables host/port/dbname/coll/exp_id control
                the location of the saved data for the run, in order of
                increasing specificity.  When choosing these, note that:

                - If a given host/port/dbname/coll/exp_id already has saved checkpoints,\
                then any new call to start training with these same location variables\
                will start to train from the most recent saved checkpoint.  If you mistakenly\
                try to start training a new model with different variable names, or structure,\
                from that existing checkpoint, an error will be raised, as the model will be\
                incompatiable with the saved variables.

                - When choosing what dbname, coll, and exp_id, to use, keep in mind that mongodb\
                queries only operate over a single collection.  So if you want to analyze\
                results from a bunch of experiments together using mongod queries, you should\
                put them all in the same collection, but with different exp_ids. If, on the\
                other hand, you never expect to analyze data from two experiments together,\
                you can put them in different collections or different databases. Choosing\
                between putting two experiments in two collections in the same database\
                or in two totally different databases will depend on how you want to organize\
                your results and is really a matter of preference.

            - do_save (bool, default: True)
                Whether to save to database
            - save_initial_filters (bool, default: True)
                Whether to save initial model filters at step = 0,
            - save_metrics_freq (int, default: 5)
                How often to store train results to database
            - save_valid_freq (int, default: 3000)
                How often to calculate and store validation results
                to database
            - save_filters_freq (int, default: 30000)
                How often to save filter values to database
            - cache_filters_freq (int, default: 3000)
                How often to cache filter values locally and save
                to ___RECENT database
            - cache_max_num (int, default: 6)
                Maximal number of cached filters to keep in __RECENT database
            - cache_dir (str, default: None)
                Path where caches will be saved locally. If None, will default to
                ~/.tfutils/<host:post>/<dbname>/<collname>/<exp_id>.

        model_params (dict): Containing function that produces model and arguments to that function.

            - model_params['func'] 
                The function producing the model.

                The function's signature is:

                Args:

                - ``inputs``: data object
                - ``train`` (boolean): if in training or testing 
                - ``seed`` (int): seed for use in random generation

                Returns:

                - ``outputs`` (tf.Operations): train output tensorflow nodes
                - Additional configurations you want to store in database

            - Remaining items in model_params are dictionary of arguments passed to func.

        train_params (dict): Containing params for data sources and targets in training.

            - train_params['data_params'] 
                This contains params for the data

                - ``train_params['data_params']['func']`` is the function that constructs the data:

                    The function's signature is:

                    Args:

                    - ``batch_size``: Batch size for input data

                    Returns:

                    - ``inputs``: A dictionary of tensors that will be sent to model function

                - ``train_params['data_params']['batch_size']`` batch size of the data, will be sent to func

                - Remainder of ``train_params['data_params']`` are kwargs passed to func

            - train_params['targets'] (optional) 
                contains params for additional train targets

                - ``train_params['targets']['func']`` is a function that produces tensorflow nodes as training targets:

                    The function's signature is:

                    Args:

                    - ``inputs``: returned values of ``train_params['data_params']['func']``
                    - ``output``: first returned value of ``train_params['model_params']['func']``

                    Returns:

                    A dictionary of tensors that will be computed and stored in the database

                - Remainder of ``train_parms['targets']`` are arguments to func.

            - train_params['validate_first'] (optional, bool, default is True):
                controls whether validating before training

            - train_params['thres_loss'] (optional, float, default: 100): 
                If loss exceeds this during training, HiLossError is thrown

            - train_params['num_steps'] (int or None, default: None): 
                How many total steps of the optimization are run.
                If None, train is run until process is cancelled.

        loss_params (dict): Parameters for helper.get_loss_base function to build loss.

            - loss_params['pred_targets'] (a string or a list of strings):
                contain the names of inputs nodes that will be sent into the loss function

            - loss_params['loss_func']:
                the function used to calculate the loss. Must be provided.

            - loss_params['loss_func_kwargs'] (dict):
                Keyword parameters sent to ``loss_params['loss_func']``. Default is {}.

            - loss_params['agg_func']:
                The aggregate function, default is None.

            - loss_params['agg_func_kwargs']: 
                Keyword parameters sent to ``loss_params['agg_func']``. Default is {}.

            - loss_params['loss_per_case_func'] (Deprecated):
                Deprecated parameter, the same as ``loss_params['loss_func']``.

            - loss_params['targets'] (Deprecated):
                Deprecated parameter, the same as ``loss_params['targets']``.

        learning_rate_params (dict): Parameters for specifying learning_rate.

            - learning_rate_params['func']:
                The function producing tensorflow node acting as learning rate. 
                This function must accept argument ``global_step``.

            - remainder of learning_rate_params are arguments to func.

        optimizer_params (dict): Parameters for creating optimizer.

            - optimizer_params['optimizer']:
                A class producing an optimizer object, 
                which should have function ``compute_gradients`` and ``apply_gradients``. 
                The signatures of these two functions are similar as tensorflow basic optimizer classes.

                Must accept:

                - "learning_rate" -- the result of the learning_rate_func call

                - Remainder of optimizer_params (aside form "optimizer") are arguments
                  to the optimizer func

            - optimizer_params['func'] (Deprecated):
                Deprecated parameter, the same as ``optimizer_params['optimizer']``.

        validation_params (dict): Dictionary of validation sources. The structure if this dictionary is:

            {
                <validation_target_name_1>: {
                    data: {
                        'func': (callable) data source function for this validation,

                        <kwarg1>: <value1> for 'func',

                        ...
                        },
                    targets: {
                        'func': (callable) returning targets,

                        <kwarg1>: <value1> for 'func',

                        ...
                        },
                    num_steps (int): 
                        number of batches of validation source to compute,
                    agg_func (optional, callable):  
                        how to aggregate validation results
                        across batches after computation. Signature is:

                            - one input argument: the list of validation batch results
                            - one output: aggregated version
                        Default is ``utils.identity_func``
                    online_agg_func (optional, callable):  
                        how to aggregate validation results
                        on a per-batch basis. Siganture is:

                            - three input arguments: (current aggregate, new result, step)
                            - one output: new aggregated result
                        On first step, current aggregate passed in is None.
                        The final result is passed to the "agg_func".
                        Default is ``utils.append_and_return``
                },

                <validation_target_name_2>: ...
            }

            For each validation_target_name key, the targets are computed and then added to
            the output dictionary to be computed every so often -- unlike train_targets which
            are computed on each time step, these are computed on a basic controlled by the
            valid_save_freq specific in the save_params.

        load_params (dict):
            Similar to save_params, if you want loading to happen from a different
            location than where saving occurs. Parameters include:

            - host (str)
                Hostname where database connection lives
            - port (int)
                Port where database connection lives
            - dbname (str)
                Name of database for storage
            - collname (str)
                Name of collection for storage
            - exp_id (str)
                Experiment id descriptor
            - do_restore (bool, default: True)
                Whether to restore from saved model
            - query (dict)
                mongodb query describing how to load from loading database
            - from_ckpt (string)
                Path to load from a TensorFlow checkpoint (instead of from the db)
            - to_restore (list of strings or a regex/callable which returns strings)
                Specifies which variables should be loaded from the checkpoint.
                Any variables not specified here will be reinitialized.
            - load_param_dict (dict)
                A dictionary whose keys are the names of the variables that are to be loaded
                from the checkpoint, and the values are the names of the variables of the model
                that you want to restore with the value of the corresponding checkpoint variable.

        log_device_placement (bool, default is False): 
            Advanced parameter. Whether to log device placement in tensorflow session

        dont_run (bool, default is False): 
            Advanced parameter. Whether returning everything, not actually training 

        skip_check (bool, default is False): 
            Advanced parameter. Whether skipping github check, could be useful when working in detached head

    """

    # use tpu only if a tpu_name has been specified and not a multi-model
    if isinstance(model_params, list):  # multi-model mode
        use_tpu = (model_params[0].get('tpu_name', None) is not None)
        assert (use_tpu is False)
    else:
        use_tpu = (model_params.get('tpu_name', None) is not None)
    if use_tpu:
        log.info('Using tpu: %s' % model_params['tpu_name'])
    params, train_args = parse_params(
        'train',
        model_params,
        dont_run=dont_run,
        skip_check=skip_check,
        load_params=load_params,
        loss_params=loss_params,
        save_params=save_params,
        train_params=train_params,
        optimizer_params=optimizer_params,
        validation_params=validation_params,
        learning_rate_params=learning_rate_params,
        log_device_placement=log_device_placement,
        use_tpu=use_tpu or use_estimator)

    if use_estimator or use_tpu:
        return tpu_train_from_params(params, train_args, use_tpu=use_tpu)
    else:
        with tf.Graph().as_default(), tf.device(DEFAULT_HOST):
            # For convenience, use list of dicts instead of dict of lists
            _params = [{key: value[i]
                        for (key, value) in params.items()}
                       for i in range(len(params['model_params']))]
            _trargs = [{key: value[i]
                        for (key, value) in train_args.items()}
                       for i in range(len(params['model_params']))]

            # Use a single dataprovider for all models.
            data_params = _params[0]['train_params']['data_params']

            _params[0]['train_params']['data_params'], inputs \
                    = get_data(**data_params)

            # Build a graph for each distinct model.
            var_manager_list = []
            for param, trarg in zip(_params, _trargs):
                _, _, param, trarg, var_manager \
                        = get_model(inputs,
                                param['model_params'],
                                param=param,
                                trarg=trarg)

                trarg['validation_targets'], _ = \
                        get_valid_targets_dict(
                                var_manager=var_manager,
                                **param)
                var_manager_list.append(var_manager)

            # Create session.
            gpu_options = tf.GPUOptions(allow_growth=True)
            sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                gpu_options=gpu_options,
                log_device_placement=log_device_placement,
            ))

            # Initialize variables here
            init_op_global = tf.global_variables_initializer()
            sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            sess.run(init_op_local)
            log.info('Initialized from scratch first')

            # Build database interface for each model
            # This interface class will handle the records saving, model saving, and
            # model restoring.
            for param, trarg, var_manager in zip(_params, _trargs,
                                                 var_manager_list):

                trarg['dbinterface'] = DBInterface(
                    sess=sess,
                    params=param,
                    var_manager=var_manager,
                    global_step=trarg['global_step'],
                    save_params=param['save_params'],
                    load_params=param['load_params'])
                ## Model will be restored from saved database here
                trarg['dbinterface'].initialize()

            # Convert back to a dictionary of lists
            params = {
                key: [param[key] for param in _params]
                for key in _params[0].keys()
            }
            train_args = {
                key: [trarg[key] for trarg in _trargs]
                for key in _trargs[0].keys()
            }

            if dont_run:
                return train_args

            return train(sess, **train_args)
Ejemplo n.º 12
0
    def load_from_db(self,
                     query,
                     cache_filters=False,
                     collfs=None,
                     collfs_recent=None):
        """Load checkpoint from the database.

        Checks the recent and regular checkpoint fs to find the latest one
        matching the query. Returns the GridOut obj corresponding to the
        record.

        Args:
            query: dict expressing MongoDB query
        """
        if collfs is None:
            collfs = self.collfs
        coll = collfs._GridFS__files
        if collfs_recent is None:
            collfs_recent = self.collfs_recent
        coll_recent = collfs_recent._GridFS__files

        query['saved_filters'] = True
        count = collfs.find(query).count()
        if count > 0:  # get latest that matches query
            ckpt_record = coll.find(query, sort=[('uploadDate', -1)])[0]
            loading_from = coll
        else:
            ckpt_record = None

        try:
            count_recent = collfs_recent.find(query).count()
        except Exception as inst:
            raise er.OperationFailure(
                inst.args[0] +
                "\n Is your dbname too long? Mongo requires that dbnames be no longer than 64 characters."
            )
        if count_recent > 0:  # get latest that matches query
            ckpt_record_recent = coll_recent.find(query,
                                                  sort=[('uploadDate', -1)])[0]
            # use the record with latest timestamp
            if ckpt_record is None or ckpt_record_recent[
                    'uploadDate'] > ckpt_record['uploadDate']:
                loading_from = coll_recent
                ckpt_record = ckpt_record_recent

        if count + count_recent == 0:  # no matches for query
            log.warning('No matching checkpoint for query "{}"'.format(
                repr(query)))
            return

        database = loading_from._Collection__database
        log.info('Loading checkpoint from %s' % loading_from.full_name)

        if cache_filters:
            filename = os.path.basename(ckpt_record['filename'])
            cache_filename = os.path.join(self.cache_dir, filename)

            # check if there is no local copy
            if not os.path.isfile(cache_filename):
                log.info('No cache file at %s, loading from DB' %
                         cache_filename)
                # create new file to write from gridfs
                load_dest = open(cache_filename, "w+")
                load_dest.close()
                load_dest = open(cache_filename, 'rwb+')
                fsbucket = gridfs.GridFSBucket(
                    database, bucket_name=loading_from.name.split('.')[0])
                fsbucket.download_to_stream(ckpt_record['_id'], load_dest)
                load_dest.close()
                if ckpt_record[
                        '_saver_write_version'] == saver_pb2.SaverDef.V2:
                    assert cache_filename.endswith('.tar')
                    tar = tarfile.open(cache_filename)
                    tar.extractall(path=self.cache_dir)
                    tar.close()
                    cache_filename = os.path.splitext(cache_filename)[0]
                    verify_pb2_v2_files(cache_filename, ckpt_record)
            else:
                if ckpt_record[
                        '_saver_write_version'] == saver_pb2.SaverDef.V2:
                    cache_filename = os.path.splitext(cache_filename)[0]
                    verify_pb2_v2_files(cache_filename, ckpt_record)
                log.info('Cache file found at %s, using that to load' %
                         cache_filename)
        else:
            cache_filename = None
        return ckpt_record, cache_filename
Ejemplo n.º 13
0
    def initialize(self):
        """Fetch record then uses tf's saver.restore."""
        if self.do_restore:
            # First, determine which checkpoint to use.
            if self.from_ckpt is not None:
                # Use a cached checkpoint file.
                ckpt_filename = self.from_ckpt
                log.info('Restoring variables from checkpoint %s ...' \
                        % ckpt_filename)
            else:
                # Otherwise, use a database checkpoint.
                self.load_rec() if self.load_data is None else None
                if self.load_data is not None:
                    rec, ckpt_filename = self.load_data
                    log.info('Restoring variables from record %s (step %d)...' \
                             % (str(rec['_id']), rec['step']))
                else:
                    # No db checkpoint to load.
                    ckpt_filename = None

            if ckpt_filename is not None:
                # Determine which vars should be restored from the specified checkpoint.
                restore_vars = self.get_restore_vars(ckpt_filename)
                restore_names = [name for name, var in restore_vars.items()]
                # remap the actually restored names to the new ones
                if self.load_param_dict:
                    for each_old_name in self.load_param_dict.keys():
                        if each_old_name in restore_names:
                            restore_names.remove(each_old_name)
                            restore_names.append(
                                self.load_param_dict[each_old_name])

                # Actually load the vars.
                log.info('Restored Vars (in ckpt, in graph):\n' +
                         str(restore_names))
                tf_saver_restore = tf.train.Saver(restore_vars)
                tf_saver_restore.restore(self.sess, ckpt_filename)
                log.info('... done restoring.')

                # Run post init_ops if needed
                if self.var_manager:
                    self.sess.run(
                        tf.group(*self.var_manager.get_post_init_ops()))

                # Reinitialize all other, unrestored vars.
                unrestored_vars = [\
                        var \
                        for name, var in self.var_list.items() \
                        if name not in restore_names]
                unrestored_var_names = [\
                        name \
                        for name, var in self.var_list.items() \
                        if (name not in restore_names) and not(any([name.endswith(s) for s in OPTIMIZER_NAMES]))]
                log.info('Unrestored Vars (in graph, not in ckpt):\n' +
                         str(unrestored_var_names))
                self.sess.run(tf.variables_initializer(
                    unrestored_vars))  # initialize variables not restored
                assert len(self.sess.run(
                    tf.report_uninitialized_variables())) == 0, (self.sess.run(
                        tf.report_uninitialized_variables()))

        if not self.do_restore \
                or (self.load_data is None and self.from_ckpt is None):
            init_op_global = tf.global_variables_initializer()
            self.sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            self.sess.run(init_op_local)
            if self.var_manager:
                self.sess.run(tf.group(*self.var_manager.get_post_init_ops()))
Ejemplo n.º 14
0
def tpu_train_from_params(params, train_args, use_tpu=False):
    """
    Main tpu training interface function, called by train_from_params in tfutils.train.
    See the doc string there for info.
    """

    # use this for tpu and estimator logging
    tf.logging.set_verbosity(tf.logging.INFO)
    # For convenience, use list of dicts instead of dict of lists
    _params = [{key: value[i]
                for (key, value) in params.items()}
               for i in range(len(params['model_params']))]
    _trargs = [{key: value[i]
                for (key, value) in train_args.items()}
               for i in range(len(params['model_params']))]

    param = _params[0]
    trarg = _trargs[0]
    # Support only single model
    assert (len(_params) == 1)
    train_data_params = param['train_params']['data_params']

    model_params = param['model_params']
    lr_params = param['learning_rate_params']
    opt_params = param['optimizer_params']
    loss_params = param['loss_params']
    validation_params = param['validation_params']
    save_params = param['save_params']
    # set up estimator func
    estimator_fn, params_to_pass = create_train_estimator_fn(
        use_tpu=use_tpu,
        model_params=model_params,
        lr_params=lr_params,
        opt_params=opt_params,
        loss_params=loss_params,
        validation_params=validation_params)

    if use_tpu:
        if len(param['validation_params'].keys()) > 0:
            valid_k = param['validation_params'].keys()[0]
            validation_data_params = param['validation_params'][valid_k][
                'data_params']
            eval_batch_size = validation_data_params['batch_size']
        else:
            eval_batch_size = None
        # grab tpu name and gcp, etc from model params
        train_m_config = create_train_tpu_config(
            model_dir=save_params.get('cache_dir', ''),
            tpu_name=model_params.get('tpu_name', None),
            gcp_project=model_params.get('gcp_project', None),
            steps_per_checkpoint=save_params.get('save_filters_freq', None),
            tpu_zone=model_params.get('tpu_zone', DEFAULT_TPU_ZONE),
            num_shards=model_params.get('num_shards', DEFAULT_NUM_SHARDS),
            keep_checkpoint_max=save_params.get('checkpoint_max', 5),
            iterations_per_loop=model_params.get('iterations_per_loop',
                                                 DEFAULT_ITERATIONS_PER_LOOP),
            model_params=model_params)
        train_estimator_classifier = tpu_estimator_lib.TPUEstimator(
            use_tpu=True,
            model_fn=estimator_fn,
            config=train_m_config,
            train_batch_size=train_data_params['batch_size'],
            eval_batch_size=eval_batch_size,
            params=params_to_pass)
        val_estimator_classifier = None

        if model_params.get('num_shards', DEFAULT_NUM_SHARDS) > 8:
            log.info("You are training in pod mode")
            log.info(
                "Setting up validation on a single independent TPU device")
            assert model_params.get('val_tpu_name') is not None
            val_m_config = create_train_tpu_config(
                model_dir=save_params.get('cache_dir', ''),
                tpu_name=model_params.get('val_tpu_name', None),
                gcp_project=model_params.get('gcp_project', None),
                steps_per_checkpoint=save_params.get('save_filters_freq',
                                                     None),
                tpu_zone=model_params.get('val_tpu_zone', DEFAULT_TPU_ZONE),
                num_shards=8,
                keep_checkpoint_max=save_params.get('checkpoint_max', 5),
                iterations_per_loop=model_params.get(
                    'iterations_per_loop', DEFAULT_ITERATIONS_PER_LOOP),
                model_params=model_params)

            val_estimator_classifier = tpu_estimator_lib.TPUEstimator(
                use_tpu=True,
                model_fn=estimator_fn,
                config=val_m_config,
                train_batch_size=train_data_params['batch_size'],
                eval_batch_size=eval_batch_size,
                params=params_to_pass)

    else:
        train_estimator_classifier = tf.estimator.Estimator(
            model_fn=estimator_fn, params=params_to_pass)

    return train_estimator(train_cls=train_estimator_classifier,
                           eval_cls=val_estimator_classifier,
                           param=param,
                           trarg=trarg)
Ejemplo n.º 15
0
    def save(self,
             train_res=None,
             valid_res=None,
             step=None,
             validation_only=False):
        """Actually save record into DB and makes local filter caches."""
        if train_res is None:
            train_res = {}
        if valid_res is None:
            valid_res = {}

        if (not validation_only) and (step is None):
            if not hasattr(self.global_step, 'eval'):
                raise NoGlobalStepError(
                    'If step is none, you must pass global_step'
                    ' tensorflow operation to the saver.')
            step = self.global_step.eval(session=self.sess)

        train_res = copy.copy(train_res)
        valid_res = {_k: copy.copy(_v) for _k, _v in valid_res.items()}
        duration = time.time() - self.start_time_step

        if self.rec_to_save is None:
            rec = {
                'exp_id': self.exp_id,
                'params': self.sonified_params,
                'saved_filters': False,
                'duration': duration
            }
            self.rec_to_save = rec
        else:
            rec = self.rec_to_save
        rec['step'] = step

        if len(train_res) > 0:
            # TODO: also include error rate of the train set to monitor overfitting
            message = 'Step {} ({:.0f} ms) -- '.format(step, 1000 * duration)

            # If ndarray found, get the mean of it
            for k, v in train_res.items():
                if k not in ['optimizer', '__grads__'] and \
                        isinstance(v, np.ndarray) and len(v) > 1:
                    train_res[k] = np.mean(v)

            msg2 = [
                '{}: {:.4f}'.format(k, v) for k, v in train_res.items()
                if k not in ['optimizer', '__grads__']
                and k not in self.save_to_gfs
            ]
            message += ', '.join(msg2)
            log.info(message)

            if '__grads__' in train_res:
                del train_res['__grads__']
            if 'optimizer' in train_res:
                del train_res['optimizer']
            if 'train_results' not in rec:
                rec['train_results'] = []
            rec['train_results'].append(train_res)

        # print validation set performance
        if len(valid_res) > 0:
            rec['validation_results'] = valid_res
            message = 'Validation -- '
            message += ', '.join('{}: {}'.format(
                k,
                {_k: _v
                 for _k, _v in v.items() if _k not in self.save_to_gfs})
                                 for k, v in valid_res.items())
            log.info(message)

        if validation_only:
            rec['validates'] = self.load_data[0]['_id']
            save_filters_permanent = save_filters_tmp = False
            need_to_save = True
        else:
            save_filters_permanent = (
                (step % self.save_filters_freq == 0)
                and (step > 0 or
                     (self.save_initial_filters and not self.load_data)))
            save_filters_tmp = (
                (step % self.cache_filters_freq == 0)
                and (step > 0 or
                     (self.save_initial_filters and not self.load_data)))
            save_metrics_now = step % self.save_metrics_freq == 0
            save_valid_now = step % self.save_valid_freq == 0
            need_to_save = save_filters_permanent or save_filters_tmp or save_metrics_now or save_valid_now

        need_to_save = self.do_save and need_to_save

        if need_to_save:
            self.rec_to_save = None
            self.sync_with_host()
            save_to_gfs = {}
            for _k in self.save_to_gfs:
                if train_res:
                    if 'train_results' not in save_to_gfs:
                        save_to_gfs['train_results'] = {}
                    if _k in train_res:
                        save_to_gfs['train_results'][_k] = [
                            r.pop(_k) for r in rec['train_results'] if _k in r
                        ]
                        if len(save_to_gfs['train_results'][_k]) == 1:
                            save_to_gfs['train_results'][_k] == save_to_gfs[
                                'train_results'][_k][0]
                if valid_res:
                    if 'validation_results' not in save_to_gfs:
                        save_to_gfs['validation_results'] = {}
                    for _vk in valid_res:
                        if _vk not in save_to_gfs['validation_results']:
                            save_to_gfs['validation_results'][_vk] = {}
                        if _k in valid_res[_vk]:
                            save_to_gfs['validation_results'][_vk][
                                _k] = valid_res[_vk].pop(_k)

            save_rec = sonify(rec, skip=self._skip_check)
            make_mongo_safe(save_rec)

            coord = tf.train.Coordinator()
            thread = CoordinatedThread(coord=coord,
                                       target=self._save_thread,
                                       args=(save_filters_permanent,
                                             save_filters_tmp, save_rec, step,
                                             save_to_gfs))
            thread.daemon = True
            thread.start()
            self.checkpoint_thread = thread
            self.checkpoint_coord = coord
Ejemplo n.º 16
0
def test_from_params(load_params,
                     model_params,
                     validation_params,
                     log_device_placement=False,
                     save_params=None,
                     dont_run=False,
                     skip_check=False,
                     use_estimator=False):
    """
    Main testing interface function.

    Same as train_from_parameters; but just performs testing without training.

    For documentation, see argument descriptions in train_from_params.

    """
    # use tpu only if a tpu_name has been specified and not a multi-model
    if isinstance(model_params, list):  # multi-model mode
        use_tpu = (model_params[0].get('tpu_name', None) is not None)
        assert (use_tpu is False)
    else:
        use_tpu = (model_params.get('tpu_name', None) is not None)
    if use_tpu:
        log.info('Using tpu: %s' % model_params['tpu_name'])

    params, test_args = parse_params('test',
                                     model_params,
                                     dont_run=dont_run,
                                     skip_check=skip_check,
                                     save_params=save_params,
                                     load_params=load_params,
                                     validation_params=validation_params,
                                     log_device_placement=log_device_placement,
                                     use_tpu=use_tpu)

    # do not need to create sess with estimator interface
    if use_estimator or use_tpu:
        return tpu_test_from_params(params, test_args, use_tpu=use_tpu)
    else:
        with tf.Graph().as_default(), tf.device(DEFAULT_HOST):

            # create session
            gpu_options = tf.GPUOptions(allow_growth=True)
            sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                gpu_options=gpu_options,
                log_device_placement=log_device_placement,
            ))

            init_op_global = tf.global_variables_initializer()
            sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            sess.run(init_op_local)
            log.info('Initialized from scratch first')

            # For convenience, use list of dicts instead of dict of lists
            _params = [{key: value[i]
                        for (key, value) in params.items()}
                       for i in range(len(params['model_params']))]
            _ttargs = [{key: value[i]
                        for (key, value) in test_args.items()}
                       for i in range(len(params['model_params']))]

            # Build a graph for each distinct model.
            for param, ttarg in zip(_params, _ttargs):
                print(param['load_params'])
                from_ckpt = param['load_params'].get('from_ckpt')
                use_ckpt = (from_ckpt is not None)

                if not 'cache_dir' in load_params:
                    temp_cache_dir = save_params.get('cache_dir', None)
                    load_params['cache_dir'] = temp_cache_dir
                    log.info('cache_dir not found in load_params, '\
                            + 'using cache_dir ({}) from save_params'.format(
                                temp_cache_dir))

                ttarg['dbinterface'] = DBInterface(
                    var_manager=None,
                    params=param,
                    load_params=param['load_params'])
                if not use_ckpt:
                    ttarg['dbinterface'].load_rec()
                    ld = ttarg['dbinterface'].load_data
                    assert ld is not None, "No load data found for query, aborting"
                    ld = ld[0]
                    # TODO: have option to reconstitute model_params entirely from
                    # saved object ("revivification")
                    param['model_params']['seed'] = ld['params'][
                        'model_params']['seed']
                    cfg_final = ld['params']['model_params']['cfg_final']
                else:
                    cfg_final = param['model_params'].get('cfg_final', {})

                ttarg['validation_targets'], var_manager \
                        = get_valid_targets_dict(
                            loss_params=None,
                            cfg_final=cfg_final,
                            **param)

                param['load_params']['do_restore'] = True
                param['model_params']['cfg_final'] = cfg_final

                # Build database interface class, loading model
                ttarg['dbinterface'] = DBInterface(
                    sess=sess,
                    params=param,
                    var_manager=var_manager,
                    load_params=param['load_params'],
                    save_params=param['save_params'])
                ttarg['dbinterface'].initialize()

                ttarg['save_intermediate_freq'] \
                        = param['save_params'].get('save_intermediate_freq')

            # Convert back to a dictionary of lists
            params = {
                key: [param[key] for param in _params]
                for key in _params[0].keys()
            }
            test_args = {
                key: [ttarg[key] for ttarg in _ttargs]
                for key in _ttargs[0].keys()
            }

            if dont_run:
                return test_args

            res = test(sess, **test_args)
            sess.close()
            return res
Ejemplo n.º 17
0
    def _save_thread(self, save_filters_permanent, save_filters_tmp, save_rec,
                     step, save_to_gfs):
        if save_filters_permanent or save_filters_tmp:
            save_rec['saved_filters'] = True
            save_path = os.path.join(self.cache_dir, 'checkpoint')
            log.info('Saving model with path prefix %s ... ' % save_path)
            saved_path = self.tf_saver.save(self.sess,
                                            save_path=save_path,
                                            global_step=step,
                                            write_meta_graph=False)
            log.info('... done saving with path prefix %s' % saved_path)
            putfs = self.collfs if save_filters_permanent else self.collfs_recent
            log.info('Putting filters into %s database' % repr(putfs))
            save_rec['_saver_write_version'] = self.tf_saver._write_version
            if self.tf_saver._write_version == saver_pb2.SaverDef.V2:
                file_data = get_saver_pb2_v2_files(saved_path)
                save_rec['_saver_num_data_files'] = file_data['num_data_files']
                tarfilepath = saved_path + '.tar'
                tar = tarfile.open(tarfilepath, 'w')
                for _f in file_data['files']:
                    tar.add(_f, arcname=os.path.split(_f)[1])
                tar.close()
                with open(tarfilepath, 'rb') as _fp:
                    outrec = putfs.put(_fp, filename=tarfilepath, **save_rec)
            else:
                with open(saved_path, 'rb') as _fp:
                    outrec = putfs.put(_fp, filename=saved_path, **save_rec)
            log.info('... done putting filters into database.')

            if not save_filters_permanent:
                recent_gridfs_files = self.collfs_recent._GridFS__files
                recent_query_result = recent_gridfs_files.find(
                    {'saved_filters': True}, sort=[('uploadDate', 1)])
                num_cached_filters = recent_query_result.count()
                cache_max_num = self.cache_max_num
                if num_cached_filters > cache_max_num:
                    log.info('Cleaning up cached filters')
                    fsbucket = gridfs.GridFSBucket(
                        recent_gridfs_files._Collection__database,
                        bucket_name=recent_gridfs_files.name.split('.')[0])

                    for del_indx in xrange(0,
                                           num_cached_filters - cache_max_num):
                        #log.info(recent_query_result[del_indx]['uploadDate'])
                        fsbucket.delete(recent_query_result[del_indx]['_id'])

        if not save_filters_permanent:
            save_rec['saved_filters'] = False
            log.info('Inserting record into database.')
            outrec = self.collfs._GridFS__files.insert_one(save_rec)

        if not isinstance(outrec, ObjectId):
            outrec = outrec.inserted_id

        if save_to_gfs:
            idval = str(outrec)
            save_to_gfs_path = idval + "_fileitems"
            self.collfs.put(cPickle.dumps(save_to_gfs),
                            filename=save_to_gfs_path,
                            item_for=outrec)

        sys.stdout.flush()  # flush the stdout buffer
        self.outrecs.append(outrec)
Ejemplo n.º 18
0
def train(sess,
          dbinterface,
          train_loop,
          train_targets,
          global_step,
          num_minibatches=1,
          num_steps=float('inf'),
          thres_loss=DEFAULT_TRAIN_THRES_LOSS,
          validate_first=True,
          validation_targets=None):
    """Actually runs the training evaluation loop.

    Args:
        sess (tesorflow.Session):
            Object in which to run calculations.

        dbinterface (DBInterface object): Saver through which to save results.

        train_loop (callable withs args: sess and train_targets):
            Callable that specifies a custom training loop
        train_targets (dict of tensorflow nodes): Targets to train.
            One item in this dict must be "optimizer" or similar
            to make anything happen
        num_minibatches (int): How many minibatches to use to before applying gradient update.
        num_steps (int): How many steps to train to before quitting
        validation_targets (dict of tensorflow objects, default: None):
            Objects on which validation will be computed
        thres_loss (float, default: 100):
            If loss exceeds this during training, HiLossError is thrown

    """
    # Collect args in a dict of lists
    train_args = {
        'num_steps': num_steps,
        'thres_loss': thres_loss,
        'train_loop': train_loop,
        'global_step': global_step,
        'dbinterface': dbinterface,
        'train_targets': train_targets,
        'validate_first': validate_first,
        'num_minibatches': num_minibatches,
        'validation_targets': validation_targets
    }

    # Convert to a list of dicts
    trargs = [{key: value[i]
               for (key, value) in train_args.items()}
              for i in range(len(train_targets))]

    num_steps = [t['num_steps'] for t in trargs]
    steps = [t['global_step'].eval(session=sess) for t in trargs]

    # Start initial validation
    for (step, trarg) in zip(steps, trargs):

        if step >= trarg['num_steps']:
            log.info('Training cancelled since step ({}) is >= num_steps ({})'.
                     format(step, trarg['num_steps']))
            return

        log.info('Training beginning ...')

        if step == 0:
            trarg['dbinterface'].start_time_step = time.time()
            if trarg['validate_first']:
                valid_res = run_all_validations(
                    sess,
                    trarg['validation_targets'],
                    dbinterface=trarg['dbinterface'])
    train_loop = train_args['train_loop'][0]
    train_targets = train_args['train_targets']

    # Run training
    while any(step < num_step for (step, num_step) in zip(steps, num_steps)):

        start_time_step = time.time()
        train_results = train_loop(sess,
                                   train_targets,
                                   num_minibatches=trarg['num_minibatches'])

        for (step, trarg, train_res) in zip(steps, trargs, train_results):

            old_step = step
            step = trarg['global_step'].eval(session=sess)

            if step <= old_step:
                raise NoChangeError(\
                        'Your optimizer should have incremented the global step,'
                        ' but did not: old_step=%d, new_step=%d' \
                                % (old_step, step))
            if np.isnan(train_res['loss']):
                raise NanLossError(\
                        'Loss has become NaN')
            if train_res['loss'] > trarg['thres_loss']:
                raise HiLossError(\
                        'Loss {:.2f} exceeded the threshold {:.2f}'.format(
                            train_res['loss'],
                            trarg['thres_loss']))

            # Validation
            vtargs = trarg['validation_targets'] \
                    if step % trarg['dbinterface'].save_valid_freq == 0 else {}
            valid_res = run_all_validations(sess, vtargs)

            # Save
            trarg['dbinterface'].start_time_step = start_time_step
            trarg['dbinterface'].save(train_res=train_res,
                                      valid_res=valid_res,
                                      validation_only=False)

        steps = [t['global_step'].eval(session=sess) for t in trargs]

    # Sync and close the session
    res = []
    for trarg in trargs:
        trarg['dbinterface'].sync_with_host()
        res.append(trarg['dbinterface'].outrecs)

    sess.close()
    return res
Ejemplo n.º 19
0
def test_from_params(load_params,
                     model_params,
                     validation_params,
                     log_device_placement=False,
                     save_params=None,
                     dont_run=False,
                     skip_check=False,
                     ):
    """
    Main testing interface function.

    Same as train_from_parameters; but just performs testing without training.

    For documentation, see argument descriptions in train_from_params.

    """
    params, test_args = parse_params(
            'test',
            model_params,
            dont_run=dont_run,
            skip_check=skip_check,
            save_params=save_params,
            load_params=load_params,
            validation_params=validation_params,
            log_device_placement=log_device_placement,
            )

    with tf.Graph().as_default(), tf.device(DEFAULT_HOST):

        # create session
        sess = tf.Session(
                config=tf.ConfigProto(
                    allow_soft_placement=True,
                    log_device_placement=log_device_placement,
                    ))

        init_op_global = tf.global_variables_initializer()
        sess.run(init_op_global)
        init_op_local = tf.local_variables_initializer()
        sess.run(init_op_local)
        log.info('Initialized from scratch first')

        # For convenience, use list of dicts instead of dict of lists
        _params = [{key: value[i] for (key, value) in params.items()}
                   for i in range(len(params['model_params']))]
        _ttargs = [{key: value[i] for (key, value) in test_args.items()}
                   for i in range(len(params['model_params']))]

        # Build a graph for each distinct model.
        for param, ttarg in zip(_params, _ttargs):

            if not 'cache_dir' in load_params:
                temp_cache_dir = save_params.get('cache_dir', None)
                load_params['cache_dir'] = temp_cache_dir
                log.info('cache_dir not found in load_params, using cache_dir ({}) from save_params'.format(temp_cache_dir))

            ttarg['dbinterface'] = DBInterface(params=param, load_params=param['load_params'])
            ttarg['dbinterface'].load_rec()
            ld = ttarg['dbinterface'].load_data
            assert ld is not None, "No load data found for query, aborting"
            ld = ld[0]
            # TODO: have option to reconstitute model_params entirely from
            # saved object ("revivification")
            param['model_params']['seed'] = ld['params']['model_params']['seed']
            cfg_final = ld['params']['model_params']['cfg_final']

            ttarg['validation_targets'] = \
                    get_valid_targets_dict(
                        loss_params=None,
                        cfg_final=cfg_final,
                        **param)

            # tf.get_variable_scope().reuse_variables()

            param['load_params']['do_restore'] = True
            param['model_params']['cfg_final'] = cfg_final

            prefix = param['model_params']['prefix'] + '/'
            all_vars = variables._all_saveable_objects()
            var_list = strip_prefix(prefix, all_vars)

            ttarg['dbinterface'] = DBInterface(sess=sess,
                                               params=param,
                                               var_list=var_list,
                                               load_params=param['load_params'],
                                               save_params=param['save_params'])
            ttarg['dbinterface'].initialize(no_scratch=True)
            ttarg['save_intermediate_freq'] = param['save_params'].get('save_intermediate_freq')

        # Convert back to a dictionary of lists
        params = {key: [param[key] for param in _params]
                  for key in _params[0].keys()}
        test_args = {key: [ttarg[key] for ttarg in _ttargs]
                     for key in _ttargs[0].keys()}

        if dont_run:
            return test_args

        res = test(sess, **test_args)
        sess.close()
        return res