예제 #1
0
    def reset(self, deterministic_input=False, deterministic_pair=None):
        if deterministic_input:
            pairs = deterministic_pair
        else:
            pairs = random.choice(self.train_data)
        src_plain = pairs[0]

        src_id = self.encoder.input_ids[src_plain][0]
        target_id = self.encoder.input_ids[src_plain][1]
        input_length = len(src_id)
        target_length = len(target_id)

        encoder_hidden = self.encoder.initHidden()
        encoder_outputs = torch.zeros(self.env_max_step, self.encoder.hidden_size, device=device)

        src_id = torch.tensor(src_id).to(device)
        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(src_id[ei],encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        encoder_padded = torch.zeros(1, self.env_max_step, self.decoder.hidden_size)
        encoder_padded[:,:len(encoder_outputs),:] = encoder_outputs
        
        decoder_hidden = encoder_hidden

        obs = [None]*5
        obs[0] = src_plain
        obs[1] = ptu.to_numpy(encoder_padded.detach()) # detach here so grad won't propgate to env
        obs[2] = ptu.to_numpy(decoder_hidden.detach()) # detach here so grad won't propgate to env
        obs[3] = np.array([SOS_token])
        obs[4] = 0
        return obs
예제 #2
0
    def step(self, observation, action):
        # observation is [src plain, encoder padded, decoder hidden, curr_input, curr_index]
        # action is the action distribution
        done = False
        target_id = self.input_ids[observation[0]][1]
        curr_index = observation[4]

        prev_hidden = torch.from_numpy(observation[2])
        encoder_padded = torch.from_numpy(observation[1])

        action_cur = torch.tensor([[action[0]]]).to(self.device)
        prev_hidden = prev_hidden.to(self.device)
        decoded_result = self.decoder(action_cur, prev_hidden, encoder_padded)
        next_hidden = ptu.to_numpy(decoded_result[1].detach())

        # the reward can't be too small, otherwise no signal
        # the reward can't be too large, otherwise will only learn little to be satisified
        # a reward of x means that 1 correct prediction will be killed by x incorrect predictions
        if (action == target_id[curr_index]):
            reward = 10  #5/(((abs(l)**3)+1e-5) + 0.05)
        else:
            reward = -1  #
        assert len(target_id) == len(observation[0]), observation[0]
        if curr_index + 1 == len(target_id):
            done = True
            next_observation = []
        else:
            next_observation = [
                observation[0], observation[1], next_hidden, action,
                observation[4] + 1
            ]

        return next_observation, reward, done
예제 #3
0
    def step(self, observation, action):
        # observation is [src plain, encoder padded, decoder hidden, curr_input, curr_index]
        # action is the action distribution
        if(type(action)==list):
            ac, q = action[0], action[1]
        else:
            ac, q = action, [1]
        done = False
        src_id = self.input_ids[observation[0]][0]
        target_id = self.input_ids[observation[0]][1]
        curr_index = observation[4]
        
        prev_hidden = torch.from_numpy(observation[2])
        encoder_padded = torch.from_numpy(observation[1])

        action_cur = torch.tensor([[ac[0]]]).to(self.device)
        prev_hidden = prev_hidden.to(self.device)
        decoded_result = self.decoder(action_cur, prev_hidden, encoder_padded)
        next_hidden = ptu.to_numpy(decoded_result[1].detach())

        # the reward can't be too small, otherwise no signal
        # the reward can't be too large, otherwise will only learn little to be satisified
        # a reward of x means that 1 correct prediction will be killed by x incorrect predictions
        """if (ac == target_id[curr_index]):
            q_sum = sum(q)
            q_max = max(q)
            # the more unsure you are the more i encourage you for making that decision
            if (src_id[curr_index] != target_id[curr_index]):
                # this is a true edit-->boost low confidence
                reward = 10*len(q)/(q_max/q_sum)
            else:
                # this is correctly not editting, recognizing that is hard! boost!
                reward = 20*len(q)/(q_max/q_sum) #5/(((abs(l)**3)+1e-5) + 0.05)
            #if(type(action)==list):
            #    print("correct:",str(q_max*100/q_sum))
        else:
            q_sum = sum(q)
            q_max = max(q)
            # the more sure you are the more i discourage you for making that decision
            if (src_id[curr_index] != target_id[curr_index]):
                # this should be editted but yielded wrong edits
                reward = -10*(q_max/q_sum)
            else:
                # this should not be editted yet you did, severe punishment
                reward = -50 # to punish the fact that we made way too many edits when unnecessary
            #if(type(action)==list):
            #    print("wrong:",str(q_max*100/q_sum))
        #print(self.lang.index2word[target_id[curr_index]], self.lang.index2word[src_id[curr_index]])"""
        if (action == target_id[curr_index]):
            reward = 10 #5/(((abs(l)**3)+1e-5) + 0.05)
        else:
            reward = -1 #
        assert len(target_id) == len(observation[0]), observation[0]
        if curr_index + 1 == len(target_id):
            done = True
            next_observation = []
        else:
            next_observation = [observation[0], observation[1], next_hidden, ac, observation[4]+1]
        return next_observation, reward, done
예제 #4
0
    def estimate_advantage(self, obs, q_values):
        baselines_unnormalized = ptu.to_numpy(self.actor.get_baseline(obs))
        baselines = baselines_unnormalized * np.std(q_values) + np.mean(q_values)
        
        advantages = q_values - baselines

        #advantages = q_values.copy()
        return advantages
예제 #5
0
    def get_action(self, ob):
        ob = np.array(ob, dtype=object).reshape(-1,5)
        action_distribution, prob = self.get_action_distribution(ob)
        batch_size = len(ob)

        next_pos_to_predict = ob[:,4].tolist()
        src = ob[:,0].tolist()
        next_id_in_src = [self.lang.word2index[src[i][next_pos_to_predict[i]]] for i in range(batch_size)]
        easily_confused = [self.lang.correct_confused[id] for id in next_id_in_src]

        #action = action_distribution.sample()  # don't bother with rsample
        #taking the most probable action in its respective confusion set
        prob_of_interest = [(easily_confused[i], prob[i,easily_confused[i]].squeeze()) for i in range(batch_size)]
        action = torch.tensor([[prob_of_interest[i][0][torch.argmax(prob_of_interest[i][1])]] for i in range(batch_size)])
        return ptu.to_numpy(action)
예제 #6
0
 def qa_values(self, obs):
     qa_values = self.q_net(obs)
     return ptu.to_numpy(qa_values)