def apply(self, state_below, mask_below=None, init_state=None, context=None): if K.ndim(state_below) == 2: state_below = K.expand_dims(state_below, 1) if mask_below is None: mask_below = K.ones_like(K.sum(state_below, axis=2, keepdims=True)) if init_state is None: # nb_samples,n_hids init_state = K.repeat_elements(K.expand_dims(K.zeros_like(K.sum(state_below, axis=[0, 2]))), self.n_hids, axis=1) print('init state ',K.ndim(init_state)) state_below_xh = K. dot(state_below, self.W_xh) state_below_xz = K. dot(state_below, self.W_xz) state_below_xr = K.dot(state_below, self.W_xr) sequences = [state_below_xh, state_below_xz, state_below_xr, mask_below] if K._BACKEND == 'theano': fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step(x_h, x_z, x_r, x_m, h_tm1) else: fn = lambda h_tm1, (x_h, x_z, x_r, x_m): self._step(x_h, x_z, x_r, x_m, h_tm1) rval = K.scan(fn, sequences=sequences, outputs_initials=init_state, name=_p(self.pname, 'layers')) self.output = rval return self.output
def apply(self, state_below, mask_below=None, init_state=None, init_context=None, c=None, c_mask=None, one_step=False, cov_before=None, fertility=None): # assert c, 'Context must be provided' # assert c.ndim == 3, 'Context must be 3-d: n_seq * batch_size * dim' # state_below: n_steps * batch_size/1 * embedding # mask if mask_below is None: # sampling or beamsearch mask_below = K.ones_like(K.sum(state_below, axis=-1, keepdims=True)) # nb_samples if K.ndim(mask_below) != K.ndim(state_below): mask_below = K.expand_dims(mask_below) assert K.ndim(mask_below) == K.ndim(state_below) if one_step: assert init_state is not None, 'previous state must be provided' if init_state is None: init_state = self.create_init_state(init_context) state_below_xh = K.dot(state_below, self.W_xh) state_below_xz = K.dot(state_below, self.W_xz) state_below_xr = K.dot(state_below, self.W_xr) if self.with_attention: # time steps, nb_samples, n_hids p_from_c = K.reshape(K.dot(c, self.A_cp), shape=(K.shape(c)[0], K.shape(c)[1], self.n_hids)) else: c_z = K.dot(init_context, self.W_cz) c_r = K.dot(init_context, self.W_cr) c_h = K.dot(init_context, self.W_ch) if one_step: if self.with_attention: return self._step_attention(state_below_xh, state_below_xz, state_below_xr, mask_below, init_state, c, c_mask, p_from_c, cov_tm1=cov_before, fertility=fertility) else: return self._step_context(state_below_xh, state_below_xz, state_below_xr, mask_below, init_state, c_z, c_r, c_h, init_context) else: sequences = [ state_below_xh, state_below_xz, state_below_xr, mask_below ] # decoder hidden state outputs_info = [init_state] if self.with_attention: # ctx, probs if K._BACKEND == 'theano': outputs_info += [None, None] else: outputs_info += [ K.zeros_like(K.sum(c, axis=0)), K.zeros_like(K.sum(c, axis=-1)) ] if self.with_coverage: # initialization for coverage # TODO: check c is 3D init_cov = K.repeat_elements(K.expand_dims( K.zeros_like(K.sum(c, axis=2))), self.coverage_dim, axis=2) outputs_info.append(init_cov) # fertility is not constructed outside when training if self.coverage_type is 'linguistic': fertility = self._get_fertility(c) else: fertility = K.zeros_like(K.sum(c, axis=2)) if K._BACKEND == 'theano': fn = lambda x_h, x_z, x_r, x_m, h_tm1, cov_tm1: self._step_attention( x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c, cov_tm1=cov_tm1, fertility=fertility) else: fn = lambda (h_tm1, ctx_tm1, probs_tm1, cov_tm1), ( x_h, x_z, x_r, x_m): self._step_attention( x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c, cov_tm1=cov_tm1, fertility=fertility) else: if K._BACKEND == 'theano': fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step_attention( x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c) else: fn = lambda (h_tm1, ctx_tm1, probs_tm1), ( x_h, x_z, x_r, x_m): self._step_attention( x_h, x_z, x_r, x_m, h_tm1, c, c_mask, p_from_c) else: if K._BACKEND == 'theano': fn = lambda x_h, x_z, x_r, x_m, h_tm1: self._step_context( x_h, x_z, x_r, x_m, h_tm1, c_z, c_r, c_h, init_context) else: fn = lambda (h_tm1, ), ( x_h, x_z, x_r, x_m): self._step_context( x_h, x_z, x_r, x_m, h_tm1, c_z, c_r, c_h, init_context) self.output = K.scan(fn, sequences=sequences, outputs_initials=outputs_info, name=_p(self.pname, 'layers')) return self.output