def create_model(session, batch_size):
    """
  Create model and initialize it or load its parameters in a session

  Args
    session: tensorflow session
    batch_size: integer. Number of examples in each batch
  Returns
    model: The created (or loaded) model
  Raises
    ValueError if asked to load a model, but the checkpoint specified by
    FLAGS.load cannot be found.
  """

    model = linear_model.LinearModel(FLAGS.linear_size,
                                     FLAGS.num_layers,
                                     FLAGS.residual,
                                     FLAGS.batch_norm,
                                     FLAGS.max_norm,
                                     batch_size,
                                     FLAGS.learning_rate,
                                     summaries_dir,
                                     dtype=tf.float32)
    print("\n\nCreating model with fresh parameters.")
    session.run(tf.global_variables_initializer())
    return model
예제 #2
0
def load_model(session, batch_size):
  model = linear_model.LinearModel(
      FLAGS.linear_size,
      FLAGS.num_layers,
      FLAGS.residual,
      FLAGS.batch_norm,
      FLAGS.max_norm,
      batch_size,
      FLAGS.learning_rate,
      summaries_dir,
      FLAGS.predict_14,
      dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)

  # Load a previously saved model
  ckpt = tf.train.get_checkpoint_state( train_dir, latest_filename="checkpoint")
  print( "train_dir", train_dir )

  if ckpt and ckpt.model_checkpoint_path:
    # Check if the specific checkpoint exists
    if os.path.isfile(os.path.join(train_dir,"checkpoint-{0}.index".format(FLAGS.load))):
      ckpt_name = os.path.join( os.path.join(train_dir,"checkpoint-{0}".format(FLAGS.load)) )
    else:
      raise ValueError("Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load))

    print("Loading model {0}".format( ckpt_name ))
    model.saver.restore( session, ckpt.model_checkpoint_path )
    return model
  else:
    print("Could not find checkpoint. Aborting.")
    raise( ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path ) )

  return model
def create_model(session, batch_size):
    """
  Loads model in a session

  Args
    session: tensorflow session
    batch_size: integer. Number of examples in each batch
  Returns
    model: The loaded model
  """

    model = linear_model.LinearModel(FLAGS.linear_size,
                                     FLAGS.num_layers,
                                     FLAGS.residual,
                                     FLAGS.batch_norm,
                                     FLAGS.max_norm,
                                     batch_size,
                                     dtype=tf.float32)

    to_restore = tf.train.latest_checkpoint(summaries_dir)
    print("\nRestoring model from {}\n".format(to_restore))
    model.saver.restore(session,
                        to_restore)  # restore model from last checkpoint

    return model
def create_model( session, batch_size ):
  """
  Create model and initialize it or load its parameters in a session

  Args
    session: tensorflow session
    actions: list of string. Actions to train/test on
    batch_size: integer. Number of examples in each batch
  Returns
    model: The created (or loaded) model
  Raises
    ValueError if asked to load a model, but the checkpoint specified by
    FLAGS.load cannot be found.
  """

  model = linear_model.LinearModel(
      FLAGS.linear_size,
      FLAGS.num_layers,
      FLAGS.residual,
      FLAGS.batch_norm,
      FLAGS.max_norm,
      batch_size,
      FLAGS.learning_rate,
      summaries_dir,
      FLAGS.predict_14,
      dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)
   

  if FLAGS.load <= 0:
    # Create a new model from scratch
    print("Creating model with fresh parameters.")
    session.run( tf.global_variables_initializer() )
    return model

  # Load a previously saved model
  ckpt = tf.train.get_checkpoint_state( train_dir, latest_filename="checkpoint")
  print( "train_dir", train_dir )

  if ckpt and ckpt.model_checkpoint_path:
    # Check if the specific checkpoint exists
    if FLAGS.load > 0:
      if os.path.isfile(os.path.join(train_dir,"checkpoint-{0}.index".format(FLAGS.load))):
        ckpt_name = os.path.join( os.path.join(train_dir,"checkpoint-{0}".format(FLAGS.load)) )
      else:
        raise ValueError("Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load))
    else:
      ckpt_name = os.path.basename( ckpt.model_checkpoint_path )

    print("Loading model {0}".format( ckpt_name ))
    model.saver.restore( session, ckpt.model_checkpoint_path )
    return model
  else:
    print("Could not find checkpoint. Aborting.")
    raise( ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path ) )

  return model
예제 #5
0
def create_model(session, actions, batch_size):
    model = linear_model.LinearModel(
        FLAGS.linear_size,
        FLAGS.num_layers,
        FLAGS.residual,
        FLAGS.batch_norm,
        FLAGS.max_norm,
        batch_size,
        FLAGS.learning_rate,
        summaries_dir,
        FLAGS.predict_14,
        dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)

    if FLAGS.load <= 0:
        # Create a new model from scratch
        print("Creating model with fresh parameters.")
        session.run(tf.global_variables_initializer())

    return model

    # Load a previously saved model
    ckpt = tf.train.get_checkpoint_state(train_dir,
                                         latest_filename="checkpoint")
    print("train_dir", train_dir)

    if ckpt and ckpt.model_checkpoint_path:
        # Check if the specific checkpoint exists
        if FLAGS.load > 0:
            if os.path.isfile(
                    os.path.join(train_dir,
                                 "checkpoint-{0}.index".format(FLAGS.load))):
                ckpt_name = os.path.join(
                    os.path.join(train_dir,
                                 "checkpoint-{0}".format(FLAGS.load)))
            else:
                raise ValueError(
                    "Asked to load checkpoint {0}, but it does not seem to exist"
                    .format(FLAGS.load))
        else:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)

            print("Loading model {0}".format(ckpt_name))
            init_g = tf.global_variables_initializer()
            init_l = tf.local_variables_initializer()
            with tf.Session() as sess:
                session.run(init_g)
                session.run(init_l)

            model.saver.restore(session, ckpt.model_checkpoint_path)
            return model
    else:
        print("Could not find checkpoint. Aborting.")
        raise (ValueError, "Checkpoint {0} does not seem to exist".format(
            ckpt.model_checkpoint_path))

        return model
예제 #6
0
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    ax1.set_title("System Transition Function")
    viz_1_1d.plot_1_1d_system(system, ax=ax1)

    # Collect data by applying random actions in the system
    states, actions, rewards = system.sim_traj(random_control)
    states_prime = states[1:]
    states = states[:-1]

    # Matching expected format for model.fit()
    data = [(np.array([states[i]]), np.array([actions[i]]), rewards[i],
             np.array([states_prime[i]]), False) for i in range(len(states))]
    print("Fitting on {} datapoints".format(len(data)))

    # Fit linear model to data
    model = linear_model.LinearModel(1, 1)
    model.fit(data)
    model.limits_low = np.array([system.state_limits[0]])
    model.limits_high = np.array([system.state_limits[1]])
    print("Model is A={}, B={} (not fitting reward function now)\n".format(
        model.A, model.B))

    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    ax2.set_title("Linear Model of Transition Function")
    viz_1_1d.plot_1_1d_model(model,
                             ax=ax2,
                             action_limit_low=system.action_limits[0],
                             action_limit_high=system.action_limits[1])

    # Compute optimal controller for fit model
    system.reset()
예제 #7
0
def create_model(session, actions,
                 batch_size):  #主要分为两部分:要么生成模型,要么加载之前的checkpoint
    """
  Create model and initialize it or load its parameters in a session

  Args
    session: tensorflow session
    actions: list of string. Actions to train/test on
    batch_size: integer. Number of examples in each batch
  Returns
    model: The created (or loaded) model
  Raises
    ValueError if asked to load a model, but the checkpoint specified by
    FLAGS.load cannot be found.
  """

    model = linear_model.LinearModel(
        FLAGS.linear_size,
        FLAGS.num_layers,
        FLAGS.residual,
        FLAGS.batch_norm,
        FLAGS.max_norm,
        batch_size,
        FLAGS.learning_rate,
        summaries_dir,
        FLAGS.predict_14,
        dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)

    if FLAGS.load <= 0:
        # Create a new model from scratch
        print("Creating model with fresh parameters.")
        session.run(
            tf.global_variables_initializer()
        )  #session.run使用已定义好的会话来计算关心的结果  在TensorFlow中一个变量的值在被明确调用前,这个变量的初始化过程需要被明确地调用。
        return model

    # Load a previously saved model
    ckpt = tf.train.get_checkpoint_state(
        train_dir, latest_filename="checkpoint")  #experiment中的checkpoint
    print("train_dir", train_dir)

    ##########################下面都是用来checkpoint的#############################
    if ckpt and ckpt.model_checkpoint_path:
        # Check if the specific checkpoint exists
        if FLAGS.load > 0:
            if os.path.isfile(
                    os.path.join(train_dir,
                                 "checkpoint-{0}.index".format(FLAGS.load))
            ):  #train_dir+'checkpoint-1.index' 即experiment.checkpoint-24371.index是否为文件
                ckpt_name = os.path.join(
                    os.path.join(train_dir, "checkpoint-{0}".format(
                        FLAGS.load)))  # experiments\checkpoint-24371
            else:
                raise ValueError(
                    "Asked to load checkpoint {0}, but it does not seem to exist"
                    .format(FLAGS.load))
        else:
            ckpt_name = os.path.basename(
                ckpt.model_checkpoint_path)  #返回path最后的文件名

        print("Loading model {0}".format(ckpt_name))
        model.saver.restore(session,
                            ckpt.model_checkpoint_path)  #之前保存的结果会被加载到当前的对话中
        return model  #model是个类
    else:
        print("Could not find checkpoint. Aborting.")
        raise (ValueError, "Checkpoint {0} does not seem to exist".format(
            ckpt.model_checkpoint_path))


##############################################################################
    return model