def augmentAndMix(x_orig, k, alpha, preprocess): # k : number of chains # alpha : sampling constant x_temp = x_orig # back up for skip connection x_aug = torch.zeros_like(preprocess(x_orig)) mixing_weight_dist = Dirichlet(torch.empty(k).fill_(alpha)) mixing_weights = mixing_weight_dist.sample() for i in range(k): sampled_augs = random.sample(augmentations, k) aug_chain_length = random.choice(range(1, k + 1)) aug_chain = sampled_augs[:aug_chain_length] for aug in aug_chain: severity = random.choice(range(1, 6)) x_temp = aug(x_temp, severity) x_aug += mixing_weights[i] * preprocess(x_temp) skip_conn_weight_dist = Beta(torch.tensor([alpha]), torch.tensor([alpha])) skip_conn_weight = skip_conn_weight_dist.sample() x_augmix = skip_conn_weight * x_aug + ( 1 - skip_conn_weight) * preprocess(x_orig) return x_augmix
def get_best_move(self, s, v, rm=None): proba = torch.tensor([v[a] for a in s.legal_moves]) # pylint: disable=E proba = F.softmax(self.c.eval_move * proba, dim=0).numpy() i_max_no_noise = proba.argmax() if self.add_noise: dir_dist = Dirichlet(torch.zeros(len(s.legal_moves)) + self.c.alpha_dir) noise = dir_dist.sample().numpy() proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise # Best move i_max = proba.argmax() best_move = s.legal_moves[i_max] # For RunManager if rm: proba_dictate = int(i_max_no_noise == i_max) rm.proba( proba[i_max], self.c.eps_dir * noise[i_max], proba_dictate ) return best_move
def get_best_move(self, s, v, rm=None): # Compute the indices of the legal moves in the tensor v. legal_mask = torch.zeros(v.shape, dtype=torch.bool) for a in s.legal_moves: legal_mask[encoding.a_id(a)] = True # Compute the probabilities of each legal moves. proba = torch.from_numpy(v[legal_mask]) proba = F.softmax(self.c.eval_move * proba, dim=0).numpy() i_max_no_noise = proba.argmax() # Add noise if so. if self.add_noise: dir_dist = Dirichlet( torch.zeros(len(s.legal_moves)) + self.c.alpha_dir) noise = dir_dist.sample().numpy() proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise # Best move i_max = proba.argmax() best_move = s.legal_moves[i_max] # For RunManager if rm: best_move_code = np.ravel_multi_index(encoding.a_id(best_move), v.shape) v_dictate = int(v.argmax() == best_move_code) proba_dictate = int(i_max_no_noise == i_max) rm.proba(v.max(), proba[i_max], self.c.eps_dir * noise[i_max], v[legal_mask.logical_not()].max().item(), v_dictate, proba_dictate) return best_move
def max_ucb_noise_node(self, n): Nc = len(n.children) dir_dist = Dirichlet(torch.zeros(Nc) + self.alpha_dir) noises = dir_dist.sample().numpy() i_max = max(range(Nc), key=lambda i: self.ucb_noise(n.children[i], noises[i])) return n.children[i_max]
def sample(self, labels, max_length, sos_id, scale=1): lab_emb = self.label_lookup(labels).unsqueeze(0) mu, logvar = self.z_prior(lab_emb) z = self.reparameterize(mu, logvar) if scale != 1: z = mu + (z - mu) * scale alphas = self.topic_prior(torch.cat([z, lab_emb], dim=2)) dist = Dirichlet(alphas.cpu()) topics = dist.sample().to(alphas.device) return self.generate(z, topics, lab_emb, max_length, sos_id)
def sample(self, num_samples, max_length, sos_id, device): """Randomly sample latent code to sample texts. Note that num_samples should not be too large. """ z_size = self.fcmu.out_features z = torch.randn(1, num_samples, z_size, device=device) alphas = self.topic_prior(z) dist = Dirichlet(alphas.cpu()) topics = dist.sample().to(device) return self.generate(z, topics, max_length, sos_id)
class DirichletSkillTanhGaussianPolicy(SkillTanhGaussianPolicy): def __init__( self, hidden_sizes, obs_dim, action_dim, std=None, init_w=1e-3, skill_dim=10, gamma=1e-3, **kwargs ): super().__init__( hidden_sizes=hidden_sizes, obs_dim=obs_dim, action_dim=action_dim, std=std, init_w=init_w, skill_dim=skill_dim, **kwargs ) self.gamma = gamma self.skill_space = Dirichlet(torch.ones(self.skill_dim)) self.skill = self.skill_space.sample().cpu().numpy() def get_action(self, obs_np, deterministic=False): # generate (skill_dim, ) matrix that stacks one-hot skill vectors # online reinforcement learning obs_np = np.concatenate((obs_np, self.skill), axis=0) actions = self.get_actions(obs_np[None], deterministic=deterministic) return actions[0, :], {"skill": self.skill} def skill_reset(self): self.skill = self.skill_space.sample().cpu().numpy() def alpha_update(self, epoch, tau): d_alpha = min(self.gamma+(1-self.gamma)*epoch/tau, 1) * torch.ones(self.skill_dim).cpu() self.skill_space = Dirichlet(torch.tensor(d_alpha)) def alpha_reset(self): self.skill_space = Dirichlet(torch.ones(self.skill_dim))
def plot_dir(alpha, size): model = Dirichlet(torch.tensor(alpha)) sample = model.sample(torch.Size([size])).data fig = plt.figure() ax = plt.axes(projection='3d') ax.scatter3D(sample[:, 0], sample[:, 1], sample[:, 2], color='red') ax.plot([0, 0], [1, 0], [0, 1], linewidth=3, color='purple') ax.plot([0, 1], [0, 0], [1, 0], linewidth=3, color='purple') ax.plot([0, 1], [1, 0], [0, 0], linewidth=3, color='purple') ax.set_xlim((0, 1)) ax.set_ylim((0, 1)) ax.set_zlim((0, 1)) ax.view_init(60, 35)
def _sample_volume_alphas(self, n_related): if self.uniform_volumes: u = Uniform(0.25, 1.25) return u.sample().repeat(n_related) if isinstance(self.concentration, (float, int)): concentration = self.concentration else: concentration = self.concentration.rvs() dirichlet = Dirichlet( torch.tensor([concentration for _ in range(n_related)])) if self.random_seed is not None: torch.manual_seed(self.random_seed) return dirichlet.sample() * float(self.n_classes)
class AugMix(nn.Module): def __init__(self, k=3, alpha=1, severity=3): super(AugMix, self).__init__() self.k = k self.alpha = alpha self.severity = severity self.dirichlet = Dirichlet(torch.full(torch.Size([k]), alpha, dtype=torch.float32)) self.beta = Beta(alpha, alpha) self.augs = augmentations self.kl = nn.KLDivLoss(reduction='batchmean') def augment_and_mix(self, images, preprocess): ''' Args: images: PIL Image preprocess: transform[ToTensor, Normalize] Returns: AugmentAndMix Tensor ''' mix = torch.zeros_like(preprocess(images)) w = self.dirichlet.sample() for i in range(self.k): aug = images.copy() depth = np.random.randint(1, 4) for _ in range(depth): op = np.random.choice(self.augs) aug = op(aug, 3) mix = mix + w[i] * preprocess(aug) m = self.beta.sample() augmix = m * preprocess(images) + (1 - m) * mix return augmix def jensen_shannon(self, logits_o, logits_1, logits_2): p_o = F.softmax(logits_o, dim=1) p_1 = F.softmax(logits_1, dim=1) p_2 = F.softmax(logits_2, dim=1) # kl(q.log(), p) -> KL(p, q) M = torch.clamp((p_o + p_1 + p_2) / 3, 1e-7, 1) # to avoid exploding js = (self.kl(M.log(), p_o) + self.kl(M.log(), p_1) + self.kl(M.log(), p_2)) / 3 return js
def one_hot_dirichlet(self, y): try: y = th.from_numpy(y) except TypeError: None y_1d = y y_hot = th.zeros((y.size(0), th.max(y).int()+1)) y_dir = th.zeros((y.size(0), th.max(y).int()+1)) for i in range(y.size(0)): y_hot[i, y_1d[i].int()] = 10000 for i in range(y.size(0)): m = Dirichlet(y_hot[i] + 100) y_dir[i] = m.sample() return y_dir
def forward(self, inputs): """ get embedding of support and query :param inputs: support: (N_way*N_shot)x3x84x84 s_labels: (N_way*N_shot)xN_way, one-hot [25, 5] query: (N_way*N_query)x3x84x84 q_labels: (N_way*N_query)xN_way, one-hot :return: emb_support : (1, Nc, 64*5*5) emb_query : (Nq, 1, 64*5*5) """ ## process inputs self.num_classes, self.num_support, self.num_queries, self.concat_s_q = preprocess_input(inputs) self.q_labels = inputs[-1] ## encoding part emb = self.encoder(self.concat_s_q) # emb shape:(100*64*5*5) # emb_s:(Ns*64*5*5), emb_q:(Nq*64*5*5) emb_s, emb_q = torch.split(emb, [self.num_classes*self.num_support, self.num_classes*self.num_queries], 0) ## prototype part alpha = 0.2 m = Dirichlet(torch.tensor([alpha]*5)) convexhull_weights = m.sample((emb_s.size(0),)).cuda(0) if self.training: # convex combination emb_s = emb_s.view(self.num_classes, self.num_support, np.prod(emb_s.shape[1:])) # (5, 5, 64*5*5), (Nc*Ns*s_dim) emb_s = torch.cat([torch.matmul(emb_s[i].transpose(0,1), convexhull_weights[i].view(5,1)) for i in range(5)], dim=1).transpose(0,1) # (5, 64*5*5), (Nc*s_dim) else: # prototype emb_s = emb_s.view(self.num_classes, self.num_support, np.prod(emb_s.shape[1:])).mean(1) # (5, 64*5*5), (Nc*s_dim) emb_q = emb_q.view(-1, np.prod(emb_q.shape[1:])) # (Nq, 64*5*5) assert emb_s.shape[-1] == emb_q.shape[-1], 'the dimension of embeddings must be equal' emb_s = torch.unsqueeze(emb_s, 0) # 1xNxD, (1, Nc, 64*5*5) emb_q = torch.unsqueeze(emb_q, 1) # Nx1xD, (Nq, 1, 64*5*5) return emb_s, emb_q
class DirichletActionSelector(object): def __init__(self, INITIAL_EPSILON, FINAL_EPSILON, policy_net, EPS_DECAY, n_actions, lamb, device): self._eps = INITIAL_EPSILON self._FINAL_EPSILON = FINAL_EPSILON self._INITIAL_EPSILON = INITIAL_EPSILON self._policy_net = policy_net self._EPS_DECAY = EPS_DECAY self._n_actions = n_actions self._device = device distn_params = [ 1 / lamb for _ in range(policy_net.get_num_ensembles()) ] self.distn = Dirichlet(torch.tensor(distn_params)) def select_action(self, state, training=True): sample = random.random() if training: self._eps -= (self._INITIAL_EPSILON - self._FINAL_EPSILON) / self._EPS_DECAY self._eps = max(self._eps, self._FINAL_EPSILON) if sample > self._eps: with torch.no_grad(): if training: alpha = self.distn.sample() q_val = 0 for i in range(self._policy_net.get_num_ensembles()): q_val += alpha[i] * \ self._policy_net(state.to(self._device), ens_num=i) a = q_val.max(1)[1].cpu().view(1, 1) else: a = self._policy_net(state.to( self._device)).max(1)[1].cpu().view(1, 1) else: a = torch.tensor([[random.randrange(self._n_actions)]], device='cpu', dtype=torch.long) return a.numpy()[0, 0].item(), self._eps
def reconstruct(self, inputs, topics, lengths, max_length, sos_id, fix_z=False, fix_t=True): enc_emb = self.lookup(inputs) topics.unsqueeze_(0) hn, _ = self.encoder(enc_emb, lengths) if self.is_joint: fix_t = True hn = torch.cat([hn, topics], dim=2) mu, logvar = self.fcmu(hn), self.fclogvar(hn) if fix_z: z = mu else: z = self.reparameterize(mu, logvar) if not fix_t: alphas = self.topic_prior(z) dist = Dirichlet(alphas.cpu()) topics = dist.sample().to(z.device) return self.generate(z, topics, max_length, sos_id)
class ColoredGraphGamePlayer: def __init__(self, graph_location: "str", color_count: "int"): self.episode = 0 self.episode_cond = None self._color_count = color_count self._game = ColoredGraphGame(graph_location=graph_location) self._model = MCTS() self._model.initiate_sample(self._game.graph()) _tmp_dirichlet = 10 / ( (self._color_count) * self._game._colored_graph.vertex_count) self._dirichlet = Dirichlet( torch.tensor([_tmp_dirichlet for _ in range(self._color_count)])) self._resign_treshold = float("inf") self._count_uncolored_vertices = 0 # Sets its "_resign_treshold" parameter. def set_resign_treshold(self, treshold_: "int"): self._resign_treshold = treshold_ # Returns the game board. def return_board(self) -> "1d Int Numpy Array": return self._game.board() # Train the nnet. def reinforce_nnet( self, data_set: "[[[Int List], [float List], float], ...] List", learning_rate_: "float", epoch_: "int", batch_size_: "int", l2_coef: "float"): self._model.reinforce(data_set, learning_rate_, epoch_, batch_size_, l2_coef) # Calculates the dirichlet probability on the pi. def calculate_dirichlet_probability(self, pi_: "float list") -> "float list": dirProb = np.float64( [round(i.item(), 3) for i in self._dirichlet.sample()]) if sum(dirProb) != 1: indexMax = np.argmax(dirProb) dirProb[indexMax] = dirProb[indexMax] + 1.0 - sum(dirProb) return np.array( [0.750 * pi_[i] + 0.250 * dirProb[i] for i in range(len(pi_))]) # Prints episode's condition. def print_episode_conditions(self): print("Episode: ", self.episode, ", Condition is: ", self.episode_cond) # Change nnet of the model. def set_nnet(self, nnet_: "NeuralNet"): self._model.set_nnet(nnet_) # Return nnet of the model. def return_nnet(self) -> "NeuralNet": return self._model.return_nnet() # Changes the pi value to a more trainable pi and returns color. def finalize_pi_color(self, pi_: "float List", turn_: "int", turn_threshold_: "int") -> "float list, int": if pi_ is False: color = 0 pi = np.array([0 for _ in range(self._color_count)]) else: T = (turn_ < self._game._colored_graph.vertex_count * turn_threshold_) if T: pi_ = self.calculate_dirichlet_probability(pi_) color = np.random.choice(len(pi_), p=pi_) + 1 pi = pi_ else: pi = [0 for _ in range(len(pi_))] max_index = np.argmax(pi_) pi[max_index] = 1 color = max_index + 1 return pi, color # Plays # of episode where in each episode, simulation count is increased by a coefficent. Finally returns average condition. def play_arena(self, episode_count_: "int", simulation_count: "int", simulation_coef: "float", c: "int", timer_: "GeneralTimer") -> "float": conditions_ = [] for episode in range(episode_count_): while not self._game.eog(): for _ in range(simulation_count + episode * simulation_coef): self._model.simulate(c, timer_) self._model.synchronize_sample(self._game.graph()) # Predict the action. timer_.simulation.start_time() self._model.calculate_legal_pi(self._game.graph(), self._color_count + 1) if self._model.pi is not False: color = np.argmax(self._model.pi) + 1 else: color = 0 self._game.play_turn(color) self._model.sync_sample_update_tree(self._game.graph(), color) timer_.simulation.stop_time() if self._game.get_condition() == float(1): print("Solution is found on the arena, Simulation count was", simulation_count + episode * simulation_coef, end=". ") return self._game.get_condition() else: conditions_.append(self._game.get_condition()) self.reset_tree_game() return (sum(conditions_) / float(len(conditions_))) # Plays episode (set of turns). def play_game( self, simulation_count: "int", turn_threshold_: "int", c: "int", derive_data: "bool", timer_: "GeneralTimer" ) -> "2d [[Int], [float], float] List or bool": self._count_uncolored_vertices = 0 episodic_data = ExpandableEpisodicDataSet(self.episode, self._color_count) while not self._game.eog(): for _ in range(simulation_count): self._model.simulate(c, timer_) self._model.synchronize_sample(self._game.graph()) # Predict the action. timer_.simulation.start_time() self._model.calculate_legal_pi(self._game.graph(), self._color_count + 1) pi, color = self.finalize_pi_color(self._model.pi, self._game.turn, turn_threshold_) if self.check_resign(color): timer_.simulation.stop_time() del episodic_data return False else: episodic_data.insert(self._game.board(), pi) self._game.play_turn(color) self._model.sync_sample_update_tree(self._game.graph(), color) timer_.simulation.stop_time() episodic_data.set_reward(self._game.graph().condition()) self.episode_cond = self._game.graph().condition() if derive_data: episodic_data.derive() self.episode += 1 return episodic_data.finalize() # If uncolored vertex count exceeds the treshold_, then resign. def check_resign(self, color_: "int") -> "bool": if color_ == 0: self._count_uncolored_vertices += 1 if self._count_uncolored_vertices > self._resign_treshold: return True else: return False # End of episode or comparison. def reset_tree_game(self): self._game.reset_game() self._model.reset_sync_tree() # Compare the nnet with model's nnet with # of episodes each increases # of simulations. # If solution is found then returns True. # Called on the end of one iteration (episode set). def set_best_nnet(self, old_nnet_: "NeuralNet", episode_count_: "int", simulation_count_: "int", simulation_coef_: "int", c_: "int", timer_: "GeneralTimer") -> "bool": self.episode = 0 self.episode_cond = 0 new_nnet_condition = self.play_arena(episode_count_, simulation_count_, simulation_coef_, c_, timer_) if new_nnet_condition == float(1): print("Trained Network") return True new_nnet_ = self.return_nnet() self.set_nnet(old_nnet_) old_nnet_condition = self.play_arena(episode_count_, simulation_count_, simulation_coef_, c_, timer_) if old_nnet_condition == float(1): print("Old Network") return True print("Trained Network's average condition is: ", new_nnet_condition, "on", episode_count_, "games.") print("Old Network's average condition is: ", old_nnet_condition, "on", episode_count_, "games.") if new_nnet_condition >= old_nnet_condition: self.set_nnet(new_nnet_) print("Trained Network is Selected!") else: print("Old Network is selected.") print("=" * 99) return False
def one_hot_dirichlet(x: Tensor, num_classes=-1): labels = one_hot(x, num_classes=num_classes).float() labels = (labels * 10000) + 100 distribution = Dirichlet(labels) return distribution.sample()