Пример #1
0
def export_model(saver: tf.train.Saver, input_names: list, output_name: str, model_name: str):
    """
    You can find node names by using debugger: just connect it right after model is created and look for nodes in the inspec
    :param saver:
    :param input_names:
    :param output_name:
    :param model_name:
    :return:
    """
    os.makedirs("./out", exist_ok=True)
    tf.train.write_graph(K.get_session().graph_def, 'out',
                         model_name + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + model_name + '.chkp')

    # pbtxt is human readable representation of the graph
    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None,
                              False, 'out/' + model_name + '.chkp', output_name,
                              "save/restore_all", "save/Const:0",
                              'out/frozen_' + model_name + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    # optimization of the graph so we can use it in the android app
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def, input_names, [output_name],
        tf.float32.as_datatype_enum)

    # This is archived optimal graph in the protobuf format we'll use in our android App.
    with tf.gfile.FastGFile('out/opt_' + model_name + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")
Пример #2
0
def load(sess: tf.Session, saver: tf.train.Saver, checkpoint_dir: str) -> bool:
    '''Loads the most recent checkpoint from checkpoint_dir.

    Args
    - sess: tf.Session
    - saver: tf.train.Saver
    - checkpoint_dir: str, path to directory containing checkpoint(s)

    Returns: bool, True if successful at restoring checkpoint from given dir
    '''
    print(f'Reading from checkpoint dir: {checkpoint_dir}')
    if checkpoint_dir is None:
        raise ValueError('No checkpoint path, given, cannot load checkpoint')
    if not os.path.isdir(checkpoint_dir):
        raise ValueError('Given path is not a valid directory.')

    # read the CheckpointState proto from 'checkpoint' file in checkpoint_dir
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        print(f'Loading checkpoint: {ckpt_name}')
        if not checkpoint_path_exists(ckpt.model_checkpoint_path):
            raise LoadNoFileError(
                'Checkpoint could not be loaded because it does not exist,'
                ' but its information is in the checkpoint meta-data file.')
        saver.restore(sess, ckpt.model_checkpoint_path)
        return True
    return False
Пример #3
0
 def save_model(self,
                    sess: tf.Session,
                    model_params: dict,
                    saver: tf.train.Saver,
                    epochs: int) -> None:
         saver.save(sess, model_params['model_path'] + '/' + self.name, global_step=epochs)
         with open(model_params['model_path'] + '/model_params.json', 'w+') as f:
             json.dump(model_params, f)
Пример #4
0
def model_save(saver: tf.train.Saver, sess: tf.Session, model_path: str,
               model_prefix: str, batch_id: int):
    try:
        saver.save(sess, f"{model_path}/{model_prefix}", global_step=batch_id)
        print(f"Successful saving model[{batch_id}] ...")
        return True

    except Exception as error:
        print(f"Failed saving model[{batch_id}] : {error}")
        return False
Пример #5
0
def restore_model(sess: tf.Session, saver: tf.train.Saver, c: TrainConfig):
    model_name = get_model_name(c)
    save_dir = os.path.join(c.save_dir, model_name)
    model_path = os.path.join(c.save_dir, model_name, 'model')
    latest_ckpt = tf.train.latest_checkpoint(save_dir)
    if latest_ckpt:
        print(f'💽 restoring latest checkpoint from: {latest_ckpt}')
        saver.restore(sess, latest_ckpt)
    else:
        print(f'💽 training new model')
    return model_path
Пример #6
0
def load(saver: tf.train.Saver, sess: tf.Session, checkpoint_dir: str):
    print(" [*] Reading checkpoints from %s..." % checkpoint_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
        counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
        print(" [*] Success to read {}".format(ckpt_name))
        return True, counter
    else:
        print(" [*] Failed to find checkpoints")
        return False, 0
Пример #7
0
 def restore_model_and_get_inital_epoch(self, session: tf.Session,
                                        saver: tf.train.Saver,
                                        load_path: str):
     print("load_path", load_path, flush=True)
     checkpoint_path = self.get_checkpoint_path(load_path)
     print("checkpoint_path", checkpoint_path, flush=True)
     latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
     print("latest_checkpoint", latest_checkpoint, flush=True)
     if latest_checkpoint is not None:
         saver.restore(session, latest_checkpoint)
         return int(pathlib.Path(latest_checkpoint).name)
     else:
         return 0
Пример #8
0
def restore_model(sess: tf.Session, path: str, saver: tf.train.Saver = None) -> tf.train.Saver:
    """
    Loads a tensorflow session from the given path.
    NOTE: This currently loads *all* variables in the saved file, unless one passes in a custom Saver object.
    :param sess: The tensorflow checkpoint to load from
    :param path: The path to the saved data
    :param saver: A custom saver object to use. This can be used to only load certain variables. If None,
    creates a saver object that loads all variables.
    :return: The saver object used.
    """
    if saver is None:
        saver = tf.train.Saver(tf.all_variables())
    saver.restore(sess, path)
    return saver
Пример #9
0
    def _restore_checkpoint(saver: tf.train.Saver,
                            sess: tf.Session,
                            path: Optional[str] = None):
        if path and saver:
            # if a directory is given instead of a path, try to find a checkpoint file there
            checkpoint_file = tf.train.latest_checkpoint(
                path) if os.path.isdir(path) else path

            if checkpoint_file and tf.train.checkpoint_exists(checkpoint_file):
                saver.restore(sess, checkpoint_file)
                tf.logging.info("Model loaded from {}".format(checkpoint_file))
            else:
                tf.logging.info(
                    "No valid checkpoint has been found at {}. Ignoring.".
                    format(path))
Пример #10
0
 def save_graph(self, epoch, logs, session: tf.Session,
                saver: tf.train.Saver, checkpoint_path: str,
                save_summary_writer: tf.summary.FileWriter):
     save_summary_writer.add_graph(session.graph)
     saver.save(session, save_path=f"{checkpoint_path}/{epoch}")
def saver_fn(i, session: tf.Session, saver: tf.train.Saver):
    print('Saving Session {}. . .'.format(i))
    saver.save(session, '/tmp/new_save/checkpoint', global_step=i)
Пример #12
0
def save(saver: tf.train.Saver, sess: tf.Session, checkpoint_dir, step):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    saver.save(sess, checkpoint_dir, global_step=step, write_meta_graph=False)
Пример #13
0
 def reconstruct_saved_model_variables(cls, sess: tf.Session, saver: tf.train.Saver, model_dir: str) -> None:
     print('Loading saved model variables from', model_dir)
     latest_checkpoint = tf.train.latest_checkpoint(model_dir)
     saver.restore(sess, latest_checkpoint)
Пример #14
0
 def save(self, sess: tf.Session, saver: tf.train.Saver, ckpt_dir,
          global_step):
     self.logger.info("saving model ...")
     saver.save(sess,
                "{}/model.ckpt".format(ckpt_dir),
                global_step=global_step)
Пример #15
0
def eval_model(is_training: tf.Variable, sess: tf.Session, best_iou: float,
               val_loss: tf.Tensor, val_acc: tf.Tensor,
               val_iou_update: tf.Operation, val_iou: tf.Tensor,
               val_iou_reset: tf.Operation, val_writer: tf.summary.FileWriter,
               epoch: int, saver: tf.train.Saver) -> float:
    """
    evaluates model with one pass over validation set

    :param is_training: tf var which indicates if model is training
    :param sess: tf sess
    :param best_iou: best validation iou until now
    :param val_loss: val loss tensor
    :param val_acc: val accuracy tensor
    :param val_iou_update: val iou update operation
    :param val_iou: val iou tensor
    :param val_iou_reset: val iou reset operation
    :param val_writer: val summary writer
    :param epoch: index of current epoch
    :param saver: tf model saver
    :return: new best iou
    """
    acc_sum, loss_sum = 0, 0

    # toggle training off
    assign_op = is_training.assign(False)
    sess.run(assign_op)

    val_batches = N_VAL_SAMPLES // BATCH_SIZE
    print(f"starting evaluation {val_batches} batches")

    for j in range(val_batches):
        loss_val, acc_val, _, val_iou_val = sess.run(
            [val_loss, val_acc, val_iou_update, val_iou])
        print(
            f"\tevaluation epoch: {epoch:03d}\tbatch {j:03d} eval:"
            f"\tloss: {loss_val:.4f}\taccuracy: {acc_val:.4f}\taccumulated iou {val_iou_val:.4f}"
        )
        acc_sum += acc_val
        loss_sum += loss_val

    # validation summary
    loss = loss_sum / val_batches
    acc = acc_sum / val_batches
    iou = val_iou_val
    summary = get_tf_summary(loss, acc, iou)
    val_writer.add_summary(summary, epoch)
    print(
        f"evaluation:\tmean loss: {loss:.4f}\tmean acc: {acc:.4f}\tmean iou {iou:.4f}\n"
    )

    # save model if it is better
    if iou > best_iou:
        best_iou = iou
        save_path = saver.save(
            sess,
            os.path.join(LOG_DIR + "_train",
                         f"best_model_epoch_{epoch:03d}.ckpt"))
        print(f"Model saved in file: {save_path}\n")

    # reset accumulator
    sess.run(val_iou_reset)

    # toggle training on
    assign_op = is_training.assign(True)
    sess.run(assign_op)

    return best_iou
Пример #16
0
 def restore_from_checkpoint(self, sess: tf.Session, loader: tf.train.Saver,
                             path_to_checkpoint_dir: str):
     loader.restore(sess, path_to_checkpoint_dir)
Пример #17
0
def train(parameters: dict,
          path: str = os.path.join('./model', 'model'),
          saver: tf.train.Saver = None) -> None:

    weights = {
        'h1':
        tf.Variable(
            tf.random_normal(
                [parameters['image_size'], parameters['n_hidden_1']])),
        'h2':
        tf.Variable(
            tf.random_normal(
                [parameters['n_hidden_1'], parameters['n_hidden_2']])),
        'out':
        tf.Variable(
            tf.random_normal(
                [parameters['n_hidden_2'], parameters['n_classes']]))
    }
    biases = {
        'b1': tf.Variable(tf.random_normal(parameters['n_hidden_1'])),
        'b2': tf.Variable(tf.random_normal([parameters['n_hidden_2']])),
        'out': tf.Variable(tf.random_normal([parameters['n_classes']]))
    }
    x, y = init_x_y(parameters)
    pred = multilayer_perceptron(x, weights, biases)
    # Define loss and optimizer
    cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    optimizer = tf.train.AdamOptimizer(
        learning_rate=parameters['learning_rate']).minimize(cost)

    # Initializing the variables
    init = tf.global_variables_initializer()
    batch_size = parameters['batch_size']

    # Launch the graph
    with tf.Session() as sess:
        sess.run(init)

        # Training cycle
        for epoch in range(parameters['training_epochs']):
            avg_cost = 0.
            total_batch = int(mnist.train.num_examples / batch_size)
            # Loop over all batches
            for i in range(total_batch):
                batch_x, batch_y = mnist.train.next_batch(batch_size)
                # Run optimization op (backprop) and cost op (to get loss value)
                _, c = sess.run([optimizer, cost],
                                feed_dict={
                                    x: batch_x,
                                    y: batch_y
                                })
                # Compute average loss
                avg_cost += c / total_batch
                # Display logs per epoch step
            if epoch % parameters['display_step'] == 0:
                print("Epoch:", '%04d' % (epoch + 1), "cost=",
                      "{:.9f}".format(avg_cost))
        print("Optimization Finished!")
        if saver:
            saver.save(sess, save_path=path)

        # Test model
        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

        # Calculate accuracy
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        print("Accuracy:",
              accuracy.eval({
                  x: mnist.test.images,
                  y: mnist.test.labels
              }))