コード例 #1
0
ファイル: VIN.py プロジェクト: jhare96/reinforcement-learning
    def backprop(self, states, locs, R, actions):
        x, y = zip(*locs)
        Qsa = self.forward(totorch(states, self.device),
                           torch.tensor(x).to(self.device),
                           torch.tensor(y)).to(self.device)
        actions_onehot = totorch(one_hot(actions, self.action_size),
                                 self.device)
        Qvalue = torch.sum(Qsa * actions_onehot, axis=1)
        loss = torch.mean(torch.square(totorch(R).float().cuda() - Qvalue))

        loss.backward()
        self.optim.step()
        self.optim.zero_grad()
        return loss.detach().cpu().numpy()
コード例 #2
0
ファイル: A3C.py プロジェクト: jhare96/reinforcement-learning
def train(global_model, model, env, nsteps, num_episodes, ID):
    opt = torch.optim.RMSprop(global_model.parameters(), lr=1e-3)
    episode = 0
    episode_steps = 0
    episode_score = 0
    T = 0
    state = env.reset()
    start = time.time()
    while episode < num_episodes:
        rollout = []
        for t in range(nsteps):
            with torch.no_grad():
                policy, value = model(totorch(state[None], device='cpu'))
                policy, value = tonumpy(policy), tonumpy(value)
            action = np.random.choice(policy.shape[1], p=policy[0])
            next_state, reward, done, info = env.step(action)
            episode_score += reward
            rollout.append((state, action, reward, value, done))
            state = next_state

            T += 1
            episode_steps += 1

            if done or t == nsteps-1:
                states, actions, rewards, values, dones = stack_many(*zip(*rollout))
                with torch.no_grad():
                    _, last_values = model.forward(totorch(next_state[None], device='cpu'))
                    last_values = last_values.cpu().numpy()
                

                    R = lambda_return(rewards, values, last_values, dones, gamma=0.9, lambda_=0.95, clip=False)
                
                loss = update_params(model, global_model, opt, states, actions, R)
                
                #self.T += t

                if done:
                    episode += 1
                    state = env.reset()
                    if episode % 1 == 0:
                        time_taken = time.time() - start 
                        print(f'worker {ID}, total worker steps {T:,} local episode {episode}, episode score {episode_score} episode steps {episode_steps}, time taken {time_taken:,.1f}s, fps {episode_steps/time_taken:.2f}')
                    episode_steps = 0
                    episode_score = 0
                    start = time.time()
                    break
コード例 #3
0
ファイル: VIN.py プロジェクト: jhare96/reinforcement-learning
 def eval_state(self, state, loc):
     with torch.no_grad():
         x, y = zip(*loc)
         x, y = torch.tensor(x).to(self.device), torch.tensor(y).to(
             self.device)
         state_torch = totorch(state, self.device)
         Qsa = self.model(state_torch, x, y)
     return tonumpy(Qsa)
コード例 #4
0
 def evaluate(self,
              state: np.ndarray,
              hidden: np.ndarray = None,
              done=None):
     state = totorch(state, self.device)
     hidden = totorch_many(
         *hidden, device=self.device) if hidden is not None else None
     with torch.no_grad():
         policy, value, hidden = self.forward(state, hidden, done)
     return tonumpy(policy), tonumpy(value), tonumpy_many(*hidden)
コード例 #5
0
ファイル: A3C.py プロジェクト: jhare96/reinforcement-learning
def update_params(lm, gm, gopt, states, actions, R):
    states, R, actions = totorch(states, 'cpu'), totorch(R, 'cpu'), totorch(actions, 'cpu')
    actions_onehot = F.one_hot(actions.long(), num_classes=lm.action_size)
    policies, values = lm.forward(states)
    loss = lm.loss(policies, R, values, actions_onehot)

    loss.backward()

    if lm.grad_clip is not None:
        torch.nn.utils.clip_grad_norm_(lm.parameters(), lm.grad_clip)
    
    for local_param, global_param in zip(lm.parameters(), gm.parameters()):
        global_param._grad = local_param.grad
    
    gopt.step()
    gopt.zero_grad()
    #self.scheduler.step()

    lm.load_state_dict(gm.state_dict())
    return loss.detach().cpu().numpy()
コード例 #6
0
 def get_pixel_control(self, state:np.ndarray):
     with torch.no_grad():
         enc_state = self.policy.model(totorch(state, self.device))
         Qaux = self.Qaux(enc_state)
     return tonumpy(Qaux)
コード例 #7
0
 def evaluate(self, state):
     with torch.no_grad():
         policy, value_extr, value_intr = self.forward(
             totorch(state, self.device))
     return tonumpy(policy), tonumpy(value_extr), tonumpy(value_intr)
コード例 #8
0
 def evaluate(self, state: np.ndarray):
     with torch.no_grad():
         policy, _ = self.policy.forward(totorch(state, self.policy.device))
         value = self.value.forward(totorch(state, self.value.device))
     return tonumpy(policy), tonumpy(value)
コード例 #9
0
 def get_value(self, state: np.ndarray):
     with torch.no_grad():
         value = self.value.forward(totorch(state, self.value.device))
     return tonumpy(value)
コード例 #10
0
 def get_policy(self, state: np.ndarray):
     with torch.no_grad():
         policy, Adv = self.policy.forward(
             totorch(state, self.policy.device))
     return tonumpy(policy), tonumpy(Adv)
コード例 #11
0
 def evaluate(self, state: np.ndarray):
     state = totorch(state, self.device)
     with torch.no_grad():
         policy, value = self.forward(state)
     return tonumpy(policy), tonumpy(value)
コード例 #12
0
 def evaluate(self, state):
     with torch.no_grad():
         Qsa = self.forward(totorch(state, self.device))
     return Qsa.cpu().numpy()
コード例 #13
0
 def get_pixel_control(self, state:np.ndarray, action_reward, hidden):
     state, action_reward, hidden = totorch(state, self.device), totorch(action_reward, self.device), totorch_many(*hidden, device=self.device)
     with torch.no_grad():
         lstm_state, _ = self.policy.lstm_forward(state, action_reward, hidden, done=None)
         Qaux = self.Qaux(lstm_state)
     return tonumpy(Qaux)