def __call__(self,M,*inputs): assert len(inputs) == len(self.Wizs) n = M.shape[0] summands = [Xi.dot(Wiz) for (Xi,Wiz) in zip(inputs,self.Wizs)] + [M.dot(self.Wmz),cgt.repeat(self.bz,n, axis=0)] z = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wir) for (Xi,Wir) in zip(inputs,self.Wirs)] + [M.dot(self.Wmr),cgt.repeat(self.br,n, axis=0)] r = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wim) for (Xi,Wim) in zip(inputs,self.Wims)] + [(r*M).dot(self.Wmm),cgt.repeat(self.bm,n, axis=0)] Mtarg = cgt.tanh(cgt.add_multi(summands)) #pylint: disable=E1111 Mnew = (1-z)*M + z*Mtarg return Mnew
def __call__(self, M, *inputs): assert len(inputs) == len(self.Wizs) n = M.shape[0] summands = [Xi.dot(Wiz) for (Xi, Wiz) in zip(inputs, self.Wizs)] + [ M.dot(self.Wmz), cgt.repeat(self.bz, n, axis=0) ] z = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wir) for (Xi, Wir) in zip(inputs, self.Wirs)] + [ M.dot(self.Wmr), cgt.repeat(self.br, n, axis=0) ] r = cgt.sigmoid(cgt.add_multi(summands)) summands = [Xi.dot(Wim) for (Xi, Wim) in zip(inputs, self.Wims) ] + [(r * M).dot(self.Wmm), cgt.repeat(self.bm, n, axis=0)] Mtarg = cgt.tanh(cgt.add_multi(summands)) #pylint: disable=E1111 Mnew = (1 - z) * M + z * Mtarg return Mnew
def make_funcs(config, dbg_out=None): params, Xs, Ys, C_0, H_0, C_T, H_T, C_1, H_1 = lstm_network( config['rnn_steps'], config['num_inputs'], config['num_outputs'], config['num_units'], config['num_mems']) # basic size_batch = Xs[0].shape[0] dY = Ys[0].shape[-1] Ys_gt = [ cgt.matrix(fixed_shape=(size_batch, dY), name='Y%d' % t) for t in range(len(Ys)) ] Ys_var = [cgt.tensor3(fixed_shape=(size_batch, dY, dY)) for _ in Ys] net_inputs, net_outputs = Xs + C_0 + H_0 + Ys_var, Ys + C_T + H_T # calculate loss loss_vec = [] for i in range(len(Ys)): # if i == 0: continue _l = dist.gaussian.logprob(Ys_gt[i], Ys[i], Ys_var[i]) loss_vec.append(_l) loss_vec = cgt.add_multi(loss_vec) if config['weight_decay'] > 0.: params_flat = cgt.concatenate([p.flatten() for p in params]) loss_param = config['weight_decay'] * cgt.sum(params_flat**2) loss_vec -= loss_param # / size_batch loss = cgt.sum(loss_vec) / config['rnn_steps'] / size_batch grad = cgt.grad(loss, params) # functions def f_init(size_batch): c_0, h_0 = [], [] for _n_m in config['num_mems']: if _n_m > 0: c_0.append(np.zeros((size_batch, _n_m))) h_0.append(np.zeros((size_batch, _n_m))) return c_0, h_0 f_step = cgt.function([Xs[0]] + C_0 + H_0, [Ys[0]] + C_1 + H_1) f_loss = cgt.function(net_inputs + Ys_gt, loss) f_grad = cgt.function(net_inputs + Ys_gt, grad) f_surr = cgt.function(net_inputs + Ys_gt, [loss] + net_outputs + grad) return params, f_step, f_loss, f_grad, f_init, f_surr
def make_funcs(config, dbg_out=None): params, Xs, Ys, C_0, H_0, C_T, H_T, C_1, H_1 = lstm_network( config['rnn_steps'], config['num_inputs'], config['num_outputs'], config['num_units'], config['num_mems'] ) # basic size_batch = Xs[0].shape[0] dY = Ys[0].shape[-1] Ys_gt = [cgt.matrix(fixed_shape=(size_batch, dY), name='Y%d'%t) for t in range(len(Ys))] Ys_var = [cgt.tensor3(fixed_shape=(size_batch, dY, dY)) for _ in Ys] net_inputs, net_outputs = Xs + C_0 + H_0 + Ys_var, Ys + C_T + H_T # calculate loss loss_vec = [] for i in range(len(Ys)): # if i == 0: continue _l = dist.gaussian.logprob(Ys_gt[i], Ys[i], Ys_var[i]) loss_vec.append(_l) loss_vec = cgt.add_multi(loss_vec) if config['weight_decay'] > 0.: params_flat = cgt.concatenate([p.flatten() for p in params]) loss_param = config['weight_decay'] * cgt.sum(params_flat ** 2) loss_vec -= loss_param # / size_batch loss = cgt.sum(loss_vec) / config['rnn_steps'] / size_batch grad = cgt.grad(loss, params) # functions def f_init(size_batch): c_0, h_0 = [], [] for _n_m in config['num_mems']: if _n_m > 0: c_0.append(np.zeros((size_batch, _n_m))) h_0.append(np.zeros((size_batch, _n_m))) return c_0, h_0 f_step = cgt.function([Xs[0]] + C_0 + H_0, [Ys[0]] + C_1 + H_1) f_loss = cgt.function(net_inputs + Ys_gt, loss) f_grad = cgt.function(net_inputs + Ys_gt, grad) f_surr = cgt.function(net_inputs + Ys_gt, [loss] + net_outputs + grad) return params, f_step, f_loss, f_grad, f_init, f_surr