コード例 #1
0
def main():
    """
    Constants
    """
    k: int = 20
    epsilon: float = 0.01
    init_val: int = 10
    c: float = 1.0
    max_time: int = 1000
    rounds: int = 1000
    policy: Policy = EpsilonGreedyPolicy(epsilon)
    agent = Agent(k, policy)
    bandits = [Bandit() for _ in range(k)]

    def play_round():
        """
        Simulates one round of the game.
        """
        # get the next action
        action = agent.choose_action()
        # get a reward from the bandit
        reward = bandits[action].get_reward()
        # play the action
        agent.play_action(action, reward)
        return reward

    def reset():
        agent.reset()
        for bandit in bandits:
            bandit.reset()
        optimal_bandit = np.argmax([bandit.get_reward() for bandit in bandits])

    def print_bandits():
        for i, bandit in enumerate(bandits):
            print('Bandit {} reward={}'.format(i, bandit.get_reward()))

    def experiment():
        scores = np.zeros(max_time, dtype=float)
        for _ in range(rounds):
            for t in range(max_time):
                scores[t] += play_round()
            reset()

        return scores / rounds

    def plot(label):
        print_bandits()
        scores = experiment()
        time = range(max_time)
        plt.title(label + " for k = " + str(k))
        plt.ylim([0.0, 2.0])
        plt.xlabel('Steps')
        plt.ylabel('Avg. Reward')
        plt.scatter(x=time, y=scores, s=0.5)
        plt.show()

    plot(policy.__str__())
コード例 #2
0
ファイル: train.py プロジェクト: gabrielhuang/connectfour
def q_learn(board_prototype, nepisodes, alpha, gamma, epsilon):
    '''
    Q-Learning using Epsilon-greedy policy
    http://webdocs.cs.ualberta.ca/~sutton/book/ebook/node65.html
    '''
    global Q
    Q = {}
    for episode in xrange(nepisodes):
        # Create empty board with right size
        board = board_prototype.clone()
        for i in range(board.ncols()*board.nrows()):
            q_greedy_policy = QGreedyPolicy(Q)
            eps_greedy_policy = EpsilonGreedyPolicy(q_greedy_policy, epsilon)
            
            color = Board.BLACK if i%2 else Board.RED
            
            old_state = board.to_tuple()      # s
            
            if color == Board.RED:
                board.flip()
            action = eps_greedy_policy.take_action(board) # a
            winner = board.play(color, action)
            reward = get_reward(board, we_are=Board.BLACK) # r_t
            if color == Board.RED:
                board.flip()            
            
            new_state = board.to_tuple()         # s'
            
            Q.setdefault(old_state, {})
            Q[old_state].setdefault(action, 0.)
            current = Q[old_state][action] # Q(s,a)

            Q.setdefault(new_state, {})
            best = max_action(Q[new_state], value_if_empty=0.) # max_a Q(s',a)
            
            # Q(s,a) <- Q(s,a) + alpha * (r_t + gamma * max_a Q(s',a) - Q(s,a))
            Q[old_state][action] = current + alpha * (reward + gamma * best - current)
            if winner != Board.EMPTY:
                break 
    return Q
コード例 #3
0
def run_bandit(epsilon, n, num_trials, num_sessions):
    # Runs the bandit for a single epsilon, n
    policy = EpsilonGreedyPolicy(epsilon)
    bandit = GaussianBandit(n)
    agent = Agent(n, policy, num_trials)
    env = Environment(bandit, agent, num_trials, num_sessions)
    rewards, num_best = env.run()

    plot_ave_reward(rewards)
    plt.show()

    plot_percent_best_action(num_best)
    plt.show()
コード例 #4
0
def compare_epsilons(n, epsilons):
    # Compare various values of n and epsilon

    # maximizer: epsilon = 1, complete exploration
    # satisficer: epsilon = 0, complete exploitation
    rewards = np.zeros((len(epsilons), num_sessions, num_trials))
    num_best = np.zeros((len(epsilons), num_sessions, num_trials))
    ave_reward = np.zeros((len(epsilons), num_trials))
    cum_reward = np.zeros(num_sessions)
    ave_cum_reward = np.zeros((len(epsilons), 2))

    for i in range(len(epsilons)):
        policy = EpsilonGreedyPolicy(epsilons[i])
        bandit = GaussianBandit(n)
        agent = Agent(n, policy, num_trials)
        env = Environment(bandit, agent, num_trials, num_sessions)
        rewards[i, :, :], num_best[i, :, :] = env.run()

    # Compare average reward across values of epsilon
    color = iter(cm.rainbow(np.linspace(0, 1, len(epsilons))))
    for i in range(len(epsilons)):
        c = next(color)
        ave_reward[i, :] = rewards[i, :, :].mean(axis=0)
        plt.plot(ave_reward[i, :], label="Epsilon:" + str(epsilons[i]), c=c)
        plt.title("Average Reward" + ", n: " + str(n))
        plt.xlabel('Trial')
        plt.ylabel('Reward')
        plt.legend(loc="upper left")
        plt.rc('legend', fontsize='x-small')
    plt.show()

    color2 = iter(cm.rainbow(np.linspace(0, 1, len(epsilons))))
    for i in range(len(epsilons)):
        c = next(color2)
        ave_percent_best = num_best[i, :, :].mean(axis=0)
        plt.plot(ave_percent_best, label="Epsilon:" + str(epsilons[i]), c=c)
        plt.title("Average Percent Best Option" + ", n: " + str(n))
        plt.xlabel('Trial')
        plt.ylabel('Percent Best Option')
        plt.legend(loc="upper left")
        plt.rc('legend', fontsize='x-small')
    plt.show()

    for i in range(len(epsilons)):
        for j in range(num_sessions):
            cum_reward[j] = rewards[i, j, :].sum()
        ave_cum_reward[i, :] = [epsilons[i], np.mean(cum_reward)]
    print(np.shape(cum_reward))
    print(np.shape(ave_cum_reward))
    print(ave_cum_reward)
コード例 #5
0
 def __init__(self, epsilon, lambda_=0.5):
     self.feedSrv = rospy.Service('SuturoMlHeadNextAction', SuturoMlNextAction, self.nextActionCallback)
     self.policyPringPub = rospy.Publisher('SuturoMlPolicy', String, queue_size=10, latch=True)
     self.policy = []
     # self.q = defaultdict(lambda : 10)
     # self.actions = filter(lambda x: x[0].startswith('Const'), [(a,s) for a,s in vars(SuturoMlAction).iteritems()])
     self.actions = ["GRAB-SIDE blue_handle",
                     "GRAB-SIDE red_cube",
                     "TURN",
                     "OPEN-GRIPPER",
                     "GRAB-TOP blue_handle",
                     "GRAB-TOP red_cube",
                     "PLACE-IN-ZONE"]
     self.q = None
     self.policyMaker = EpsilonGreedyPolicy(self.q, self.actions, epsilon)
     # self.policyMaker = Haxx0rPolicy(self.actions)
     self.learner = SarsaLambdaLearner(self.policyMaker, l=lambda_)
     self.q = self.learner.get_q()
     self.policyMaker.updateQ(self.q)
     # self.policyMaker = ReverseGreedyPolicy(self.q, self.actions)
     rospy.wait_for_service('json_prolog/simple_query')
     self.prolog = Prolog()
     print("SuturoMlHeadLearnerPolicyFeeder started.")
コード例 #6
0
def compare_n(n_list):
    # Compare across values of n
    rewards = np.zeros((len(n_list), num_sessions, num_trials))
    num_best = np.zeros((len(n_list), num_sessions, num_trials))
    cum_reward = np.zeros(num_sessions)
    ave_cum_reward = np.zeros((len(n_list), 2))

    for i in range(len(n_list)):
        policy = EpsilonGreedyPolicy(epsilon)
        bandit = GaussianBandit(n_list[i])
        agent = Agent(n_list[i], policy, num_trials)
        env = Environment(bandit, agent, num_trials, num_sessions)
        rewards[i, :, :], num_best[i, :, :] = env.run()

    # Compare average reward across values of epsilon
    color = iter(cm.rainbow(np.linspace(0, 1, len(n_list))))
    for i in range(len(n_list)):
        c = next(color)
        ave_reward = rewards[i, :, :].mean(axis=0)
        plt.plot(ave_reward, label="n:" + str(n_list[i]), c=c)
        plt.title("Average Reward")
        plt.xlabel('Trial')
        plt.ylabel('Reward')
        plt.legend(loc="upper left")
    plt.show()

    color2 = iter(cm.rainbow(np.linspace(0, 1, len(n_list))))
    for i in range(len(n_list)):
        c = next(color2)
        ave_percent_best = num_best[i, :, :].mean(axis=0)
        plt.plot(ave_percent_best, label="n:" + str(n_list[i]), c=c)
        plt.title("Average Percent Best Option")
        plt.xlabel('Trial')
        plt.ylabel('Percent Best Option')
        plt.legend(loc="upper left")
    plt.show()

    for i in range(len(n_list)):
        for j in range(num_sessions):
            cum_reward[j] = rewards[i, j, :].sum()
        ave_cum_reward[i, :] = [n_list[i], np.mean(cum_reward)]
    print(np.shape(cum_reward))
    print(np.shape(ave_cum_reward))
    print(ave_cum_reward)
コード例 #7
0
ファイル: joint2.py プロジェクト: rbcommits/RobotLearning
def main():
    # define arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--render",
                        action="store_true",
                        help="Render the state")
    parser.add_argument("--render_interval",
                        type=int,
                        default=10,
                        help="Number of rollouts to skip before rendering")
    parser.add_argument("--num_rollouts",
                        type=int,
                        default=-1,
                        help="Number of max rollouts")
    parser.add_argument("--logfile",
                        type=str,
                        help="Indicate where to save rollout data")
    parser.add_argument(
        "--load_params",
        type=str,
        help="Load previously learned parameters from [LOAD_PARAMS]")
    parser.add_argument("--save_params",
                        type=str,
                        help="Save learned parameters to [SAVE_PARAMS]")
    args = parser.parse_args()

    signal.signal(signal.SIGINT, stopsigCallback)
    global stopsig

    # create the basketball environment
    env = BasketballVelocityEnv(fps=60.0,
                                timeInterval=0.1,
                                goal=[0, 5, 0],
                                initialLengths=np.array([0, 0, 1, 1, 0, 0, 0]),
                                initialAngles=np.array([0, 45, 0, 0, 0, 0, 0]))

    # create space
    stateSpace = ContinuousSpace(ranges=env.state_range())
    actionRange = env.action_range()
    actionSpace = DiscreteSpace(
        intervals=[15 for i in range(2)] + [1],
        ranges=[actionRange[1], actionRange[2], actionRange[7]])
    processor = JointProcessor(actionSpace)

    # create the model and policy functions
    modelFn = MxFullyConnected(sizes=[stateSpace.n + actionSpace.n, 64, 32, 1],
                               alpha=0.001,
                               use_gpu=True)
    if args.load_params:
        print("loading params...")
        modelFn.load_params(args.load_params)

    softmax = lambda s: np.exp(s) / np.sum(np.exp(s))
    policyFn = EpsilonGreedyPolicy(
        epsilon=0.5,
        getActionsFn=lambda state: actionSpace.sample(1024),
        distributionFn=lambda qstate: softmax(modelFn(qstate)))
    dataset = ReplayBuffer()
    if args.logfile:
        log = open(args.logfile, "a")

    rollout = 0
    while args.num_rollouts == -1 or rollout < args.num_rollouts:
        print("Iteration:", rollout)
        state = env.reset()
        reward = 0
        done = False
        steps = 0
        while not done:
            if stopsig:
                break
            action = policyFn(state)
            nextState, reward, done, info = env.step(
                createAction(processor.process_env_action(action)))
            dataset.append(state, action, reward, nextState)
            state = nextState
            steps += 1
            if args.render and rollout % args.render_interval == 0:
                env.render()
        if stopsig:
            break

        dataset.reset()  # push trajectory into the dataset buffer
        modelFn.fit(processor.process_Q(dataset.sample(1024)), num_epochs=10)
        print("Reward:", reward if (reward >= 0.00001) else 0, "with Error:",
              modelFn.score(), "with steps:", steps)
        if args.logfile:
            log.write("[" + str(rollout) + ", " + str(reward) + ", " +
                      str(modelFn.score()) + "]\n")

        rollout += 1
        if rollout % 100 == 0:
            policyFn.epsilon *= 0.95
            print("Epsilon is now:", policyFn.epsilon)

    if args.logfile:
        log.close()
    if args.save_params:
        print("saving params...")
        modelFn.save_params(args.save_params)
コード例 #8
0
	
	return episode


if __name__ == '__main__':
	track = read_track('track.bmp')
	env = TrackEnvironment(track)

	Q = QFunction(track.size(), VELOCITY_RANGE)
	Pi = build_max_policy(Q)
	
	rewards = []
	best_reward_for_report_range = -10000
	
	for i in range(TRAIN_STEPS):
		soft_policy = EpsilonGreedyPolicy(EPSILON, Pi, NUM_ACTIONS)
		episode_data = rollout(env, soft_policy)
		episode = episode_data.get_steps()
		rewards.append(episode_data.get_total_reward())
		
		best_reward_for_report_range = max(best_reward_for_report_range, episode_data.get_total_reward())
		
		if i % REPORT_EVERY == 0:
			print("Training step {}...".format(i))
			print("Best reward: {}".format(best_reward_for_report_range))
			best_reward_for_report_range = -1000
			print("Average reward: {}".format(np.average(rewards[max(0, i - REPORT_EVERY):])))
		
		T = len(episode) - 1
		
		G = 0.0
コード例 #9
0
             'C': [(c0, ), (c0, ), (c0, ), (2 * c0, )]
         }
     else:
         kwargs = {
             'D': D,
             'M': M,
             'learning_rate': args.alpha,
             'F': [(1, 1)],
             'c0': c0,
             'C': [(c0, )]
         }
     graph_f = cnn.build_graph
 else:
     kwargs = {'D': D, 'M': M, 'learning_rate': args.alpha}
     graph_f = ann.build_graph
 pol = EpsilonGreedyPolicy(eps=1.0, decay_f=decay_f)
 pol.n = args.eps_start
 if args.mode == 'leaf':
     sv = tdleaf.TDLeafSupervisor(pol,
                                  mv_limit=args.move_count,
                                  depth=args.depth,
                                  y=args.gamma,
                                  l=args.lambd)
 else:
     sv = tdstem.TDStemSupervisor(pol,
                                  mv_limit=args.move_count,
                                  depth=args.depth,
                                  y=args.gamma,
                                  l=args.lambd)
 sv.run(args.I,
        args.N,
コード例 #10
0
ファイル: dqn.py プロジェクト: timrobot/ArmRL
def main():
    # define arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--render",
                        action="store_true",
                        help="Render the state")
    parser.add_argument("--render_interval",
                        type=int,
                        default=10,
                        help="Number of rollouts to skip before rendering")
    parser.add_argument("--num_rollouts",
                        type=int,
                        default=-1,
                        help="Number of max rollouts")
    parser.add_argument("--logfile",
                        type=str,
                        help="Indicate where to save rollout data")
    parser.add_argument(
        "--load_params",
        type=str,
        help="Load previously learned parameters from [LOAD_PARAMS]")
    parser.add_argument("--save_params",
                        type=str,
                        help="Save learned parameters to [SAVE_PARAMS]")
    parser.add_argument("--silent",
                        action="store_true",
                        help="Suppress print of the DQN config")
    parser.add_argument("--gamma",
                        type=float,
                        default=0.99,
                        help="Discount factor")
    parser.add_argument("--epsilon",
                        type=float,
                        default=0.1,
                        help="Random factor (for Epsilon-greedy)")
    parser.add_argument("--eps_anneal",
                        type=int,
                        default=0,
                        help="The amount of episodes to anneal epsilon by")
    parser.add_argument("--sample_size",
                        type=int,
                        default=256,
                        help="Number of samples from the dataset per episode")
    parser.add_argument("--num_epochs",
                        type=int,
                        default=50,
                        help="Number of epochs to run per episode")
    parser.add_argument("--episode_length",
                        type=int,
                        default=128,
                        help="Number of rollouts per episode")
    parser.add_argument("--noise",
                        type=float,
                        help="Amount of noise to add to the actions")
    parser.add_argument("--test", action="store_true", help="Test the params")
    args = parser.parse_args()

    signal.signal(signal.SIGINT, stopsigCallback)
    global stopsig

    # create the basketball environment
    env = BasketballVelocityEnv(fps=60.0,
                                timeInterval=1.0,
                                goal=[0, 5, 0],
                                initialLengths=np.array([0, 0, 1, 1, 1, 0, 1]),
                                initialAngles=np.array(
                                    [0, 45, -20, -20, 0, -20, 0]))

    # create space
    stateSpace = ContinuousSpace(ranges=env.state_range())
    actionSpace = DiscreteSpace(intervals=[25 for i in range(7)] + [1],
                                ranges=env.action_range())
    processor = DQNProcessor(actionSpace)

    # create the model and policy functions
    modelFn = DQNNetwork(
        sizes=[stateSpace.n + actionSpace.n, 128, 256, 256, 128, 1],
        alpha=0.001,
        use_gpu=True,
        momentum=0.9)
    if args.load_params:
        print("Loading params...")
        modelFn.load_params(args.load_params)

    allActions = actionSpace.sampleAll()
    policyFn = EpsilonGreedyPolicy(
        epsilon=args.epsilon if not args.test else 0,
        getActionsFn=lambda state: allActions,
        distributionFn=lambda qstate: modelFn(qstate),
        processor=processor)
    replayBuffer = RingBuffer(max_limit=2048)
    if args.logfile:
        log = open(args.logfile, "a")

    if not args.silent:
        print("Env space range:", env.state_range())
        print("Env action range:", env.action_range())
        print("State space:", stateSpace.n)
        print("Action space:", actionSpace.n)
        print("Action space bins:", actionSpace.bins)
        print("Epsilon:", args.epsilon)
        print("Epsilon anneal episodes:", args.eps_anneal)
        print("Gamma:", args.gamma)
        __actionShape = policyFn.getActions(None).shape
        totalActions = np.prod(actionSpace.bins)
        print("Actions are sampled:", __actionShape[0] != totalActions)
        print("Number of actions:", totalActions)

    rollout = 0
    if not args.silent and not args.test:
        iterationBar = ProgressBar(maxval=args.episode_length)
    while args.num_rollouts == -1 or rollout < args.num_rollouts:
        if stopsig: break
        if not args.silent and not args.test:
            iterationBar.printProgress(rollout % args.episode_length,
                                       prefix="Query(s,a,s',r)",
                                       suffix="epsilon: " +
                                       str(policyFn.epsilon))
        state = env.reset()
        reward = 0
        done = False
        steps = 0
        while not done and steps < 5:  # 5 step max
            action = policyFn(state)
            if steps == 4:  # throw immediately
                action[-2] = 0
                action[-1] = 1
            envAction = processor.process_env_action(action)
            if args.noise:
                envAction[:7] += np.random.normal(scale=np.ones([7]) *
                                                  args.noise)
            nextState, reward, done, info = env.step(envAction)
            replayBuffer.append([state, action, nextState, reward, done])
            if args.test and done: print("Reward:", reward)
            state = nextState
            steps += 1
            if args.render and (rollout + 1) % args.render_interval == 0:
                env.render()

        rollout += 1
        if args.eps_anneal > 0:  # linear anneal
            epsilon_diff = args.epsilon - min(0.1, args.epsilon)
            policyFn.epsilon = args.epsilon - min(rollout, args.eps_anneal) / \
                float(args.eps_anneal) * epsilon_diff

        if rollout % args.episode_length == 0 and not args.test:
            dataset = replayBuffer.sample(args.sample_size)
            states = np.array([d[0] for d in dataset])
            actions = np.array([d[1] for d in dataset])
            nextStates = [d[2] for d in dataset]
            rewards = np.array([[d[3]]
                                for d in dataset])  # rewards require extra []
            terminal = [d[4] for d in dataset]

            QS0 = processor.process_Qstate(states, actions)
            Q1 = np.zeros(rewards.shape, dtype=np.float32)
            if not args.silent:
                progressBar = ProgressBar(maxval=len(nextStates))
            for i, nextState in enumerate(nextStates):
                if stopsig: break
                if not args.silent:
                    progressBar.printProgress(i,
                                              prefix="Creating Q(s,a)",
                                              suffix="%s / %s" %
                                              (i + 1, len(nextStates)))
                if terminal[i]: continue  # 0
                dist = modelFn(
                    processor.process_Qstate(
                        repmat(nextState, allActions.shape[0], 1), allActions))
                Q1[i, 0] = np.max(dist)  # max[a' in A]Q(s', a')
            if stopsig: break
            Q0_ = rewards + args.gamma * Q1
            modelFn.fit({
                "qstates": QS0,
                "qvalues": Q0_
            },
                        num_epochs=args.num_epochs)

            avgQ = np.sum(Q0_) / Q0_.shape[0]
            avgR = np.sum(rewards) / rewards.shape[0]
            print("Rollouts:", rollout, "Error:", modelFn.score(),
                  "Average Q:", avgQ, "Average R:", avgR)
            print("")
            if args.logfile:
                log.write("[" + str(rollout) + ", " + str(modelFn.score()) +
                          ", " + str(avgQ) + ", " + str(avgR) + "]\n")

    if args.logfile:
        log.close()
    if args.save_params:
        print("Saving params...")
        modelFn.save_params(args.save_params)
コード例 #11
0
class SuturoMlHeadLearner(object):

    def __init__(self, epsilon, lambda_=0.5):
        self.feedSrv = rospy.Service('SuturoMlHeadNextAction', SuturoMlNextAction, self.nextActionCallback)
        self.policyPringPub = rospy.Publisher('SuturoMlPolicy', String, queue_size=10, latch=True)
        self.policy = []
        # self.q = defaultdict(lambda : 10)
        # self.actions = filter(lambda x: x[0].startswith('Const'), [(a,s) for a,s in vars(SuturoMlAction).iteritems()])
        self.actions = ["GRAB-SIDE blue_handle",
                        "GRAB-SIDE red_cube",
                        "TURN",
                        "OPEN-GRIPPER",
                        "GRAB-TOP blue_handle",
                        "GRAB-TOP red_cube",
                        "PLACE-IN-ZONE"]
        self.q = None
        self.policyMaker = EpsilonGreedyPolicy(self.q, self.actions, epsilon)
        # self.policyMaker = Haxx0rPolicy(self.actions)
        self.learner = SarsaLambdaLearner(self.policyMaker, l=lambda_)
        self.q = self.learner.get_q()
        self.policyMaker.updateQ(self.q)
        # self.policyMaker = ReverseGreedyPolicy(self.q, self.actions)
        rospy.wait_for_service('json_prolog/simple_query')
        self.prolog = Prolog()
        print("SuturoMlHeadLearnerPolicyFeeder started.")


    def nextActionCallback(self, nextActionRequest):
        r = SuturoMlNextActionResponse()
        r.action.action = self.policyMaker.getNextAction(nextActionRequest.state)
        return r

    def doTheShit(self):
        q = self.prolog.query("suturo_learning:get_learning_sequence(A)")
        print("start learning")
        for solution in q.solutions():
            # print sol
            policy = solution["A"]
            self.q = self.learner.learn(policy)
            self.policyMaker.updateQ(self.q)
        print("learning done.\n")

        ppp = defaultdict(lambda : ((-999999999999,-999999999999),))
        tmp_q = deepcopy(self.q)
        for s in self.q.iterkeys():
            state = s[0]
            for a in self.actions:
                b = tmp_q[(state,a)]
                if ppp[state][0][1] == b and not ppp[state].__contains__((a,b)):
                    ppp[state] = ppp[state] + ((a,b),)
                elif ppp[state][0][1] < b:
                    ppp[state] = ((a,b),)

        # for a,b in self.q.iteritems():
        #     print a,b
        muh = []
        for a,b in ppp.iteritems():
            muh.append((a, b))

        def cmpmuh(x,y):
            for i in range(len(x[0])):
                if x[0][i] > y[0][i]:
                    return 1
                elif x[0][i] < y[0][i]:
                    return -1
            if x[1][1] > y[1][1]:
                return 1
            elif x[1][1] < y[1][1]:
                return -1
            return 0

        muh.sort(cmp=cmpmuh)
        msg = ""
        for a in muh:
            print a
            msg += str(a[0]) +"\n" +str(a[1]) + "\n\n"
        s = String(msg)
        self.policyPringPub.publish(s)

    def pub_policy(self):
        pass
コード例 #12
0
                        type=float)
    parser.add_argument('--lambd', default=0.7, help='lambda', type=float)
    parser.add_argument('--eps-start',
                        default=0,
                        help='epsilon decay',
                        type=int)
    args = parser.parse_args()

    if args.old_model == None:
        D = faster_featurize('8/6k1/2R5/8/3K4/8/8/8 w - -').shape[1]
        print D
        M = [int(m) for m in args.M.split()]
        kwargs = {'D': D, 'M': M, 'learning_rate': 1e-4}
        graph_f = ann.build_graph
        pol = EpsilonGreedyPolicy(eps=1.0,
                                  decay_f=lambda n:
                                  (n + 1)**(-args.eps_factor))
        pol.n = args.eps_start
        sv = TDStemSupervisor(pol,
                              mv_limit=args.move_count,
                              depth=args.depth,
                              y=args.gamma,
                              l=args.lambd)
        sv.run(args.I,
               args.N,
               graph_f,
               kwargs,
               state=args.state,
               name=args.name)
    else:
        assert args.ckpt is not None