Esempio n. 1
0
experts_chkpt_lis = [
    './core/checkpoints/q_learning/skip_connection/q5_train_atari_nature/deepdqn_weights/', './core/checkpoints/q_learning/skip_connection/q5_train_atari_nature/resnet_weights/', './core/checkpoints/policy_gradients/policy_network.ckpt']
experts = []

#temp_sess = None
for meta_path, chkpt_path in zip(experts_meta_lis, experts_chkpt_lis):
    print([n.name for n in tf.get_default_graph().as_graph_def().node])
    if "deepdqn" in meta_path:
        model = NatureQN(env, config)
    if "resnet" in meta_path:
        model = ResnetQN(env, config)
    if "policy" in meta_path:
        continue
    # if temp_sess == None:
    #temp_sess = model.sess
    model.initialize(meta_path, chkpt_path)
    experts.append(model)
    # with model.graph.as_default():

print("LOADED ALL MODELS")

for i in range(len(experts)):
    guide = experts[i]
    guide_experience = [[]]
    num_points = 0
    state = env.reset()
    guide_replay_buffer = ReplayBuffer(
        config.buffer_size, config.state_history)
    while True:
            # store last state in buffer
        idx = guide_replay_buffer.store_frame(state)
Esempio n. 2
0
If so, please report your hyperparameters.

You'll find the results, log and video recordings of your agent every 250k under
the corresponding file in the results folder. A good way to monitor the progress
of the training is to use Tensorboard. The starter code writes summaries of different
variables.

To launch tensorboard, open a Terminal window and run 
tensorboard --logdir=results/
Then, connect remotely to 
address-ip-of-the-server:6006 
6006 is the default port used by tensorboard.
"""
if __name__ == '__main__':
    # make env
    env = gym.make(config.env_name)
    env = MaxAndSkipEnv(env, skip=config.skip_frame)
    env = PreproWrapper(env,
                        prepro=greyscale,
                        shape=(80, 80, 1),
                        overwrite_render=config.overwrite_render)

    # load model
    model = NatureQN(env, config)
    model.initialize()
    loaded = load_model(model)
    assert loaded != False, "Loading failed"

    # evaluate one episode of data
    model.evaluate(env, 1)