def cell(self, c, h, xi, xf, xo, xu, mask=None): hps = self.hps assert hps.isteps >= 2, "multiply and add steps of mLSTM require 2 internal steps" ''' for step in range(hps.isteps): # we can share one set of params for all isteps p = "h%d" % (0 if hps.share_isteps else step) if step == 0: h = self.linear(p, h) if hps.sproj_add is None: h = ew.multiply(h, m) else: h = hps.sproj_add.scatter_mul(h, m) elif step == 1: h = self.linear(p, h) if hps.sproj_mul is None: h = ew.add(h, a) else: h = hps.sproj_mul.scatter_add(h, a) h = ew.relu(h) else: h = self.linear(p, h, relu=True) ''' i = self.linear("hi", h) f = self.linear("hf", h) o = self.linear("ho", h) u = self.linear("hu", h) # apply update dropout, saving mask if we need to recompute forward pass if self.train and hps.dropout > 0: if mask is None: u, mask = ew.dropout(u, keep_prob=1.0-hps.dropout) else: u = ew.dropout(u, mask=mask) else: mask = None if xi is not None: i = ew.add(i, xi) f = ew.add(f, xf) o = ew.add(o, xo) u = ew.add(u, xu) c, h = ew.fused_lstm_gates(c, i, f, o, u) return c, h, mask
def cell(self, h, m, a): hps = self.hps assert hps.isteps >= 2, "multiply and add steps of mLSTM require 2 internal steps" for step in range(hps.isteps): # we can share one set of params for all isteps p = "h%d" % (0 if hps.share_isteps else step) if step == 0: h = self.linear(p, h) if hps.sproj_add is None: h = ew.multiply(h, m) else: h = hps.sproj_add.scatter_mul(h, m) elif step == 1: h = self.linear(p, h) if hps.sproj_mul is None: h = ew.add(h, a) else: h = hps.sproj_mul.scatter_add(h, a) h = ew.relu(h) else: h = self.linear(p, h, relu=True) return h
def backward(self, grad_ys): hps = self.hps w_params = [] g_params = [] b_params = [] for p in self.param_names: g, b, w = self.params[p][1:4] w_params.append(w) g_params.append(g) b_params.append(b) params = w_params + g_params + b_params nparams = len(params) nsegments = len(self.segments) # memory efficient gradients by recomputing forward pass if nsegments > 0: param_steps = [] input_grads = [] for i in range(nsegments - 1, -1, -1): with tf.name_scope("b_seg_%04d" % i): h_grads = grad_ys[i * hps.recompute:(i + 1) * hps.recompute] if i == nsegments - 1: c_grad = tf.zeros(h_grads[0].get_shape()) else: fc_matmul_op = get_parents(h_grads[0], "BlocksparseMatmulDX")[0] # delay matmul to avoid memory expansion till just prior to use add_control_input(fc_matmul_op, h_grad.op) h_grads[-1] = ew.add(h_grads[-1], h_grad) s = self.segments[i] x = self.inputs.pop() c_prev, h_prev = s[0] c_next = s[-1][0] h_next = [seg[1] for seg in s[1:]] # ensure the forward segments are computed in the backward pass only. add_control_input(c_prev.op, h_grads[-1].op) add_control_input(h_prev.op, h_grads[-1].op) grads = tf.gradients([c_next] + h_next, params + [c_prev, h_prev, x], [c_grad] + h_grads, aggregation_method=agg_method) param_steps.append(grads[0:nparams]) c_grad = grads[nparams + 0] h_grad = grads[nparams + 1] input_grads.insert(0, grads[nparams + 2]) param_grads = [] for i in range(nparams): param_grads.append(tf.add_n([g[i] for g in param_steps])) # Normal gradients for small models else: grads = tf.gradients(self.outputs, params + self.inputs, grad_ys, aggregation_method=agg_method) param_grads = grads[0:nparams] input_grads = grads[nparams:] # group param grad matmuls to efficinetly accumulate for i, p in enumerate(self.param_names): # a and m are already grouped if p not in 'am': param_grads[i] = group_param_grads(param_grads[i]) grads = list(zip(param_grads, params)) return grads, input_grads