def dirichlet_nll(parameters, target): distr = torch_dirichlet.Dirichlet(concentration=parameters) neg_log_prob = -distr.log_prob(torch.transpose(target, 0, 1)) loss = torch.mean(neg_log_prob) return loss
def train_GP_WLB(xtrain, ytrain, xtest=None, return_mean=True, return_samples=False, num_bootstraps=1, samples_per_bootstrap=10, num_iters=5000, verbose=True): ''' Samples from the WLB (=NPL with alpha=0) xtrain: NxD torch tensor of training covariates ytrain: Nx1 torch tensor of training targets xtest (optional): MxD torch tensor of test covariates return_mean: Boolean, if true and xtest is not None, includes mean and standard deviation of predictions at xtest samples_per_bootstrap: Number of samples to generate per bootstrap of predictive distribution (if xtest is not None) return_samples: Boolean, if true and xtest is not None, includes the raw samples in the return object num_bootstraps: Integer, number of bootstrap samples samples_per_bootstrap: integer, number of predictive samples to generate per bootstrap if xtest is not None num_iters: number of iterations returns: Dict with following entries: 'lengthscale': Sampled lengthscales 'sigma_n': Sampled noise standard deviations 'mean': Mean function at xtest (only included if xtest is not None and return_mean=True) 'std': Standard deviation at xtest (only included if xtest is not None and return_mean=True) 'samples': return_samples samples from posterior (only included if xtest is not None and return_samples=True) ''' weight_generator = d.Dirichlet(torch.ones(len(ytrain))) lengthscale = [] sigma_n = [] if xtest is not None: samples = np.zeros((0, len(xtest))) for b in range(num_bootstraps): weights = weight_generator.sample() * len(ytrain) bres = train_GP_weighted(xtrain, ytrain, xtest=xtest, return_mean=False, num_samples=samples_per_bootstrap, num_iters=5000, weights=weights, verbose=verbose) if xtest is not None: samples = np.vstack((samples, bres['samples'])) lengthscale.append(bres['lengthscale']) sigma_n.append(bres['sigma_n']) res = {'lengthscale': lengthscale, 'sigma_n': sigma_n} if return_mean: res['mean'] = np.mean(samples, axis=0) res['std'] = np.std(samples, axis=0) if return_samples: res['samples'] = samples return res
def dirichlet_number_generator(self, total_size): m = dirichlet.Dirichlet( torch.Tensor([total_size, total_size, total_size])) return m.sample()
def train_GP_NPL(xtrain, ytrain, x_prior_dist, y_prior_dist, xtest=None, return_mean=True, return_samples=False, num_bootstraps=1, samples_per_bootstrap=10, num_iters=5000, alpha=1., num_pseudo=10, verbose=True): ''' Samples from the NPL xtrain: NxD torch tensor of training covariates ytrain: Nx1 torch tensor of training targets x_prior_dist: torch prior distribution over covariates y_prior_dist: torch prior distribution over targets xtest (optional): MxD torch tensor of test covariates return_mean: Boolean, if true and xtest is not None, includes mean and standard deviation of predictions at xtest samples_per_bootstrap: Number of samples to generate per bootstrap of predictive distribution (if xtest is not None) return_samples: Boolean, if true and xtest is not None, includes the raw samples in the return object num_bootstraps: Integer, number of bootstrap samples samples_per_bootstrap: integer, number of predictive samples to generate per bootstrap if xtest is not None num_iters: number of iterations alpha: concentration parameter (>0) num_pseudo: number of pseudo-samples returns: Dict with following entries: 'lengthscale': Sampled lengthscales 'sigma_n': Sampled noise standard deviations 'mean': Mean function at xtest (only included if xtest is not None and return_mean=True) 'std': Standard deviation at xtest (only included if xtest is not None and return_mean=True) 'samples': return_samples samples from posterior (only included if xtest is not None and return_samples=True) ''' dirichlet_weight = torch.cat( (torch.ones(len(ytrain)), (alpha / num_pseudo) * torch.ones(num_pseudo)), 0) weight_generator = d.Dirichlet(dirichlet_weight) lengthscale = [] sigma_n = [] if xtest is not None: samples = np.zeros((0, len(xtest))) for b in range(num_bootstraps): weights = weight_generator.sample() * (len(ytrain) + alpha) pseudo_x = x_prior_dist.sample(sample_shape=torch.Size([num_pseudo])) pseudo_y = y_prior_dist.sample(sample_shape=torch.Size([num_pseudo])) both_x = torch.cat((xtrain, pseudo_x), 0) both_y = torch.cat((ytrain, pseudo_y), 0) bres = train_GP_weighted(both_x, both_y, xtest=xtest, return_mean=False, num_samples=samples_per_bootstrap, num_iters=5000, weights=weights, verbose=verbose) if xtest is not None: samples = np.vstack((samples, bres['samples'])) lengthscale.append(bres['lengthscale']) sigma_n.append(bres['sigma_n']) res = {'lengthscale': lengthscale, 'sigma_n': sigma_n} if return_mean: res['mean'] = np.mean(samples, axis=0) res['std'] = np.std(samples, axis=0) if return_samples: res['samples'] = samples return res
def create_generic_dir_stat(self): self.alpha0 = torch.sum(self.target_alpha) self.generic_mean = self.target_alpha / self.alpha0 self.generic_dir = d.Dirichlet(self.target_alpha) self.creat_covariance()
class MCTS: '''Monte Carlo tree searcher. Nodes are selected with the PUCT variant used in AlphaGo. First rollout the tree then choose. To input a move use make_move(..) and set_root(..) args: root (Go_MCTS): node representing current game state policy_net (PolicyNet): PolicyNet for getting prior distributions value_net (ValueNet): for getting board value (between -1 and 1) If no valuenet is given, rewards are based on simulations only no_sim (bool): disable simulations and evaluate only with value net kwargs: expand_thresh (int): number of visits before leaf is expanded (default 100) branch_num (int): number of children to expand. If not specified, all legal moves expanded exploration_weight (float): scalar for prior prediction (default 4.0) value_net_weight (float): scalar between 0 and 1 for mixing value network and simulation rewards (default 0.5) noise_weight (float): scalar between 0 and 1 for adding Dirichlet noise (default 0.25) Attributes: Q: dict containing total simulation rewards of each node N: dict containing total visits to each node V: dict containing accumulated value of each node children: dict containing children of each node ''' _dirichlet = dirichlet.Dirichlet(0.1 * torch.ones(go.N**2)) _val_cache = dict() _dist_cache = dict() _fts_cache = dict() def __init__(self, root, policy_net: PolicyNet = None, value_net: ValueNet = None, **kwargs): self.Q = defaultdict(int) self.N = defaultdict(int) self.V = defaultdict(float) self.children = dict() if policy_net is None: raise TypeError("Missing required keywork argument: 'policy_net'") self.policy_net = policy_net self.value_net = value_net self.no_sim = kwargs.get("no_sim", True) if self.value_net is None and self.no_sim: raise TypeError( "Keyword argument 'value_net' is required for no simulation mode" ) self.expand_thresh = kwargs.get("expand_thresh", 100) self.branch_num = kwargs.get("branch_num") self.exploration_weight = kwargs.get("exploration_weight", 4.0) self.noise_weight = kwargs.get("noise_weight", 0) if self.no_sim: self.value_net_weight = 1.0 elif self.value_net is None: self.value_net_weight = 0.0 else: self.value_net_weight = kwargs.get("value_net_weight", 0.5) #for GPU computations self.device = kwargs.get("device", torch.device("cpu")) policy_net.to(self.device) if value_net != None: value_net.to(self.device) #initialize the root self.set_root(root) def __deepcopy__(self, memo): cls = self.__class__ new_tree = cls.__new__(cls) new_tree.__dict__.update(self.__dict__) new_tree.root = deepcopy(self.root) new_tree.V = deepcopy(self.V) new_tree.Q = deepcopy(self.Q) new_tree.N = deepcopy(self.N) new_tree.children = deepcopy(self.children) return new_tree #For pickling def __getstate__(self): state_dict = self.__dict__.copy() del state_dict["policy_net"] del state_dict["value_net"] return state_dict def __setstate__(self, state_dict): self.__dict__.update(state_dict) #give the nodes a reference to the tree for n in self.children: n.tree = self for c in self.children[n]: c.tree = self #set the policy net and value net manually self.policy_net = None self.value_net = None def choose(self, node=None): '''Choose the best child of root and set it as the new root optional: node: choose from different node (doesn't affect root)''' if node is None: node = self.root if node._terminal: #print(f"{node} Board is terminal") return node if node not in self.children: return node.find_random_child() def score(n): if self.N[n] == 0: return float("-inf") # avoid unseen moves return self.N[n] # Choose most visited node best = max(self.children[node], key=score) if node == self.root: self.set_root(best) return best def rollout(self, n=1, analyze_dict=None): '''Do rollouts from the root args: n (int): number of rollouts analyze_dict: (optional) dict to store variations ''' for _ in range(n): # Get path to leaf of current search tree path = self._descend() leaf = path[-1] if analyze_dict != None and len(path) > 2: analyze_dict[path[1]] = path[1:] if not self.no_sim: score = self._simulate(leaf, gnu=True) else: score = None self._backpropagate(path, score, leaf.value) def set_root(self, node): self.root = node self.root.tree = self self.root._add_noise(self.noise_weight) self._expand(self.root) def winrate(self, node=None): '''Returns float between 0.0 and 1.0 representing winrate from persepctive of the root optional: node: return winrate of a different node''' w = self.value_net_weight if node is None: node = self.root if self.N[node] > 0: v = ((1 - w) * self.Q[node] + w * self.V[node]) / self.N[node] return (v + 1) / 2 return 0 def _descend(self): "Return a path from root down to leaf via PUCT selection" path = [self.root] node = self.root while True: # Is node a leaf? if node not in self.children or not self.children[node]: if self.N[node] > self.expand_thresh: self._expand(node) return path node = self._puct_select(node) # descend a layer deeper path.append(node) def _expand(self, node): "Update the `children` dict with the children of `node`" if node in self.children: return # already expanded if self.branch_num: self.children[node] = node.find_children(k=self.branch_num) else: self.children[node] = node.find_children() # Need to make this faster (ideally at least 10x) def _simulate(self, node, gnu=False): '''Returns the reward for a random simulation (to completion) of node optional: gnu: if True, score with gnugo (default False)''' invert_reward = not (node.turn % 2 == 0) #invert if it is white's turn while True: if node._terminal: reward = node.reward(gnu) if invert_reward: reward = -reward return reward node = node.find_random_child() def _backpropagate(self, path, reward, leaf_val): '''Send the reward back up to the ancestors of the leaf''' for node in reversed(path): self.N[node] += 1 if reward: self.Q[node] += reward reward = -reward if self.value_net != None: self.V[node] += leaf_val leaf_val = -leaf_val def _puct_select(self, node): "Select a child of node with PUCT" total_visits = sum(self.N[n] for n in self.children[node]) # First visit selects policy's top choice if total_visits == 0: total_visits = 1 def puct(n): last_move_prob = node.dist.probs[n.last_move].item() avg_reward = 0 if self.N[n] == 0 else \ ((1 - self.value_net_weight) * self.Q[n] + self.value_net_weight * self.V[n]) / self.N[n] return -avg_reward + (self.exploration_weight * last_move_prob * sqrt(total_visits) / (1 + self.N[n])) return max(self.children[node], key=puct) def _prune(self): '''Prune the tree leaving only root and its descendants''' new_children = defaultdict(int) q = [self.root] while q: n = q.pop() c = self.children.get(n) if c: new_children[n] = c q.extend(c) self.children = new_children remove_me = set() for n in self.N: if n not in new_children: remove_me.add(n) for n in remove_me: del self.Q[n] del self.N[n] if n in self.V: del self.V[n]