Exemple #1
0
def export_model(config, model_creator):
    if not config.export_path:
        raise ValueError("Export path must be specified.")
    if not config.model_version:
        raise ValueError("Export model version must be specified.")

    utils.makedir(config.export_path)

    # Create model
    model = model_helper.create_model(model_creator, config, mode="infer")

    # TensorFlow model
    config_proto = utils.get_config_proto()
    sess = tf.Session(config=config_proto, graph=model.graph)

    with model.graph.as_default():
        loaded_model, global_step = model_helper.create_or_load_model(
            model.model, config.best_eval_loss_dir, sess, "infer")

        export_dir = os.path.join(config.export_path, config.model_version)
        builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
        inputs = {
            "word_ids1":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_ids1),
            "word_ids2":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_ids2),
            "word_len1":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_len1),
            "word_len2":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_len2),
            "char_ids1":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_ids1),
            "char_ids2":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_ids2),
            "char_len1":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_len1),
            "char_len2":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_len2)
        }
        outputs = {
            "simscore":
            tf.saved_model.utils.build_tensor_info(loaded_model.simscore)
        }
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs,
                outputs=outputs,
                method_name=tf.saved_model.signature_constants.
                PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING], {
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                prediction_signature
            })
        builder.save()
        logger.info("Export model succeed.")
Exemple #2
0
    def __init__(self, args, net_input_shape):
        '''
        Create evaluation model and load the pre-train weights for inference.
        '''
        self.net_input_shape = net_input_shape
        weights_path = join(args.weights_path)
        # Create model object in inference mode but Disable decoder layer.
        _, eval_model, _ = create_model(args, net_input_shape, enable_decoder=False)

        # Load weights trained on MS-COCO by name because part of output layers are disable.
        eval_model.load_weights(weights_path, by_name=True)
        self.model = eval_model
Exemple #3
0
def test(config, model_creator):
    # for metric in config.metrics.split(","):
    best_metric_label = "best_eval_loss"
    model_dir = getattr(config, best_metric_label + "_dir")

    logger.info("Start evaluating saved best model on training-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.train_file, model_dir)

    logger.info("Start evaluating saved best model on dev-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.dev_file, model_dir)

    logger.info("Start evaluating saved best model on test-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.test_file, model_dir)
Exemple #4
0
def inference(config, model_creator):
    output_file = "output_" + os.path.split(config.infer_file)[-1].split(".")[0]
    # Inference output directory
    pred_file = os.path.join(config.model_dir, output_file)
    utils.makedir(pred_file)

    # Inference
    model_dir = config.best_eval_loss_dir

    # Create model
    # model_creator = my_model.MyModel
    infer_model = model_helper.create_model(model_creator, config, mode="infer")

    # TensorFlow model
    sess_config = utils.get_config_proto()
    infer_sess = tf.Session(config=sess_config, graph=infer_model.graph)

    with infer_model.graph.as_default():
        loaded_infer_model, _ = model_helper.create_or_load_model(
            infer_model.model, model_dir, infer_sess, "infer")

    run_infer(config, loaded_infer_model, infer_sess, pred_file)
def main(args):

    net_input_shape = (RESOLUTION, RESOLUTION, 3
                       )  # Only access RGB 3 channels.
    # Create the model for training/testing/manipulation
    # enable_decoder = False only for SegCaps R3 to disable recognition image output on evaluation model
    # to speed up performance.
    model_list = create_model(args=args,
                              input_shape=net_input_shape,
                              enable_decoder=True)
    #print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)

    #     args.output_name = 'sh-' + str(args.shuffle_data) + '_a-' + str(args.aug_data)

    args.time = time
    if platform.system() == 'Windows':
        args.use_multiprocessing = False
    else:
        args.use_multiprocessing = True
    args.check_dir = join(args.data_root, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    #if args.train == True:
    from train import train
    # Run training
    train(args, model_list[0], net_input_shape)

    if args.test == True:
        from test import test
        # Run testing
        test(args, net_input_shape)

    if args.manip == True:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
Exemple #6
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (
        args.train or args.test or args.manip
    ), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)
    except:
        # Create the training and test splits if not found
        logging.info(
            '\nNo existing training, validate, test files...System will generate it.'
        )
        split_data(args.data_root_dir, num_splits=args.Kfold)
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)

    # Get image properties from first image. Assume they are all the same.
    logging.info('\nRead image files...%s' %
                 (join(args.data_root_dir, 'imgs', train_list[0][0])))
    # Get image shape from the first image.
    image = sitk.GetArrayFromImage(
        sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0])))
    img_shape = image.shape  # # (x, y, channels)
    if args.dataset == 'luna16':
        net_input_shape = (img_shape[1], img_shape[2], args.slices)
    else:
        args.slices = 1
        if GRAYSCALE:
            net_input_shape = (RESOLUTION, RESOLUTION, 1)  # only one channel
        else:
            net_input_shape = (RESOLUTION, RESOLUTION, 3
                               )  # Only access RGB 3 channels.
    # Create the model for training/testing/manipulation
    # enable_decoder = False only for SegCaps R3 to disable recognition image output on evaluation model
    # to speed up performance.
    model_list = create_model(args=args,
                              input_shape=net_input_shape,
                              enable_decoder=True)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)

    #     args.output_name = 'sh-' + str(args.shuffle_data) + '_a-' + str(args.aug_data)

    args.time = time
    if platform.system() == 'Windows':
        args.use_multiprocessing = False
    else:
        args.use_multiprocessing = True
    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
Exemple #7
0
def train(config, model_creator):
    steps_per_stats = config.steps_per_stats
    steps_per_eval = config.steps_per_eval
    model_dir = config.model_dir
    log_dir = config.log_dir
    ckpt_name = config.ckpt_name
    ckpt_path = os.path.join(model_dir, ckpt_name)

    # Create model
    train_model = model_helper.create_model(model_creator, config, "train")
    eval_model = model_helper.create_model(model_creator, config, "eval")
    # infer_model = model_helper.create_model(model_creator, config, "infer")

    train_data = data_helper.load_data(config.train_file,
                                       config.word_vocab_file,
                                       config.char_vocab_file,
                                       w_max_len1=config.max_word_len1,
                                       w_max_len2=config.max_word_len2,
                                       c_max_len1=config.max_char_len1,
                                       c_max_len2=config.max_char_len2,
                                       text_split="|",
                                       split="\t")
    train_iterator = data_helper.batch_iterator(train_data,
                                                batch_size=config.batch_size,
                                                shuffle=True)

    eval_data = data_helper.load_data(config.dev_file,
                                      config.word_vocab_file,
                                      config.char_vocab_file,
                                      w_max_len1=config.max_word_len1,
                                      w_max_len2=config.max_word_len2,
                                      c_max_len1=config.max_char_len1,
                                      c_max_len2=config.max_char_len2,
                                      text_split="|",
                                      split="\t")
    # eval_iterator = data_helper.batch_iterator(eval_data, batch_size=config.batch_size, shuffle=False)

    # TensorFlow model
    session_config = utils.get_config_proto()
    train_sess = tf.Session(config=session_config, graph=train_model.graph)
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    # infer_sess = tf.Session(config=config, graph=infer_model.graph)

    # Summary Writer
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(log_dir, "train_log"), train_model.graph)
    eval_summary_writer = tf.summary.FileWriter(
        os.path.join(log_dir, "eval_log"), eval_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")
        local_initializer = tf.local_variables_initializer()

        # running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics")
        # running_vars_initializer = tf.variables_initializer(var_list=running_vars)

    step_time, train_loss, train_acc, train_rec, train_pre, train_f1, train_auc, gN = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    lr = loaded_train_model.learning_rate.eval(session=train_sess)
    last_stat_step = global_step
    last_eval_step = global_step

    logger.info("# Start step %d" % global_step)

    epoch_idx = 0
    while epoch_idx < config.num_train_epochs:
        start_time = time.time()
        try:
            # TODO: tf.metrics
            # train_sess.run(running_vars_initializer)
            train_sess.run(local_initializer)

            batch = next(train_iterator)
            b_word_ids1, b_word_ids2, b_word_len1, b_word_len2, b_char_ids1, b_char_ids2, b_char_len1, b_char_len2, b_labels = batch
            # for b in batch:
            #     print(b)
            train_summary1, pred, step_loss, _, acc_op, rec_op, pre_op, auc_op, global_step, grad_norm, lr = \
                loaded_train_model.train(train_sess, b_word_ids1, b_word_ids2, b_word_len1, b_word_len2,
                                         b_char_ids1, b_char_ids2, b_char_len1, b_char_len2, b_labels)
            train_summary2, step_acc, step_rec, step_pre, step_auc = \
                train_sess.run([loaded_train_model.train_summary2,
                                loaded_train_model.accuracy,
                                loaded_train_model.recall,
                                loaded_train_model.precision,
                                loaded_train_model.auc])
            config.epoch_step += 1

        except StopIteration:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_idx += 1
            config.epoch_step = 0
            train_iterator = data_helper.batch_iterator(
                train_data, batch_size=config.batch_size, shuffle=True)
            continue

        step_time += (time.time() - start_time)
        train_loss += step_loss
        train_acc += step_acc
        train_rec += step_rec
        train_pre += step_pre
        train_auc += step_auc
        gN += grad_norm

        if global_step - last_stat_step >= steps_per_stats:
            last_stat_step = global_step
            step_time /= steps_per_stats
            train_loss /= steps_per_stats
            train_acc /= steps_per_stats
            train_rec /= steps_per_stats
            train_pre /= steps_per_stats
            train_f1 = (2 * train_rec * train_pre) / (train_rec + train_pre +
                                                      0.00000001)
            gN /= steps_per_stats

            logger.info(
                "  step %d lr %g step_time %.2fs loss %.4f acc %.4f rec %.4f pre %.4f f1 %.4f auc %.4f gN %.2f"
                % (global_step, lr, step_time, train_loss, train_acc,
                   train_rec, train_pre, train_f1, train_auc, grad_norm))
            train_summary_writer.add_summary(train_summary1,
                                             global_step=global_step)
            train_summary_writer.add_summary(train_summary2,
                                             global_step=global_step)
            step_time, train_loss, train_acc, train_rec, train_pre, train_f1, train_auc, gN = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          ckpt_path,
                                          global_step=global_step)
            # Evaluate on dev
            run_eval(config,
                     eval_model,
                     eval_sess,
                     eval_data,
                     model_dir,
                     ckpt_name,
                     eval_summary_writer,
                     save_on_best=True)

    logger.info("# Finished epoch %d, step %d." % (epoch_idx, global_step))

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  ckpt_path,
                                  global_step=global_step)
    logger.info(
        "# Final, step %d lr %g step_time %.2fs loss %.4f acc %.4f rec %.4f pre %.4f f1 %.4f auc %.4f gN %.2f"
        % (global_step, lr, step_time, train_loss, train_acc, train_rec,
           train_pre, train_f1, train_auc, gN))
    logger.info("# Done training!")

    train_summary_writer.close()
    eval_summary_writer.close()