def test_serializable(): with suppress_params_loading(): obj = Simple(name="obj") obj1 = Serializable.clone(obj, name="obj1") assert obj.w.name.startswith('obj/') assert obj1.w.name.startswith('obj1/') obj2 = AllArgs(0, *(1,), **{'kwarg': 2}) obj3 = Serializable.clone(obj2) assert obj3.vararg == 0 assert len(obj3.args) == 1 and obj3.args[0] == 1 assert len(obj3.kwargs) == 1 and obj3.kwargs['kwarg'] == 2
def test_serializable(): with suppress_params_loading(): obj = Simple(name="obj") obj1 = Serializable.clone(obj, name="obj1") assert obj.w.name.startswith('obj/') assert obj1.w.name.startswith('obj1/')
def train(self): gc_dump_time = time.time() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # This seems like a rather sequential method pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, replacement_prob=self.replacement_prob, ) self.start_worker() self.init_opt() # This initializes the optimizer parameters sess.run(tf.global_variables_initializer()) itr = 0 path_length = 0 path_return = 0 terminal = False initial = False observation = self.env.reset() #with tf.variable_scope("sample_policy"): #with suppress_params_loading(): #sample_policy = pickle.loads(pickle.dumps(self.policy)) with tf.variable_scope("sample_policy"): sample_policy = Serializable.clone(self.policy) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") train_qf_itr, train_policy_itr = 0, 0 for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 initial = True else: initial = False action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the flag was set if self.include_horizon_terminal_transitions: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) else: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = pool.random_batch(self.batch_size) itrs = self.do_training(itr, batch) train_qf_itr += itrs[0] train_policy_itr += itrs[1] sample_policy.set_param_values(self.policy.get_param_values()) itr += 1 if time.time() - gc_dump_time > 100: gc.collect() gc_dump_time = time.time() logger.log("Training finished") logger.log("Trained qf %d steps, policy %d steps"%(train_qf_itr, train_policy_itr)) if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()