def create_model(session, batch_size): """ Create model and initialize it or load its parameters in a session Args session: tensorflow session batch_size: integer. Number of examples in each batch Returns model: The created (or loaded) model Raises ValueError if asked to load a model, but the checkpoint specified by FLAGS.load cannot be found. """ model = linear_model.LinearModel(FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, FLAGS.learning_rate, summaries_dir, dtype=tf.float32) print("\n\nCreating model with fresh parameters.") session.run(tf.global_variables_initializer()) return model
def load_model(session, batch_size): model = linear_model.LinearModel( FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, FLAGS.learning_rate, summaries_dir, FLAGS.predict_14, dtype=tf.float16 if FLAGS.use_fp16 else tf.float32) # Load a previously saved model ckpt = tf.train.get_checkpoint_state( train_dir, latest_filename="checkpoint") print( "train_dir", train_dir ) if ckpt and ckpt.model_checkpoint_path: # Check if the specific checkpoint exists if os.path.isfile(os.path.join(train_dir,"checkpoint-{0}.index".format(FLAGS.load))): ckpt_name = os.path.join( os.path.join(train_dir,"checkpoint-{0}".format(FLAGS.load)) ) else: raise ValueError("Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load)) print("Loading model {0}".format( ckpt_name )) model.saver.restore( session, ckpt.model_checkpoint_path ) return model else: print("Could not find checkpoint. Aborting.") raise( ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path ) ) return model
def create_model(session, batch_size): """ Loads model in a session Args session: tensorflow session batch_size: integer. Number of examples in each batch Returns model: The loaded model """ model = linear_model.LinearModel(FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, dtype=tf.float32) to_restore = tf.train.latest_checkpoint(summaries_dir) print("\nRestoring model from {}\n".format(to_restore)) model.saver.restore(session, to_restore) # restore model from last checkpoint return model
def create_model( session, batch_size ): """ Create model and initialize it or load its parameters in a session Args session: tensorflow session actions: list of string. Actions to train/test on batch_size: integer. Number of examples in each batch Returns model: The created (or loaded) model Raises ValueError if asked to load a model, but the checkpoint specified by FLAGS.load cannot be found. """ model = linear_model.LinearModel( FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, FLAGS.learning_rate, summaries_dir, FLAGS.predict_14, dtype=tf.float16 if FLAGS.use_fp16 else tf.float32) if FLAGS.load <= 0: # Create a new model from scratch print("Creating model with fresh parameters.") session.run( tf.global_variables_initializer() ) return model # Load a previously saved model ckpt = tf.train.get_checkpoint_state( train_dir, latest_filename="checkpoint") print( "train_dir", train_dir ) if ckpt and ckpt.model_checkpoint_path: # Check if the specific checkpoint exists if FLAGS.load > 0: if os.path.isfile(os.path.join(train_dir,"checkpoint-{0}.index".format(FLAGS.load))): ckpt_name = os.path.join( os.path.join(train_dir,"checkpoint-{0}".format(FLAGS.load)) ) else: raise ValueError("Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load)) else: ckpt_name = os.path.basename( ckpt.model_checkpoint_path ) print("Loading model {0}".format( ckpt_name )) model.saver.restore( session, ckpt.model_checkpoint_path ) return model else: print("Could not find checkpoint. Aborting.") raise( ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path ) ) return model
def create_model(session, actions, batch_size): model = linear_model.LinearModel( FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, FLAGS.learning_rate, summaries_dir, FLAGS.predict_14, dtype=tf.float16 if FLAGS.use_fp16 else tf.float32) if FLAGS.load <= 0: # Create a new model from scratch print("Creating model with fresh parameters.") session.run(tf.global_variables_initializer()) return model # Load a previously saved model ckpt = tf.train.get_checkpoint_state(train_dir, latest_filename="checkpoint") print("train_dir", train_dir) if ckpt and ckpt.model_checkpoint_path: # Check if the specific checkpoint exists if FLAGS.load > 0: if os.path.isfile( os.path.join(train_dir, "checkpoint-{0}.index".format(FLAGS.load))): ckpt_name = os.path.join( os.path.join(train_dir, "checkpoint-{0}".format(FLAGS.load))) else: raise ValueError( "Asked to load checkpoint {0}, but it does not seem to exist" .format(FLAGS.load)) else: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) print("Loading model {0}".format(ckpt_name)) init_g = tf.global_variables_initializer() init_l = tf.local_variables_initializer() with tf.Session() as sess: session.run(init_g) session.run(init_l) model.saver.restore(session, ckpt.model_checkpoint_path) return model else: print("Could not find checkpoint. Aborting.") raise (ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path)) return model
ax1 = fig.add_subplot(2, 2, 1, projection='3d') ax1.set_title("System Transition Function") viz_1_1d.plot_1_1d_system(system, ax=ax1) # Collect data by applying random actions in the system states, actions, rewards = system.sim_traj(random_control) states_prime = states[1:] states = states[:-1] # Matching expected format for model.fit() data = [(np.array([states[i]]), np.array([actions[i]]), rewards[i], np.array([states_prime[i]]), False) for i in range(len(states))] print("Fitting on {} datapoints".format(len(data))) # Fit linear model to data model = linear_model.LinearModel(1, 1) model.fit(data) model.limits_low = np.array([system.state_limits[0]]) model.limits_high = np.array([system.state_limits[1]]) print("Model is A={}, B={} (not fitting reward function now)\n".format( model.A, model.B)) ax2 = fig.add_subplot(2, 2, 2, projection='3d') ax2.set_title("Linear Model of Transition Function") viz_1_1d.plot_1_1d_model(model, ax=ax2, action_limit_low=system.action_limits[0], action_limit_high=system.action_limits[1]) # Compute optimal controller for fit model system.reset()
def create_model(session, actions, batch_size): #主要分为两部分:要么生成模型,要么加载之前的checkpoint """ Create model and initialize it or load its parameters in a session Args session: tensorflow session actions: list of string. Actions to train/test on batch_size: integer. Number of examples in each batch Returns model: The created (or loaded) model Raises ValueError if asked to load a model, but the checkpoint specified by FLAGS.load cannot be found. """ model = linear_model.LinearModel( FLAGS.linear_size, FLAGS.num_layers, FLAGS.residual, FLAGS.batch_norm, FLAGS.max_norm, batch_size, FLAGS.learning_rate, summaries_dir, FLAGS.predict_14, dtype=tf.float16 if FLAGS.use_fp16 else tf.float32) if FLAGS.load <= 0: # Create a new model from scratch print("Creating model with fresh parameters.") session.run( tf.global_variables_initializer() ) #session.run使用已定义好的会话来计算关心的结果 在TensorFlow中一个变量的值在被明确调用前,这个变量的初始化过程需要被明确地调用。 return model # Load a previously saved model ckpt = tf.train.get_checkpoint_state( train_dir, latest_filename="checkpoint") #experiment中的checkpoint print("train_dir", train_dir) ##########################下面都是用来checkpoint的############################# if ckpt and ckpt.model_checkpoint_path: # Check if the specific checkpoint exists if FLAGS.load > 0: if os.path.isfile( os.path.join(train_dir, "checkpoint-{0}.index".format(FLAGS.load)) ): #train_dir+'checkpoint-1.index' 即experiment.checkpoint-24371.index是否为文件 ckpt_name = os.path.join( os.path.join(train_dir, "checkpoint-{0}".format( FLAGS.load))) # experiments\checkpoint-24371 else: raise ValueError( "Asked to load checkpoint {0}, but it does not seem to exist" .format(FLAGS.load)) else: ckpt_name = os.path.basename( ckpt.model_checkpoint_path) #返回path最后的文件名 print("Loading model {0}".format(ckpt_name)) model.saver.restore(session, ckpt.model_checkpoint_path) #之前保存的结果会被加载到当前的对话中 return model #model是个类 else: print("Could not find checkpoint. Aborting.") raise (ValueError, "Checkpoint {0} does not seem to exist".format( ckpt.model_checkpoint_path)) ############################################################################## return model