Example #1
0
    def apply(self, state_below, mask_below=None, init_state=None, context=None):

        if K.ndim(state_below) == 2:
            state_below = K.expand_dims(state_below, 1)

        if mask_below is None:
            mask_below = K.ones_like(K.sum(state_below, axis=2, keepdims=True))

        if init_state is None:
            # nb_samples,n_hids
            init_state = K.repeat_elements(K.expand_dims(K.zeros_like(K.sum(state_below, axis=[0, 2]))), self.n_hids, axis=1)
        print('init state ',K.ndim(init_state))

        state_below_xh = K. dot(state_below, self.W_xh)
        state_below_xz = K. dot(state_below, self.W_xz)
        state_below_xr = K.dot(state_below, self.W_xr)
        sequences = [state_below_xh, state_below_xz, state_below_xr, mask_below]

        if K._BACKEND == 'theano':
            fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step(x_h, x_z, x_r, x_m, h_tm1)
        else:
            fn = lambda h_tm1, (x_h, x_z, x_r, x_m): self._step(x_h, x_z, x_r, x_m, h_tm1)

        rval = K.scan(fn,
                      sequences=sequences,
                      outputs_initials=init_state,
                      name=_p(self.pname, 'layers'))

        self.output = rval
        return self.output
Example #2
0
    def apply(self,
              state_below,
              mask_below=None,
              init_state=None,
              init_context=None,
              c=None,
              c_mask=None,
              one_step=False,
              cov_before=None,
              fertility=None):

        # assert c, 'Context must be provided'
        # assert c.ndim == 3, 'Context must be 3-d: n_seq * batch_size * dim'

        # state_below: n_steps * batch_size/1 * embedding

        # mask
        if mask_below is None:  # sampling or beamsearch
            mask_below = K.ones_like(K.sum(state_below, axis=-1,
                                           keepdims=True))  # nb_samples

        if K.ndim(mask_below) != K.ndim(state_below):
            mask_below = K.expand_dims(mask_below)

        assert K.ndim(mask_below) == K.ndim(state_below)

        if one_step:
            assert init_state is not None, 'previous state must be provided'

        if init_state is None:
            init_state = self.create_init_state(init_context)

        state_below_xh = K.dot(state_below, self.W_xh)
        state_below_xz = K.dot(state_below, self.W_xz)
        state_below_xr = K.dot(state_below, self.W_xr)

        if self.with_attention:
            # time steps, nb_samples, n_hids
            p_from_c = K.reshape(K.dot(c, self.A_cp),
                                 shape=(K.shape(c)[0], K.shape(c)[1],
                                        self.n_hids))
        else:
            c_z = K.dot(init_context, self.W_cz)
            c_r = K.dot(init_context, self.W_cr)
            c_h = K.dot(init_context, self.W_ch)

        if one_step:
            if self.with_attention:
                return self._step_attention(state_below_xh,
                                            state_below_xz,
                                            state_below_xr,
                                            mask_below,
                                            init_state,
                                            c,
                                            c_mask,
                                            p_from_c,
                                            cov_tm1=cov_before,
                                            fertility=fertility)
            else:
                return self._step_context(state_below_xh, state_below_xz,
                                          state_below_xr, mask_below,
                                          init_state, c_z, c_r, c_h,
                                          init_context)
        else:
            sequences = [
                state_below_xh, state_below_xz, state_below_xr, mask_below
            ]
            # decoder hidden state
            outputs_info = [init_state]
            if self.with_attention:
                # ctx, probs
                if K._BACKEND == 'theano':
                    outputs_info += [None, None]
                else:
                    outputs_info += [
                        K.zeros_like(K.sum(c, axis=0)),
                        K.zeros_like(K.sum(c, axis=-1))
                    ]

                if self.with_coverage:
                    # initialization for coverage
                    # TODO: check c is 3D
                    init_cov = K.repeat_elements(K.expand_dims(
                        K.zeros_like(K.sum(c, axis=2))),
                                                 self.coverage_dim,
                                                 axis=2)
                    outputs_info.append(init_cov)
                    # fertility is not constructed outside when training
                    if self.coverage_type is 'linguistic':
                        fertility = self._get_fertility(c)
                    else:
                        fertility = K.zeros_like(K.sum(c, axis=2))
                    if K._BACKEND == 'theano':
                        fn = lambda x_h, x_z, x_r, x_m, h_tm1, cov_tm1: self._step_attention(
                            x_h,
                            x_z,
                            x_r,
                            x_m,
                            h_tm1,
                            c,
                            c_mask,
                            p_from_c,
                            cov_tm1=cov_tm1,
                            fertility=fertility)
                    else:
                        fn = lambda (h_tm1, ctx_tm1, probs_tm1, cov_tm1), (
                            x_h, x_z, x_r, x_m): self._step_attention(
                                x_h,
                                x_z,
                                x_r,
                                x_m,
                                h_tm1,
                                c,
                                c_mask,
                                p_from_c,
                                cov_tm1=cov_tm1,
                                fertility=fertility)
                else:
                    if K._BACKEND == 'theano':
                        fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step_attention(
                            x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c)
                    else:
                        fn = lambda (h_tm1, ctx_tm1, probs_tm1), (
                            x_h, x_z, x_r, x_m): self._step_attention(
                                x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c)

            else:
                if K._BACKEND == 'theano':
                    fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step_context(
                        x_h, x_z, x_r, x_m, h_tm1, c_z, c_r, c_h, init_context)
                else:
                    fn = lambda (h_tm1, ), (
                        x_h, x_z, x_r, x_m): self._step_context(
                            x_h, x_z, x_r, x_m, h_tm1, c_z, c_r, c_h,
                            init_context)

            self.output = K.scan(fn,
                                 sequences=sequences,
                                 outputs_initials=outputs_info,
                                 name=_p(self.pname, 'layers'))

            return self.output