def setUp(self): """Set up class before _each_ test method is executed. Creates a tensorflow session and instantiates a dbinterface. """ self.setup_model() self.sess = tf.Session( config=tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), log_device_placement=self.params['log_device_placement'], )) # TODO: Determine whether this should be called here or # in dbinterface.initialize() self.sess.run(tf.global_variables_initializer()) self.dbinterface = DBInterface( sess=self.sess, params=self.params, cache_dir=self.CACHE_DIR, save_params=self.save_params, load_params=self.load_params) self.step = 0
def test_estimator(cls_dict, param, ttarg): # load params query stores path to checkpoint if param['load_params']['do_restore'] and (param['load_params']['query'] is not None): # path to specific checkpoint load_dir = param['load_params']['query'] else: # gets latest checkpoint from model_dir load_dir = None ttarg['dbinterface'] = DBInterface(sess=None, params=param, save_params=param['save_params'], load_params=param['load_params']) ttarg['dbinterface'].start_time_step = time.time() m_predictions = {} for valid_k in cls_dict.keys(): cls = cls_dict[valid_k] validation_data_params = param['validation_params'][valid_k][ 'data_params'] # can use to filter particular params to save, if not there will set to None and all saved filter_keys = param['validation_params'][valid_k].get('keys_to_save') session_hooks = param['validation_params'][valid_k].get('hooks') valid_fn = validation_data_params['func'] log.info('Starting to evaluate ({}).'.format(valid_k)) eval_results = cls.predict(input_fn=valid_fn, predict_keys=filter_keys, hooks=session_hooks, checkpoint_path=load_dir) m_predictions[valid_k] = list(eval_results) log.info('Saving eval results to database.') # set validation only to be True to just save the results and not filters ttarg['dbinterface'].save(valid_res=m_predictions, validation_only=True) log.info('Done saving eval results to database.') # sync with hosts res = [] ttarg['dbinterface'].sync_with_host() res.append(trarg['dbinterface'].outrecs) # returning final eval results for convenience return eval_results, res
def train_from_params( save_params, model_params, train_params, loss_params=None, learning_rate_params=None, optimizer_params=None, validation_params=None, load_params=None, log_device_placement=DEFAULT_PARAMS[ 'log_device_placement'], # advanced dont_run=DEFAULT_PARAMS['dont_run'], # advanced skip_check=DEFAULT_PARAMS['skip_check'], # advanced use_estimator=False): """ Main training interface function. Args: save_params (dict): Describing the parameters used to construct the save database, and control saving. These include: - host (str) Hostname where database connection lives - port (int) Port where database connection lives - dbname (str) Name of database for storage - collname (str) Name of collection for storage - exp_id (str) Experiment id descriptor NOTE: the variables host/port/dbname/coll/exp_id control the location of the saved data for the run, in order of increasing specificity. When choosing these, note that: - If a given host/port/dbname/coll/exp_id already has saved checkpoints,\ then any new call to start training with these same location variables\ will start to train from the most recent saved checkpoint. If you mistakenly\ try to start training a new model with different variable names, or structure,\ from that existing checkpoint, an error will be raised, as the model will be\ incompatiable with the saved variables. - When choosing what dbname, coll, and exp_id, to use, keep in mind that mongodb\ queries only operate over a single collection. So if you want to analyze\ results from a bunch of experiments together using mongod queries, you should\ put them all in the same collection, but with different exp_ids. If, on the\ other hand, you never expect to analyze data from two experiments together,\ you can put them in different collections or different databases. Choosing\ between putting two experiments in two collections in the same database\ or in two totally different databases will depend on how you want to organize\ your results and is really a matter of preference. - do_save (bool, default: True) Whether to save to database - save_initial_filters (bool, default: True) Whether to save initial model filters at step = 0, - save_metrics_freq (int, default: 5) How often to store train results to database - save_valid_freq (int, default: 3000) How often to calculate and store validation results to database - save_filters_freq (int, default: 30000) How often to save filter values to database - cache_filters_freq (int, default: 3000) How often to cache filter values locally and save to ___RECENT database - cache_max_num (int, default: 6) Maximal number of cached filters to keep in __RECENT database - cache_dir (str, default: None) Path where caches will be saved locally. If None, will default to ~/.tfutils/<host:post>/<dbname>/<collname>/<exp_id>. model_params (dict): Containing function that produces model and arguments to that function. - model_params['func'] The function producing the model. The function's signature is: Args: - ``inputs``: data object - ``train`` (boolean): if in training or testing - ``seed`` (int): seed for use in random generation Returns: - ``outputs`` (tf.Operations): train output tensorflow nodes - Additional configurations you want to store in database - Remaining items in model_params are dictionary of arguments passed to func. train_params (dict): Containing params for data sources and targets in training. - train_params['data_params'] This contains params for the data - ``train_params['data_params']['func']`` is the function that constructs the data: The function's signature is: Args: - ``batch_size``: Batch size for input data Returns: - ``inputs``: A dictionary of tensors that will be sent to model function - ``train_params['data_params']['batch_size']`` batch size of the data, will be sent to func - Remainder of ``train_params['data_params']`` are kwargs passed to func - train_params['targets'] (optional) contains params for additional train targets - ``train_params['targets']['func']`` is a function that produces tensorflow nodes as training targets: The function's signature is: Args: - ``inputs``: returned values of ``train_params['data_params']['func']`` - ``output``: first returned value of ``train_params['model_params']['func']`` Returns: A dictionary of tensors that will be computed and stored in the database - Remainder of ``train_parms['targets']`` are arguments to func. - train_params['validate_first'] (optional, bool, default is True): controls whether validating before training - train_params['thres_loss'] (optional, float, default: 100): If loss exceeds this during training, HiLossError is thrown - train_params['num_steps'] (int or None, default: None): How many total steps of the optimization are run. If None, train is run until process is cancelled. loss_params (dict): Parameters for helper.get_loss_base function to build loss. - loss_params['pred_targets'] (a string or a list of strings): contain the names of inputs nodes that will be sent into the loss function - loss_params['loss_func']: the function used to calculate the loss. Must be provided. - loss_params['loss_func_kwargs'] (dict): Keyword parameters sent to ``loss_params['loss_func']``. Default is {}. - loss_params['agg_func']: The aggregate function, default is None. - loss_params['agg_func_kwargs']: Keyword parameters sent to ``loss_params['agg_func']``. Default is {}. - loss_params['loss_per_case_func'] (Deprecated): Deprecated parameter, the same as ``loss_params['loss_func']``. - loss_params['targets'] (Deprecated): Deprecated parameter, the same as ``loss_params['targets']``. learning_rate_params (dict): Parameters for specifying learning_rate. - learning_rate_params['func']: The function producing tensorflow node acting as learning rate. This function must accept argument ``global_step``. - remainder of learning_rate_params are arguments to func. optimizer_params (dict): Parameters for creating optimizer. - optimizer_params['optimizer']: A class producing an optimizer object, which should have function ``compute_gradients`` and ``apply_gradients``. The signatures of these two functions are similar as tensorflow basic optimizer classes. Must accept: - "learning_rate" -- the result of the learning_rate_func call - Remainder of optimizer_params (aside form "optimizer") are arguments to the optimizer func - optimizer_params['func'] (Deprecated): Deprecated parameter, the same as ``optimizer_params['optimizer']``. validation_params (dict): Dictionary of validation sources. The structure if this dictionary is: { <validation_target_name_1>: { data: { 'func': (callable) data source function for this validation, <kwarg1>: <value1> for 'func', ... }, targets: { 'func': (callable) returning targets, <kwarg1>: <value1> for 'func', ... }, num_steps (int): number of batches of validation source to compute, agg_func (optional, callable): how to aggregate validation results across batches after computation. Signature is: - one input argument: the list of validation batch results - one output: aggregated version Default is ``utils.identity_func`` online_agg_func (optional, callable): how to aggregate validation results on a per-batch basis. Siganture is: - three input arguments: (current aggregate, new result, step) - one output: new aggregated result On first step, current aggregate passed in is None. The final result is passed to the "agg_func". Default is ``utils.append_and_return`` }, <validation_target_name_2>: ... } For each validation_target_name key, the targets are computed and then added to the output dictionary to be computed every so often -- unlike train_targets which are computed on each time step, these are computed on a basic controlled by the valid_save_freq specific in the save_params. load_params (dict): Similar to save_params, if you want loading to happen from a different location than where saving occurs. Parameters include: - host (str) Hostname where database connection lives - port (int) Port where database connection lives - dbname (str) Name of database for storage - collname (str) Name of collection for storage - exp_id (str) Experiment id descriptor - do_restore (bool, default: True) Whether to restore from saved model - query (dict) mongodb query describing how to load from loading database - from_ckpt (string) Path to load from a TensorFlow checkpoint (instead of from the db) - to_restore (list of strings or a regex/callable which returns strings) Specifies which variables should be loaded from the checkpoint. Any variables not specified here will be reinitialized. - load_param_dict (dict) A dictionary whose keys are the names of the variables that are to be loaded from the checkpoint, and the values are the names of the variables of the model that you want to restore with the value of the corresponding checkpoint variable. log_device_placement (bool, default is False): Advanced parameter. Whether to log device placement in tensorflow session dont_run (bool, default is False): Advanced parameter. Whether returning everything, not actually training skip_check (bool, default is False): Advanced parameter. Whether skipping github check, could be useful when working in detached head """ # use tpu only if a tpu_name has been specified and not a multi-model if isinstance(model_params, list): # multi-model mode use_tpu = (model_params[0].get('tpu_name', None) is not None) assert (use_tpu is False) else: use_tpu = (model_params.get('tpu_name', None) is not None) if use_tpu: log.info('Using tpu: %s' % model_params['tpu_name']) params, train_args = parse_params( 'train', model_params, dont_run=dont_run, skip_check=skip_check, load_params=load_params, loss_params=loss_params, save_params=save_params, train_params=train_params, optimizer_params=optimizer_params, validation_params=validation_params, learning_rate_params=learning_rate_params, log_device_placement=log_device_placement, use_tpu=use_tpu or use_estimator) if use_estimator or use_tpu: return tpu_train_from_params(params, train_args, use_tpu=use_tpu) else: with tf.Graph().as_default(), tf.device(DEFAULT_HOST): # For convenience, use list of dicts instead of dict of lists _params = [{key: value[i] for (key, value) in params.items()} for i in range(len(params['model_params']))] _trargs = [{key: value[i] for (key, value) in train_args.items()} for i in range(len(params['model_params']))] # Use a single dataprovider for all models. data_params = _params[0]['train_params']['data_params'] _params[0]['train_params']['data_params'], inputs \ = get_data(**data_params) # Build a graph for each distinct model. var_manager_list = [] for param, trarg in zip(_params, _trargs): _, _, param, trarg, var_manager \ = get_model(inputs, param['model_params'], param=param, trarg=trarg) trarg['validation_targets'], _ = \ get_valid_targets_dict( var_manager=var_manager, **param) var_manager_list.append(var_manager) # Create session. gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, gpu_options=gpu_options, log_device_placement=log_device_placement, )) # Initialize variables here init_op_global = tf.global_variables_initializer() sess.run(init_op_global) init_op_local = tf.local_variables_initializer() sess.run(init_op_local) log.info('Initialized from scratch first') # Build database interface for each model # This interface class will handle the records saving, model saving, and # model restoring. for param, trarg, var_manager in zip(_params, _trargs, var_manager_list): trarg['dbinterface'] = DBInterface( sess=sess, params=param, var_manager=var_manager, global_step=trarg['global_step'], save_params=param['save_params'], load_params=param['load_params']) ## Model will be restored from saved database here trarg['dbinterface'].initialize() # Convert back to a dictionary of lists params = { key: [param[key] for param in _params] for key in _params[0].keys() } train_args = { key: [trarg[key] for trarg in _trargs] for key in _trargs[0].keys() } if dont_run: return train_args return train(sess, **train_args)
def test_from_params(load_params, model_params, validation_params, log_device_placement=False, save_params=None, dont_run=False, skip_check=False, use_estimator=False): """ Main testing interface function. Same as train_from_parameters; but just performs testing without training. For documentation, see argument descriptions in train_from_params. """ # use tpu only if a tpu_name has been specified and not a multi-model if isinstance(model_params, list): # multi-model mode use_tpu = (model_params[0].get('tpu_name', None) is not None) assert (use_tpu is False) else: use_tpu = (model_params.get('tpu_name', None) is not None) if use_tpu: log.info('Using tpu: %s' % model_params['tpu_name']) params, test_args = parse_params('test', model_params, dont_run=dont_run, skip_check=skip_check, save_params=save_params, load_params=load_params, validation_params=validation_params, log_device_placement=log_device_placement, use_tpu=use_tpu) # do not need to create sess with estimator interface if use_estimator or use_tpu: return tpu_test_from_params(params, test_args, use_tpu=use_tpu) else: with tf.Graph().as_default(), tf.device(DEFAULT_HOST): # create session gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, gpu_options=gpu_options, log_device_placement=log_device_placement, )) init_op_global = tf.global_variables_initializer() sess.run(init_op_global) init_op_local = tf.local_variables_initializer() sess.run(init_op_local) log.info('Initialized from scratch first') # For convenience, use list of dicts instead of dict of lists _params = [{key: value[i] for (key, value) in params.items()} for i in range(len(params['model_params']))] _ttargs = [{key: value[i] for (key, value) in test_args.items()} for i in range(len(params['model_params']))] # Build a graph for each distinct model. for param, ttarg in zip(_params, _ttargs): print(param['load_params']) from_ckpt = param['load_params'].get('from_ckpt') use_ckpt = (from_ckpt is not None) if not 'cache_dir' in load_params: temp_cache_dir = save_params.get('cache_dir', None) load_params['cache_dir'] = temp_cache_dir log.info('cache_dir not found in load_params, '\ + 'using cache_dir ({}) from save_params'.format( temp_cache_dir)) ttarg['dbinterface'] = DBInterface( var_manager=None, params=param, load_params=param['load_params']) if not use_ckpt: ttarg['dbinterface'].load_rec() ld = ttarg['dbinterface'].load_data assert ld is not None, "No load data found for query, aborting" ld = ld[0] # TODO: have option to reconstitute model_params entirely from # saved object ("revivification") param['model_params']['seed'] = ld['params'][ 'model_params']['seed'] cfg_final = ld['params']['model_params']['cfg_final'] else: cfg_final = param['model_params'].get('cfg_final', {}) ttarg['validation_targets'], var_manager \ = get_valid_targets_dict( loss_params=None, cfg_final=cfg_final, **param) param['load_params']['do_restore'] = True param['model_params']['cfg_final'] = cfg_final # Build database interface class, loading model ttarg['dbinterface'] = DBInterface( sess=sess, params=param, var_manager=var_manager, load_params=param['load_params'], save_params=param['save_params']) ttarg['dbinterface'].initialize() ttarg['save_intermediate_freq'] \ = param['save_params'].get('save_intermediate_freq') # Convert back to a dictionary of lists params = { key: [param[key] for param in _params] for key in _params[0].keys() } test_args = { key: [ttarg[key] for ttarg in _ttargs] for key in _ttargs[0].keys() } if dont_run: return test_args res = test(sess, **test_args) sess.close() return res
def train_estimator(train_cls, eval_cls, param, trarg): if eval_cls is None: eval_cls = train_cls model_dir = param['save_params'].get('cache_dir', '') train_steps = param['train_params']['num_steps'] # only single targets during eval mode need_val = len(param['validation_params'].keys()) > 0 steps_per_eval = param['save_params'].get('save_valid_freq') if need_val: valid_k = param['validation_params'].keys()[0] validation_data_params = param['validation_params'][valid_k][ 'data_params'] valid_steps = param['validation_params'][valid_k]['num_steps'] valid_fn = validation_data_params['func'] if steps_per_eval is None: steps_per_eval = param['save_params']['save_filters_freq'] else: save_filters_freq = param['save_params'].get('save_filters_freq') if save_filters_freq is not None: # these need to be the same right now because estimator loads # from last checkpoint after validating assert (steps_per_eval == save_filters_freq) else: param['save_params']['save_filters_freq'] = steps_per_eval train_fn = param['train_params']['data_params']['func'] model_params = param['model_params'] iterations_per_loop = model_params.get('iterations_per_loop', DEFAULT_ITERATIONS_PER_LOOP) if (steps_per_eval is None) or (steps_per_eval < iterations_per_loop ): # eval steps cannot be less than TPU iterations log.info( 'Setting save_valid_freq ({}) to be the same as iterations_per_loop ({}).' .format(steps_per_eval, iterations_per_loop)) steps_per_eval = iterations_per_loop train_hooks = param['train_params'].get('hooks') if need_val: valid_hooks = param['validation_params'][valid_k].get('hooks') else: valid_hooks = None current_step = estimator._load_global_step_from_checkpoint_dir(model_dir) # initialize db here (currently no support for loading and saving to different places. May need to modify init so load_params can load from different dir, estimator interface limited # when loading and saving to different paths, may need to create a new config) trarg['dbinterface'] = DBInterface(sess=None, params=param, global_step=current_step, save_params=param['save_params'], load_params=param['load_params'], cache_dir=model_dir) log.info('Training beginning ...') log.info('Training for %d steps. Current ' 'step %d' % (train_steps, current_step)) trarg['dbinterface'].start_time_step = time.time() tpu_validate_first = param['train_params'].get('tpu_validate_first', False) def do_tpu_validation(): log.info('Starting to evaluate.') eval_results = eval_cls.evaluate(input_fn=valid_fn, hooks=valid_hooks, steps=valid_steps) log.info('Saving eval results to database.') trarg['dbinterface'].save(valid_res={valid_k: eval_results}, validation_only=True) log.info('Done saving eval results to database.') return eval_results if tpu_validate_first: eval_results = do_tpu_validation() while current_step < train_steps: next_eval = min(current_step + steps_per_eval, train_steps) log.info('Training until step %d' % next_eval) train_cls.train(input_fn=train_fn, max_steps=next_eval, hooks=train_hooks) current_step = next_eval if need_val: eval_results = do_tpu_validation() # sync with hosts res = [] trarg['dbinterface'].sync_with_host() res.append(trarg['dbinterface'].outrecs) # returning final eval results for convenience return eval_results, res
class TestDBInterface(unittest.TestCase): PORT = 29101 HOST = 'localhost' EXP_ID = 'TEST_EXP_ID' DATABASE_NAME = 'TFUTILS_TESTDB' COLLECTION_NAME = 'TFUTILS_TESTCOL' CACHE_DIR = 'TFUTILS_TEST_CACHE_DIR' @classmethod def setUpClass(cls): """Set up class once before any test methods are run.""" cls.setup_log() cls.setup_conn() cls.setup_cache() cls.setup_params() @classmethod def tearDownClass(cls): """Tear down class after all test methods have run.""" cls.remove_directory(cls.CACHE_DIR) cls.remove_database(cls.DATABASE_NAME) # Close primary MongoDB connection. cls.conn.close() def setUp(self): """Set up class before _each_ test method is executed. Creates a tensorflow session and instantiates a dbinterface. """ self.setup_model() self.sess = tf.Session( config=tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), log_device_placement=self.params['log_device_placement'], )) # TODO: Determine whether this should be called here or # in dbinterface.initialize() self.sess.run(tf.global_variables_initializer()) self.dbinterface = DBInterface( sess=self.sess, params=self.params, cache_dir=self.CACHE_DIR, save_params=self.save_params, load_params=self.load_params) self.step = 0 def tearDown(self): """Tear Down is called after _each_ test method is executed.""" self.sess.close() @unittest.skip("skipping") def test_init(self): # TODO: Test all permutations of __init__ params. pass @unittest.skip("skipping") def test_load_rec(self): pass @unittest.skip("skipping") def test_initialize(self): pass def test_get_restore_vars(self): # First, train model and save a checkpoint self.train_model() # weights_name='Weights' saved_path = self.save_test_checkpoint() # Create a new model with different variable names. self.setup_model(weights_name='Filters') # Reset var_list in DBInterface self.dbinterface.var_list = { var.op.name: var for var in tf.global_variables()} # Restore first checkpoint vars. mapping = {'Weights': 'Filters'} self.dbinterface.load_param_dict = mapping restore_vars = self.dbinterface.get_restore_vars(saved_path) self.log.info('restore_vars:') for name, var in restore_vars.items(): self.log.info('(name, var.name): ({}, {})'.format(name, var.name)) self.assertEqual(var.op.name, mapping[name]) def test_filter_var_list(self): var_list = {var.op.name: var for var in tf.global_variables()} # Test None self.dbinterface.to_restore = None filtered_var_list = self.dbinterface.filter_var_list(var_list) self.assertEqual(filtered_var_list, var_list) # Test list of strings self.dbinterface.to_restore = ['Weights'] filtered_var_list = self.dbinterface.filter_var_list(var_list) for name, var in filtered_var_list.items(): self.assertIn(name, ['Weights']) self.assertNotIn(name, ['Bias', 'global_step']) # Test regex self.dbinterface.to_restore = re.compile(r'Bias') filtered_var_list = self.dbinterface.filter_var_list(var_list) for name, var in filtered_var_list.items(): self.assertIn(name, ['Bias']) self.assertNotIn(name, ['Weights', 'global_step']) # Test invalid type (should raise TypeError) self.dbinterface.to_restore = {'invalid_key': 'invalid_value'} with self.assertRaises(TypeError): filtered_var_list = self.dbinterface.filter_var_list(var_list) @unittest.skip("skipping") def test_tf_saver(self): pass @unittest.skip("skipping") def test_load_from_db(self): pass @unittest.skip("skipping") def test_save(self): self.dbinterface.initialize() self.dbinterface.start_time_step = time.time() train_res = self.train_model(num_steps=100) self.dbinterface.save(train_res=train_res, step=self.step) @unittest.skip("skipping") def test_sync_with_host(self): pass @unittest.skip("skipping") def test_save_thread(self): pass @unittest.skip("skipping") def test_initialize_from_ckpt(self): save_path = self.save_test_checkpoint() self.load_test_checkpoint(save_path) def train_model(self, num_steps=100): x_train = [1, 2, 3, 4] y_train = [0, -1, -2, -3] x = tf.get_default_graph().get_tensor_by_name('x:0') y = tf.get_default_graph().get_tensor_by_name('y:0') feed_dict = {x: x_train, y: y_train} pre_global_step = self.sess.run(self.global_step) for step in range(num_steps): train_res = self.sess.run(self.train_targets, feed_dict=feed_dict) self.log.info('Step: {}, loss: {}'.format(step, train_res['loss'])) post_global_step = self.sess.run(self.global_step) self.assertEqual(pre_global_step + num_steps, post_global_step) self.step += num_steps return train_res def save_test_checkpoint(self): self.log.info('Saving checkpoint to {}'.format(self.save_path)) saved_checkpoint_path = self.dbinterface.tf_saver.save(self.sess, save_path=self.save_path, global_step=self.global_step, write_meta_graph=False) self.log.info('Checkpoint saved to {}'.format(saved_checkpoint_path)) return saved_checkpoint_path def load_test_checkpoint(self, save_path): reader = tf.train.NewCheckpointReader(save_path) saved_shapes = reader.get_variable_to_shape_map() self.log.info('Saved Vars:\n' + str(saved_shapes.keys())) for name in saved_shapes.keys(): self.log.info( 'Name: {}, Tensor: {}'.format(name, reader.get_tensor(name))) def setup_model(self, weights_name='Weights', bias_name='Bias'): """Set up simple tensorflow model.""" tf.reset_default_graph() self.global_step = tf.get_variable( 'global_step', [], dtype=tf.int64, trainable=False, initializer=tf.constant_initializer(0)) # Model parameters and placeholders. x = tf.placeholder(tf.float32, name='x') y = tf.placeholder(tf.float32, name='y') W = tf.get_variable(weights_name, [1], dtype=tf.float32) b = tf.get_variable(bias_name, [1], dtype=tf.float32) # Model output, loss and optimizer. linear_model = W * x + b loss = tf.reduce_sum(tf.square(linear_model - y)) optimizer_base = tf.train.GradientDescentOptimizer(0.01) # Model train op. optimizer = optimizer_base.minimize( loss, global_step=self.global_step) # Train targets. self.train_targets = {'loss': loss, 'optimizer': optimizer} @classmethod def setup_log(cls): cls.log = logging.getLogger(':'.join([__name__, cls.__name__])) cls.log.setLevel('DEBUG') @classmethod def setup_conn(cls): cls.conn = pymongo.MongoClient(host=cls.HOST, port=cls.PORT) @classmethod def setup_cache(cls): cls.cache_dir = os.path.join(cls.CACHE_DIR, '%s:%d' % (cls.HOST, cls.PORT), cls.DATABASE_NAME, cls.COLLECTION_NAME, cls.EXP_ID) cls.makedirs(cls.cache_dir) cls.save_path = os.path.join(cls.cache_dir, 'checkpoint') @classmethod def setup_params(cls): cls.model_params = {'func': model.mnist_tfutils_new, 'devices': ['/gpu:0', '/gpu:1'], 'prefix': 'model_0'} cls.save_params = { 'host': cls.HOST, 'port': cls.PORT, 'dbname': cls.DATABASE_NAME, 'collname': cls.COLLECTION_NAME, 'exp_id': cls.EXP_ID, 'save_valid_freq': 20, 'save_filters_freq': 200, 'cache_filters_freq': 100} cls.train_params = { 'data_params': {'func': data.build_data, 'batch_size': 100, 'group': 'train', 'directory': TFUTILS_HOME}, 'num_steps': 500} cls.loss_params = { 'targets': ['labels'], 'agg_func': tf.reduce_mean, 'loss_per_case_func': tf.nn.sparse_softmax_cross_entropy_with_logits} cls.load_params = {'do_restore': True} cls.optimizer_params = {'func': optimizer.ClipOptimizer, 'optimizer_class': tf.train.MomentumOptimizer, 'clip': True, 'momentum': 0.9} cls.learning_rate_params = {'learning_rate': 0.05, 'decay_steps': 10000 // 256, 'decay_rate': 0.95, 'staircase': True} cls.params = { 'dont_run': False, 'skip_check': True, 'model_params': cls.model_params, 'train_params': cls.train_params, 'validation_params': {}, 'log_device_placement': False, 'save_params': cls.save_params, 'load_params': cls.load_params, 'loss_params': cls.loss_params, 'optimizer_params': cls.optimizer_params, 'learning_rate_params': cls.learning_rate_params} @classmethod def remove_checkpoint(cls, checkpoint): """Remove a tf.train.Saver checkpoint.""" cls.log.info('Removing checkpoint: {}'.format(checkpoint)) # TODO: remove ckpt cls.log.info('Checkpoint successfully removed.') raise NotImplementedError @classmethod def remove_directory(cls, directory): """Remove a directory.""" cls.log.info('Removing directory: {}'.format(directory)) shutil.rmtree(directory) cls.log.info('Directory successfully removed.') @classmethod def remove_database(cls, database_name): """Remove a MonogoDB database.""" cls.log.info('Removing database: {}'.format(database_name)) cls.conn.drop_database(database_name) cls.log.info('Database successfully removed.') @classmethod def remove_collection(cls, collection_name): """Remove a MonogoDB collection.""" cls.log.debug('Removing collection: {}'.format(collection_name)) cls.conn[cls.DATABASE_NAME][collection_name].drop() cls.log.info('Collection successfully removed.') @classmethod def remove_document(cls, document): raise NotImplementedError @staticmethod def makedirs(dir): try: os.makedirs(dir) except OSError as e: if e.errno != errno.EEXIST: raise
def test_from_params(load_params, model_params, validation_params, log_device_placement=False, save_params=None, dont_run=False, skip_check=False, ): """ Main testing interface function. Same as train_from_parameters; but just performs testing without training. For documentation, see argument descriptions in train_from_params. """ params, test_args = parse_params( 'test', model_params, dont_run=dont_run, skip_check=skip_check, save_params=save_params, load_params=load_params, validation_params=validation_params, log_device_placement=log_device_placement, ) with tf.Graph().as_default(), tf.device(DEFAULT_HOST): # create session sess = tf.Session( config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=log_device_placement, )) init_op_global = tf.global_variables_initializer() sess.run(init_op_global) init_op_local = tf.local_variables_initializer() sess.run(init_op_local) log.info('Initialized from scratch first') # For convenience, use list of dicts instead of dict of lists _params = [{key: value[i] for (key, value) in params.items()} for i in range(len(params['model_params']))] _ttargs = [{key: value[i] for (key, value) in test_args.items()} for i in range(len(params['model_params']))] # Build a graph for each distinct model. for param, ttarg in zip(_params, _ttargs): if not 'cache_dir' in load_params: temp_cache_dir = save_params.get('cache_dir', None) load_params['cache_dir'] = temp_cache_dir log.info('cache_dir not found in load_params, using cache_dir ({}) from save_params'.format(temp_cache_dir)) ttarg['dbinterface'] = DBInterface(params=param, load_params=param['load_params']) ttarg['dbinterface'].load_rec() ld = ttarg['dbinterface'].load_data assert ld is not None, "No load data found for query, aborting" ld = ld[0] # TODO: have option to reconstitute model_params entirely from # saved object ("revivification") param['model_params']['seed'] = ld['params']['model_params']['seed'] cfg_final = ld['params']['model_params']['cfg_final'] ttarg['validation_targets'] = \ get_valid_targets_dict( loss_params=None, cfg_final=cfg_final, **param) # tf.get_variable_scope().reuse_variables() param['load_params']['do_restore'] = True param['model_params']['cfg_final'] = cfg_final prefix = param['model_params']['prefix'] + '/' all_vars = variables._all_saveable_objects() var_list = strip_prefix(prefix, all_vars) ttarg['dbinterface'] = DBInterface(sess=sess, params=param, var_list=var_list, load_params=param['load_params'], save_params=param['save_params']) ttarg['dbinterface'].initialize(no_scratch=True) ttarg['save_intermediate_freq'] = param['save_params'].get('save_intermediate_freq') # Convert back to a dictionary of lists params = {key: [param[key] for param in _params] for key in _params[0].keys()} test_args = {key: [ttarg[key] for ttarg in _ttargs] for key in _ttargs[0].keys()} if dont_run: return test_args res = test(sess, **test_args) sess.close() return res