def _saved_config_file_path(): filepaths = [ os.path.join(environment.EXPERIMENT_DIR, filename) for filename in ('config.py', 'config.yaml') ] for filepath in filepaths: if file_io.exists(filepath): return filepath raise FileNotFoundError("Config file not found: '{}'".format("' nor '".join(filepaths)))
def save_yaml(output_dir, config): """Save two yaml files. 1. 'config.yaml' is duplication of python config file as yaml. 2. 'meta.yaml' for application. The yaml's keys defined by `PARAMS_FOR_EXPORT`. """ if not file_io.exists(output_dir): file_io.makedirs(output_dir) config_yaml_path = _save_config_yaml(output_dir, config) meta_yaml_path = _save_meta_yaml(output_dir, config) return config_yaml_path, meta_yaml_path
def train(config_file, experiment_id=None, recreate=False, profile_step=-1): if not experiment_id: # Default model_name will be taken from config file: {model_name}.yml. model_name = os.path.splitext(os.path.basename(config_file))[0] experiment_id = '{}_{:%Y%m%d%H%M%S}'.format(model_name, datetime.now()) run(config_file, experiment_id, recreate, profile_step) output_dir = os.environ.get('OUTPUT_DIR', 'saved') experiment_dir = os.path.join(output_dir, experiment_id) checkpoint = os.path.join(experiment_dir, 'checkpoints', 'checkpoint') if not file_io.exists(checkpoint): raise Exception( 'Checkpoints are not created in {}'.format(experiment_dir)) with file_io.File(checkpoint) as stream: data = yaml.load(stream, Loader=yaml.Loader) checkpoint_name = os.path.basename(data['model_checkpoint_path']) return experiment_id, checkpoint_name
def evaluate(config, restore_path, output_dir): if restore_path is None: restore_file = executor.search_restore_filename(environment.CHECKPOINTS_DIR) restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file) if not file_io.exists("{}.index".format(restore_path)): raise Exception("restore file {} dont exists.".format(restore_path)) if output_dir is None: output_dir = os.path.join(os.path.dirname(os.path.dirname(restore_path)), "evaluate") logger.info(f"restore_path:{restore_path}") DatasetClass = config.DATASET_CLASS ModelClass = config.NETWORK_CLASS network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()} if "test" in DatasetClass.available_subsets: subset = "test" else: subset = "validation" validation_dataset = setup_dataset(config, subset, seed=0) graph = tf.Graph() with graph.as_default(): if ModelClass.__module__.startswith("blueoil.networks.object_detection"): model = ModelClass( classes=validation_dataset.classes, num_max_boxes=validation_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=validation_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) is_training = tf.constant(False, name="is_training") images_placeholder, labels_placeholder = model.placeholders() output = model.inference(images_placeholder, is_training) metrics_ops_dict, metrics_update_op = model.metrics(output, labels_placeholder) model.summary(output, labels_placeholder) summary_op = tf.compat.v1.summary.merge_all() metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict) init_op = tf.compat.v1.global_variables_initializer() reset_metrics_op = tf.compat.v1.local_variables_initializer() saver = tf.compat.v1.train.Saver(max_to_keep=None) session_config = None # tf.ConfigProto(log_device_placement=True) sess = tf.compat.v1.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) validation_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/evaluate") saver.restore(sess, restore_path) last_step = sess.run(model.global_step) # init metrics values test_step_size = int(math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) logger.info(f"test_step_size{test_step_size}") for test_step in range(test_step_size): logger.info(f"test_step{test_step}") images, labels = validation_dataset.feed() feed_dict = { images_placeholder: images, labels_placeholder: labels, } # Summarize at only last step. if test_step == test_step_size - 1: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) validation_writer.add_summary(summary, last_step) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_summary = sess.run(metrics_summary_op) validation_writer.add_summary(metrics_summary, last_step) is_tfds = "TFDS_KWARGS" in config.DATASET dataset_name = config.DATASET.TFDS_KWARGS["name"] if is_tfds else config.DATASET_CLASS.__name__ dataset_path = config.DATASET.TFDS_KWARGS["data_dir"] if is_tfds else "" metrics_dict = { 'task_type': config.TASK.value, 'network_name': config.NETWORK_CLASS.__name__, 'dataset_name': dataset_name, 'dataset_path': dataset_path, 'last_step': int(last_step), 'metrics': {k: float(sess.run(op)) for k, op in metrics_ops_dict.items()}, } save_json(output_dir, json.dumps(metrics_dict, indent=4,), metrics_dict["last_step"]) validation_dataset.close()
def test_exists(): with tempfile.NamedTemporaryFile() as f: assert file_io.exists(f.name)