Пример #1
0
    def cell(self, c, h, xi, xf, xo, xu, mask=None):
        hps = self.hps

        assert hps.isteps >= 2, "multiply and add steps of mLSTM require 2 internal steps"

        '''
        for step in range(hps.isteps):

            # we can share one set of params for all isteps
            p = "h%d" % (0 if hps.share_isteps else step)

            if step == 0:
                h = self.linear(p, h)
                if hps.sproj_add is None:
                    h = ew.multiply(h, m)
                else:
                    h = hps.sproj_add.scatter_mul(h, m)
            elif step == 1:
                h = self.linear(p, h)
                if hps.sproj_mul is None:
                    h = ew.add(h, a)
                else:
                    h = hps.sproj_mul.scatter_add(h, a)
                h = ew.relu(h)
            else:
                h = self.linear(p, h, relu=True)
        '''

        i = self.linear("hi", h)
        f = self.linear("hf", h)
        o = self.linear("ho", h)
        u = self.linear("hu", h)

        # apply update dropout, saving mask if we need to recompute forward pass
        if self.train and hps.dropout > 0:
            if mask is None:
                u, mask = ew.dropout(u, keep_prob=1.0-hps.dropout)
            else:
                u = ew.dropout(u, mask=mask)
        else:
            mask = None

        if xi is not None:
            i = ew.add(i, xi)
            f = ew.add(f, xf)
            o = ew.add(o, xo)
            u = ew.add(u, xu)

        c, h = ew.fused_lstm_gates(c, i, f, o, u)
        return c, h, mask
Пример #2
0
    def cell(self, h, m, a):
        hps = self.hps
        assert hps.isteps >= 2, "multiply and add steps of mLSTM require 2 internal steps"

        for step in range(hps.isteps):

            # we can share one set of params for all isteps
            p = "h%d" % (0 if hps.share_isteps else step)

            if step == 0:
                h = self.linear(p, h)
                if hps.sproj_add is None:
                    h = ew.multiply(h, m)
                else:
                    h = hps.sproj_add.scatter_mul(h, m)
            elif step == 1:
                h = self.linear(p, h)
                if hps.sproj_mul is None:
                    h = ew.add(h, a)
                else:
                    h = hps.sproj_mul.scatter_add(h, a)
                h = ew.relu(h)
            else:
                h = self.linear(p, h, relu=True)

        return h
Пример #3
0
    def backward(self, grad_ys):
        hps = self.hps
        w_params = []
        g_params = []
        b_params = []
        for p in self.param_names:
            g, b, w = self.params[p][1:4]
            w_params.append(w)
            g_params.append(g)
            b_params.append(b)
        params = w_params + g_params + b_params
        nparams = len(params)
        nsegments = len(self.segments)

        # memory efficient gradients by recomputing forward pass
        if nsegments > 0:
            param_steps = []
            input_grads = []
            for i in range(nsegments - 1, -1, -1):
                with tf.name_scope("b_seg_%04d" % i):

                    h_grads = grad_ys[i * hps.recompute:(i + 1) *
                                      hps.recompute]
                    if i == nsegments - 1:
                        c_grad = tf.zeros(h_grads[0].get_shape())
                    else:
                        fc_matmul_op = get_parents(h_grads[0],
                                                   "BlocksparseMatmulDX")[0]

                        # delay matmul to avoid memory expansion till just prior to use
                        add_control_input(fc_matmul_op, h_grad.op)

                        h_grads[-1] = ew.add(h_grads[-1], h_grad)

                    s = self.segments[i]
                    x = self.inputs.pop()
                    c_prev, h_prev = s[0]
                    c_next = s[-1][0]
                    h_next = [seg[1] for seg in s[1:]]

                    # ensure the forward segments are computed in the backward pass only.
                    add_control_input(c_prev.op, h_grads[-1].op)
                    add_control_input(h_prev.op, h_grads[-1].op)

                    grads = tf.gradients([c_next] + h_next,
                                         params + [c_prev, h_prev, x],
                                         [c_grad] + h_grads,
                                         aggregation_method=agg_method)

                    param_steps.append(grads[0:nparams])
                    c_grad = grads[nparams + 0]
                    h_grad = grads[nparams + 1]
                    input_grads.insert(0, grads[nparams + 2])

            param_grads = []
            for i in range(nparams):
                param_grads.append(tf.add_n([g[i] for g in param_steps]))

        # Normal gradients for small models
        else:
            grads = tf.gradients(self.outputs,
                                 params + self.inputs,
                                 grad_ys,
                                 aggregation_method=agg_method)
            param_grads = grads[0:nparams]
            input_grads = grads[nparams:]

        # group param grad matmuls to efficinetly accumulate
        for i, p in enumerate(self.param_names):
            # a and m are already grouped
            if p not in 'am':
                param_grads[i] = group_param_grads(param_grads[i])

        grads = list(zip(param_grads, params))

        return grads, input_grads