class Game_Master(): def __init__(self,target_size,vocab_size,bias,test_data,seed,episodes=1000): self.data = test_data self.vocab_size = vocab_size self.target_size = target_size self.speaker_bias = 1.0 - bias self.listener_bias = bias self.speaker = Speaker(512,self.target_size,self.vocab_size,self.speaker_bias,seed) self.listener = Listener(512,self.vocab_size,self.target_size,self.listener_bias,seed) self.episode_rewards = [] self.accuracy_history = deque(maxlen=100) self.symbols_used = {} self.synonym_array = np.zeros((target_size,vocab_size)) self.accuracy_record = [] self.episodes = 1000 def create_labels(self,num_labels,num_repeats): arr = np.array([]) for i in range(num_labels): n = np.repeat(i,num_repeats) arr = np.concatenate((arr,n)) np.random.shuffle(arr) return arr def sample(self,label,test_or_train="X_train"): train = self.data[test_or_train] size = train[0].shape[0] target_one_hot = to_categorical(label,num_classes=self.target_size) target_image = 0 image_list = [] for i in range(self.target_size): sampled_image = np.expand_dims(train[i][np.random.choice(size),:],0) if i == label: target_image = sampled_image image_list.append(sampled_image) return target_image, image_list, target_one_hot def cross_entropy(self,predictions, targets, epsilon=1e-12): predictions = np.clip(predictions, epsilon, 1. - epsilon) N = predictions.shape[0] ce = -np.sum(np.sum(targets*np.log(predictions+1e-9)))/N return ce def test(self): y = self.create_labels(10,1) one_hot = to_categorical(y) it = 0 while it < y.shape[0]: label = y[it] target_image, image_list, target_one_hot = self.sample(label,"X_test") speaker_action, speaker_probs = self.speaker.act(target_image) speaker_action_class = np.expand_dims(to_categorical(speaker_action,num_classes=self.vocab_size),0) image_list.append(speaker_action_class) listener_action, listener_probs = self.listener.act(image_list) if listener_action == label: speaker_reward = 1.0 listener_reward = 1.0 else: speaker_reward = 0.0 listener_reward = 0.0 self.synonym_array[int(label)][int(speaker_action)] += 1 it += 1 def play(self): y = self.create_labels(10,100) one_hot = to_categorical(y) bar = tqdm(np.arange(self.episodes)) for i in bar: it = 0 acc = np.zeros(self.target_size) right = 0 while it < y.shape[0]: label = y[it] target_image, image_list, target_one_hot = self.sample(label) speaker_action, speaker_probs = self.speaker.act(target_image) speaker_action_class = np.expand_dims(to_categorical(speaker_action,num_classes=self.vocab_size),0) image_list.append(speaker_action_class) listener_action, listener_probs = self.listener.act(image_list) if listener_action == label: speaker_reward = 1.0 listener_reward = 1.0 right += 1 else: speaker_reward = 0.0 listener_reward = 0.0 acc = np.vstack((acc,listener_probs)) self.speaker.save(target_image,speaker_action,speaker_probs,speaker_reward) self.listener.save(image_list,speaker_action_class,listener_action,listener_probs,listener_reward) it += 1 accuracy = right/y.shape[0] total = np.sum(self.speaker.params["rewards"]) cross = self.cross_entropy(acc[1:,:],one_hot) self.episode_rewards.append(total) self.accuracy_history.append(accuracy) self.accuracy_record.append(accuracy) bar.set_description("Cross Entropy: " + str(cross) + ", Rolling Acc 100: " + str(np.mean(self.accuracy_history)) + ", Accuracy: " + str(accuracy)) self.speaker.train() self.listener.train() self.test()