Example #1
0
    def update(self, labels, preds):
        for label, pred_label in zip(labels, preds):
            if pred_label.shape != label.shape:
                pred_label = ndarray.argmax_channel(pred_label)
            pred_label = pred_label.asnumpy().astype('int32').flatten()
            label = label.asnumpy().astype('int32').flatten()

            idx = np.where(label != -1)
            self.sum_metric += (pred_label[idx] != label[idx]).sum()
            self.num_inst += len(pred_label[idx])
Example #2
0
    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        for i in range(len(labels)):
            pred_label = ndarray.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')

            mx.metric.check_label_shapes(label, pred_label)
            
            ind = np.nonzero(label.flat)
            pred_label_real = pred_label.flat[ind]
            #print label, pred_label, ind
            label_real = label.flat[ind]
            self.sum_metric += (pred_label_real == label_real).sum()
            self.num_inst += len(pred_label_real)
Example #3
0
    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        for i in range(len(labels)):
            pred_label = ndarray.argmax_channel(
                preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')

            mx.metric.check_label_shapes(label, pred_label)

            ind = np.nonzero(label.flat)
            pred_label_real = pred_label.flat[ind]
            #print label, pred_label, ind
            label_real = label.flat[ind]
            self.sum_metric += (pred_label_real == label_real).sum()
            self.num_inst += len(pred_label_real)
    def update(self, labels, preds):
        check_label_shapes(labels, preds)

        for label, pred_label in zip(labels, preds):
            if pred_label.shape != label.shape:
                pred_label = ndarray.argmax_channel(pred_label)
            pred_label = pred_label.asnumpy().astype(np.int32).reshape(
                pred_label.shape[0], -1)
            label = label.asnumpy().astype(np.int32)

            check_label_shapes(label, pred_label)

            valid_index = label != 255
            self.sum_metric += (
                label[valid_index] == pred_label[valid_index]).sum()
            self.num_inst += valid_index.sum()
    def update(self, labels, preds):
        check_label_shapes(labels, preds)

        for label, pred_label in zip(labels, preds):
            pred_label = ndarray.argmax_channel(pred_label).asnumpy().astype('int32')
            label = label.reshape((label.size,)).asnumpy().astype('int32')

            check_label_shapes(label, pred_label)
            
            #mask = np.logical_and(label!=self.mask, label!=0)
            #pred_label = pred_label[mask]
            #label = label[mask]

            #logging.debug("EVAL: label = {0}, pred = {1}".format(str(label), str(pred_label)))

            self.sum_metric += (pred_label.flat == label.flat).sum()
            self.num_inst += len(pred_label.flat)
Example #6
0
    def train_one_batch(self, st, stpo, at, rt, tt, unfreeze_weight=False):
        """
        st : state_at_time_t
        stpo : state_at_time_t_plus_one
        at : action at time t
        rt : reward at time t --> instant reward
        tt : termination of the game at time t : 0 or 1

        From all the state needed for a forward backward,
        each sub result is calculated to reach the loss backward
        """
        # 1 : get the y_ddqn
        # The formula is Yddqn = Rt + Q(st+1, a*, theta-)
        # a) forward on the q vector (net with theta -)
        target_q = self.target_q_mod.forward(is_train=False, data=stpo)
        # b) retrieve a* with theta_i param : a = argmax Q(stpo, theta_i)
        a_q = self.loss_q_mod.forward(is_train=False, data=stpo)
        a = nd.argmax_channel(a_q[0])
        # c) combine q vec with the a*
        y_ddqn = rt + self.gamma * (tt) * nd.choose_element_0index(
            target_q[0], a)

        # 2 : build the loss on the current batch
        # Here some optim is done with the full netwok loss waiting for a y_ddqn
        # The loss is just | y_ddqn - Q(st, at) |
        current_q = self.loss_q_mod.forward(is_train=True,
                                            data=st,
                                            loss_action=a,
                                            loss_target=y_ddqn)
        print(current_q[0], at, rt)
        self.loss_q_mod.backward()

        # 3 : Update parameters
        self.update_weights(self.loss_q_mod, self.updater)

        # 4 : Calculate the loss
        loss = nd.sum((y_ddqn - nd.choose_element_0index(current_q[0], at))**2,
                      axis=0)

        # 5 : the occasional forward weight updater
        if unfreeze_weight:
            self.copy_to_freezed_network()

        return loss[0]
Example #7
0
def calculate_avg_reward(game, qnet, test_steps=125000, exploartion=0.05):
    game.force_restart()
    action_num = len(game.action_set)
    total_reward = 0
    steps_left = test_steps
    episode = 0
    while steps_left > 0:
        # Running New Episode
        episode += 1
        episode_q_value = 0.0
        game.begin_episode(steps_left)
        start = time.time()
        while not game.episode_terminate:
            # 1. We need to choose a new action based on the current game status
            if game.state_enabled:
                do_exploration = (npy_rng.rand() < exploartion)
                if do_exploration:
                    action = npy_rng.randint(action_num)
                else:
                    # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                    # We can simply stack the current_state() of gaming instances and give prediction for all of them
                    # We need to wait after calling calc_score(.), which makes the program slow
                    # TODO Profiling the speed of this part!
                    current_state = game.current_state()
                    state = nd.array(
                        current_state.reshape((1, ) + current_state.shape),
                        ctx=qnet.ctx) / float(255.0)
                    action = int(
                        nd.argmax_channel(
                            qnet.forward(is_train=False,
                                         data=state)[0]).asscalar())
            else:
                action = npy_rng.randint(action_num)

            # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
            game.play(action)
        end = time.time()
        steps_left -= game.episode_step
        print('Episode:%d, FPS:%s, Steps Left:%d, Reward:%d' \
              % (episode, game.episode_step / (end - start), steps_left, game.episode_reward))
        total_reward += game.episode_reward
    avg_reward = total_reward / float(episode)
    return avg_reward
Example #8
0
def calculate_avg_reward(game, qnet, test_steps=125000, exploartion=0.05):
    game.force_restart()
    action_num = len(game.action_set)
    total_reward = 0
    steps_left = test_steps
    episode = 0
    while steps_left > 0:
        # Running New Episode
        episode += 1
        episode_q_value = 0.0
        game.begin_episode(steps_left)
        start = time.time()
        while not game.episode_terminate:
            # 1. We need to choose a new action based on the current game status
            if game.state_enabled:
                do_exploration = (npy_rng.rand() < exploartion)
                if do_exploration:
                    action = npy_rng.randint(action_num)
                else:
                    # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                    # We can simply stack the current_state() of gaming instances and give prediction for all of them
                    # We need to wait after calling calc_score(.), which makes the program slow
                    # TODO Profiling the speed of this part!
                    current_state = game.current_state()
                    state = nd.array(current_state.reshape((1,) + current_state.shape),
                                     ctx=qnet.ctx) / float(255.0)
                    action = nd.argmax_channel(
                        qnet.forward(is_train=False, data=state)[0]).asscalar()
            else:
                action = npy_rng.randint(action_num)

            # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
            game.play(action)
        end = time.time()
        steps_left -= game.episode_step
        print('Episode:%d, FPS:%s, Steps Left:%d, Reward:%d' \
              % (episode, game.episode_step / (end - start), steps_left, game.episode_reward))
        total_reward += game.episode_reward
    avg_reward = total_reward / float(episode)
    return avg_reward
Example #9
0
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('arena', 'games', 'roms',
                                             'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient',
                        required=False,
                        type=float,
                        default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q',
                        required=False,
                        type=bool,
                        default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd',
                        required=False,
                        type=float,
                        default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default='gpu',
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Saving directory of model files.')
    parser.add_argument(
        '--start-eps',
        required=False,
        type=float,
        default=1.0,
        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size',
                        required=False,
                        type=int,
                        default=50000,
                        help='The step that the training starts')
    parser.add_argument(
        '--kvstore-update-period',
        required=False,
        type=int,
        default=1,
        help='The period that the worker updates the parameters from the sever'
    )
    parser.add_argument(
        '--kv-type',
        required=False,
        type=str,
        default=None,
        help=
        'type of kvstore, default will not use kvstore, could also be dist_async'
    )
    parser.add_argument('--optimizer',
                        required=False,
                        type=str,
                        default="adagrad",
                        help='type of optimizer')
    args = parser.parse_args()

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s-lr%g' % (rom_name, args.lr)
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84

    ctx = parse_ctx(args.ctx)
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom,
                     resize_mode='scale',
                     replay_start_size=replay_start_size,
                     resized_rows=rows,
                     resized_cols=cols,
                     max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size,
                     display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - eps_min) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }
    dqn_sym = dqn_sym_nature(action_num)
    qnet = Base(data_shapes=data_shapes,
                sym_gen=dqn_sym,
                name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    use_easgd = False
    if args.optimizer != "easgd":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        rescale_grad=1.0,
                                        wd=args.wd)
    else:
        use_easgd = True
        easgd_beta = 0.9
        easgd_p = 4
        easgd_alpha = easgd_beta / (args.kvstore_update_period * easgd_p)
        server_optimizer = mx.optimizer.create(name="ServerEASGD",
                                               learning_rate=easgd_alpha)
        easgd_eta = 0.00025
        local_optimizer = mx.optimizer.create(name='adagrad',
                                              learning_rate=args.lr,
                                              eps=args.eps,
                                              clip_gradient=args.clip_gradient,
                                              rescale_grad=1.0,
                                              wd=args.wd)
        central_weight = OrderedDict([(n, nd.zeros(v.shape, ctx=q_ctx))
                                      for n, v in qnet.params.items()])
    # Create KVStore
    if args.kv_type != None:
        kv = kvstore.create(args.kv_type)

        #Initialize KVStore
        for idx, v in enumerate(qnet.params.values()):
            kv.init(idx, v)

        # Set Server optimizer on KVStore
        if not use_easgd:
            kv.set_optimizer(optimizer)
        else:
            kv.set_optimizer(server_optimizer)
            local_updater = mx.optimizer.get_updater(local_optimizer)
        kvstore_update_period = args.kvstore_update_period
        args.dir_path = args.dir_path + "-" + str(kv.rank)
    else:
        updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(
                            current_state.reshape((1, ) + current_state.shape),
                            ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(is_train=False,
                                                data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states,
                                           ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        qval = qnet.forward(is_train=False,
                                            data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(is_train=True,
                                           data=states,
                                           dqn_action=actions,
                                           dqn_reward=target_rewards)
                    qnet.backward()

                    if args.kv_type != None:
                        if use_easgd:
                            if total_steps % kvstore_update_period == 0:
                                for ind, k in enumerate(qnet.params.keys()):
                                    kv.pull(ind,
                                            central_weight[k],
                                            priority=-ind)
                                    qnet.params[k][:] -= easgd_alpha * \
                                                         (qnet.params[k] - central_weight[k])
                                    kv.push(ind, qnet.params[k], priority=-ind)
                            qnet.update(updater=local_updater)
                        else:
                            update_on_kvstore(kv, qnet.params,
                                              qnet.params_grad)
                    else:
                        qnet.update(updater=updater)

                    # 3.3 Calculate Loss
                    diff = nd.abs(
                        nd.choose_element_0index(outputs[0], actions) -
                        target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = 0.5 * nd.sum(nd.square(quadratic_part)).asnumpy()[0] +\
                           nd.sum(diff - quadratic_part).asnumpy()[0]
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    # (We can do annealing instead of hard copy)
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            if args.kv_type != None:
                info_str = "Node[%d]: " % kv.rank
            else:
                info_str = ""
            info_str += "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (
                    episode_loss / episode_update_step, episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (
                    episode_q_value / episode_action_step, episode_action_step)
            logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        if args.kv_type is not None:
            logging.info(
                "Node[%d]: Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                (kv.rank, epoch, fps, epoch_reward / float(episode), episode))
        else:
            logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                         (epoch, fps, epoch_reward / float(episode), episode))
Example #10
0
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('arena', 'games', 'roms',
                                             'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--rms-decay',
                        required=False,
                        type=float,
                        default=0.95,
                        help='Decay rate of the RMSProp')
    parser.add_argument('--clip-gradient',
                        required=False,
                        type=float,
                        default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q',
                        required=False,
                        type=bool,
                        default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd',
                        required=False,
                        type=float,
                        default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default=None,
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Saving directory of model files.')
    parser.add_argument(
        '--start-eps',
        required=False,
        type=float,
        default=1.0,
        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size',
                        required=False,
                        type=int,
                        default=50000,
                        help='The step that the training starts')
    parser.add_argument(
        '--kvstore-update-period',
        required=False,
        type=int,
        default=16,
        help='The period that the worker updates the parameters from the sever'
    )
    parser.add_argument(
        '--kv-type',
        required=False,
        type=str,
        default=None,
        help=
        'type of kvstore, default will not use kvstore, could also be dist_async'
    )
    parser.add_argument('--optimizer',
                        required=False,
                        type=str,
                        default="adagrad",
                        help='type of optimizer')
    parser.add_argument('--nactor',
                        required=False,
                        type=int,
                        default=16,
                        help='number of actor')
    parser.add_argument('--exploration-period',
                        required=False,
                        type=int,
                        default=4000000,
                        help='length of annealing of epsilon greedy policy')
    parser.add_argument('--replay-memory-size',
                        required=False,
                        type=int,
                        default=100,
                        help='size of replay memory')
    parser.add_argument('--single-batch-size',
                        required=False,
                        type=int,
                        default=5,
                        help='batch size for every actor')
    parser.add_argument('--symbol',
                        required=False,
                        type=str,
                        default="nature",
                        help='type of network, nature or nips')
    parser.add_argument('--sample-policy',
                        required=False,
                        type=str,
                        default="recent",
                        help='minibatch sampling policy, recent or random')
    parser.add_argument('--epoch-num',
                        required=False,
                        type=int,
                        default=50,
                        help='number of epochs')
    parser.add_argument('--param-update-period',
                        required=False,
                        type=int,
                        default=5,
                        help='Parameter update period')
    parser.add_argument('--resize-mode',
                        required=False,
                        type=str,
                        default="scale",
                        help='Resize mode, scale or crop')
    parser.add_argument('--eps-update-period',
                        required=False,
                        type=int,
                        default=8000,
                        help='eps greedy policy update period')
    parser.add_argument('--server-optimizer',
                        required=False,
                        type=str,
                        default="easgd",
                        help='type of server optimizer')
    parser.add_argument('--nworker',
                        required=False,
                        type=int,
                        default=1,
                        help='number of kv worker')
    parser.add_argument('--easgd-alpha',
                        required=False,
                        type=float,
                        default=0.01,
                        help='easgd alpha')
    args, unknown = parser.parse_known_args()
    logging.info(str(args))

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        time_str = time.strftime("%m%d_%H%M_%S", time.localtime())
        args.dir_path = ('dqn-%s-%d_' % (rom_name,int(args.lr*10**5)))+time_str \
                        + "_" + os.environ.get('DMLC_TASK_ID')
        logging.info("saving to dir: " + args.dir_path)
    if args.ctx == None:
        args.ctx = os.environ.get('CTX')
    logging.info("Context: %s" % args.ctx)
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) > 0 else (device, 0)
           for device, num in ctx]

    # Async verision
    nactor = args.nactor
    param_update_period = args.param_update_period

    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = args.replay_memory_size
    history_length = 4
    rows = 84
    cols = 84
    q_ctx = mx.Context(*ctx[0])
    games = []
    for g in range(nactor):
        games.append(
            AtariGame(rom_path=args.rom,
                      resize_mode=args.resize_mode,
                      replay_start_size=replay_start_size,
                      resized_rows=rows,
                      resized_cols=cols,
                      max_null_op=max_start_nullops,
                      replay_memory_size=replay_memory_size,
                      display_screen=args.visualization,
                      history_length=history_length))

    ##RUN NATURE
    freeze_interval = 40000 / nactor
    freeze_interval /= param_update_period
    epoch_num = args.epoch_num
    steps_per_epoch = 4000000 / nactor
    discount = 0.99
    save_screens = False
    eps_start = numpy.ones((3, )) * args.start_eps
    eps_min = numpy.array([0.1, 0.01, 0.5])
    eps_decay = (eps_start - eps_min) / (args.exploration_period / nactor)
    eps_curr = eps_start
    eps_id = numpy.zeros((nactor, ))
    eps_update_period = args.eps_update_period
    eps_update_count = numpy.zeros((nactor, ))

    single_batch_size = args.single_batch_size
    minibatch_size = nactor * single_batch_size
    action_num = len(games[0].action_set)
    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }

    if args.symbol == "nature":
        dqn_sym = dqn_sym_nature(action_num)
    elif args.symbol == "nips":
        dqn_sym = dqn_sym_nips(action_num)
    else:
        raise NotImplementedError
    qnet = Base(data_shapes=data_shapes,
                sym=dqn_sym,
                name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    if args.optimizer == "adagrad":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        rescale_grad=1.0,
                                        wd=args.wd)
    elif args.optimizer == "rmsprop" or args.optimizer == "rmspropnoncentered":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        gamma1=args.rms_decay,
                                        gamma2=0,
                                        rescale_grad=1.0,
                                        wd=args.wd)
        lr_decay = (args.lr - 0) / (steps_per_epoch * epoch_num /
                                    param_update_period)

    # Create kvstore
    use_easgd = False
    if args.kv_type != None:
        kvType = args.kv_type
        kv = kvstore.create(kvType)
        #Initialize kvstore
        for idx, v in enumerate(qnet.params.values()):
            kv.init(idx, v)
        if args.server_optimizer == "easgd":
            use_easgd = True
            easgd_beta = 0.9
            easgd_alpha = args.easgd_alpha
            server_optimizer = mx.optimizer.create(name="ServerEasgd",
                                                   learning_rate=easgd_alpha)
            easgd_eta = 0.00025
            central_weight = OrderedDict([(n, v.copyto(q_ctx))
                                          for n, v in qnet.params.items()])
            kv.set_optimizer(server_optimizer)
            updater = mx.optimizer.get_updater(optimizer)
        else:
            kv.set_optimizer(optimizer)
        kvstore_update_period = args.kvstore_update_period
        npy_rng = numpy.random.RandomState(123456 + kv.rank)
    else:
        updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    states_buffer_for_act = numpy.zeros(
        (nactor, history_length) + (rows, cols), dtype='uint8')
    states_buffer_for_train = numpy.zeros(
        (minibatch_size, history_length + 1) + (rows, cols), dtype='uint8')
    next_states_buffer_for_train = numpy.zeros(
        (minibatch_size, history_length) + (rows, cols), dtype='uint8')
    actions_buffer_for_train = numpy.zeros((minibatch_size, ), dtype='uint8')
    rewards_buffer_for_train = numpy.zeros((minibatch_size, ), dtype='float32')
    terminate_flags_buffer_for_train = numpy.zeros((minibatch_size, ),
                                                   dtype='bool')
    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    ave_fps = 0
    ave_loss = 0
    time_for_info = time.time()
    parallel_executor = concurrent.futures.ThreadPoolExecutor(nactor)
    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        #
        for g, game in enumerate(games):
            game.start()
            game.begin_episode()
            eps_rand = npy_rng.rand()
            if eps_rand < 0.4:
                eps_id[g] = 0
            elif eps_rand < 0.7:
                eps_id[g] = 1
            else:
                eps_id[g] = 2
        episode_stats = [EpisodeStat() for i in range(len(games))]
        while steps_left > 0:
            for g, game in enumerate(games):
                if game.episode_terminate:
                    episode += 1
                    epoch_reward += game.episode_reward
                    if args.kv_type != None:
                        info_str = "Node[%d]: " % kv.rank
                    else:
                        info_str = ""
                    info_str += "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                                % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                                   ave_fps, (eps_curr[eps_id[g]]))
                    info_str += ", Avg Loss:%f" % ave_loss
                    if episode_stats[g].episode_action_step > 0:
                        info_str += ", Avg Q Value:%f/%d" % (
                            episode_stats[g].episode_q_value /
                            episode_stats[g].episode_action_step,
                            episode_stats[g].episode_action_step)
                    if g == 0: logging.info(info_str)
                    if eps_update_count[g] * eps_update_period < total_steps:
                        eps_rand = npy_rng.rand()
                        if eps_rand < 0.4:
                            eps_id[g] = 0
                        elif eps_rand < 0.7:
                            eps_id[g] = 1
                        else:
                            eps_id[g] = 2
                        eps_update_count[g] += 1
                    game.begin_episode(steps_left)
                    episode_stats[g] = EpisodeStat()

            if total_steps > history_length:
                for g, game in enumerate(games):
                    current_state = game.current_state()
                    states_buffer_for_act[g] = current_state

            states = nd.array(states_buffer_for_act, ctx=q_ctx) / float(255.0)

            qval_npy = qnet.forward(batch_size=nactor,
                                    data=states)[0].asnumpy()
            actions_that_max_q = numpy.argmax(qval_npy, axis=1)
            actions = [0] * nactor
            for g, game in enumerate(games):
                # 1. We need to choose a new action based on the current game status
                if games[g].state_enabled and games[
                        g].replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr[eps_id[g]])
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        action = actions_that_max_q[g]
                        episode_stats[g].episode_q_value += qval_npy[g, action]
                        episode_stats[g].episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)
                actions[g] = action
            # t0=time.time()
            for ret in parallel_executor.map(play_game, zip(games, actions)):
                pass
            # t1=time.time()
            # logging.info("play time: %f" % (t1-t0))
            eps_curr = numpy.maximum(eps_curr - eps_decay, eps_min)
            total_steps += 1
            steps_left -= 1
            if total_steps % 100 == 0:
                this_time = time.time()
                ave_fps = (100 / (this_time - time_for_info))
                time_for_info = this_time

            # 3. Update our Q network if we can start sampling from the replay memory
            #    Also, we update every `update_interval`
            if total_steps > minibatch_size and \
                total_steps % (param_update_period) == 0 and \
                games[-1].replay_memory.sample_enabled:
                if use_easgd and training_steps % kvstore_update_period == 0:
                    for paramIndex in range(len(qnet.params)):
                        k = qnet.params.keys()[paramIndex]
                        kv.pull(paramIndex,
                                central_weight[k],
                                priority=-paramIndex)
                        qnet.params[k][:] -= easgd_alpha * (qnet.params[k] -
                                                            central_weight[k])
                        kv.push(paramIndex,
                                qnet.params[k],
                                priority=-paramIndex)
                # 3.1 Draw sample from the replay_memory
                for g, game in enumerate(games):
                    episode_stats[g].episode_update_step += 1
                    nsample = single_batch_size
                    i0 = (g * nsample)
                    i1 = (g + 1) * nsample
                    if args.sample_policy == "recent":
                        action, reward, terminate_flag=game.replay_memory.sample_last(batch_size=nsample,\
                            states=states_buffer_for_train,offset=i0)
                    elif args.sample_policy == "random":
                        action, reward, terminate_flag=game.replay_memory.sample_inplace(batch_size=nsample,\
                            states=states_buffer_for_train,offset=i0)
                    actions_buffer_for_train[i0:i1] = action
                    rewards_buffer_for_train[i0:i1] = reward
                    terminate_flags_buffer_for_train[i0:i1] = terminate_flag
                states = nd.array(states_buffer_for_train[:, :-1],
                                  ctx=q_ctx) / float(255.0)
                next_states = nd.array(states_buffer_for_train[:, 1:],
                                       ctx=q_ctx) / float(255.0)
                actions = nd.array(actions_buffer_for_train, ctx=q_ctx)
                rewards = nd.array(rewards_buffer_for_train, ctx=q_ctx)
                terminate_flags = nd.array(terminate_flags_buffer_for_train,
                                           ctx=q_ctx)

                # 3.2 Use the target network to compute the scores and
                #     get the corresponding target rewards
                if not args.double_q:
                    target_qval = target_qnet.forward(
                        batch_size=minibatch_size, data=next_states)[0]
                    target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                            nd.argmax_channel(target_qval))\
                                       * (1.0 - terminate_flags) * discount
                else:
                    target_qval = target_qnet.forward(
                        batch_size=minibatch_size, data=next_states)[0]
                    qval = qnet.forward(batch_size=minibatch_size,
                                        data=next_states)[0]

                    target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                            nd.argmax_channel(qval))\
                                       * (1.0 - terminate_flags) * discount

                outputs = qnet.forward(batch_size=minibatch_size,
                                       is_train=True,
                                       data=states,
                                       dqn_action=actions,
                                       dqn_reward=target_rewards)
                qnet.backward(batch_size=minibatch_size)

                if args.kv_type is None or use_easgd:
                    qnet.update(updater=updater)
                else:
                    update_on_kvstore(kv, qnet.params, qnet.params_grad)

                # 3.3 Calculate Loss
                diff = nd.abs(
                    nd.choose_element_0index(outputs[0], actions) -
                    target_rewards)
                quadratic_part = nd.clip(diff, -1, 1)
                loss = (0.5 * nd.sum(nd.square(quadratic_part)) +
                        nd.sum(diff - quadratic_part)).asscalar()
                if ave_loss == 0:
                    ave_loss = loss
                else:
                    ave_loss = 0.95 * ave_loss + 0.05 * loss

                # 3.3 Update the target network every freeze_interval
                # (We can do annealing instead of hard copy)
                if training_steps % freeze_interval == 0:
                    qnet.copy_params_to(target_qnet)

                if args.optimizer == "rmsprop" or args.optimizer == "rmspropnoncentered":
                    optimizer.lr -= lr_decay

                if save_screens and training_steps % (
                        60 * 60 * 2 / param_update_period) == 0:
                    logging.info("saving screenshots")
                    for g in range(nactor):
                        screen = states_buffer_for_train[(
                            g * single_batch_size), -2, :, :].reshape(
                                states_buffer_for_train.shape[2:])
                        cv2.imwrite("screen_" + str(g) + ".png", screen)
                training_steps += 1

        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        if args.kv_type != None:
            logging.info(
                "Node[%d]: Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                (kv.rank, epoch, fps, epoch_reward / float(episode), episode))
        else:
            logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                         (epoch, fps, epoch_reward / float(episode), episode))
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('roms', 'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient',
                        required=False,
                        type=float,
                        default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q',
                        required=False,
                        type=bool,
                        default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd',
                        required=False,
                        type=float,
                        default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default='gpu',
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Saving directory of model files.')
    parser.add_argument(
        '--start-eps',
        required=False,
        type=float,
        default=1.0,
        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size',
                        required=False,
                        type=int,
                        default=50000,
                        help='The step that the training starts')
    parser.add_argument(
        '--kvstore-update-period',
        required=False,
        type=int,
        default=1,
        help='The period that the worker updates the parameters from the sever'
    )
    parser.add_argument(
        '--kv-type',
        required=False,
        type=str,
        default=None,
        help=
        'type of kvstore, default will not use kvstore, could also be dist_async'
    )
    parser.add_argument('--optimizer',
                        required=False,
                        type=str,
                        default="adagrad",
                        help='type of optimizer')
    args = parser.parse_args()

    ## custom
    args.start_eps = 0.2
    args.replay_start_size = 1000

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s-lr%g' % (rom_name, args.lr)
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84

    ctx = parse_ctx(args.ctx)
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom,
                     resize_mode='resize',
                     replay_start_size=replay_start_size,
                     resized_rows=rows,
                     resized_cols=cols,
                     max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size,
                     display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 30
    steps_per_epoch = 100000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - eps_min) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }
    dqn_sym = dqn_sym_nature(action_num)
    qnet = Base(data_shapes=data_shapes,
                sym_gen=dqn_sym,
                name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    use_easgd = False
    optimizer = mx.optimizer.create(name=args.optimizer,
                                    learning_rate=args.lr,
                                    eps=args.eps,
                                    clip_gradient=args.clip_gradient,
                                    rescale_grad=1.0,
                                    wd=args.wd)

    updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    for epoch in range(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        current_state = game.current_state()
                        state = nd.array(
                            current_state.reshape((1, ) + current_state.shape),
                            ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(is_train=False,
                                                data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states,
                                           ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                            nd.argmax_channel(target_qval)) \
                                                   * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        qval = qnet.forward(is_train=False,
                                            data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                            nd.argmax_channel(qval)) \
                                                   * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(is_train=True,
                                           data=states,
                                           dqn_action=actions,
                                           dqn_reward=target_rewards)
                    qnet.backward()
                    qnet.update(updater=updater)

                    # 3.3 Calculate Loss
                    diff = nd.abs(
                        nd.choose_element_0index(outputs[0], actions) -
                        target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = 0.5 * nd.sum(nd.square(quadratic_part)).asnumpy()[0] + \
                           nd.sum(diff - quadratic_part).asnumpy()[0]
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                       % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                          game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (
                    episode_loss / episode_update_step, episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (
                    episode_q_value / episode_action_step, episode_action_step)
            if episode % 100 == 0:
                logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        # qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                     (epoch, fps, epoch_reward / float(episode), episode))
Example #12
0
def main():
    parser = argparse.ArgumentParser(description='Script to test the trained network on a game.')
    parser.add_argument('-r', '--rom', required=False, type=str,
                        default=os.path.join('arena', 'games', 'roms', 'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v', '--visualization', required=False, type=int, default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr', required=False, type=float, default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps', required=False, type=float, default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient', required=False, type=float, default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q', required=False, type=bool, default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd', required=False, type=float, default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu',
                        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d', '--dir-path', required=False, type=str, default='',
                        help='Saving directory of model files.')
    parser.add_argument('--start-eps', required=False, type=float, default=1.0,
                        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size', required=False, type=int, default=50000,
                        help='The step that the training starts')
    parser.add_argument('--kvstore-update-period', required=False, type=int, default=1,
                        help='The period that the worker updates the parameters from the sever')
    parser.add_argument('--kv-type', required=False, type=str, default=None,
                        help='type of kvstore, default will not use kvstore, could also be dist_async')
    args, unknown = parser.parse_known_args()
    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s' % rom_name
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) >0 else (device, 0) for device, num in ctx]
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size,
                     resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size, display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - 0.1) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {'data': (minibatch_size, history_length) + (rows, cols),
                   'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)}
    #optimizer = mx.optimizer.create(name='sgd', learning_rate=args.lr,wd=args.wd)
    optimizer = mx.optimizer.Nop()
    dqn_output_op = DQNOutputNpyOp()
    dqn_sym = dqn_sym_nature(action_num, dqn_output_op)
    qnet = Base(data_shapes=data_shapes, sym=dqn_sym, name='QNet',
                  initializer=DQNInitializer(factor_type="in"),
                  ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)
    # Create kvstore
    testShape = (1,1686180*100)
    testParam = nd.ones(testShape,ctx=q_ctx)
    testGrad = nd.zeros(testShape,ctx=q_ctx)

    # Create kvstore

    if args.kv_type != None:
        kvType = args.kv_type
        kvStore = kvstore.create(kvType)
        #Initialize kvstore
        for idx,v in enumerate(qnet.params.values()):
            kvStore.init(idx,v);
        # Set optimizer on kvstore
        kvStore.set_optimizer(optimizer)
        kvstore_update_period = args.kvstore_update_period
    else:
        updater = mx.optimizer.get_updater(optimizer)

    # if args.kv_type != None:
    #     kvType = args.kv_type
    #     kvStore = kvstore.create(kvType)
    #     kvStore.init(0,testParam)
    #     testOptimizer = mx.optimizer.Nop()
    #     kvStore.set_optimizer(testOptimizer)
    #     kvstore_update_period = args.kvstore_update_period


    qnet.print_stat()
    target_qnet.print_stat()
    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    while(1):
        time_before_wait = time.time()

        # kvStore.push(0,testGrad,priority=0)
        # kvStore.pull(0,testParam,priority=0)
        # testParam.wait_to_read()

        for paramIndex in range(len(qnet.params)):#range(6):#
            k=qnet.params.keys()[paramIndex]
            kvStore.push(paramIndex,qnet.params_grad[k],priority=-paramIndex)
            kvStore.pull(paramIndex,qnet.params[k],priority=-paramIndex)

        for v in qnet.params.values():
            v.wait_to_read()
        logging.info("wait time %f" %(time.time()-time_before_wait))

    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(current_state.reshape((1,) + current_state.shape),
                                         ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(batch_size=1, data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states, ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        qval = qnet.forward(batch_size=minibatch_size, data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(batch_size=minibatch_size,is_train=True, data=states,
                                              dqn_action=actions,
                                              dqn_reward=target_rewards)
                    qnet.backward(batch_size=minibatch_size)
                    nd.waitall()
                    time_before_update = time.time()

                    if args.kv_type != None:
                        if total_steps % kvstore_update_period == 0:
                            update_to_kvstore(kvStore,qnet.params,qnet.params_grad)
                    else:
                        qnet.update(updater=updater)
                    logging.info("update time %f" %(time.time()-time_before_update))
                    time_before_wait = time.time()
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))

                    '''nd.waitall()
                    time_before_wait = time.time()
                    kvStore.push(0,testGrad,priority=0)
                    kvStore.pull(0,testParam,priority=0)
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))'''
                    # 3.3 Calculate Loss
                    diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = (0.5 * nd.sum(nd.square(quadratic_part)) + nd.sum(diff - quadratic_part)).asscalar()
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    # (We can do annealing instead of hard copy)
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step,
                                                  episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step,
                                                  episode_action_step)
            logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d"
                     % (epoch, fps, epoch_reward / float(episode), episode))
Example #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-envs', type=int, default=1)
    parser.add_argument('--t-max', type=int, default=1)
    parser.add_argument('--learning-rate', type=float, default=0.0002)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--steps-per-epoch', type=int, default=100000)
    parser.add_argument('--testing', type=int, default=0)
    parser.add_argument('--continue-training', type=int, default=0)
    parser.add_argument('--epoch-num', type=int, default=40)
    parser.add_argument('--start-epoch', type=int, default=20)
    parser.add_argument('--testing-epoch', type=int, default=3)
    parser.add_argument('--save-log', type=str, default='basic/log')
    parser.add_argument('--signal-num', type=int, default=4)
    parser.add_argument('--toxin', type=int, default=0)
    parser.add_argument('--a1-AC-folder', type=str, default='basic/a1_Qnet')
    parser.add_argument('--eps-start', type=float, default=1.0)
    parser.add_argument('--replay-start-size', type=int, default=50000)
    parser.add_argument('--decay-rate', type=int, default=500000)
    parser.add_argument('--replay-memory-size', type=int, default=1000000)
    parser.add_argument('--eps-min', type=float, default=0.05)

    rewards = {
        "positive": 1.0,
        "negative": -1.0,
        "tick": -0.002,
        "loss": -2.0,
        "win": 2.0
    }

    args = parser.parse_args()
    config = Config(args)
    q_ctx = config.ctx
    steps_per_epoch = args.steps_per_epoch
    np.random.seed(args.seed)
    start_epoch = args.start_epoch
    testing_epoch = args.testing_epoch
    save_log = args.save_log
    epoch_num = args.epoch_num
    epoch_range = range(epoch_num)
    toxin = args.toxin
    a1_Qnet_folder = args.a1_AC_folder

    freeze_interval = 10000
    update_interval = 5
    replay_memory_size = args.replay_memory_size
    discount = 0.99
    replay_start_size = args.replay_start_size
    history_length = 1
    eps_start = args.eps_start
    eps_min = args.eps_min
    eps_decay = (eps_start - eps_min) / args.decay_rate
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32

    testing = args.testing
    testing = True if testing == 1 else False
    continue_training = args.continue_training
    continue_training = True if continue_training == 1 else False

    game = HunterWorld(width=256,
                       height=256,
                       num_preys=10,
                       draw=False,
                       num_hunters=2,
                       num_toxins=toxin)

    env = PLE(game,
              fps=30,
              force_fps=True,
              display_screen=False,
              reward_values=rewards,
              resized_rows=80,
              resized_cols=80,
              num_steps=2)

    replay_memory = ReplayMemory(state_dim=(148, ),
                                 history_length=history_length,
                                 memory_size=replay_memory_size,
                                 replay_start_size=replay_start_size,
                                 state_dtype='float32')

    action_set = env.get_action_set()
    action_map = []
    for action1 in action_set[0].values():
        for action2 in action_set[1].values():
            action_map.append([action1, action2])
    action_map = np.array(action_map)
    action_num = action_map.shape[0]

    target1 = Qnetwork(actions_num=action_num,
                       q_ctx=q_ctx,
                       isTrain=False,
                       batch_size=1,
                       dir=dir,
                       folder=a1_Qnet_folder)
    target32 = Qnetwork(actions_num=action_num,
                        q_ctx=q_ctx,
                        isTrain=False,
                        batch_size=32,
                        dir=dir,
                        folder=a1_Qnet_folder)
    Qnet = Qnetwork(actions_num=action_num,
                    q_ctx=q_ctx,
                    isTrain=True,
                    batch_size=32,
                    dir=dir,
                    folder=a1_Qnet_folder)

    if testing:
        env.force_fps = False
        env.game.draw = True
        env.display_screen = True
        Qnet.load_params(testing_epoch)
    elif continue_training:
        epoch_range = range(start_epoch, epoch_num + start_epoch)
        Qnet.load_params(start_epoch - 1)
        logging_config(logging, dir, save_log, file_name)
    else:
        logging_config(logging, dir, save_log, file_name)

    copyTargetQNetwork(Qnet.model, target1.model)
    copyTargetQNetwork(Qnet.model, target32.model)

    logging.info('args=%s' % args)
    logging.info('config=%s' % config.__dict__)
    print_params(logging, Qnet.model)

    training_steps = 0
    total_steps = 0
    for epoch in epoch_range:
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        env.reset_game()
        while steps_left > 0:
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            episode_reward = 0
            episode_step = 0
            collisions = 0.0
            time_episode_start = time.time()
            env.reset_game()
            while not env.game_over():
                if replay_memory.size >= history_length and replay_memory.size > replay_start_size:
                    do_exploration = (np.random.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = np.random.randint(action_num)
                    else:
                        current_state = replay_memory.latest_slice()
                        state = nd.array(
                            current_state.reshape((1, ) + current_state.shape),
                            ctx=q_ctx)
                        target1.model.forward(mx.io.DataBatch([state], []))
                        q_value = target1.model.get_outputs()[0].asnumpy()[0]
                        action = numpy.argmax(q_value)
                        episode_q_value += q_value[action]
                        episode_action_step += 1
                else:
                    action = np.random.randint(action_num)

                next_ob, reward, terminal_flag = env.act(action_map[action])

                reward = np.sum(reward)
                replay_memory.append(
                    np.array(next_ob).flatten(), action, reward, terminal_flag)

                total_steps += 1
                episode_reward += reward
                if reward < 0:
                    collisions += 1
                episode_step += 1

                if total_steps % update_interval == 0 and replay_memory.size > replay_start_size:
                    training_steps += 1

                    state_batch, actions, rewards, nextstate_batch, terminate_flags = replay_memory.sample(
                        batch_size=minibatch_size)
                    state_batch = nd.array(state_batch, ctx=q_ctx)
                    actions_batch = nd.array(actions, ctx=q_ctx)
                    reward_batch = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    target32.model.forward(
                        mx.io.DataBatch([nd.array(nextstate_batch, ctx=q_ctx)],
                                        []))
                    Qvalue = target32.model.get_outputs()[0]

                    y_batch = reward_batch + nd.choose_element_0index(
                        Qvalue, nd.argmax_channel(Qvalue)) * (
                            1.0 - terminate_flags) * discount

                    Qnet.model.forward(mx.io.DataBatch(
                        [state_batch, actions_batch, y_batch], []),
                                       is_train=True)
                    Qnet.model.backward()
                    Qnet.model.update()

                    if training_steps % 10 == 0:
                        loss1 = 0.5 * nd.square(
                            nd.choose_element_0index(
                                Qnet.model.get_outputs()[0], actions_batch) -
                            y_batch)
                        episode_loss += nd.sum(loss1).asnumpy()
                        episode_update_step += 1

                    if training_steps % freeze_interval == 0:
                        copyTargetQNetwork(Qnet.model, target1.model)
                        copyTargetQNetwork(Qnet.model, target32.model)

            steps_left -= episode_step
            time_episode_end = time.time()
            epoch_reward += episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                       % (epoch, episode, steps_left, episode_step, steps_per_epoch, episode_reward,
                          episode_step / (time_episode_end - time_episode_start), eps_curr)

            info_str += ", Collision:%f/%d " % (collisions / episode_step,
                                                collisions)

            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (
                    episode_loss / episode_update_step, episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d " % (
                    episode_q_value / episode_action_step, episode_action_step)

            if episode % 1 == 0:
                logging.info(info_str)
                print info_str

        end = time.time()
        fps = steps_per_epoch / (end - start)
        Qnet.save_params(epoch)
        print "Epoch:%d, FPS:%f, Avg Reward: %f/%d" % (
            epoch, fps, epoch_reward / float(episode), episode)
Example #14
0
def main():
    parser = argparse.ArgumentParser(description='Script to test the trained network on a game.')
    parser.add_argument('-r', '--rom', required=False, type=str,
                        default=os.path.join('roms', 'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v', '--visualization', action='store_true',
                        help='Visualize the runs.')
    parser.add_argument('--lr', required=False, type=float, default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps', required=False, type=float, default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient', required=False, type=float, default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q', action='store_true',
                        help='Use Double DQN only if specified')
    parser.add_argument('--wd', required=False, type=float, default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu',
                        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d', '--dir-path', required=False, type=str, default='',
                        help='Saving directory of model files.')
    parser.add_argument('--start-eps', required=False, type=float, default=1.0,
                        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size', required=False, type=int, default=50000,
                        help='The step that the training starts')
    parser.add_argument('--kvstore-update-period', required=False, type=int, default=1,
                        help='The period that the worker updates the parameters from the sever')
    parser.add_argument('--kv-type', required=False, type=str, default=None,
                        help='type of kvstore, default will not use kvstore, could also be dist_async')
    parser.add_argument('--optimizer', required=False, type=str, default="adagrad",
                        help='type of optimizer')
    args = parser.parse_args()

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s-lr%g' % (rom_name, args.lr)
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84

    ctx = parse_ctx(args.ctx)
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size,
                     resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size, display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - eps_min) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {'data': (minibatch_size, history_length) + (rows, cols),
                   'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)}
    dqn_sym = dqn_sym_nature(action_num)
    qnet = Base(data_shapes=data_shapes, sym_gen=dqn_sym, name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    use_easgd = False
    optimizer = mx.optimizer.create(name=args.optimizer, learning_rate=args.lr, eps=args.eps,
                    clip_gradient=args.clip_gradient,
                    rescale_grad=1.0, wd=args.wd)
    updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    for epoch in range(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(current_state.reshape((1,) + current_state.shape),
                                         ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(is_train=False, data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states, ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(is_train=False, data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(is_train=False, data=next_states)[0]
                        qval = qnet.forward(is_train=False, data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(is_train=True,
                                           data=states,
                                           dqn_action=actions,
                                           dqn_reward=target_rewards)
                    qnet.backward()
                    qnet.update(updater=updater)

                    # 3.3 Calculate Loss
                    diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = 0.5 * nd.sum(nd.square(quadratic_part)).asnumpy()[0] +\
                           nd.sum(diff - quadratic_part).asnumpy()[0]
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step,
                                                  episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step,
                                                  episode_action_step)
            if episode % 100 == 0:
                logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d"
                 % (epoch, fps, epoch_reward / float(episode), episode))
Example #15
0
 def get_action(self, st):
     st = nd.expand_dims(st, axis=0)
     a_q = self.infer_q_mod.forward(is_train=False, data=st)
     a = nd.argmax_channel(a_q[0])
     return a