Пример #1
0
class Game_Master():
    
    def __init__(self,target_size,vocab_size,bias,test_data,seed,episodes=1000):
        self.data = test_data
        self.vocab_size = vocab_size
        self.target_size = target_size
        self.speaker_bias = 1.0 - bias
        self.listener_bias = bias
        self.speaker = Speaker(512,self.target_size,self.vocab_size,self.speaker_bias,seed)
        self.listener = Listener(512,self.vocab_size,self.target_size,self.listener_bias,seed)
        self.episode_rewards = []
        self.accuracy_history = deque(maxlen=100)
        self.symbols_used = {}
        self.synonym_array = np.zeros((target_size,vocab_size))
        self.accuracy_record = []
        self.episodes = 1000

    def create_labels(self,num_labels,num_repeats):
        arr = np.array([])
        for i in range(num_labels):
            n = np.repeat(i,num_repeats)
            arr = np.concatenate((arr,n))
        np.random.shuffle(arr)
        return arr    

    def sample(self,label,test_or_train="X_train"):
        train = self.data[test_or_train]
        size = train[0].shape[0]
        target_one_hot = to_categorical(label,num_classes=self.target_size)
        target_image = 0
        image_list = []
        for i in range(self.target_size):
            sampled_image = np.expand_dims(train[i][np.random.choice(size),:],0)
            if i == label:
                target_image = sampled_image
            image_list.append(sampled_image)
        return target_image, image_list, target_one_hot

    def cross_entropy(self,predictions, targets, epsilon=1e-12):
        predictions = np.clip(predictions, epsilon, 1. - epsilon)
        N = predictions.shape[0]
        ce = -np.sum(np.sum(targets*np.log(predictions+1e-9)))/N
        return ce

    def test(self):
        y = self.create_labels(10,1)
        one_hot = to_categorical(y)
        it = 0
        while it < y.shape[0]:
            label = y[it]
            target_image, image_list, target_one_hot = self.sample(label,"X_test")
            speaker_action, speaker_probs = self.speaker.act(target_image)
            speaker_action_class = np.expand_dims(to_categorical(speaker_action,num_classes=self.vocab_size),0)
            image_list.append(speaker_action_class)
            listener_action, listener_probs = self.listener.act(image_list)
            if listener_action == label:
                speaker_reward = 1.0
                listener_reward = 1.0
            else:
                speaker_reward =  0.0
                listener_reward = 0.0
            self.synonym_array[int(label)][int(speaker_action)] += 1
            it += 1

    def play(self):
        y = self.create_labels(10,100)
        one_hot = to_categorical(y)
        bar = tqdm(np.arange(self.episodes))
        for i in bar:
            it = 0
            acc = np.zeros(self.target_size)
            right = 0
            while it < y.shape[0]:
                label = y[it]
                target_image, image_list, target_one_hot = self.sample(label)
                speaker_action, speaker_probs = self.speaker.act(target_image)
                speaker_action_class = np.expand_dims(to_categorical(speaker_action,num_classes=self.vocab_size),0)
                image_list.append(speaker_action_class)
                listener_action, listener_probs = self.listener.act(image_list)
                if listener_action == label:
                    speaker_reward = 1.0
                    listener_reward = 1.0
                    right += 1
                else:
                    speaker_reward =  0.0
                    listener_reward = 0.0
                acc = np.vstack((acc,listener_probs))
                self.speaker.save(target_image,speaker_action,speaker_probs,speaker_reward)
                self.listener.save(image_list,speaker_action_class,listener_action,listener_probs,listener_reward)
                it += 1
            accuracy = right/y.shape[0]
            total = np.sum(self.speaker.params["rewards"])
            cross = self.cross_entropy(acc[1:,:],one_hot)
            self.episode_rewards.append(total)
            self.accuracy_history.append(accuracy)
            self.accuracy_record.append(accuracy)
            bar.set_description("Cross Entropy: " + str(cross) + ", Rolling Acc 100: " + str(np.mean(self.accuracy_history)) + ", Accuracy: " + str(accuracy))
            self.speaker.train()     
            self.listener.train()
        self.test()