Exemplo n.º 1
0
def main():
    env = get_player(
        args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
    file_path = "memory.npz"
    rpm = ReplayMemory(
        MEMORY_SIZE,
        IMAGE_SIZE,
        CONTEXT_LEN,
        load_file=True,  # load replay memory data from file
        file_path=file_path)
    act_dim = env.action_space.n

    model = AtariModel(act_dim)
    algorithm = DQN(
        model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num)
    agent = AtariAgent(
        algorithm, act_dim=act_dim, total_step=args.train_total_steps)
    if os.path.isfile('./model.ckpt'):
        logger.info("load model from file")
        agent.restore('./model.ckpt')

    if args.train:
        logger.info("train with memory data")
        run_train_step(agent, rpm)
        logger.info("finish training. Save the model.")
        agent.save('./model.ckpt')
    else:
        logger.info("collect experience")
        collect_exp(env, rpm, agent)
        rpm.save_memory()
        logger.info("finish collecting, save successfully")
Exemplo n.º 2
0
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--game_name', default='Phoenix-v0')
    test_env = get_player('Phoenix-v0',
                          image_size=IMAGE_SIZE,
                          context_len=CONTEXT_LEN)
    save_path = './dqn_model.ckpt'

    act_dim = test_env.action_space.n

    model = AtariModel(act_dim)
    algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)

    agent = AtariAgent(algorithm,
                       act_dim=act_dim,
                       start_lr=LEARNING_RATE,
                       total_step=test_number,
                       update_freq=UPDATE_FREQ)

    agent.restore(save_path)
    eval_rewards = []
    flag = 0

    while flag < test_number:
        eval_reward = run_evaluate_episode(test_env, agent)

        #eval_rewards.append(eval_reward)
        logger.info("eval_agent done, (steps, eval_reward): ({}, {})".format(
            flag, eval_reward))
        flag += 1
Exemplo n.º 3
0
def main():
    # Prepare environments
    # env = get_player(
    #     args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
    # test_env = get_player(
    #     args.rom,
    #     image_size=IMAGE_SIZE,
    #     frame_skip=FRAME_SKIP,
    #     context_len=CONTEXT_LEN)
    env = gym.make("pseudoslam:RobotExploration-v0")
    env = MonitorEnv(env, param={'goal': args.goal, 'obs': args.obs})

    # obs = env.reset()
    # print(obs.shape)
    # raise NotImplementedError
    # Init Prioritized Replay Memory
    per = ProportionalPER(alpha=0.6, seg_num=args.batch_size, size=MEMORY_SIZE)
    suffix = args.suffix + "_Rp{}_Goal{}_Obs{}".format(args.Rp, args.goal,
                                                       args.obs)
    logdir = os.path.join(args.logdir, suffix)
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    logger.set_dir(logdir)
    modeldir = os.path.join(args.modeldir, suffix)
    if not os.path.exists(modeldir):
        os.mkdir(modeldir)

    # Prepare PARL agent
    act_dim = env.action_space.n
    model = AtariModel(act_dim)
    if args.alg == 'ddqn':
        algorithm = PrioritizedDoubleDQN(model,
                                         act_dim=act_dim,
                                         gamma=GAMMA,
                                         lr=LEARNING_RATE)
    elif args.alg == 'dqn':
        algorithm = PrioritizedDQN(model,
                                   act_dim=act_dim,
                                   gamma=GAMMA,
                                   lr=LEARNING_RATE)
    agent = AtariAgent(algorithm, act_dim=act_dim, update_freq=UPDATE_FREQ)
    if os.path.exists(args.load):
        agent.restore(args.load)
    # Replay memory warmup
    total_step = 0
    with tqdm(total=MEMORY_SIZE, desc='[Replay Memory Warm Up]') as pbar:
        mem = []
        while total_step < MEMORY_WARMUP_SIZE:
            total_reward, steps, _, _ = run_episode(env,
                                                    agent,
                                                    per,
                                                    mem=mem,
                                                    warmup=True)
            total_step += steps
            pbar.update(steps)
    per.elements.from_list(mem[:int(MEMORY_WARMUP_SIZE)])

    # env_name = args.rom.split('/')[-1].split('.')[0]
    test_flag = 0
    total_steps = 0
    pbar = tqdm(total=args.train_total_steps)
    save_steps = 0
    while total_steps < args.train_total_steps:
        # start epoch
        total_reward, steps, loss, info = run_episode(env,
                                                      agent,
                                                      per,
                                                      train=True)
        total_steps += steps
        save_steps += steps
        pbar.set_description('[train]exploration:{}'.format(agent.exploration))
        summary.add_scalar('train/score', total_reward, total_steps)
        summary.add_scalar('train/loss', loss,
                           total_steps)  # mean of total loss
        summary.add_scalar('train/exploration', agent.exploration, total_steps)
        summary.add_scalar('train/steps', steps, total_steps)
        for key in info.keys():
            summary.add_scalar('train/' + key, info[key], total_steps)
        pbar.update(steps)

        if total_steps // args.test_every_steps >= test_flag:
            print('start test!')
            while total_steps // args.test_every_steps >= test_flag:
                test_flag += 1
            pbar.write("testing")
            test_rewards = []
            for _ in tqdm(range(3), desc='eval agent'):
                eval_reward = run_evaluate_episode(env, agent)
                test_rewards.append(eval_reward)
            eval_reward = np.mean(test_rewards)
            logger.info(
                "eval_agent done, (steps, eval_reward): ({}, {})".format(
                    total_steps, eval_reward))
            summary.add_scalar('eval/reward', eval_reward, total_steps)
        if save_steps >= 100000:
            modeldir_ = os.path.join(modeldir, 'itr_{}'.format(total_steps))
            if not os.path.exists(modeldir_):
                os.mkdir(modeldir_)
            print('save model!', modeldir_)
            agent.save(modeldir_)
            save_steps = 0

    pbar.close()