def get_action(self, belief, state, det=False, scale=None): action_mean, action_std = self.forward(belief, state) if scale: #exploration distribution dist = Normal(action_mean, action_std + action_std.detach() * (1 - scale)) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) action = dist.mode() if det else dist.rsample() proposal_loglike = dist.log_prob(action).detach() #true distribution dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) policy_loglike = dist.log_prob(action) return action, policy_loglike, proposal_loglike else: dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) action = dist.mode() if det else dist.rsample() return action
def __init__(self, data_dim=28 * 28, device='cpu'): self.m = TransformedDistribution( Uniform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)), [ SigmoidTransform().inv, AffineTransform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)) ])
def forward(self, belief, state, deterministic=False, with_logprob=False,): raw_init_std = np.log(np.exp(self.init_std) - 1) hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=-1))) hidden = self.act_fn(self.fc2(hidden)) hidden = self.act_fn(self.fc3(hidden)) hidden = self.act_fn(self.fc4(hidden)) hidden = self.fc5(hidden) mean, std = torch.chunk(hidden, 2, dim=-1) # # --------- # mean = self.mean_scale * torch.tanh(mean / self.mean_scale) # bound the action to [-5, 5] --> to avoid numerical instabilities. For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions. # speed = torch.full(mean.shape, 0.3).to("cuda") # mean = torch.cat((mean, speed), -1) # # std = F.softplus(std + raw_init_std) + self.min_std # # speed = torch.full(std.shape, 0.0).to("cuda") # std = torch.cat((std, speed), -1) # # dist = torch.distributions.Normal(mean, std) # transform = [torch.distributions.transforms.TanhTransform()] # dist = torch.distributions.TransformedDistribution(dist, transform) # dist = torch.distributions.independent.Independent(dist, 1) # Introduces dependence between actions dimension # dist = SampleDist(dist) # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it. # return dist # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick mean = self.mean_scale * torch.tanh(mean / self.mean_scale) # bound the action to [-5, 5] --> to avoid numerical instabilities. For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions. std = F.softplus(std + raw_init_std) + self.min_std dist = torch.distributions.Normal(mean, std) # TanhTransform = ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) if self.fix_speed: transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)] else: transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.), # TanhTransform AffineTransform(loc=torch.tensor([0.0, self.throtlle_base]).to("cuda"), scale=torch.tensor([1.0, 0.2]).to("cuda"))] # TODO: this is limited at donkeycar env dist = TransformedDistribution(dist, transform) # dist = torch.distributions.independent.Independent(dist, 1) # Introduces dependence between actions dimension dist = SampleDist(dist) # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it. if deterministic: action = dist.mean else: action = dist.rsample() # not use logprob now if with_logprob: logp_pi = dist.log_prob(action).sum(dim=1) else: logp_pi = None # action dim: [batch, act_dim], log_pi dim:[batch] return action if not self.fix_speed else torch.cat((action, self.throtlle_base*torch.ones_like(action, requires_grad=False)), dim=-1), logp_pi # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick
def cdf(self, x): base_distribution = Uniform(0, 1) transforms = [ SigmoidTransform().inv, AffineTransform(loc=self.loc, scale=self.scale, event_dim=1) ] logistic = TransformedDistribution(base_distribution, transforms) x = x.unsqueeze(-1).expand(-1, -1, -1, self.loc.size(-1)) cdfs = logistic.cdf(x) return torch.sum(F.softmax(self.categorical_logits, dim=-1) * cdfs, dim=-1)
class StandardLogisticDistribution: def __init__(self, data_dim=28 * 28, device='cpu'): self.m = TransformedDistribution( Uniform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)), [ SigmoidTransform().inv, AffineTransform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)) ]) def log_pdf(self, z): return self.m.log_prob(z).sum(dim=1) def sample(self): return self.m.sample()
def get_action(self, belief, state, det=False): actor_out = self.forward(belief, state) if self._dist == 'tanh_normal': # actor_out.size() == (N x (action_size * 2)) # replace the below workaround raw_init_std = np.log(np.exp(self._init_std) - 1) # tmp = torch.tensor(self._init_std, # device=actor_out.get_device()) # raw_init_std = torch.log(torch.exp(tmp) - 1) action_mean, action_std_dev = torch.chunk(actor_out, 2, dim=1) action_mean = self._mean_scale * torch.tanh( action_mean / self._mean_scale) action_std = F.softplus(action_std_dev + raw_init_std) + self._min_std dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) elif self._dist == 'onehot': # actor_out.size() == (N x action_size) # fix for RuntimeError: CUDA error: device-side assert triggered actor_out = (torch.tanh(actor_out) + 1.0) * 0.5 dist = Categorical(logits=actor_out) dist = OneHotDist(dist) else: raise NotImplementedError(self._dist) if det: return dist.mode() else: return dist.sample()
def forward(self, mean, log_std, deterministic=False): log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) std = torch.exp(log_std) action_distribution = TransformedDistribution( Normal(mean, std), TanhTransform(cache_size=1)) if deterministic: action_sample = torch.tanh(mean) else: action_sample = action_distribution.rsample() log_prob = torch.sum(action_distribution.log_prob(action_sample), dim=1) return action_sample, log_prob
def policy(self, x): x = ActorNetwork.forward(self, x) mean = self._mean_layer(x) log_std = self._std_layer(x).clamp(self._log_std_min, self._log_std_max) std = torch.diag_embed(log_std.exp()) base_distribution = MultivariateNormal(mean, std) return TransformedDistribution(base_distribution, self._transforms)
def get_action(self, belief, state, det=False): action_mean, action_std = self.forward(belief, state) dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) if det: return dist.mode() else: return dist.rsample()
def __init__(self, dim, act=nn.ReLU(), num_hiddens=[50], nout=1, conf=dict()): nn.Module.__init__(self) BNN.__init__(self) self.dim = dim self.act = act self.num_hiddens = num_hiddens self.nout = nout self.steps_burnin = conf.get('steps_burnin', 2500) self.steps = conf.get('steps', 2500) self.keep_every = conf.get('keep_every', 50) self.batch_size = conf.get('batch_size', 32) self.warm_start = conf.get('warm_start', False) self.lr_weight = np.float32(conf.get('lr_weight', 1e-3)) self.lr_noise = np.float32(conf.get('lr_noise', 1e-3)) self.lr_lambda = np.float32(conf.get('lr_lambda', 1e-3)) self.alpha_w = torch.as_tensor(1. * conf.get('alpha_w', 6.)) self.beta_w = torch.as_tensor(1. * conf.get('beta_w', 6.)) self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 6.)) self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 6.)) self.noise_level = conf.get('noise_level', None) if self.noise_level is not None: prec = 1 / self.noise_level**2 prec_var = (prec * 0.25)**2 self.beta_n = torch.as_tensor(prec / prec_var) self.alpha_n = torch.as_tensor(prec * self.beta_n) print("Reset alpha_n = %g, beta_n = %g" % (self.alpha_n, self.beta_n)) self.prior_log_lambda = TransformedDistribution( Gamma(self.alpha_w, self.beta_w), ExpTransform().inv) # log of gamma distribution self.prior_log_precision = TransformedDistribution( Gamma(self.alpha_n, self.beta_n), ExpTransform().inv) self.log_lambda = nn.Parameter(torch.tensor(0.)) self.log_precs = nn.Parameter(torch.zeros(self.nout)) self.nn = NN(dim, self.act, self.num_hiddens, self.nout) self.init_nn()
def train_moons(model, optimizer, n_epochs=10001, base_distr="normal", d=2, device=None, plot_val=True, plot_interval=1000, input_grad=False): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if base_distr == "normal": distr = torch.distributions.multivariate_normal.MultivariateNormal( torch.zeros(d, device=device), torch.eye(d, device=device)) elif base_distr == "logistic": distr = TransformedDistribution( Uniform(torch.zeros(d, device=device), torch.ones(d, device=device)), SigmoidTransform().inv) else: raise ValueError("wrong base distribution") train_loss = [] pbar = trange(n_epochs) for i in pbar: #range(n_epochs): x, y = datasets.make_moons(128, noise=.1) x = torch.tensor(x, dtype=torch.float32, requires_grad=input_grad).to(device) model.train() z, log_det = model(x) l = loss(z[-1], log_det, distr, base_distr) l.backward() optimizer.step() optimizer.zero_grad() train_loss.append(l.item()) if i % 100 == 0: pbar.set_postfix_str(f"loss = {train_loss[-1]:.3f}") if plot_val and i % plot_interval == 0: print(i, train_loss[-1]) if input_grad: val_moons_grad(model, distr, i, device, base_distr) else: val_moons(model, distr, i, device, base_distr) return train_loss
def __init__(self, dim, act=nn.ReLU(), num_hiddens=[50], nout=1, conf=dict()): nn.Module.__init__(self) BNN.__init__(self) self.dim = dim self.act = act self.num_hiddens = num_hiddens self.nout = nout self.steps_burnin = conf.get('steps_burnin', 2500) self.steps = conf.get('steps', 2500) self.keep_every = conf.get('keep_every', 50) self.batch_size = conf.get('batch_size', 32) self.warm_start = conf.get('warm_start', False) self.lr_weight = conf.get('lr_weight', 2e-2) self.lr_noise = conf.get('lr_noise', 1e-1) self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 1e-2)) self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 1e-1)) # user can specify a suggested noise value, this will override alpha_n and beta_n self.noise_level = conf.get('noise_level', None) if self.noise_level is not None: prec = 1 / self.noise_level**2 prec_var = (prec * 0.25)**2 self.beta_n = torch.as_tensor(prec / prec_var) self.alpha_n = torch.as_tensor(prec * self.beta_n) print("Reset alpha_n = %g, beta_n = %g" % (self.alpha_n, self.beta_n)) self.prior_log_precision = TransformedDistribution( Gamma(self.alpha_n, self.beta_n), ExpTransform().inv) self.log_precs = nn.Parameter(torch.zeros(self.nout)) self.nn = NN(dim, self.act, self.num_hiddens, self.nout) self.gain = 5. / 3 # Assume tanh activation self.init_nn()
def test_logistic(): base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=torch.tensor([2.]), scale=torch.tensor([1.]))] model = TransformedDistribution(base_distribution, transforms) transform = Logistic(2., 1.) x = model.sample((4,)).reshape(-1, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) x = transform.sample(4) assert x.shape == (4, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) x = transform.sample(1) assert x.shape == (1, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) transform.get_parameters()
class BNN_SGDMC(nn.Module, BNN): def __init__(self, dim, act=nn.ReLU(), num_hiddens=[50], nout=1, conf=dict()): nn.Module.__init__(self) BNN.__init__(self) self.dim = dim self.act = act self.num_hiddens = num_hiddens self.nout = nout self.steps_burnin = conf.get('steps_burnin', 2500) self.steps = conf.get('steps', 2500) self.keep_every = conf.get('keep_every', 50) self.batch_size = conf.get('batch_size', 32) self.warm_start = conf.get('warm_start', False) self.lr_weight = conf.get('lr_weight', 2e-2) self.lr_noise = conf.get('lr_noise', 1e-1) self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 1e-2)) self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 1e-1)) # user can specify a suggested noise value, this will override alpha_n and beta_n self.noise_level = conf.get('noise_level', None) if self.noise_level is not None: prec = 1 / self.noise_level**2 prec_var = (prec * 0.25)**2 self.beta_n = torch.as_tensor(prec / prec_var) self.alpha_n = torch.as_tensor(prec * self.beta_n) print("Reset alpha_n = %g, beta_n = %g" % (self.alpha_n, self.beta_n)) self.prior_log_precision = TransformedDistribution( Gamma(self.alpha_n, self.beta_n), ExpTransform().inv) self.log_precs = nn.Parameter(torch.zeros(self.nout)) self.nn = NN(dim, self.act, self.num_hiddens, self.nout) self.gain = 5. / 3 # Assume tanh activation self.init_nn() def init_nn(self): self.log_precs.data = (self.alpha_n / self.beta_n).log() * torch.ones( self.nout) for l in self.nn.nn: if type(l) == nn.Linear: nn.init.xavier_uniform_(l.weight, gain=self.gain) def log_prior(self): log_p = self.prior_log_precision.log_prob(self.log_precs).sum() for n, p in self.nn.nn.named_parameters(): if "weight" in n: std = self.gain * np.sqrt(2. / (p.shape[0] + p.shape[1])) log_p += torch.distributions.Normal(0, std).log_prob(p).sum() return log_p def log_lik(self, X, y): y = y.view(-1, self.nout) nout = self.nn(X).view(-1, self.nout) precs = self.log_precs.exp() log_lik = -0.5 * precs * ( y - nout)**2 + 0.5 * self.log_precs - 0.5 * np.log(2 * np.pi) return log_lik.sum() def sgld_steps(self, num_steps, num_train): step_cnt = 0 loss = 0. while (step_cnt < num_steps): for bx, by in self.loader: log_prior = self.log_prior() log_lik = self.log_lik(bx, by) loss = -1 * (log_lik * (num_train / bx.shape[0]) + log_prior) self.opt.zero_grad() loss.backward() self.opt.step() self.scheduler.step() step_cnt += 1 if step_cnt >= num_steps: break return loss def train(self, X, y): y = y.view(-1, self.nout) num_train = X.shape[0] params = [{ 'params': self.nn.nn.parameters(), 'lr': self.lr_weight }, { 'params': self.log_precs, 'lr': self.lr_noise }] self.opt = pSGLD(params) self.scheduler = optim.lr_scheduler.LambdaLR( self.opt, lambda iter: np.float32((1 + iter)**-0.33)) # XXX: learning rate scheduler, as suggested in Teh, Yee Whye, # Alexandre H. Thiery, and Sebastian J. Vollmer. "Consistency and # fluctuations for stochastic gradient Langevin dynamics." The Journal # of Machine Learning Research 17.1 (2016): 193-225. # XXX: I'm not sure if this scheduler is still optimal for preconditioned SGLD self.loader = DataLoader(TensorDataset(X, y), batch_size=self.batch_size, shuffle=True) step_cnt = 0 self.nns = [] self.lrs = [] if not self.warm_start: self.init_nn() _ = self.sgld_steps(self.steps_burnin, num_train) # burn-in while (step_cnt < self.steps): loss = self.sgld_steps(self.keep_every, num_train) step_cnt += self.keep_every prec = self.log_precs.exp().mean() print('Step %4d, loss = %8.2f, precision = %g' % (step_cnt, loss, prec), flush=True) self.nns.append(deepcopy(self.nn)) print('Number of samples: %d' % len(self.nns)) def sample(self, num_samples=1): assert (num_samples <= len(self.nns)) return np.random.permutation(self.nns)[:num_samples] def sample_predict(self, nns, input): num_samples = len(nns) num_x = input.shape[0] pred = torch.empty(num_samples, num_x, self.nout) for i in range(num_samples): pred[i] = nns[i](input) return pred def report(self): print(self.nn.nn)
class BNN_SGDMC(nn.Module, BNN): def __init__(self, dim, act=nn.ReLU(), num_hiddens=[50], nout=1, conf=dict()): nn.Module.__init__(self) BNN.__init__(self) self.dim = dim self.act = act self.num_hiddens = num_hiddens self.nout = nout self.steps_burnin = conf.get('steps_burnin', 2500) self.steps = conf.get('steps', 2500) self.keep_every = conf.get('keep_every', 50) self.batch_size = conf.get('batch_size', 32) self.warm_start = conf.get('warm_start', False) self.lr_weight = np.float32(conf.get('lr_weight', 1e-3)) self.lr_noise = np.float32(conf.get('lr_noise', 1e-3)) self.lr_lambda = np.float32(conf.get('lr_lambda', 1e-3)) self.alpha_w = torch.as_tensor(1. * conf.get('alpha_w', 6.)) self.beta_w = torch.as_tensor(1. * conf.get('beta_w', 6.)) self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 6.)) self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 6.)) self.noise_level = conf.get('noise_level', None) if self.noise_level is not None: prec = 1 / self.noise_level**2 prec_var = (prec * 0.25)**2 self.beta_n = torch.as_tensor(prec / prec_var) self.alpha_n = torch.as_tensor(prec * self.beta_n) print("Reset alpha_n = %g, beta_n = %g" % (self.alpha_n, self.beta_n)) self.prior_log_lambda = TransformedDistribution( Gamma(self.alpha_w, self.beta_w), ExpTransform().inv) # log of gamma distribution self.prior_log_precision = TransformedDistribution( Gamma(self.alpha_n, self.beta_n), ExpTransform().inv) self.log_lambda = nn.Parameter(torch.tensor(0.)) self.log_precs = nn.Parameter(torch.zeros(self.nout)) self.nn = NN(dim, self.act, self.num_hiddens, self.nout) self.init_nn() def init_nn(self): self.log_lambda.data = self.prior_log_lambda.sample() self.log_precs.data = self.prior_log_precision.sample((self.nout, )) for layer in self.nn.nn: if isinstance(layer, nn.Linear): layer.weight.data = torch.distributions.Normal( 0, 1 / self.log_lambda.exp().sqrt()).sample( layer.weight.shape) layer.bias.data = torch.zeros(layer.bias.shape) def log_prior(self): log_p = self.prior_log_lambda.log_prob(self.log_lambda).sum() log_p += self.prior_log_precision.log_prob(self.log_precs).sum() lambd = self.log_lambda.exp() for n, p in self.nn.nn.named_parameters(): if "weight" in n: log_p += -0.5 * lambd * torch.sum(p**2) + 0.5 * p.numel() * ( self.log_lambda - np.log(2 * np.pi)) return log_p def log_lik(self, X, y): y = y.view(-1, self.nout) nout = self.nn(X).view(-1, self.nout) precs = self.log_precs.exp() log_lik = -0.5 * precs * ( y - nout)**2 + 0.5 * self.log_precs - 0.5 * np.log(2 * np.pi) return log_lik.sum() def sgld_steps(self, num_steps, num_train): step_cnt = 0 loss = 0. while (step_cnt < num_steps): for bx, by in self.loader: log_prior = self.log_prior() log_lik = self.log_lik(bx, by) loss = -1 * (log_lik * (num_train / bx.shape[0]) + log_prior) self.opt.zero_grad() loss.backward() self.opt.step() self.scheduler.step() step_cnt += 1 if step_cnt >= num_steps: break return loss def train(self, X, y): y = y.view(-1, self.nout) num_train = X.shape[0] params = [{ 'params': self.nn.nn.parameters(), 'lr': self.lr_weight }, { 'params': self.log_precs, 'lr': self.lr_noise }, { 'params': self.log_lambda, 'lr': self.lr_lambda }] # self.opt = aSGHMC(params, num_burn_in_steps = self.steps_burnin) # self.scheduler = optim.lr_scheduler.LambdaLR(self.opt, lambda iter : np.float32(1.)) self.opt = pSGLD(params) self.scheduler = optim.lr_scheduler.LambdaLR( self.opt, lambda iter: np.float32((1 + iter)**-0.33)) self.loader = DataLoader(TensorDataset(X, y), batch_size=self.batch_size, shuffle=True) step_cnt = 0 self.nns = [] self.lrs = [] if not self.warm_start: self.init_nn() _ = self.sgld_steps(self.steps_burnin, num_train) # burn-in while (step_cnt < self.steps): loss = self.sgld_steps(self.keep_every, num_train) step_cnt += self.keep_every prec = self.log_precs.exp().mean() wstd = 1 / self.log_lambda.exp().sqrt() print('Step %4d, loss = %8.2f, precision = %g, weight_std = %g' % (step_cnt, loss, prec, wstd), flush=True) self.nns.append(deepcopy(self.nn)) print('Number of samples: %d' % len(self.nns)) def sample(self, num_samples=1): assert (num_samples <= len(self.nns)) return np.random.permutation(self.nns)[:num_samples] def sample_predict(self, nns, input): num_samples = len(nns) num_x = input.shape[0] pred = torch.empty(num_samples, num_x, self.nout) for i in range(num_samples): pred[i] = nns[i](input) return pred def report(self): print(self.nn.nn)
def create_distribution(self, scale, shape, shift): wd = Weibull(scale=scale, concentration=shape) transforms = AffineTransform(loc=shift, scale=1.) weibull = TransformedDistribution(wd, transforms) return weibull
def get_action(self, belief, posterior_state, explore=False, det=False): state = posterior_state B, H, Z = belief.size(0), belief.size(1), state.size(1) actions_l_mean_lists, actions_l_std_lists = self.get_action_sequence( belief, state, B) belief, state = belief.unsqueeze(dim=1).expand( B, self.candidates, H).reshape(-1, H), state.unsqueeze(dim=1).expand( B, self.candidates, Z).reshape(-1, Z) # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) # action_mean, action_std_dev = torch.zeros(self.planning_horizon, B, 1, self.action_size, device=belief.device), torch.ones(self.planning_horizon, B, 1, self.action_size, device=belief.device) action_mean, action_std_dev = None, None for _ in range(self.optimisation_iters): # print("optimization_iters",_) # Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles) if _ == 0: sub_action_list = [] for id in range(len(self.actor_pool)): # a = self.candidates//len(self.actor_pool) action = ( actions_l_mean_lists[id] + actions_l_std_lists[id] * torch.randn(self.top_planning_horizon, B, self.candidates // len(self.actor_pool), self.action_size, device=belief.device) ).view( self.top_planning_horizon, B * self.candidates // len(self.actor_pool), self.action_size ) # Sample actions (time x (batch x candidates) x actions) sub_action_list.append(action) actions = torch.cat(sub_action_list, dim=1) else: actions = (action_mean + action_std_dev * torch.randn( self.top_planning_horizon, B, self.candidates, self.action_size, device=belief.device)).view( self.top_planning_horizon, B * self.candidates, self.action_size ) # Sample actions (time x (batch x candidates) x actions) # Sample next states beliefs, states, _, _ = self.upper_transition_model( state, actions, belief) # if args.MultiGPU: # actions_trans = torch.transpose(actions, 0, 1).cuda() # beliefs, states, _, _ = self.transition_model(state, actions_trans, belief) # beliefs, states = list(map(lambda x: x.view(-1, self.candidates, x.shape[2]), [beliefs, states])) # # else: # beliefs, states, _, _ = self.transition_model(state, actions, belief) # beliefs, states, _, _ = self.transition_model(state, actions, belief)# [12, 1000, 200] [12, 1000, 30] : 12 horizon steps; 1000 candidates # Calculate expected returns (technically sum of rewards over planning horizon) returns = self.reward_model(beliefs.view(-1, H), states.view( -1, Z )).view(self.top_planning_horizon, -1).sum( dim=0) # output from r-model[12000]->view[12, 1000]->sum[1000] # Re-fit belief to the K best action sequencessetting -> Repositories _, topk = returns.reshape(B, self.candidates).topk( self.top_candidates, dim=1, largest=True, sorted=False) topk += self.candidates * torch.arange( 0, B, dtype=torch.int64, device=topk.device).unsqueeze( dim=1) # Fix indices for unrolled actions best_actions = actions[:, topk.view(-1)].reshape( self.top_planning_horizon, B, self.top_candidates, self.action_size) # Update belief with new means and standard deviations action_mean, action_std_dev = best_actions.mean( dim=2, keepdim=True), best_actions.std(dim=2, unbiased=False, keepdim=True) # Return sample action from distribution dist = Normal(action_mean[0].squeeze(dim=1), action_std_dev[0].squeeze(dim=1)) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) if det: tmp = dist.mode() return tmp else: tmp = dist.rsample() return tmp
def wrap_prior_dist(prior_dist, transforms): return TransformedDistribution(prior_dist, transforms)