Beispiel #1
0
def evaluate_model(barista_net, model, num_batches):
    set_net_params(barista_net.net, model)
    avg_q = 0
    for _ in xrange(num_batches):
        barista_net.load_minibatch()
        barista_net.forward(end='Q_out')
        avg_q += np.mean(barista_net.blobs['Q_out'].data)
        # print barista_net.blobs['Q_out'].data.squeeze()

    avg_q /= num_batches
    return avg_q
Beispiel #2
0
def main():
    args = get_args()

    # Instantiate network
    bnet = baristanet.BaristaNet(args.architecture, args.model, None)

    # load parameters from checkpoint into the model
    params = load_saved_checkpoint(args.checkpoint)
    netutils.set_net_params(bnet.net, params)

    # Initialize game player
    replay_dataset = ReplayDataset("temp-dset.hdf5", bnet.state[0].shape,
                                   dset_size=300*args.num_games,
                                   overwrite=True)

    game = SnakeGame()
    preprocessor = generate_preprocessor(bnet.state.shape[2:], gray_scale)
    exp_gain = ExpGain(bnet, ['w', 'a', 's', 'd'], preprocessor, game.cpu_play,
                       replay_dataset, game.encode_state())

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    # Generate experiences
    frame_index = 0
    num_games_played = 0
    state = exp_gain.get_preprocessed_state()
    while num_games_played < args.num_games:
        # select action
        if random.random() < args.epsilon:
            action = random.choice(exp_gain.actions)
        else:
            idx = bnet.select_action([state])
            action = exp_gain.actions[idx]

        exp_gain.play_action(action)

        # Render frame
        frame = gray_scale(exp_gain.sequence[-1].reshape((1,)+exp_gain.sequence[-1].shape))[-1]

        big_frame = cv2.resize(frame, (0,0), fx=10, fy=10, interpolation=cv2.INTER_NEAREST) 
        cv2.imwrite(os.path.join(args.output_dir, "frame-%d.png" % frame_index), big_frame)
        frame_index += 1
        # cv2.imshow("Game", frame)
        # cv2.waitKey(33)

        # Check if Snake has died
        if exp_gain.game_over:
            print "Game Over"
            exp_gain.reset_game()
            num_games_played += 1

        # Get next state
        state = exp_gain.get_preprocessed_state()
    def evaluate(self, model, num_trials):
        """ Runs |num_trials| games and returns average score. """
        for eg in self.engines:
            set_net_params(eg.net.net, model)
            eg.reset_game()

        total_score = 0
        trials_completed = 0
        scores = [0] * self.batch_size
        while trials_completed < num_trials:
            states = [eg.get_preprocessed_state() for eg in self.engines]
            actions = self.net.select_action(states,
                                             batch_size=self.batch_size)
            for i, (action, eg) in enumerate(zip(actions, self.engines)):
                scores[i] += eg.play_action(eg.actions[action])
                if eg.game_over:
                    total_score += scores[i]
                    trials_completed += 1
                    if trials_completed == num_trials:
                        break
                    eg.reset_game()
                    scores[i] = 0

        return float(total_score)/num_trials