Esempio n. 1
0
def augmentAndMix(x_orig, k, alpha, preprocess):
    # k : number of chains
    # alpha : sampling constant

    x_temp = x_orig  # back up for skip connection

    x_aug = torch.zeros_like(preprocess(x_orig))
    mixing_weight_dist = Dirichlet(torch.empty(k).fill_(alpha))
    mixing_weights = mixing_weight_dist.sample()

    for i in range(k):
        sampled_augs = random.sample(augmentations, k)
        aug_chain_length = random.choice(range(1, k + 1))
        aug_chain = sampled_augs[:aug_chain_length]

        for aug in aug_chain:
            severity = random.choice(range(1, 6))
            x_temp = aug(x_temp, severity)

        x_aug += mixing_weights[i] * preprocess(x_temp)

    skip_conn_weight_dist = Beta(torch.tensor([alpha]), torch.tensor([alpha]))
    skip_conn_weight = skip_conn_weight_dist.sample()

    x_augmix = skip_conn_weight * x_aug + (
        1 - skip_conn_weight) * preprocess(x_orig)

    return x_augmix
Esempio n. 2
0
    def get_best_move(self, s, v, rm=None):

        proba = torch.tensor([v[a] for a in s.legal_moves]) # pylint: disable=E
        proba = F.softmax(self.c.eval_move * proba, dim=0).numpy()
        i_max_no_noise = proba.argmax()

        if self.add_noise:
            dir_dist = Dirichlet(torch.zeros(len(s.legal_moves)) + self.c.alpha_dir)
            noise = dir_dist.sample().numpy()
            proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise
        
        # Best move
        i_max = proba.argmax()
        best_move = s.legal_moves[i_max]
        
        # For RunManager
        if rm:
            proba_dictate = int(i_max_no_noise == i_max)
            rm.proba(
                proba[i_max],
                self.c.eps_dir * noise[i_max],
                proba_dictate
            )
        
        return best_move
Esempio n. 3
0
    def get_best_move(self, s, v, rm=None):

        # Compute the indices of the legal moves in the tensor v.
        legal_mask = torch.zeros(v.shape, dtype=torch.bool)
        for a in s.legal_moves:
            legal_mask[encoding.a_id(a)] = True

        # Compute the probabilities of each legal moves.
        proba = torch.from_numpy(v[legal_mask])
        proba = F.softmax(self.c.eval_move * proba, dim=0).numpy()
        i_max_no_noise = proba.argmax()

        # Add noise if so.
        if self.add_noise:
            dir_dist = Dirichlet(
                torch.zeros(len(s.legal_moves)) + self.c.alpha_dir)
            noise = dir_dist.sample().numpy()
            proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise

        # Best move
        i_max = proba.argmax()
        best_move = s.legal_moves[i_max]

        # For RunManager
        if rm:
            best_move_code = np.ravel_multi_index(encoding.a_id(best_move),
                                                  v.shape)
            v_dictate = int(v.argmax() == best_move_code)
            proba_dictate = int(i_max_no_noise == i_max)
            rm.proba(v.max(), proba[i_max], self.c.eps_dir * noise[i_max],
                     v[legal_mask.logical_not()].max().item(), v_dictate,
                     proba_dictate)

        return best_move
Esempio n. 4
0
 def max_ucb_noise_node(self, n):
     Nc = len(n.children)
     dir_dist = Dirichlet(torch.zeros(Nc) + self.alpha_dir)
     noises = dir_dist.sample().numpy()
     i_max = max(range(Nc),
                 key=lambda i: self.ucb_noise(n.children[i], noises[i]))
     return n.children[i_max]
Esempio n. 5
0
 def sample(self, labels, max_length, sos_id, scale=1):
     lab_emb = self.label_lookup(labels).unsqueeze(0)
     mu, logvar = self.z_prior(lab_emb)
     z = self.reparameterize(mu, logvar)
     if scale != 1:
         z = mu + (z - mu) * scale
     alphas = self.topic_prior(torch.cat([z, lab_emb], dim=2))
     dist = Dirichlet(alphas.cpu())
     topics = dist.sample().to(alphas.device)
     return self.generate(z, topics, lab_emb, max_length, sos_id)
Esempio n. 6
0
    def sample(self, num_samples, max_length, sos_id, device):
        """Randomly sample latent code to sample texts. 
        Note that num_samples should not be too large. 

        """
        z_size = self.fcmu.out_features
        z = torch.randn(1, num_samples, z_size, device=device)
        alphas = self.topic_prior(z)
        dist = Dirichlet(alphas.cpu())
        topics = dist.sample().to(device)
        return self.generate(z, topics, max_length, sos_id)
Esempio n. 7
0
class DirichletSkillTanhGaussianPolicy(SkillTanhGaussianPolicy):
    def __init__(
            self,
            hidden_sizes,
            obs_dim,
            action_dim,
            std=None,
            init_w=1e-3,
            skill_dim=10,
            gamma=1e-3,
            **kwargs
    ):
        super().__init__(
            hidden_sizes=hidden_sizes,
            obs_dim=obs_dim,
            action_dim=action_dim,
            std=std,
            init_w=init_w,
            skill_dim=skill_dim,
            **kwargs
        )
        self.gamma = gamma
        self.skill_space = Dirichlet(torch.ones(self.skill_dim))
        self.skill = self.skill_space.sample().cpu().numpy()

    def get_action(self, obs_np, deterministic=False):
        # generate (skill_dim, ) matrix that stacks one-hot skill vectors
        # online reinforcement learning
        obs_np = np.concatenate((obs_np, self.skill), axis=0)
        actions = self.get_actions(obs_np[None], deterministic=deterministic)
        return actions[0, :], {"skill": self.skill}

    def skill_reset(self):
        self.skill = self.skill_space.sample().cpu().numpy()

    def alpha_update(self, epoch, tau):
        d_alpha = min(self.gamma+(1-self.gamma)*epoch/tau, 1) * torch.ones(self.skill_dim).cpu()
        self.skill_space = Dirichlet(torch.tensor(d_alpha))

    def alpha_reset(self):
        self.skill_space = Dirichlet(torch.ones(self.skill_dim))
def plot_dir(alpha, size):
    model = Dirichlet(torch.tensor(alpha))
    sample = model.sample(torch.Size([size])).data
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.scatter3D(sample[:, 0], sample[:, 1], sample[:, 2], color='red')
    ax.plot([0, 0], [1, 0], [0, 1], linewidth=3, color='purple')
    ax.plot([0, 1], [0, 0], [1, 0], linewidth=3, color='purple')
    ax.plot([0, 1], [1, 0], [0, 0], linewidth=3, color='purple')
    ax.set_xlim((0, 1))
    ax.set_ylim((0, 1))
    ax.set_zlim((0, 1))
    ax.view_init(60, 35)
Esempio n. 9
0
 def _sample_volume_alphas(self, n_related):
     if self.uniform_volumes:
         u = Uniform(0.25, 1.25)
         return u.sample().repeat(n_related)
     if isinstance(self.concentration, (float, int)):
         concentration = self.concentration
     else:
         concentration = self.concentration.rvs()
     dirichlet = Dirichlet(
         torch.tensor([concentration for _ in range(n_related)]))
     if self.random_seed is not None:
         torch.manual_seed(self.random_seed)
     return dirichlet.sample() * float(self.n_classes)
Esempio n. 10
0
class AugMix(nn.Module):
    def __init__(self, k=3, alpha=1, severity=3):
        super(AugMix, self).__init__()
        self.k = k
        self.alpha = alpha
        self.severity = severity
        self.dirichlet = Dirichlet(torch.full(torch.Size([k]), alpha, dtype=torch.float32))
        self.beta = Beta(alpha, alpha)
        self.augs = augmentations
        self.kl = nn.KLDivLoss(reduction='batchmean')

    def augment_and_mix(self, images, preprocess):
        '''
        Args:
            images: PIL Image
            preprocess: transform[ToTensor, Normalize]

        Returns: AugmentAndMix Tensor
        '''
        mix = torch.zeros_like(preprocess(images))
        w = self.dirichlet.sample()
        for i in range(self.k):
            aug = images.copy()
            depth = np.random.randint(1, 4)
            for _ in range(depth):
                op = np.random.choice(self.augs)
                aug = op(aug, 3)
            mix = mix + w[i] * preprocess(aug)

        m = self.beta.sample()

        augmix = m * preprocess(images) + (1 - m) * mix

        return augmix

    def jensen_shannon(self, logits_o, logits_1, logits_2):
        p_o = F.softmax(logits_o, dim=1)
        p_1 = F.softmax(logits_1, dim=1)
        p_2 = F.softmax(logits_2, dim=1)

        # kl(q.log(), p) -> KL(p, q)
        M = torch.clamp((p_o + p_1 + p_2) / 3, 1e-7, 1)  # to avoid exploding
        js = (self.kl(M.log(), p_o) + self.kl(M.log(), p_1) + self.kl(M.log(), p_2)) / 3
        return js
Esempio n. 11
0
    def one_hot_dirichlet(self, y):

        try:
            y = th.from_numpy(y)
        except TypeError:
            None

        y_1d = y
        y_hot = th.zeros((y.size(0), th.max(y).int()+1))
        y_dir = th.zeros((y.size(0), th.max(y).int()+1))

        for i in range(y.size(0)):
            y_hot[i, y_1d[i].int()] = 10000

        for i in range(y.size(0)):
            m = Dirichlet(y_hot[i] + 100)
            y_dir[i] = m.sample()

        return y_dir
Esempio n. 12
0
    def forward(self, inputs):
        """
        get embedding of support and query
        :param
            inputs:
                support:    (N_way*N_shot)x3x84x84
                s_labels:   (N_way*N_shot)xN_way, one-hot [25, 5]
                query:      (N_way*N_query)x3x84x84
                q_labels:   (N_way*N_query)xN_way, one-hot
        :return:
            emb_support : (1, Nc, 64*5*5)
            emb_query : (Nq, 1, 64*5*5)
        """
        ## process inputs
        self.num_classes, self.num_support, self.num_queries, self.concat_s_q = preprocess_input(inputs)
        self.q_labels = inputs[-1]

        ## encoding part
        emb   = self.encoder(self.concat_s_q) # emb shape:(100*64*5*5)
        # emb_s:(Ns*64*5*5), emb_q:(Nq*64*5*5)
        emb_s, emb_q = torch.split(emb, [self.num_classes*self.num_support, self.num_classes*self.num_queries], 0)

        ## prototype part
        alpha = 0.2
        m = Dirichlet(torch.tensor([alpha]*5))
        convexhull_weights = m.sample((emb_s.size(0),)).cuda(0)

        if self.training:
            # convex combination
            emb_s = emb_s.view(self.num_classes, self.num_support, np.prod(emb_s.shape[1:]))  # (5, 5, 64*5*5), (Nc*Ns*s_dim)
            emb_s = torch.cat([torch.matmul(emb_s[i].transpose(0,1), convexhull_weights[i].view(5,1)) for i in range(5)], dim=1).transpose(0,1)    # (5, 64*5*5), (Nc*s_dim)
        else:    
            # prototype
            emb_s = emb_s.view(self.num_classes, self.num_support, np.prod(emb_s.shape[1:])).mean(1)  # (5, 64*5*5), (Nc*s_dim)
        
        emb_q = emb_q.view(-1, np.prod(emb_q.shape[1:]))    # (Nq, 64*5*5)
        assert emb_s.shape[-1] == emb_q.shape[-1], 'the dimension of embeddings must be equal'
        emb_s = torch.unsqueeze(emb_s, 0)     # 1xNxD, (1, Nc, 64*5*5)
        emb_q = torch.unsqueeze(emb_q, 1)     # Nx1xD, (Nq, 1, 64*5*5)

        return emb_s, emb_q
Esempio n. 13
0
class DirichletActionSelector(object):
    def __init__(self, INITIAL_EPSILON, FINAL_EPSILON, policy_net, EPS_DECAY,
                 n_actions, lamb, device):
        self._eps = INITIAL_EPSILON
        self._FINAL_EPSILON = FINAL_EPSILON
        self._INITIAL_EPSILON = INITIAL_EPSILON
        self._policy_net = policy_net
        self._EPS_DECAY = EPS_DECAY
        self._n_actions = n_actions
        self._device = device
        distn_params = [
            1 / lamb for _ in range(policy_net.get_num_ensembles())
        ]
        self.distn = Dirichlet(torch.tensor(distn_params))

    def select_action(self, state, training=True):
        sample = random.random()
        if training:
            self._eps -= (self._INITIAL_EPSILON -
                          self._FINAL_EPSILON) / self._EPS_DECAY
            self._eps = max(self._eps, self._FINAL_EPSILON)
        if sample > self._eps:
            with torch.no_grad():
                if training:
                    alpha = self.distn.sample()
                    q_val = 0
                    for i in range(self._policy_net.get_num_ensembles()):
                        q_val += alpha[i] * \
                            self._policy_net(state.to(self._device), ens_num=i)
                    a = q_val.max(1)[1].cpu().view(1, 1)
                else:
                    a = self._policy_net(state.to(
                        self._device)).max(1)[1].cpu().view(1, 1)
        else:
            a = torch.tensor([[random.randrange(self._n_actions)]],
                             device='cpu',
                             dtype=torch.long)

        return a.numpy()[0, 0].item(), self._eps
Esempio n. 14
0
 def reconstruct(self,
                 inputs,
                 topics,
                 lengths,
                 max_length,
                 sos_id,
                 fix_z=False,
                 fix_t=True):
     enc_emb = self.lookup(inputs)
     topics.unsqueeze_(0)
     hn, _ = self.encoder(enc_emb, lengths)
     if self.is_joint:
         fix_t = True
         hn = torch.cat([hn, topics], dim=2)
     mu, logvar = self.fcmu(hn), self.fclogvar(hn)
     if fix_z:
         z = mu
     else:
         z = self.reparameterize(mu, logvar)
     if not fix_t:
         alphas = self.topic_prior(z)
         dist = Dirichlet(alphas.cpu())
         topics = dist.sample().to(z.device)
     return self.generate(z, topics, max_length, sos_id)
Esempio n. 15
0
class ColoredGraphGamePlayer:
    def __init__(self, graph_location: "str", color_count: "int"):
        self.episode = 0
        self.episode_cond = None
        self._color_count = color_count
        self._game = ColoredGraphGame(graph_location=graph_location)

        self._model = MCTS()
        self._model.initiate_sample(self._game.graph())

        _tmp_dirichlet = 10 / (
            (self._color_count) * self._game._colored_graph.vertex_count)
        self._dirichlet = Dirichlet(
            torch.tensor([_tmp_dirichlet for _ in range(self._color_count)]))

        self._resign_treshold = float("inf")
        self._count_uncolored_vertices = 0

    # Sets its "_resign_treshold" parameter.
    def set_resign_treshold(self, treshold_: "int"):
        self._resign_treshold = treshold_

    # Returns the game board.
    def return_board(self) -> "1d Int Numpy Array":
        return self._game.board()

    # Train the nnet.
    def reinforce_nnet(
            self, data_set: "[[[Int List], [float List], float], ...] List",
            learning_rate_: "float", epoch_: "int", batch_size_: "int",
            l2_coef: "float"):
        self._model.reinforce(data_set, learning_rate_, epoch_, batch_size_,
                              l2_coef)

    # Calculates the dirichlet probability on the pi.
    def calculate_dirichlet_probability(self,
                                        pi_: "float list") -> "float list":
        dirProb = np.float64(
            [round(i.item(), 3) for i in self._dirichlet.sample()])
        if sum(dirProb) != 1:
            indexMax = np.argmax(dirProb)
            dirProb[indexMax] = dirProb[indexMax] + 1.0 - sum(dirProb)
        return np.array(
            [0.750 * pi_[i] + 0.250 * dirProb[i] for i in range(len(pi_))])

    # Prints episode's condition.
    def print_episode_conditions(self):
        print("Episode: ", self.episode, ", Condition is: ", self.episode_cond)

    # Change nnet of the model.
    def set_nnet(self, nnet_: "NeuralNet"):
        self._model.set_nnet(nnet_)

    # Return nnet of the model.
    def return_nnet(self) -> "NeuralNet":
        return self._model.return_nnet()

    # Changes the pi value to a more trainable pi and returns color.
    def finalize_pi_color(self, pi_: "float List", turn_: "int",
                          turn_threshold_: "int") -> "float list, int":
        if pi_ is False:
            color = 0
            pi = np.array([0 for _ in range(self._color_count)])
        else:
            T = (turn_ <
                 self._game._colored_graph.vertex_count * turn_threshold_)
            if T:
                pi_ = self.calculate_dirichlet_probability(pi_)
                color = np.random.choice(len(pi_), p=pi_) + 1
                pi = pi_
            else:
                pi = [0 for _ in range(len(pi_))]
                max_index = np.argmax(pi_)
                pi[max_index] = 1
                color = max_index + 1
        return pi, color

    # Plays # of episode where in each episode, simulation count is increased by a coefficent. Finally returns average condition.
    def play_arena(self, episode_count_: "int", simulation_count: "int",
                   simulation_coef: "float", c: "int",
                   timer_: "GeneralTimer") -> "float":
        conditions_ = []
        for episode in range(episode_count_):
            while not self._game.eog():
                for _ in range(simulation_count + episode * simulation_coef):
                    self._model.simulate(c, timer_)
                    self._model.synchronize_sample(self._game.graph())

                # Predict the action.
                timer_.simulation.start_time()
                self._model.calculate_legal_pi(self._game.graph(),
                                               self._color_count + 1)

                if self._model.pi is not False:
                    color = np.argmax(self._model.pi) + 1
                else:
                    color = 0
                self._game.play_turn(color)
                self._model.sync_sample_update_tree(self._game.graph(), color)
                timer_.simulation.stop_time()

            if self._game.get_condition() == float(1):
                print("Solution is found on the arena, Simulation count was",
                      simulation_count + episode * simulation_coef,
                      end=". ")
                return self._game.get_condition()
            else:
                conditions_.append(self._game.get_condition())
                self.reset_tree_game()
        return (sum(conditions_) / float(len(conditions_)))

    # Plays episode (set of turns).
    def play_game(
            self, simulation_count: "int", turn_threshold_: "int", c: "int",
            derive_data: "bool", timer_: "GeneralTimer"
    ) -> "2d [[Int], [float], float] List or bool":
        self._count_uncolored_vertices = 0
        episodic_data = ExpandableEpisodicDataSet(self.episode,
                                                  self._color_count)
        while not self._game.eog():
            for _ in range(simulation_count):
                self._model.simulate(c, timer_)
                self._model.synchronize_sample(self._game.graph())

            # Predict the action.
            timer_.simulation.start_time()
            self._model.calculate_legal_pi(self._game.graph(),
                                           self._color_count + 1)
            pi, color = self.finalize_pi_color(self._model.pi, self._game.turn,
                                               turn_threshold_)

            if self.check_resign(color):
                timer_.simulation.stop_time()
                del episodic_data
                return False
            else:
                episodic_data.insert(self._game.board(), pi)
                self._game.play_turn(color)
                self._model.sync_sample_update_tree(self._game.graph(), color)
                timer_.simulation.stop_time()

        episodic_data.set_reward(self._game.graph().condition())
        self.episode_cond = self._game.graph().condition()
        if derive_data:
            episodic_data.derive()
        self.episode += 1
        return episodic_data.finalize()

    # If uncolored vertex count exceeds the treshold_, then resign.
    def check_resign(self, color_: "int") -> "bool":
        if color_ == 0:
            self._count_uncolored_vertices += 1
        if self._count_uncolored_vertices > self._resign_treshold:
            return True
        else:
            return False

    # End of episode or comparison.
    def reset_tree_game(self):
        self._game.reset_game()
        self._model.reset_sync_tree()

    # Compare the nnet with model's nnet with # of episodes each increases # of simulations.
    # If solution is found then returns True.
    # Called on the end of one iteration (episode set).
    def set_best_nnet(self, old_nnet_: "NeuralNet", episode_count_: "int",
                      simulation_count_: "int", simulation_coef_: "int",
                      c_: "int", timer_: "GeneralTimer") -> "bool":
        self.episode = 0
        self.episode_cond = 0
        new_nnet_condition = self.play_arena(episode_count_, simulation_count_,
                                             simulation_coef_, c_, timer_)
        if new_nnet_condition == float(1):
            print("Trained Network")
            return True

        new_nnet_ = self.return_nnet()
        self.set_nnet(old_nnet_)

        old_nnet_condition = self.play_arena(episode_count_, simulation_count_,
                                             simulation_coef_, c_, timer_)
        if old_nnet_condition == float(1):
            print("Old Network")
            return True

        print("Trained Network's average condition is: ", new_nnet_condition,
              "on", episode_count_, "games.")
        print("Old Network's average condition is: ", old_nnet_condition, "on",
              episode_count_, "games.")
        if new_nnet_condition >= old_nnet_condition:
            self.set_nnet(new_nnet_)
            print("Trained Network is Selected!")
        else:
            print("Old Network is selected.")
        print("=" * 99)
        return False
Esempio n. 16
0
def one_hot_dirichlet(x: Tensor, num_classes=-1):
    labels = one_hot(x, num_classes=num_classes).float()
    labels = (labels * 10000) + 100
    distribution = Dirichlet(labels)
    return distribution.sample()