def run(train_data, eval_data, model_output_dir, gpu, training_params, _config):
    # Create output directory
    if not os.path.isdir(model_output_dir):
        os.makedirs(model_output_dir)
    else:
        assert _config.get('restore_model'), \
            '{0} already exists, you cannot use it as output directory. ' \
            'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir)
    # Save config
    with open(os.path.join(model_output_dir, 'config.json'), 'w') as f:
        json.dump(_config, f, indent=4, sort_keys=True)

    # Create export directory for saved models
    saved_model_dir = os.path.join(model_output_dir, 'export')
    if not os.path.isdir(saved_model_dir):
        os.makedirs(saved_model_dir)

    training_params = utils.TrainingParams.from_dict(training_params)

    session_config = tf.ConfigProto()
    session_config.gpu_options.visible_device_list = str(gpu)
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    estimator_config = tf.estimator.RunConfig().replace(session_config=session_config,
                                                        save_summary_steps=10,
                                                        keep_checkpoint_max=1)
    estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir,
                                       params=_config, config=estimator_config)

    def get_dirs_or_files(input_data):
        if os.path.isdir(input_data):
            image_input, labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels')
            # Check if training dir exists
            assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
            assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)

        elif os.path.isfile(input_data) and input_data.endswith('.csv'):
            image_input = input_data
            labels_input = None
        else:
            raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data))
        return image_input, labels_input

    train_input, train_labels_input = get_dirs_or_files(train_data)
    if eval_data is not None:
        eval_input, eval_labels_input = get_dirs_or_files(eval_data)

    # Configure exporter
    serving_input_fn = input.serving_input_filename(training_params.input_resized_size)
    if eval_data is not None:
        exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
    else:
        exporter = tf.estimator.LatestExporter(name='SimpleExporter', serving_input_receiver_fn=serving_input_fn,
                                               exports_to_keep=5)

    for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'):
        estimator.train(input.input_fn(train_input,
                                       input_label_dir=train_labels_input,
                                       num_epochs=training_params.evaluate_every_epoch,
                                       batch_size=training_params.batch_size,
                                       data_augmentation=training_params.data_augmentation,
                                       make_patches=training_params.make_patches,
                                       image_summaries=True,
                                       params=_config,
                                       num_threads=32))

        if eval_data is not None:
            eval_result = estimator.evaluate(input.input_fn(eval_input,
                                                            input_label_dir=eval_labels_input,
                                                            batch_size=1,
                                                            data_augmentation=False,
                                                            make_patches=False,
                                                            image_summaries=False,
                                                            params=_config,
                                                            num_threads=32))
        else:
            eval_result = None

        exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
                        is_the_final_export=False)
Exemple #2
0
import os