コード例 #1
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testStringCommands(self):
     """Test komend pracujacych na stringach."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = "A0\"testowe pole\""
     self.assertEquals(t.transform(rules, word), "testowe poleaLaMakota123")
     
     rules = "Az\"testowe pole\""
     self.assertEquals(t.transform(rules, word), "aLaMakota123testowe pole")
コード例 #2
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testNoRule(self):
     """Test braku regul."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = 'H'
     try:
         t.transform(rules, word)
     except ValueError:
         return
     self.fail("Exception raisen")
コード例 #3
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testIDCommands(self):
     """Test komend wykonujacych operacje insert i delete."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = '['
     self.assertEquals(t.transform(rules, word), "LaMakota123")
     
     rules = ']'
     self.assertEquals(t.transform(rules, word), "aLaMakota12")
     
     rules = 'D5'
     self.assertEquals(t.transform(rules, word), "aLaMaota123")
     
     rules = 'x43'
     self.assertEquals(t.transform(rules, word), "ako")
     
     rules = 'i7H'
     self.assertEquals(t.transform(rules, word), "aLaMakoHta123")
     
     rules = 'o8Q'
     self.assertEquals(t.transform(rules, word), "aLaMakotQ123")
コード例 #4
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testLengthControlCommands(self):
     """Test komend kontroli dlugosci."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = '<5'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '<Z'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '>Z'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '>5'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '\'9'
     self.assertEquals(t.transform(rules, word), "aLaMakota")
コード例 #5
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testCharacterClassCommands(self):
     """Test komend bazujacych na klasach znakow."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = 'sab'
     self.assertEquals(t.transform(rules, word), "bLbMbkotb123")
     
     rules = 's?vH'
     self.assertEquals(t.transform(rules, word), "HLHMHkHtH123")
     
     rules = '@M'
     self.assertEquals(t.transform(rules, word), "aLaakota123")
     
     rules = '@?c'
     self.assertEquals(t.transform(rules, word), "aaaoa123")
     
     rules = '!1'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '!?v'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '/1'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '/?v'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '=0a'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '=A?d'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '(b'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '(?d'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = ')H'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = ')?v'
     self.assertEqual(t.transform(rules, word), None)
     
     rules = '%4a'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
     
     rules = '%3?d'
     self.assertEquals(t.transform(rules, word), "aLaMakota123")
コード例 #6
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testCharacterClasses(self):
     """Test klas znakow."""
     t = Transition()
     
     rules = '(??'
     word = "?"
     self.assertEquals(t.transform(rules, word), "?")
     
     rules = '(?v'
     word = "a"
     self.assertEquals(t.transform(rules, word), "a")
     
     rules = '(?c'
     word = "b"
     self.assertEquals(t.transform(rules, word), "b")
     
     rules = '(?w'
     word = " "
     self.assertEquals(t.transform(rules, word), " ")
     
     rules = '(?p'
     word = ";"
     self.assertEquals(t.transform(rules, word), ";")
     
     rules = '(?s'
     word = "*"
     self.assertEquals(t.transform(rules, word), "*")
     
     rules = '(?l'
     word = "v"
     self.assertEquals(t.transform(rules, word), "v")
     
     rules = '(?u'
     word = "P"
     self.assertEquals(t.transform(rules, word), "P")
     
     rules = '(?d'
     word = "6"
     self.assertEquals(t.transform(rules, word), "6")
     
     rules = '(?a'
     word = "s"
     self.assertEquals(t.transform(rules, word), "s")
     
     rules = '(?x'
     word = "3"
     self.assertEquals(t.transform(rules, word), "3")
     
     rules = '(?z'
     word = " "
     self.assertEquals(t.transform(rules, word), " ")
コード例 #7
0
ファイル: TransitionTest.py プロジェクト: gotian/rjohn
 def testSimpleCommands(self):
     """Test prostych komend."""
     word = "aLaMakota123"
     t = Transition()
     
     rules = ':'
     self.assertEquals(t.transform(rules, word), word)
     
     rules = 'c'
     self.assertEquals(t.transform(rules, word), "Alamakota123")
     
     rules = 'l'
     self.assertEquals(t.transform(rules, word), "alamakota123")
     
     rules = 'u'
     self.assertEquals(t.transform(rules, word), "ALAMAKOTA123")
     
     rules = 'C'
     self.assertEquals(t.transform(rules, word), "aLAMAKOTA123")
     
     rules = 't'
     self.assertEquals(t.transform(rules, word), "AlAmAKOTA123")
     
     rules = 't0'
     self.assertEquals(t.transform(rules, word), "ALaMakota123")
     
     rules = 'r'
     self.assertEquals(t.transform(rules, word), "321atokaMaLa")
     
     rules = 'd'
     self.assertEquals(t.transform(rules, word), "aLaMakota123aLaMakota123")
     
     rules = '$q'
     self.assertEquals(t.transform(rules, word), "aLaMakota123q")
     
     rules = '^q'
     self.assertEquals(t.transform(rules, word), "qaLaMakota123")
     
     rules = '{'
     self.assertEquals(t.transform(rules, word), "LaMakota123a")
     
     rules = '}'
     self.assertEquals(t.transform(rules, word), "3aLaMakota12")
     
     rules = 'f'
     self.assertEquals(t.transform(rules, word), "aLaMakota123321atokaMaLa")
コード例 #8
0
    def replay(self):
        """Experience Replayでネットワークの重みを学習 """

        # Do nothing while size of memory is lower than batch size
        if len(self.memory) < self.batch_size:
            return

        # Extract datasets and their corresponding indices from memory
        transitions, indexes = self.memory.sample(self.batch_size)

        # ミニバッチの作成-----------------

        # transitionsは1stepごとの(state, action, next_state, reward)が、self.batch_size分格納されている
        # つまり、(state, action, next_state, reward)×self.batch_size
        # これをミニバッチにしたい。つまり
        # (state×self.batch_size, action×BATCH_SIZE, next_state, reward×BATCH_SIZE)にする
        batch = Transition(*zip(*transitions))
        batch_state = State(*zip(*batch.state))
        batch_next_state = State(*zip(*batch.next_state))

        # cartpoleがdoneになっておらず、next_stateがあるかをチェックするマスクを作成
        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch.next_state)),
                                      dtype=torch.bool).to(self.device)

        # バッチから状態、行動、報酬を格納(non_finalはdoneになっていないstate)
        # catはConcatenates(結合)のことです。
        # 例えばstateの場合、[torch.FloatTensor of size 1x4]がself.batch_size分並んでいるのですが、
        # それを size self.batch_sizex4 に変換します
        pose_batch = Variable(torch.cat(batch_state.pose)).to(self.device)
        lidar_batch = Variable(torch.cat(batch_state.lidar)).to(self.device)
        image_batch = Variable(torch.cat(batch_state.image)).to(self.device)
        mask_batch = Variable(torch.cat(batch_state.mask)).to(self.device)

        action_batch = Variable(torch.cat(batch.action)).to(self.device)
        reward_batch = Variable(torch.cat(batch.reward)).to(self.device)

        non_final_next_poses = Variable(
            torch.cat([s for s in batch_next_state.pose
                       if s is not None])).to(self.device)
        non_final_next_lidars = Variable(
            torch.cat([s for s in batch_next_state.lidar
                       if s is not None])).to(self.device)
        non_final_next_images = Variable(
            torch.cat([s for s in batch_next_state.image
                       if s is not None])).to(self.device)
        non_final_next_masks = Variable(
            torch.cat([s for s in batch_next_state.mask
                       if s is not None])).to(self.device)

        # ミニバッチの作成終了------------------

        # ネットワークを推論モードに切り替える
        self.policy_net.eval()

        # Q(s_t, a_t)を求める
        # self.policy_net(state_batch)は、[torch.FloatTensor of size self.batch_sizex2]になっており、
        # 実行したアクションに対応する[torch.FloatTensor of size self.batch_sizex1]にするために
        # gatherを使用します。
        state_action_values = self.policy_net(pose_batch, lidar_batch,
                                              image_batch, mask_batch).gather(
                                                  1, action_batch)

        # max{Q(s_t+1, a)}値を求める。
        # 次の状態がない場合は0にしておく
        next_state_values = Variable(
            torch.zeros(self.batch_size).type(torch.FloatTensor)).to(
                self.device)

        # double dqn part
        a_m = Variable(torch.zeros(self.batch_size).type(torch.LongTensor)).to(
            self.device)
        a_m[non_final_mask] = self.policy_net(
            non_final_next_poses, non_final_next_lidars, non_final_next_images,
            non_final_next_masks).max(1)[1].detach()

        a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)

        # 次の状態がある場合の値を求める
        # 出力であるdataにアクセスし、max(1)で列方向の最大値の[値、index]を求めます
        # そしてその値(index=0)を出力します
        next_state_values[non_final_mask] = self.target_net(
            non_final_next_poses, non_final_next_lidars,
            non_final_next_images, non_final_next_masks).gather(
                1, a_m_non_final_next_states).detach().squeeze()

        # 教師となるQ(s_t, a_t)値を求める
        expected_state_action_values = reward_batch + self.gamma * next_state_values
        expected_state_action_values = expected_state_action_values.unsqueeze(
            1)

        # ネットワークを訓練モードに切り替える
        self.policy_net.train()  # TODO: No need?

        # 損失関数を計算する。smooth_l1_lossはHuberlossです
        loss = F.smooth_l1_loss(state_action_values,
                                expected_state_action_values)

        # ネットワークを更新します
        self.optimizer.zero_grad()  # 勾配をリセット
        loss.backward()  # バックプロパゲーションを計算
        self.optimizer.step()  # 結合パラメータを更新

        # Update priority
        if self.prioritized and indexes != None:
            for i, val in enumerate(state_action_values):
                td_err = abs(expected_state_action_values[i].item() -
                             val.item())
                self.memory.update(indexes[i], td_err)