def make_ntm_initial_states(opt): n, m, h, b = opt.n, opt.m, opt.h, opt.b M_1nm = cgt.shared(.1*nr.randn(1,n,m)) winit_1Hn = cgt.shared(.1*nr.rand(1,2*h,n)) winit_1Hn = sum_normalize2(cgt.exp(winit_1Hn)) rinit_1hm = cgt.shared(np.zeros((1,h,m))) return [cgt.repeat(arr, b, axis=0) for arr in (M_1nm, winit_1Hn, rinit_1hm)]
def __init__(self, obs_dim, ctrl_dim): cgt.set_precision('double') Serializable.__init__(self, obs_dim, ctrl_dim) self.obs_dim = obs_dim self.ctrl_dim = ctrl_dim o_no = cgt.matrix("o_no",fixed_shape=(None,obs_dim)) a_na = cgt.matrix("a_na",fixed_shape = (None, ctrl_dim)) adv_n = cgt.vector("adv_n") oldpdist_np = cgt.matrix("oldpdist", fixed_shape=(None, 2*ctrl_dim)) self.logstd = logstd_1a = nn.parameter(np.zeros((1, self.ctrl_dim)), name="std_1a") std_1a = cgt.exp(logstd_1a) # Here's where we apply the network h0 = o_no nhid = 32 h1 = cgt.tanh(nn.Affine(obs_dim,nhid,weight_init=nn.IIDGaussian(std=0.1))(h0)) h2 = cgt.tanh(nn.Affine(nhid,nhid,weight_init=nn.IIDGaussian(std=0.1))(h1)) mean_na = nn.Affine(nhid,ctrl_dim,weight_init=nn.IIDGaussian(std=0.01))(h2) b = cgt.size(o_no, 0) std_na = cgt.repeat(std_1a, b, axis=0) oldmean_na = oldpdist_np[:, 0:self.ctrl_dim] oldstd_na = oldpdist_np[:, self.ctrl_dim:2*self.ctrl_dim] logp_n = ((-.5) * cgt.square( (a_na - mean_na) / std_na ).sum(axis=1)) - logstd_1a.sum() oldlogp_n = ((-.5) * cgt.square( (a_na - oldmean_na) / oldstd_na ).sum(axis=1)) - cgt.log(oldstd_na).sum(axis=1) ratio_n = cgt.exp(logp_n - oldlogp_n) surr = (ratio_n*adv_n).mean() pdists_np = cgt.concatenate([mean_na, std_na], axis=1) # kl = cgt.log(sigafter/) params = nn.get_parameters(surr) oldvar_na = cgt.square(oldstd_na) var_na = cgt.square(std_na) kl = (cgt.log(std_na / oldstd_na) + (oldvar_na + cgt.square(oldmean_na - mean_na)) / (2 * var_na) - .5).sum(axis=1).mean() lam = cgt.scalar() penobj = surr - lam * kl self._compute_surr_kl = cgt.function([oldpdist_np, o_no, a_na, adv_n], [surr, kl]) self._compute_grad_lagrangian = cgt.function([lam, oldpdist_np, o_no, a_na, adv_n], cgt.concatenate([p.flatten() for p in cgt.grad(penobj,params)])) self.f_pdist = cgt.function([o_no], pdists_np) self.f_objs = cgt.function([oldpdist_np, o_no, a_na, adv_n], [surr, kl]) self.pc = ParamCollection(params)
def __call__(self, M, *inputs): assert len(inputs) == len(self.Wizs) n = M.shape[0] summands = [Xi.dot(Wiz) for (Xi, Wiz) in zip(inputs, self.Wizs)] + [ M.dot(self.Wmz), cgt.repeat(self.bz, n, axis=0) ] z = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wir) for (Xi, Wir) in zip(inputs, self.Wirs)] + [ M.dot(self.Wmr), cgt.repeat(self.br, n, axis=0) ] r = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wim) for (Xi, Wim) in zip(inputs, self.Wims) ] + [(r * M).dot(self.Wmm), cgt.repeat(self.bm, n, axis=0)] Mtarg = cgt.tanh(cgt.add_multi(summands)) #pylint: disable=E1111 Mnew = (1 - z) * M + z * Mtarg return Mnew
def tile(x, reps): out = x for i, nrep in enumerate(reps): if nrep > 1: out = cgt.repeat(out, nrep, axis=i) return out
def __init__(self, obs_dim, ctrl_dim): cgt.set_precision('double') Serializable.__init__(self, obs_dim, ctrl_dim) self.obs_dim = obs_dim self.ctrl_dim = ctrl_dim o_no = cgt.matrix("o_no", fixed_shape=(None, obs_dim)) a_na = cgt.matrix("a_na", fixed_shape=(None, ctrl_dim)) adv_n = cgt.vector("adv_n") oldpdist_np = cgt.matrix("oldpdist", fixed_shape=(None, 2 * ctrl_dim)) self.logstd = logstd_1a = nn.parameter(np.zeros((1, self.ctrl_dim)), name="std_1a") std_1a = cgt.exp(logstd_1a) # Here's where we apply the network h0 = o_no nhid = 32 h1 = cgt.tanh( nn.Affine(obs_dim, nhid, weight_init=nn.IIDGaussian(std=0.1))(h0)) h2 = cgt.tanh( nn.Affine(nhid, nhid, weight_init=nn.IIDGaussian(std=0.1))(h1)) mean_na = nn.Affine(nhid, ctrl_dim, weight_init=nn.IIDGaussian(std=0.01))(h2) b = cgt.size(o_no, 0) std_na = cgt.repeat(std_1a, b, axis=0) oldmean_na = oldpdist_np[:, 0:self.ctrl_dim] oldstd_na = oldpdist_np[:, self.ctrl_dim:2 * self.ctrl_dim] logp_n = ((-.5) * cgt.square( (a_na - mean_na) / std_na).sum(axis=1)) - logstd_1a.sum() oldlogp_n = ((-.5) * cgt.square( (a_na - oldmean_na) / oldstd_na).sum(axis=1) ) - cgt.log(oldstd_na).sum(axis=1) ratio_n = cgt.exp(logp_n - oldlogp_n) surr = (ratio_n * adv_n).mean() pdists_np = cgt.concatenate([mean_na, std_na], axis=1) # kl = cgt.log(sigafter/) params = nn.get_parameters(surr) oldvar_na = cgt.square(oldstd_na) var_na = cgt.square(std_na) kl = (cgt.log(std_na / oldstd_na) + (oldvar_na + cgt.square(oldmean_na - mean_na)) / (2 * var_na) - .5).sum(axis=1).mean() lam = cgt.scalar() penobj = surr - lam * kl self._compute_surr_kl = cgt.function([oldpdist_np, o_no, a_na, adv_n], [surr, kl]) self._compute_grad_lagrangian = cgt.function( [lam, oldpdist_np, o_no, a_na, adv_n], cgt.concatenate([p.flatten() for p in cgt.grad(penobj, params)])) self.f_pdist = cgt.function([o_no], pdists_np) self.f_objs = cgt.function([oldpdist_np, o_no, a_na, adv_n], [surr, kl]) self.pc = ParamCollection(params)
def __call__(self,M,*inputs): assert len(inputs) == len(self.Wizs) n = M.shape[0] summands = [Xi.dot(Wiz) for (Xi,Wiz) in zip(inputs,self.Wizs)] + [M.dot(self.Wmz),cgt.repeat(self.bz,n, axis=0)] z = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wir) for (Xi,Wir) in zip(inputs,self.Wirs)] + [M.dot(self.Wmr),cgt.repeat(self.br,n, axis=0)] r = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wim) for (Xi,Wim) in zip(inputs,self.Wims)] + [(r*M).dot(self.Wmm),cgt.repeat(self.bm,n, axis=0)] Mtarg = cgt.tanh(cgt.add_multi(summands)) #pylint: disable=E1111 Mnew = (1-z)*M + z*Mtarg return Mnew