Exemplo n.º 1
0
def run(args):
    logging.basicConfig(filename=args.LOG_FILE, level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())

    game_handler = GameStateHandler(random_seed=123, frame_skip=args.FRAME_SKIP, use_sdl=False,
                                    image_processing=lambda x: crop_and_resize(x, args.IMAGE_HEIGHT, args.IMAGE_WIDTH))
    game_handler.loadROM(args.ROM_FILE)

    height, width = game_handler.getScreenDims()
    logging.info('Screen resolution is %dx%d' % (height, width))
    num_actions = game_handler.num_actions

    net = theano_qnetwork.DeepQNetwork(args.IMAGE_HEIGHT, args.IMAGE_WIDTH, num_actions, args.STATE_FRAMES, args.DISCOUNT_FACTOR)

    replay_memory = ReplayMemoryManager(args.IMAGE_HEIGHT, args.IMAGE_WIDTH, args.STATE_FRAMES, args.REPLAY_MEMORY_SIZE)

    monitor = Monitoring(log_train_step_every=100, smooth_episode_scores_over=50)
    agent = Agent(game_handler, net, replay_memory, None, monitor, args.TRAIN_FREQ, batch_size=args.BATCH_SIZE)

    start_epsilon = args.START_EPSILON
    exploring_duration = args.EXPLORING_DURATION

    agent.populate_replay_memory(args.MIN_REPLAY_MEMORY)
    agent.play(train_steps_limit=args.LEARNING_BEYOND_EXPLORING+args.EXPLORING_DURATION, start_eps=start_epsilon,
               final_eps=args.FINAL_EPSILON, exploring_duration=exploring_duration)
Exemplo n.º 2
0
def main(_=None):
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        d = os.path.dirname(args.log_file)
        if not os.path.exists(d):
            os.makedirs(d)
        if not args.continue_training:
            with open(args.log_file, 'w') as f:
                f.write('')
        logging.basicConfig(filename=args.log_file, level=logging.DEBUG)
        logging.getLogger().addHandler(logging.StreamHandler())

        game_handler = GameStateHandler(
                args.rom_directory + args.rom_name,
                random_seed=args.random_seed,
                frame_skip=args.frame_skip,
                use_sdl=args.use_sdl,
                repeat_action_probability=args.repeat_action_probability,
                minimum_actions=args.minimum_action_set,
                test_mode=args.test_mode,
                image_processing=lambda x: crop_and_resize(x, args.image_height, args.image_width, args.cut_top))
        num_actions = game_handler.num_actions

        if args.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(
                    learning_rate=args.learning_rate,
                    decay=args.decay,
                    momentum=0.0,
                    epsilon=args.rmsprop_epsilon)

        if not args.multi_gpu:
            if args.double_dqn:
                net = qnetwork.DualDeepQNetwork(args.image_height, args.image_width, sess, num_actions,
                                                args.state_frames, args.discount_factor, args.target_net_refresh_rate,
                                                net_type=args.net_type, optimizer=optimizer)
            else:
                net = qnetwork.DeepQNetwork(args.image_height, args.image_width, sess, num_actions, args.state_frames,
                                            args.discount_factor, net_type=args.net_type, optimizer=optimizer)
        else:
            net = multi_gpu_qnetwork.MultiGPUDualDeepQNetwork(args.image_height, args.image_width, sess, num_actions,
                                                              args.state_frames, args.discount_factor,
                                                              optimizer=optimizer, gpus=[0, 1, 2, 3])

        saver = Saver(sess, args.data_dir, args.continue_training)
        if saver.replay_memory_found():
            replay_memory = saver.get_replay_memory()
        else:
            if args.test_mode:
                logging.error('NO SAVED NETWORKS IN TEST MODE!!!')
            replay_memory = ReplayMemoryManager(args.image_height, args.image_width, args.state_frames,
                                                args.replay_memory_size, reward_clip_min=args.reward_clip_min,
                                                reward_clip_max=args.reward_clip_max)

        # todo: add parameters to handle monitor
        monitor = Monitoring(log_train_step_every=100, smooth_episode_scores_over=50)

        agent = Agent(
                game_handler=game_handler,
                qnetwork=net,
                replay_memory=replay_memory,
                saver=saver,
                monitor=monitor,
                train_freq=args.train_freq,
                test_mode=args.test_mode,
                batch_size=args.batch_size,
                save_every_x_episodes=args.saving_freq)

        sess.run(tf.initialize_all_variables())
        saver.restore(args.data_dir)
        start_epsilon = max(args.final_epsilon,
                            args.start_epsilon - saver.get_start_frame() * (args.start_epsilon - args.final_epsilon) / args.exploration_duration)
        exploring_duration = max(args.exploration_duration - saver.get_start_frame(), 1)

        if args.test_mode:
            agent.populate_replay_memory(args.state_frames, force_early_stop=True)
            agent.play_in_test_mode(args.epsilon_in_test_mode)
        else:
            agent.populate_replay_memory(args.min_replay_memory)
            agent.play(train_steps_limit=args.number_of_train_steps, start_eps=start_epsilon,
                       final_eps=args.final_epsilon, exploring_duration=exploring_duration)