Example #1
0
def dirichlet_nll(parameters, target):

    distr = torch_dirichlet.Dirichlet(concentration=parameters)
    neg_log_prob = -distr.log_prob(torch.transpose(target, 0, 1))
    loss = torch.mean(neg_log_prob)

    return loss
Example #2
0
def train_GP_WLB(xtrain,
                 ytrain,
                 xtest=None,
                 return_mean=True,
                 return_samples=False,
                 num_bootstraps=1,
                 samples_per_bootstrap=10,
                 num_iters=5000,
                 verbose=True):
    '''
    Samples from the WLB (=NPL with alpha=0)
    xtrain: NxD torch tensor of training covariates
    ytrain: Nx1 torch tensor of training targets
    xtest (optional): MxD torch tensor of test covariates
    return_mean: Boolean, if true and xtest is not None, includes mean and standard deviation of predictions at xtest
    samples_per_bootstrap: Number of samples to generate per bootstrap of predictive distribution (if xtest is not None)
    
    return_samples: Boolean, if true and xtest is not None, includes the raw samples in the return object
    num_bootstraps: Integer, number of bootstrap samples
    samples_per_bootstrap: integer, number of predictive samples to generate per bootstrap if xtest is not None
    num_iters: number of iterations
    

    returns: Dict with following entries:
    'lengthscale': Sampled lengthscales
    'sigma_n': Sampled noise standard deviations
    'mean': Mean function at xtest (only included if xtest is not None and return_mean=True)
    'std': Standard deviation at xtest (only included if xtest is not None and return_mean=True)
    'samples': return_samples samples from posterior (only included if xtest is not None and return_samples=True)
    '''
    weight_generator = d.Dirichlet(torch.ones(len(ytrain)))

    lengthscale = []
    sigma_n = []
    if xtest is not None:
        samples = np.zeros((0, len(xtest)))
    for b in range(num_bootstraps):
        weights = weight_generator.sample() * len(ytrain)
        bres = train_GP_weighted(xtrain,
                                 ytrain,
                                 xtest=xtest,
                                 return_mean=False,
                                 num_samples=samples_per_bootstrap,
                                 num_iters=5000,
                                 weights=weights,
                                 verbose=verbose)
        if xtest is not None:
            samples = np.vstack((samples, bres['samples']))
        lengthscale.append(bres['lengthscale'])
        sigma_n.append(bres['sigma_n'])
    res = {'lengthscale': lengthscale, 'sigma_n': sigma_n}
    if return_mean:
        res['mean'] = np.mean(samples, axis=0)
        res['std'] = np.std(samples, axis=0)
    if return_samples:
        res['samples'] = samples

    return res
Example #3
0
 def dirichlet_number_generator(self, total_size):
     m = dirichlet.Dirichlet(
         torch.Tensor([total_size, total_size, total_size]))
     return m.sample()
Example #4
0
def train_GP_NPL(xtrain,
                 ytrain,
                 x_prior_dist,
                 y_prior_dist,
                 xtest=None,
                 return_mean=True,
                 return_samples=False,
                 num_bootstraps=1,
                 samples_per_bootstrap=10,
                 num_iters=5000,
                 alpha=1.,
                 num_pseudo=10,
                 verbose=True):
    '''
    Samples from the NPL
    xtrain: NxD torch tensor of training covariates
    ytrain: Nx1 torch tensor of training targets
    x_prior_dist: torch prior distribution over covariates
    y_prior_dist: torch prior distribution over targets
    xtest (optional): MxD torch tensor of test covariates
    return_mean: Boolean, if true and xtest is not None, includes mean and standard deviation of predictions at xtest
    samples_per_bootstrap: Number of samples to generate per bootstrap of predictive distribution (if xtest is not None)
    
    return_samples: Boolean, if true and xtest is not None, includes the raw samples in the return object
    num_bootstraps: Integer, number of bootstrap samples
    samples_per_bootstrap: integer, number of predictive samples to generate per bootstrap if xtest is not None
    num_iters: number of iterations
    alpha: concentration parameter (>0)
    num_pseudo: number of pseudo-samples
    

    returns: Dict with following entries:
    'lengthscale': Sampled lengthscales
    'sigma_n': Sampled noise standard deviations
    'mean': Mean function at xtest (only included if xtest is not None and return_mean=True)
    'std': Standard deviation at xtest (only included if xtest is not None and return_mean=True)
    'samples': return_samples samples from posterior (only included if xtest is not None and return_samples=True)
    '''
    dirichlet_weight = torch.cat(
        (torch.ones(len(ytrain)),
         (alpha / num_pseudo) * torch.ones(num_pseudo)), 0)
    weight_generator = d.Dirichlet(dirichlet_weight)

    lengthscale = []
    sigma_n = []
    if xtest is not None:
        samples = np.zeros((0, len(xtest)))
    for b in range(num_bootstraps):
        weights = weight_generator.sample() * (len(ytrain) + alpha)
        pseudo_x = x_prior_dist.sample(sample_shape=torch.Size([num_pseudo]))
        pseudo_y = y_prior_dist.sample(sample_shape=torch.Size([num_pseudo]))
        both_x = torch.cat((xtrain, pseudo_x), 0)
        both_y = torch.cat((ytrain, pseudo_y), 0)

        bres = train_GP_weighted(both_x,
                                 both_y,
                                 xtest=xtest,
                                 return_mean=False,
                                 num_samples=samples_per_bootstrap,
                                 num_iters=5000,
                                 weights=weights,
                                 verbose=verbose)
        if xtest is not None:
            samples = np.vstack((samples, bres['samples']))
        lengthscale.append(bres['lengthscale'])
        sigma_n.append(bres['sigma_n'])
    res = {'lengthscale': lengthscale, 'sigma_n': sigma_n}
    if return_mean:
        res['mean'] = np.mean(samples, axis=0)
        res['std'] = np.std(samples, axis=0)
    if return_samples:
        res['samples'] = samples

    return res
Example #5
0
 def create_generic_dir_stat(self):
     self.alpha0 = torch.sum(self.target_alpha)
     self.generic_mean = self.target_alpha / self.alpha0
     self.generic_dir = d.Dirichlet(self.target_alpha)
     self.creat_covariance()
Example #6
0
class MCTS:
    '''Monte Carlo tree searcher. 
    Nodes are selected with the PUCT variant used in AlphaGo.
    First rollout the tree then choose. To input a move
    use make_move(..) and set_root(..)

    args:
        root (Go_MCTS): node representing current game state
        policy_net (PolicyNet): PolicyNet for getting prior distributions
        value_net (ValueNet): for getting board value (between -1 and 1)
                   If no valuenet is given, rewards are based on simulations only
        no_sim (bool): disable simulations and evaluate only with value net 
    kwargs:
        expand_thresh (int): number of visits before leaf is expanded (default 100)
        branch_num (int): number of children to expand. If not specified, all legal moves expanded 
        exploration_weight (float): scalar for prior prediction (default 4.0)
        value_net_weight (float): scalar between 0 and 1 for mixing value network
                                  and simulation rewards (default 0.5)
        noise_weight (float): scalar between 0 and 1 for adding Dirichlet noise (default 0.25)
    Attributes:
        Q: dict containing total simulation rewards of each node
        N: dict containing total visits to each node
        V: dict containing accumulated value of each node
        children: dict containing children of each node
    '''

    _dirichlet = dirichlet.Dirichlet(0.1 * torch.ones(go.N**2))
    _val_cache = dict()
    _dist_cache = dict()
    _fts_cache = dict()

    def __init__(self,
                 root,
                 policy_net: PolicyNet = None,
                 value_net: ValueNet = None,
                 **kwargs):
        self.Q = defaultdict(int)
        self.N = defaultdict(int)
        self.V = defaultdict(float)
        self.children = dict()
        if policy_net is None:
            raise TypeError("Missing required keywork argument: 'policy_net'")
        self.policy_net = policy_net
        self.value_net = value_net
        self.no_sim = kwargs.get("no_sim", True)
        if self.value_net is None and self.no_sim:
            raise TypeError(
                "Keyword argument 'value_net' is required for no simulation mode"
            )
        self.expand_thresh = kwargs.get("expand_thresh", 100)
        self.branch_num = kwargs.get("branch_num")
        self.exploration_weight = kwargs.get("exploration_weight", 4.0)
        self.noise_weight = kwargs.get("noise_weight", 0)
        if self.no_sim:
            self.value_net_weight = 1.0
        elif self.value_net is None:
            self.value_net_weight = 0.0
        else:
            self.value_net_weight = kwargs.get("value_net_weight", 0.5)

        #for GPU computations
        self.device = kwargs.get("device", torch.device("cpu"))
        policy_net.to(self.device)
        if value_net != None:
            value_net.to(self.device)

        #initialize the root
        self.set_root(root)

    def __deepcopy__(self, memo):
        cls = self.__class__
        new_tree = cls.__new__(cls)
        new_tree.__dict__.update(self.__dict__)
        new_tree.root = deepcopy(self.root)
        new_tree.V = deepcopy(self.V)
        new_tree.Q = deepcopy(self.Q)
        new_tree.N = deepcopy(self.N)
        new_tree.children = deepcopy(self.children)
        return new_tree

    #For pickling
    def __getstate__(self):
        state_dict = self.__dict__.copy()
        del state_dict["policy_net"]
        del state_dict["value_net"]
        return state_dict

    def __setstate__(self, state_dict):
        self.__dict__.update(state_dict)
        #give the nodes a reference to the tree
        for n in self.children:
            n.tree = self
            for c in self.children[n]:
                c.tree = self
        #set the policy net and value net manually
        self.policy_net = None
        self.value_net = None

    def choose(self, node=None):
        '''Choose the best child of root and set it as the new root
        optional:
            node: choose from different node (doesn't affect root)'''
        if node is None:
            node = self.root
        if node._terminal:
            #print(f"{node} Board is terminal")
            return node
        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.N[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.N[n]

        # Choose most visited node
        best = max(self.children[node], key=score)
        if node == self.root:
            self.set_root(best)
        return best

    def rollout(self, n=1, analyze_dict=None):
        '''Do rollouts from the root
        args:
            n (int): number of rollouts
            analyze_dict: (optional) dict to store variations
        '''
        for _ in range(n):
            # Get path to leaf of current search tree
            path = self._descend()
            leaf = path[-1]

            if analyze_dict != None and len(path) > 2:
                analyze_dict[path[1]] = path[1:]

            if not self.no_sim:
                score = self._simulate(leaf, gnu=True)
            else:
                score = None
            self._backpropagate(path, score, leaf.value)

    def set_root(self, node):
        self.root = node
        self.root.tree = self
        self.root._add_noise(self.noise_weight)
        self._expand(self.root)

    def winrate(self, node=None):
        '''Returns float between 0.0 and 1.0 representing winrate
        from persepctive of the root
        optional:
            node: return winrate of a different node'''
        w = self.value_net_weight
        if node is None:
            node = self.root
        if self.N[node] > 0:
            v = ((1 - w) * self.Q[node] + w * self.V[node]) / self.N[node]
            return (v + 1) / 2
        return 0

    def _descend(self):
        "Return a path from root down to leaf via PUCT selection"
        path = [self.root]
        node = self.root
        while True:
            # Is node a leaf?
            if node not in self.children or not self.children[node]:
                if self.N[node] > self.expand_thresh:
                    self._expand(node)
                return path
            node = self._puct_select(node)  # descend a layer deeper
            path.append(node)

    def _expand(self, node):
        "Update the `children` dict with the children of `node`"
        if node in self.children:
            return  # already expanded
        if self.branch_num:
            self.children[node] = node.find_children(k=self.branch_num)
        else:
            self.children[node] = node.find_children()

    # Need to make this faster (ideally at least 10x)
    def _simulate(self, node, gnu=False):
        '''Returns the reward for a random simulation (to completion) of node
        optional: 
            gnu: if True, score with gnugo (default False)'''
        invert_reward = not (node.turn % 2 == 0)  #invert if it is white's turn
        while True:
            if node._terminal:
                reward = node.reward(gnu)
                if invert_reward:
                    reward = -reward
                return reward
            node = node.find_random_child()

    def _backpropagate(self, path, reward, leaf_val):
        '''Send the reward back up to the ancestors of the leaf'''
        for node in reversed(path):
            self.N[node] += 1
            if reward:
                self.Q[node] += reward
                reward = -reward
            if self.value_net != None:
                self.V[node] += leaf_val
                leaf_val = -leaf_val

    def _puct_select(self, node):
        "Select a child of node with PUCT"
        total_visits = sum(self.N[n] for n in self.children[node])
        # First visit selects policy's top choice
        if total_visits == 0:
            total_visits = 1

        def puct(n):
            last_move_prob = node.dist.probs[n.last_move].item()
            avg_reward = 0 if self.N[n] == 0 else \
                    ((1 - self.value_net_weight) * self.Q[n]
                     + self.value_net_weight * self.V[n]) / self.N[n]
            return -avg_reward + (self.exploration_weight * last_move_prob *
                                  sqrt(total_visits) / (1 + self.N[n]))

        return max(self.children[node], key=puct)

    def _prune(self):
        '''Prune the tree leaving only root and its descendants'''
        new_children = defaultdict(int)
        q = [self.root]
        while q:
            n = q.pop()
            c = self.children.get(n)
            if c:
                new_children[n] = c
                q.extend(c)
        self.children = new_children
        remove_me = set()
        for n in self.N:
            if n not in new_children:
                remove_me.add(n)
        for n in remove_me:
            del self.Q[n]
            del self.N[n]
            if n in self.V:
                del self.V[n]