示例#1
0
def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )

    model.save('pong_model.pkl')
    env.close()
示例#2
0
def main(name, size):
    model_location = ("logs/"+name+"NoFrameskip-v4_"+str(size)+"/model.pkl")
    env_name = name+"NoFrameskip-v4"
    #env = gym.make(name+"NoFrameskip-v4")

    env = make_atari(env_name)
    env = deepq.wrap_atari_dqn(env)
    act = deepq.load(model_location)
    episodeRewards = []

    for i in range(100):
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            # env.render()
            action = act(obs[None])[0]
            obs, rew, done, _ = env.step(action)

            episode_rew += rew
        print(episode_rew)
        episodeRewards.append(episode_rew)

    output = name+","+str(size)
    for r in episodeRewards:
        output+=","+str(r)
    output+="\n"
    with open("validationStats.csv", 'a') as myfile:
        myfile.write(output)
示例#3
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized)
    )
    # act.save("pong_model.pkl") XXX
    env.close()
示例#4
0
def test_deepq():
    """
    test DeepQ on atari
    """
    clear_tf_session()
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
                                    hiddens=[256],
                                    dueling=True)

    deepq.learn(env,
                q_func=model,
                learning_rate=1e-4,
                max_timesteps=NUM_TIMESTEPS,
                buffer_size=10000,
                exploration_fraction=0.1,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                prioritized_replay_alpha=0.6,
                checkpoint_freq=10000)

    env.close()
示例#5
0
def main():
    exp_dir = './runs/pong'

    # by default CSV logs will be created in OS temp directory
    logger.configure(dir=exp_dir, 
        format_strs=['stdout','log','csv','tensorboard'], log_suffix=None)

    # create Atari environment, use no-op reset, max pool last two frames
    env = make_atari('PongNoFrameskip-v4')

    # by default monitor will log episod reward and log
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    learn_params = defaults.atari()
    learn_params['checkpoint_path'] = exp_dir
    learn_params['checkpoint_freq'] = 100000 
    learn_params['print_freq'] = 10

    model = deepq.learn(
        env,

        # below are defaults
        #convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        #hiddens=[256],

        total_timesteps=int(1e7),
        **learn_params
    )

    model.save('pong_model.pkl')
    env.close()
示例#6
0
def main():
    logger.configure(dir=game + "_train_log")
    env = make_atari(name + 'NoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(env,
                        "conv_only",
                        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
                        hiddens=[512],
                        dueling=False,
                        lr=0.00025,
                        total_timesteps=int(5e7),
                        buffer_size=1000000,
                        exploration_fraction=0.02,
                        exploration_final_eps=0.1,
                        train_freq=4,
                        learning_starts=50000,
                        target_network_update_freq=10000,
                        gamma=0.99,
                        print_freq=1000,
                        checkpoint_path=game + "_checkpoints")

    model.save(game + '_model.pkl')
    env.close()
示例#7
0
def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )

    model.save('pong_model.pkl')
    env.close()
示例#8
0
    def get_player(self, train=False):
        if self.env:
            return env

        if self.config['ENV_TYPE'] == 'Classic':
            env = gym.make(self.config['ENV_NAME'])
        elif self.config['ENV_TYPE'] == 'Atari':
            if train:
                env = make_atari(self.config['ENV_NAME'])
                env = bench.Monitor(env, self.logger.get_dir())
                env = deepq.wrap_atari_dqn(env)
            else:
                env = gym.make(self.config['ENV_NAME'])
                env = deepq.wrap_atari_dqn(env)
        else:
            raise Exception('Environment Type %s - Not Supported' % self.config['ENV_TYPE'])
        return env
示例#9
0
def make_env(game_name):
    env = gym.make(game_name + "NoFrameskip-v4")
    monitored_env = bench.Monitor(
        env, logger.get_dir()
    )  # puts rewards and number of steps in info, before environment is wrapped
    env = deepq.wrap_atari_dqn(
        monitored_env
    )  # applies a bunch of modification to simplify the observation space (downsample, make b/w)
    return env, monitored_env
示例#10
0
def train():
   
    logger.configure()
    set_global_seeds(args.seed)

    directory = os.path.join(args.log_dir, '_'.join([args.env, datetime.datetime.now().strftime("%m%d%H%M")]))
    if not os.path.exists(directory):
            os.makedirs(directory)
    else:
        ValueError("The directory already exists...", directory)
    json.dump(vars(args), open(os.path.join(directory, 'learning_prop.json'), 'w'))

    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    nb_test_steps = args.nb_test_steps if args.nb_test_steps > 0 else None
    if args.record == 1:
        env = Monitor(env, directory=args.log_dir)
    with tf.device(args.device):
        model = deepq.models.cnn_to_mlp(
            convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
            hiddens=[256],
            dueling=bool(args.dueling),
        )

        act, records = deepq.learn(
            env,
            q_func=model,
            lr=args.learning_rate,
            max_timesteps=args.nb_train_steps,
            buffer_size=args.buffer_size,
            exploration_fraction=args.eps_fraction,
            exploration_final_eps=args.eps_min,
            train_freq=4,
            print_freq=1000,
            checkpoint_freq=int(args.nb_train_steps/10),
            learning_starts=args.nb_warmup_steps,
            target_network_update_freq=args.target_update_freq,
            gamma=0.99,
            prioritized_replay=bool(args.prioritized),
            prioritized_replay_alpha=args.prioritized_replay_alpha,
            epoch_steps = args.nb_epoch_steps,
            gpu_memory = args.gpu_memory,
            double_q = args.double_q,
            directory=directory,
            nb_test_steps = nb_test_steps,
            scope = args.scope,
        )
        print("Saving model to model.pkl")
        act.save(os.path.join(directory,"model.pkl"))
    env.close()
    plot(records, directory)
示例#11
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    # parser.add_argument('--checkpoint-path', type=str, default=None)
    parser.add_argument(
        '--checkpoint-path',
        type=str,
        default=
        '/home/yangxu/PycharmProjects/ros_inwork/baselines/deepq/experiments/save'
    )

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    # 这里使用了3个卷积层,两个全连接层(第一个全连接层为num_convs*hiddens,第二层为hiddens*num_action)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close()
示例#12
0
def main():
    env = gym.make("PongNoFrameskip-v4")
    env = deepq.wrap_atari_dqn(env)
    act = deepq.load("pong_model.pkl")

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            obs, rew, done, _ = env.step(act(obs[None])[0])
            episode_rew += rew
        print("Episode reward", episode_rew)
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='Breakout')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('experiment_id')
    args = parser.parse_args()
    logging_directory = Path('./experiments/{}--{}'.format(args.experiment_id, args.env))
    if not logging_directory.exists():
        logging_directory.mkdir(parents=True)
    logger.configure(str(logging_directory), ['stdout', 'tensorboard', 'json'])
    model_directory = logging_directory / 'models'
    if not model_directory.exists():
        model_directory.mkdir(parents=True)
    set_global_seeds(args.seed)
    env_name = args.env + "NoFrameskip-v4"
    env = make_atari(env_name)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
    )
    exploration_schedule = PiecewiseSchedule(
        endpoints=[(0, 1), (1e6, 0.1), (5 * 1e6, 0.01)], outside_value=0.01)

    act = learn(
        env,
        q_func=model,
        beta1=0.9,
        beta2=0.99,
        epsilon=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=1000000,
        exploration_schedule=exploration_schedule,
        start_lr=1e-4,
        end_lr=5 * 1e-5,
        start_step=1e6,
        end_step=5 * 1e6,
        train_freq=4,
        print_freq=10,
        batch_size=32,
        learning_starts=50000,
        target_network_update_freq=10000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        model_directory=model_directory
    )
    act.save(str(model_directory / "act_model.pkl"))
    env.close()
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--game_latency", type=int, default=0,help="Input-Latency used in game run.")
    parser.add_argument("--trained_with_latency", type=int, default=0, help="Input-Latency the model was trained with.")
    parser.add_argument("--trained_with_all_latency_mode", type=int, default=0, help="Input-Latency-All-Mode the model was trained with.")

    parser.add_argument("--game_runs", type=int, default=25, help="How offen the game should be ran.")

    args = parser.parse_args()

    prepath = ""
    if args.trained_with_all_latency_mode == 0:
        modelfile = "data.L" + str(args.trained_with_latency) + "/PongNoFrameskip-v4.L" + str(args.trained_with_latency) + ".pkl"
        csvfile = "data.L" + str(args.trained_with_latency)+ "/PongNoFrameskip-v4.L" + str(args.trained_with_latency) + ".pkl.on-L" + str(args.game_latency) + ".csv"
    else:
        modelfile =  "data.LM" + str(args.trained_with_all_latency_mode)+ "/PongNoFrameskip-v4.LM" + str(args.trained_with_all_latency_mode) + ".pkl"
        csvfile = "data.LM" + str(args.trained_with_all_latency_mode) + "/PongNoFrameskip-v4.LM" + str(args.trained_with_all_latency_mode) + ".pkl.on-L" + str(args.game_latency) + ".csv"

    modelfile = prepath + modelfile
    csvfile = prepath + csvfile

    env = make_atari("PongNoFrameskip-v4")
    env = deepq.wrap_atari_dqn(env)
    act = deepq.load(modelfile)

    with open(csvfile, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows([["Reward", "Time", "Frames"]])
        i = 0
        while i < args.game_runs:
            i+=1
            obs, done = env.reset(), False
            timeslice_start = datetime.datetime.now()
            frames = 0
            q_input = deque([0] * args.game_latency)

            episode_rew = 0
            while not done:
                # uncomment to analyse what the game is doing by your eye
                # env.render()
                # time.sleep(0.01)
                q_input.append(act(obs[None])[0])
                obs, rew, done, _ = env.step(q_input.popleft())
                frames+=1
                episode_rew += rew
            timespent = (datetime.datetime.now() - timeslice_start)
            writer.writerows([[episode_rew, timespent.total_seconds(), frames]])
            f.flush()
            print(str(episode_rew) + ", " + str(timespent.total_seconds()) + ", " + str(frames))
示例#15
0
def main():
    """
    run the atari test
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    with tf.Session():
        deepq.learn(
            env,
            q_func=model,
            learning_rate=1e-4,
            max_timesteps=args.num_timesteps,
            buffer_size=10000,
            exploration_fraction=0.1,
            exploration_final_eps=0.01,
            train_freq=4,
            learning_starts=10000,
            target_network_update_freq=1000,
            gamma=0.99,
            prioritized_replay=bool(args.prioritized),
            prioritized_replay_alpha=args.prioritized_replay_alpha,
            checkpoint_freq=args.checkpoint_freq,
            checkpoint_path=args.checkpoint_path,
        )

    env.close()
示例#16
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--train-with-latency', type=int, default=0)
    parser.add_argument('--train-with-all-latency-mode', type=int, default=0)
    args = parser.parse_args()
    loggerid = "L" + (("M" + str(args.train_with_all_latency_mode)) if
                      (args.train_with_all_latency_mode != 0) else str(
                          args.train_with_latency))
    loggerdir = "./data." + loggerid + "/"
    logger.configure(dir=loggerdir)
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        print_freq=1,
        train_with_latency=args.train_with_latency,
        train_with_all_latency_mode=args.train_with_all_latency_mode)
    act.save(loggerdir + args.env + "." + loggerid + ".pkl")
    env.close()
示例#17
0
def main():
    env = gym.make("PongNoFrameskip-v4")
    env = deepq.wrap_atari_dqn(env)
    model = deepq.learn(env,
                        "conv_only",
                        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
                        hiddens=[256],
                        dueling=True,
                        total_timesteps=0)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            obs, rew, done, _ = env.step(model(obs[None])[0])
            episode_rew += rew
        print("Episode reward", episode_rew)
示例#18
0
def test():
    env = make_atari(args.env)
    env = deepq.wrap_atari_dqn(env)
    act = deepq.load(os.path.join(args.log_dir, args.log_fname))
    if args.record:
        env = Monitor(env, directory=args.log_dir)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        t = 0
        while not done:
            if not(args.record):
                env.render()
            obs, rew, done, _ = env.step(act(obs[None])[0])
            episode_rew += rew
            t += 1
        print("Episode reward %.2f after %d steps"%(episode_rew, t))
示例#19
0
def main(envName='BreakoutNoFrameskip-v4', bufferSize=10000, timesteps=3e6):
    # parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    # parser.add_argument('--buffer', type=int, default=10000)
    # parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    # parser.add_argument('--prioritized', type=int, default=1)
    # parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    # parser.add_argument('--dueling', type=int, default=1)
    # parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    # parser.add_argument('--checkpoint-freq', type=int, default=10000)
    # parser.add_argument('--checkpoint-path', type=str, default=os.getcwd()+"/logs")
    # args = parser.parse_args()
    # logger.configure(dir=args.checkpoint_path)
    logger.configure(dir=os.getcwd() + "/logs/" + str(envName) + "_" +
                     str(bufferSize))
    set_global_seeds(0)
    env = make_atari(envName)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[512],
        dueling=bool(1),
    )

    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=int(timesteps),
        buffer_size=bufferSize,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(1),
        prioritized_replay_alpha=0.6,
        checkpoint_freq=10000,
    )
    act.save(os.getcwd() + "/logs/" + str(envName) + "_" + str(bufferSize) +
             "/model.pkl")
    env.close()
示例#20
0
文件: run_atari.py 项目: erincmer/BPO
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--env', help='environment ID', default='BeamRiderNoFrameskip-v4'
    )  # TODO changed to Beamrider since it gives larger rewards easy to see progress
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int,
                        default=0)  # TODO made it false code was complaining
    parser.add_argument('--dueling', type=int,
                        default=0)  # TODO made it false for code simplicity
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    args = parser.parse_args()
    logger.configure("./log/BeamRider")  # TODO log results under BeamRider
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
        nbins=1000,  # TODO number of bins
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        min_Val=-50,  # TODO min value of Q values
        max_Val=50,  # TODO max value of Q values
        nbins=1000  # TODO number of bins
    )
    # act.save("pong_model.pkl") XXX
    env.close()
示例#21
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close()
示例#22
0
def main():
    env = gym.make("PongNoFrameskip-v4")
    env = deepq.wrap_atari_dqn(env)
    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        total_timesteps=0
    )

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            obs, rew, done, _ = env.step(model(obs[None])[0])
            episode_rew += rew
        print("Episode reward", episode_rew)
示例#23
0
def test(env0, act_greedy, nb_itrs=5, nb_test_steps=10000):

    env = env0
    while (hasattr(env, 'env')):
        env = env.env

    total_rewards = []
    for _ in range(nb_itrs):
        if hasattr(env, 'ale'):
            from baselines.common.atari_wrappers import make_atari
            env_new = make_atari(env.spec.id)
            env_new = deepq.wrap_atari_dqn(env_new)
        else:
            env_new = gym.make(env.spec.id)
        obs = env_new.reset()

        if nb_test_steps is None:
            done = False
            episode_reward = 0
            while not done:
                action = act_greedy(np.array(obs)[None])[0]
                obs, rew, done, _ = env_new.step(action)
                episode_reward += rew
            total_rewards.append(episode_reward)
        else:
            t = 0
            episodes = []
            episode_reward = 0
            while (t < nb_test_steps):
                action = act_greedy(np.array(obs)[None])[0]
                obs, rew, done, _ = env_new.step(action)
                episode_reward += rew
                if done:
                    episodes.append(episode_reward)
                    episode_reward = 0
                    obs = env_new.reset()
                t += 1
            total_rewards.append(np.mean(episodes))

    return np.array(total_rewards, dtype=np.float32)
示例#24
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID',
                        default='SeaquestNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    subparsers = parser.add_subparsers(dest='subparser',
                                       help='train model or test existing model')
    train_parser = subparsers.add_parser('train')
    # Mnih et al (2015) and other DeepMind work usually train for 200e6 frames,
    # which is 50e6 time steps with 4 frameskip (introduced by wrapper in
    # make_atari.)
    train_parser.add_argument('--num-timesteps', type=int, default=int(50e6))
    train_parser.add_argument('--out-dir', type=str, default=None,
                              help='checkpoint directory')

    test_parser = subparsers.add_parser('test')
    test_parser.add_argument('--fps', type=int, default=int(1e3))
    test_parser.add_argument('model', help='model file', type=str)

    args = parser.parse_args()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        logger.configure()
        set_global_seeds(args.seed)

        env = make_atari(args.env)
        env = bench.Monitor(env, logger.get_dir())
        env = deepq.wrap_atari_dqn(env)

        if args.subparser == 'train':
            train(env, args)
        elif args.subparser == 'test':
            test(env, args)
        else:
            assert False
示例#25
0
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--dir', help='model directory', default='')

    args = parser.parse_args()

    print("Args env: " + args.env)
    env = gym.make(args.env)
    env = deepq.wrap_atari_dqn(env)
    act = deepq.load(args.dir)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            obs, rew, done, _ = env.step(act(obs[None])[0])
            episode_rew += rew
        print("Episode reward", episode_rew)
示例#26
0
def make_env(game_name):
    env = gym.make(game_name + "NoFrameskip-v4")
    env = bench.Monitor(env, None)
    env = SimpleMonitor(env)
    env = wrap_atari_dqn(env)
    return env
示例#27
0
def main():
    logger.configure(dir="breakout_train_log")
    env = make_atari('BreakoutNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=False,
        total_timesteps=0,
        load_path="checkpoints/model.pkl",
    )
    sess = get_session()
    pytorch_network = Net()
    variables_names = [v.name for v in tf.trainable_variables()]
    values = sess.run(variables_names)
    for k, v in zip(variables_names, values):
        if (k == 'deepq/q_func/convnet/Conv/weights:0'):
            weight1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv/biases:0'):
            bias1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_1/weights:0'):
            weight2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_1/biases:0'):
            bias2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_2/weights:0'):
            weight3 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_2/biases:0'):
            bias3 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected/weights:0'):
            weight_fc1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected/biases:0'):
            bias_fc1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected_1/weights:0'):
            weight_fc2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected_1/biases:0'):
            bias_fc2 = torch.from_numpy(v)

    pytorch_network.conv1.weight = torch.nn.Parameter(weight1.permute(3, 2, 0, 1))
    pytorch_network.conv1.bias = torch.nn.Parameter(bias1)
    # pytorch_network.conv1.bias = torch.nn.Parameter(np.zeros_like(bias1))

    pytorch_network.conv2.weight = torch.nn.Parameter(weight2.permute(3, 2, 0, 1))
    pytorch_network.conv2.bias = torch.nn.Parameter(bias2)
    pytorch_network.conv3.weight = torch.nn.Parameter(weight3.permute(3, 2, 0, 1))
    pytorch_network.conv3.bias = torch.nn.Parameter(bias3)

    pytorch_network.fc1.weight = torch.nn.Parameter(weight_fc1.permute(1, 0))
    pytorch_network.fc1.bias = torch.nn.Parameter(bias_fc1)

    pytorch_network.fc2.weight = torch.nn.Parameter(weight_fc2.permute(1, 0))
    pytorch_network.fc2.bias = torch.nn.Parameter(bias_fc2)

    torch.save(pytorch_network.state_dict(), 'pytorch_breakout_dqn.pt')


    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            # print(torch.tensor(obs[None]).dtype)
            # print(pytorch_network(torch.tensor(obs[None], dtype=torch.float).permute(0, 3, 1, 2)))
            # print(model(obs[None]))
            # print(tf.global_variables())

            action = torch.argmax(pytorch_network(torch.tensor(obs[None], dtype=torch.float).permute(0, 3, 1, 2)))
            # print(action)
            obs, rew, done, _ = env.step(action)
            episode_rew += rew
        print("Episode reward", episode_rew)
    # model.save('breakout_model.pkl')
    env.close()
示例#28
0
def make_env(game_name):
    env = gym.make(game_name + "NoFrameskip-v4")
    env_monitored = bench.Monitor(env, None)
    env = deepq.wrap_atari_dqn(env_monitored)
    return env_monitored, env
def main():
    game = 'qbert'
    env = game.capitalize() + "NoFrameskip-v4"
    #    env = game + "NoFrameskip-v4"
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default=env)
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=0)
    parser.add_argument('--dueling', type=int, default=0)
    parser.add_argument('--expected-dueling', type=int, default=0)
    parser.add_argument('--double', type=int, default=1)
    parser.add_argument('--alpha', type=int, default=1)  # Enables alpha-dqn
    parser.add_argument('--sample', type=int, default=0)
    parser.add_argument('--expected', type=int,
                        default=1)  # Turns on Expected-alpha-dqn
    parser.add_argument('--surrogate', type=int,
                        default=0)  # Turns on Surrogate-alpha-dqn
    parser.add_argument('--steps', type=int, default=1)
    parser.add_argument('--eps-val', type=float, default=0.01)
    parser.add_argument('--alpha-val', type=float, default=0.01)
    parser.add_argument('--piecewise-schedule', type=int, default=0)
    parser.add_argument('--alpha-epsilon', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(50e6))
    parser.add_argument('--exploration-fraction', type=int, default=int(1e6))
    parser.add_argument('--optimal-test', type=int, default=0)

    args = parser.parse_args()
    if args.alpha:
        assert (bool(args.expected) != bool(args.surrogate)
                )  # both versions can't run together
    if args.surrogate:
        surrogate = True
    else:
        surrogate = False
    typename = ''
    if args.double: typename += 'D'

    typename += "DQN"
    typename += '_EPS_{}'.format(args.eps_val)
    if args.alpha:
        if args.eps_val != args.alpha_val:
            typename += '_ALPHA_{}'.format(args.alpha_val)
    if surrogate:
        typename += '_Surrogate'
    else:
        typename += '_Expected'
        if args.sample:
            typename += '_Sample'
    if args.piecewise_schedule:
        typename += '_Piecewise'
    if args.alpha_epsilon:
        typename += '_ALPHA=EPS'
    typename += '_{}_step'.format(args.steps)

    if args.prioritized: typename += '_PRIORITY'
    if args.dueling: typename += '_DUEL'
    game = args.env[:-14].lower()
    directory = 'AlphaGreedy/'
    dir = osp.join(
        '../experiments/' + directory + game,
        datetime.datetime.now().strftime(typename + "-%d-%m-%H-%M-%f"))
    logger.configure(dir=dir)
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True)
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=(float(args.exploration_fraction) /
                              float(args.num_timesteps)),
        exploration_final_eps=args.eps_val,
        alpha_val=args.alpha_val,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        double=bool(args.double),
        epsilon=bool(args.alpha),  # turns on alpha-dqn
        eps_val=args.eps_val,
        q1=surrogate,  # q1 is surrogate, else, expected
        n_steps=args.steps,
        sample=bool(args.sample),
        piecewise_schedule=bool(args.piecewise_schedule),
        test_agent=1e6)
    act.save()
    env.close()
示例#30
0
def main(game: str, model: str):
    name = ''.join([g.capitalize() for g in game.split('_')])
    env = make_atari(name + 'NoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True)
    env = deepq.wrap_atari_dqn(env)
    n_actions = env.action_space.n

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            # 1 input image channel, 6 output channels, 5x5 square convolution
            # kernel
            self.conv1 = nn.Conv2d(4, 32, 8, stride=4, padding=2)
            self.relu1 = nn.ReLU()
            self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=0)
            self.relu2 = nn.ReLU()
            self.pad2 = nn.ConstantPad2d((1, 2, 1, 2), value=0)
            self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
            self.pad3 = nn.ConstantPad2d((1, 1, 1, 1), value=0)
            self.relu3 = nn.ReLU()
            # an affine operation: y = Wx + b
            self.fc1 = nn.Linear(7744, 512)
            self.relu4 = nn.ReLU()
            self.fc2 = nn.Linear(512, n_actions)

        def forward(self, x):
            # Max pooling over a (2, 2) window
            x = x / 255.
            x = self.relu1(self.conv1(x))
            # If the size is a square you can only specify a single number
            x = self.pad2(x)

            x = self.relu2(self.conv2(x))

            x = self.pad3(x)
            x = self.relu3(self.conv3(x))
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(-1, self.num_flat_features(x))
            x = self.relu4(self.fc1(x))
            x = self.fc2(x)
            return x

        def num_flat_features(self, x):
            size = x.size()[1:]  # all dimensions except the batch dimension
            num_features = 1
            for s in size:
                num_features *= s
            return num_features

    # logger.configure(dir="breakout_train_log")

    if model is None:
        model = game + "_model.pkl"

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[512],
        dueling=False,
        total_timesteps=0,
        exploration_final_eps=0.05,
        load_path="trained_networks/" + model,
    )
    sess = get_session()
    pytorch_network = Net()
    variables_names = [v.name for v in tf.trainable_variables()]
    values = sess.run(variables_names)
    for k, v in zip(variables_names, values):
        if (k == 'deepq/q_func/convnet/Conv/weights:0'):
            weight1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv/biases:0'):
            bias1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_1/weights:0'):
            weight2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_1/biases:0'):
            bias2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_2/weights:0'):
            weight3 = torch.from_numpy(v)
        if (k == 'deepq/q_func/convnet/Conv_2/biases:0'):
            bias3 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected/weights:0'):
            weight_fc1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected/biases:0'):
            bias_fc1 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected_1/weights:0'):
            weight_fc2 = torch.from_numpy(v)
        if (k == 'deepq/q_func/action_value/fully_connected_1/biases:0'):
            bias_fc2 = torch.from_numpy(v)

    pytorch_network.conv1.weight = torch.nn.Parameter(
        weight1.permute(3, 2, 0, 1))
    pytorch_network.conv1.bias = torch.nn.Parameter(bias1)
    # pytorch_network.conv1.bias = torch.nn.Parameter(np.zeros_like(bias1))

    pytorch_network.conv2.weight = torch.nn.Parameter(
        weight2.permute(3, 2, 0, 1))
    pytorch_network.conv2.bias = torch.nn.Parameter(bias2)
    pytorch_network.conv3.weight = torch.nn.Parameter(
        weight3.permute(3, 2, 0, 1))
    pytorch_network.conv3.bias = torch.nn.Parameter(bias3)

    pytorch_network.fc1.weight = torch.nn.Parameter(weight_fc1.permute(1, 0))
    pytorch_network.fc1.bias = torch.nn.Parameter(bias_fc1)

    pytorch_network.fc2.weight = torch.nn.Parameter(weight_fc2.permute(1, 0))
    pytorch_network.fc2.bias = torch.nn.Parameter(bias_fc2)

    torch.save(pytorch_network.state_dict(), 'pytorch_' + game + '.pt')

    rewards = np.zeros(100)
    for episode in range(100):
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            # env.render()
            # print(torch.tensor(obs[None]).dtype)
            # print(pytorch_network(torch.tensor(obs[None], dtype=torch.float).permute(0, 3, 1, 2)))
            # print(model(obs[None]))
            # print(tf.global_variables())

            probabilities = policy(
                pytorch_network(
                    torch.tensor(obs[None],
                                 dtype=torch.float).permute(0, 3, 1, 2))[0],
                0.05)
            action = np.random.choice(np.arange(len(probabilities)),
                                      p=probabilities)
            # action = model(obs[None])[0]
            obs, rew, done, _ = env.step(action)
            episode_rew += rew
        rewards[episode] = episode_rew
        print("Episode " + str(episode) + " reward", episode_rew)
    # model.save('breakout_model.pkl')
    env.close()
    print("Avg: ", np.mean(rewards))
示例#31
0
def main():
    args = parse_args()
    logdir = args.pop('logdir')
    # logger.configure(dir=logdir, enable_std_out=True)
    logger.configure(dir=logdir, enable_std_out=False)
    with open(os.path.join(logger.get_dir(), "hyparam.yaml"), 'w') as f:
        yaml.dump(args, f, default_flow_style=False)

    policy_mode = args.pop('policy_mode')
    save_array_flag = args.pop('save_array_flag')
    use_my_env_wrapper = args.pop('use_my_env_wrapper')
    env_id = args.pop('env_id')

    if policy_mode == "large_variance":
        if args["exploitation_ratio_on_bottleneck"] is None or args[
                "bottleneck_threshold_ratio"] is None:
            raise AssertionError
        if args["exploitation_ratio_on_bottleneck"] is not None:
            array_logger = array_logger_getter.get_logger()
            array_logger.set_log_dir(logdir, exist_ok=True)
            array_logger.set_save_array_flag(save_array_flag)

    if use_my_env_wrapper:
        env = make_atari_nature(env_id)
    else:
        env = make_atari(env_id)
    env = deepq.wrap_atari_dqn(env)
    num_cpu = 1
    config = tf.ConfigProto(
        allow_soft_placement=True,
        inter_op_parallelism_threads=num_cpu,
        intra_op_parallelism_threads=num_cpu,
        gpu_options=tf.GPUOptions(visible_device_list=args.pop("gpu"),
                                  allow_growth=True),
    )
    config.gpu_options.allow_growth = True
    # nature_set = {'network': 'cnn', 'prioritized_replay': False, 'buffer_size': int(1e5), 'total_time_steps': int(2e6)}
    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=False,
        lr=1e-4,
        # total_timesteps=int(1e7),
        # total_timesteps=int(2e3)+1,
        buffer_size=10000,
        # exploration_fraction=0.1,
        # exploration_final_eps=0.01,
        train_freq=4,
        # learning_starts=1000,
        # target_network_update_freq=100,
        learning_starts=1000,
        # target_network_update_freq=1000,
        target_network_update_freq=500,
        gamma=0.99,
        # prioritized_replay=False,
        batch_size=64,
        # print_freq=1,
        # print_freq=200,
        print_freq=1000,
        config=config,
        bottleneck_threshold_update_freq=1000,
        **args,
    )

    model.save(os.path.join(logger.get_dir(), 'Breakout_final_model.pkl'))
    env.close()
def main(seed=10086,
         time=500,
         percentile=99.9,
         game='video_pinball',
         model='pytorch_video_pinball.pt',
         episode=50):
    seed = seed
    n_examples = 15000
    time = time
    epsilon = 0.05
    percentile = percentile
    episode = episode

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        device = 'cuda'
    else:
        device = 'cpu'

    print("device", device)
    print("game", game, "episode", episode, "time", time, "seed", seed,
          "percentile", percentile)
    set_seed(seed)

    name = ''.join([g.capitalize() for g in game.split('_')])
    env = make_atari(game, max_episode_steps=18000)
    env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True)
    env = deepq.wrap_atari_dqn(env)
    env.seed(seed)
    n_actions = env.action_space.n

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            # 1 input image channel, 6 output channels, 5x5 square convolution
            # kernel

            self.conv1 = nn.Conv2d(4, 32, 8, stride=4, padding=2)
            self.relu1 = nn.ReLU()
            self.pad2 = nn.ConstantPad2d((1, 2, 1, 2), value=0)
            self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=0)
            self.relu2 = nn.ReLU()
            self.pad3 = nn.ConstantPad2d((1, 1, 1, 1), value=0)
            self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
            self.relu3 = nn.ReLU()
            #self.perm = Permute((0, 2, 3, 1))
            self.perm = Permute((1, 2, 0))
            self.fc1 = nn.Linear(7744, 512)
            self.relu4 = nn.ReLU()
            self.fc2 = nn.Linear(512, n_actions)

        def forward(self, x):
            x = x / 255.0
            x = self.relu1(self.conv1(x))
            x = self.pad2(x)
            x = self.relu2(self.conv2(x))
            x = self.pad3(x)
            x = self.relu3(self.conv3(x))
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(-1, self.num_flat_features(x))
            x = self.relu4(self.fc1(x))
            x = self.fc2(x)
            return x

        def show(self, x):
            x = x
            x = self.relu1(self.conv1(x))
            x = self.pad2(x)
            x = self.relu2(self.conv2(x))
            x = self.pad3(x)
            x = self.relu3(self.conv3(x))
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(-1, self.num_flat_features(x))
            x = self.relu4(self.fc1(x))
            x = self.fc2(x)
            return torch.max(x, 1)[1].data[0]

        def num_flat_features(self, x):
            size = x.size()[1:]  # all dimensions except the batch dimension
            num_features = 1
            for s in size:
                num_features *= s
            return num_features

    model_path = model
    ANN_model = Net()
    ANN_model.load_state_dict(torch.load(model_path))
    ANN_model.eval()
    ANN_model = ANN_model.to(device)

    images = []
    cnt = 0

    for epi in range(1000):
        cur_images = []
        cur_cnt = 0
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            #env.render()
            image = torch.from_numpy(obs[None]).permute(0, 3, 1, 2)
            cur_images.append(image.detach().numpy())
            actions_value = ANN_model(image.to(device)).cpu()[0]
            probs, best_action = policy(actions_value, epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            obs, rew, done, info = env.step(action)
            cur_cnt += 1
        if info['ale.lives'] == 0:
            if cur_cnt + cnt < n_examples:
                cnt += cur_cnt
                images += cur_images
            else:
                print("normalization image cnt", cnt)
                break

    images = torch.from_numpy(np.array(images)).reshape(-1, 4, 84,
                                                        84).float() / 255

    SNN = ann_to_snn(ANN_model,
                     input_shape=(4, 84, 84),
                     data=images.to(device),
                     percentile=percentile)
    SNN = SNN.to(device)

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)

    for c in SNN.connections:
        if isinstance(SNN.connections[c], MaxPool2dConnection):
            SNN.add_monitor(Monitor(SNN.connections[c],
                                    state_vars=['firing_rates'],
                                    time=time),
                            name=f'{c[0]}_{c[1]}_rates')

    f = open(
        "game" + game + "episode" + str(episode) + "time" + str(time) +
        "percentile" + str(percentile) + ".csv", 'a')
    game_cnt = 0
    mix_cnt = 0
    spike_cnt = 0
    cnt = 0
    rewards = np.zeros(episode)
    while (game_cnt < episode):

        obs, done = env.reset(), False
        while not done:
            image = torch.from_numpy(obs[None]).permute(0, 3, 1,
                                                        2).float() / 255
            image = image.to(device)

            ANN_action = ANN_model.show(image.to(device))

            inpts = {'Input': image.repeat(time, 1, 1, 1, 1)}
            SNN.run(inputs=inpts, time=time)

            spikes = {
                l: SNN.monitors[l].get('s')
                for l in SNN.monitors if 's' in SNN.monitors[l].state_vars
            }
            voltages = {
                l: SNN.monitors[l].get('v')
                for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars
            }

            actions_value = spikes['12'].sum(0).cpu() + voltages['12'][
                time - 1].cpu()
            action = torch.max(actions_value, 1)[1].data.numpy()[0]

            spike_actions_value = spikes['12'].sum(0).cpu()
            spike_action = torch.max(spike_actions_value, 1)[1].data.numpy()[0]

            cnt += 1
            if ANN_action == action:
                mix_cnt += 1
            if ANN_action == spike_action:
                spike_cnt += 1

            probs, best_action = policy(actions_value[0], epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            SNN.reset_state_variables()
            obs, rew, done, info = env.step(action)

        if info['ale.lives'] == 0:
            rewards[game_cnt] = info['episode']['r']
            print("Episode " + str(game_cnt) + " reward", rewards[game_cnt])
            print("cnt", cnt, "mix", mix_cnt / cnt, "spike", spike_cnt / cnt)
            f.write(
                str(rewards[game_cnt]) + ", " + str(mix_cnt / cnt) + ", " +
                str(spike_cnt / cnt) + "\n")
            game_cnt += 1
            mix_cnt = 0
            spike_cnt = 0
            cnt = 0
        elif 'TimeLimit.truncated' in info:
            if info['TimeLimit.truncated'] == True:
                rewards[game_cnt] = info['episode']['r']
                print("Episode " + str(game_cnt) + " reward",
                      rewards[game_cnt])
                print("cnt", cnt, "mix", mix_cnt / cnt, "spike",
                      spike_cnt / cnt)
                f.write(
                    str(rewards[game_cnt]) + ", " + str(mix_cnt / cnt) + ", " +
                    str(spike_cnt / cnt) + "\n")
                game_cnt += 1
                mix_cnt = 0
                spike_cnt = 0
                cnt = 0

    env.close()
    f.close()
    print("Avg: ", np.mean(rewards))
    output_str = "Avg: " + str(np.mean(rewards))
示例#33
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='PongNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--logdir', type=str, default='~/.tmp/deepq')

    # General hyper-parameters 
    parser.add_argument('--isKfac', type=int, default=0)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=160)
    parser.add_argument('--target_network_update_freq', type=int, default=1000)
    
    # Kfac parameters
    parser.add_argument('--kfac_fisher_metric', type=str, default='gn')        
    parser.add_argument('--kfac_momentum', type=float, default=0.9)
    parser.add_argument('--kfac_clip_kl', type=float, default=0.01)
    parser.add_argument('--kfac_epsilon', type=float, default=1e-2)
    parser.add_argument('--kfac_stats_decay', type=float, default=0.99)
    parser.add_argument('--kfac_cold_iter', type=float, default=10)

    args = parser.parse_args()
    print('_'.join([str(arg) for arg in vars(args)]))
    logdir = osp.join(args.logdir, '_'.join([str(getattr(args, arg)) for arg in vars(args) if arg != 'logdir']))
    logger.configure(dir=logdir)
    #Get parameters in kfac
    kfac_paras = {}
    for arg in vars(args):
        if arg[:4] == 'kfac':
            kfac_paras[arg[5:]] = getattr(args, arg)

    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (32, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )
    act = deepq.learn(
        env,
        q_func=model,
        isKfac=args.isKfac,
        kfac_paras=kfac_paras,
        lr=args.lr,
        max_timesteps=args.num_timesteps,
        batch_size=args.batch_size,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=args.target_network_update_freq,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized)
    )
    # act.save("pong_model.pkl") XXX
    env.close()