def load_and_run(model_config, root_dir): """An example usage of loading and running a model from the dataset. Args: model_config: A ModelConfig object that contains relevant hyperparameters of a model. root_dir: Directory containing the dataset """ model_path = model_config.get_checkpoint_path(root_dir) model_fn = model_config.get_model_fn() with tf.Session() as sess: input_fn = data_util.get_input(data=model_config.dataset, data_format=model_config.data_format) image, _ = input_fn() logits = model_fn(image, is_training=False) sess.run(tf.global_variables_initializer()) model_config.load_parameters(model_path, sess) sess.run(logits)
def evaluate_model(model_config, root_dir): """Example for evalutate a model.""" model_path = model_config.get_checkpoint_path(root_dir) model_fn = model_config.get_model_fn() with tf.Session() as sess: input_fn = data_util.get_input(batch_size=500, data=model_config.dataset, data_format=model_config.data_format, mode=tf.estimator.ModeKeys.EVAL) images, labels = input_fn() logits = model_fn(images, is_training=False) predictions = tf.argmax(logits, axis=-1) true_labels = tf.argmax(labels, axis=-1) sess.run(tf.global_variables_initializer()) model_config.load_parameters(model_path, sess) correct_prediction = 0 for _ in range(20): batch_prediction, batch_label = sess.run( [predictions, true_labels]) correct_prediction += np.sum( np.int32(np.equal(batch_prediction, batch_label))) return correct_prediction / 10000.
results, 'r') as f: dd = json.load(f) eval_loss = dd['loss'] eval_cross_entropy = dd['CrossEntropy'] eval_global_step = dd['global_step'] eval_accuracy = dd['Accuracy'] except tf.errors.NotFoundError as e: print('Failed to load model results') print(e) continue try: # Other information input_fn = data_util.get_input( data_dir, data=model_config.dataset, data_format=model_config. data_format, repeat_num=1) all_activations, samples_per_object, layer_names, layer_indices, layer_n_neurons = elu.extract_layers( input_fn, root_dir, model_config) except: print('Failed to load model') raise #except tf.errors.InvalidArgumentError: # failures += [filename] # print('Failed reading %s'%filename) # continue #except tf.errors.NotFoundError: # failures += [filename] # print('Failed reading %s'%filename) # continue