def decode_forward_with_state(self, last_encoded, all_encoded, mask, input_data, state, new_seq): if new_seq: last_encoded.copyto(self.init_state_executor.arg_dict["encoded"]) self.init_state_executor.forward() init_hs = self.init_state_executor.outputs[0] # init_hs.copyto(self.decode_executor.arg_dict["target_l0_init_h"]) self.decode_executor.arg_dict["target_l0_init_c"][:] = 0.0 state = LSTMState( c=self.decode_executor.arg_dict["target_l0_init_c"], h=init_hs) all_encoded.copyto(self.decode_executor.arg_dict["attended"]) mask.copyto(self.decode_executor.arg_dict["encoded_mask"]) input_data.copyto(self.decode_executor.arg_dict["target"]) state.c.copyto(self.decode_executor.arg_dict["target_l0_init_c"]) state.h.copyto(self.decode_executor.arg_dict["target_l0_init_h"]) self.decode_executor.forward() prob = self.decode_executor.outputs[0] c = self.decode_executor.outputs[1] h = self.decode_executor.outputs[2] attention_weights = self.decode_executor.outputs[3] return prob, attention_weights, LSTMState(c=c, h=h)
def lstm_attention_decode_symbol(t_num_lstm_layer, t_seq_len, t_vocab_size, t_num_hidden, t_num_embed, t_num_label, t_dropout, attention, source_seq_len): data = mx.sym.Variable("target") seqidx = 0 embed_weight = mx.sym.Variable("target_embed_weight") cls_weight = mx.sym.Variable("target_cls_weight") cls_bias = mx.sym.Variable("target_cls_bias") input_weight = mx.sym.Variable("target_input_weight") # input_bias = mx.sym.Variable("target_input_bias") param_cells = [] last_states = [] for i in range(t_num_lstm_layer): param_cells.append( LSTMParam(i2h_weight=mx.sym.Variable("target_l%d_i2h_weight" % i), i2h_bias=mx.sym.Variable("target_l%d_i2h_bias" % i), h2h_weight=mx.sym.Variable("target_l%d_h2h_weight" % i), h2h_bias=mx.sym.Variable("target_l%d_h2h_bias" % i))) state = LSTMState(c=mx.sym.Variable("target_l%d_init_c" % i), h=mx.sym.Variable("target_l%d_init_h" % i)) # state = LSTMState(c=mx.sym.Variable("target_l%d_init_c" % i), # h=init_hs[i]) last_states.append(state) assert (len(last_states) == t_num_lstm_layer) hidden = mx.sym.Embedding(data=data, input_dim=t_vocab_size, output_dim=t_num_embed, weight=embed_weight, name="target_embed") all_encoded = mx.sym.Variable("attended") encoded = mx.sym.SliceChannel(data=all_encoded, axis=1, num_outputs=source_seq_len) weights, weighted_encoded = attention.attend(attended=encoded, concat_attended=all_encoded, state=last_states[0].h, attend_masks=None, use_masking=False) con = mx.sym.Concat(hidden, weighted_encoded) hidden = mx.sym.FullyConnected(data=con, num_hidden=t_num_embed, weight=input_weight, no_bias=True, name='input_fc') # hidden = mx.sym.Activation(data=hidden, act_type='tanh', name='input_act') # stack LSTM for i in range(t_num_lstm_layer): if i == 0: dp = 0. else: dp = t_dropout next_state = lstm(t_num_hidden, indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=seqidx, layeridx=i, dropout=dp) hidden = next_state.h last_states[i] = next_state fc = mx.sym.FullyConnected(data=hidden, num_hidden=t_num_label, weight=cls_weight, bias=cls_bias, name='target_pred') sm = mx.sym.SoftmaxOutput(data=fc, name='target_softmax') output = [sm] for state in last_states: output.append(state.c) output.append(state.h) output.append(weights) return mx.sym.Group(output)
def translate_one_with_beam(max_decode_len, sentence, model_buckets, unroll_len, source_vocab, target_vocab, revert_vocab, target_ndarray, beam_size, eos_index): input_length = len(sentence) cur_model = get_bucket_model(model_buckets, input_length) input_ndarray = mx.nd.zeros((beam_size, unroll_len)) mask_ndarray = mx.nd.zeros((beam_size, unroll_len)) beam = [[ BeamNode(father=-1, content='<s>', score=0.0, acc_score=0.0, finish=False, finishLen=0) for i in range(beam_size) ]] beam_state = [None] MakeInput_beam(sentence, source_vocab, unroll_len, input_ndarray, mask_ndarray, beam_size) last_encoded, all_encoded = cur_model.encode( input_ndarray, mask_ndarray) # last_encoded means the last time step hidden for i in range(max_decode_len): MakeTargetInput_beam(beam[-1], target_vocab, target_ndarray) prob, attention_weights, new_state = cur_model.decode_forward_with_state( last_encoded, all_encoded, mask_ndarray, target_ndarray, beam_state[-1], i == 0) log_prob = -mx.ndarray.log(prob) finished_beam = [t for t, x in enumerate(beam[-1]) if x.finish] for idx in range(beam_size): # log_prob[idx] = mx.nd.add(log_prob[idx], beam[-1][idx].score) if not beam[-1][idx].finish: # log_prob[idx] += beam[-1][idx].acc_score log_prob[idx] = (log_prob[idx] + beam[-1][idx].acc_score * beam[-1][idx].finishLen) / ( beam[-1][idx].finishLen + 1) else: # log_prob[idx] = beam[-1][idx].acc_score log_prob[idx] = beam[-1][idx].acc_score for idx in finished_beam: log_prob[idx][:eos_index] = np.inf log_prob[idx][eos_index + 1:] = np.inf (indexes, outputs), chosen_costs = _smallest(log_prob.asnumpy(), beam_size, only_first_row=(i == 0)) next_chars = [ revert_vocab[idx] if idx in revert_vocab else '' for idx in outputs ] next_state_h = mx.nd.empty(new_state.h.shape, ctx=mx.gpu(0)) next_state_c = mx.nd.empty(new_state.c.shape, ctx=mx.gpu(0)) for idx in range(beam_size): next_state_h[idx] = new_state.h[np.asscalar(indexes[idx])] next_state_c[idx] = new_state.c[np.asscalar(indexes[idx])] next_state = LSTMState(c=next_state_c, h=next_state_h) beam_state.append(next_state) next_beam = [ BeamNode( father=indexes[idx], content=next_chars[idx] if not beam[-1][indexes[idx]].finish else beam[-1][indexes[idx]].content, score=chosen_costs[idx] - beam[-1][indexes[idx]].acc_score, acc_score=chosen_costs[idx], finish=(next_chars[idx] == '</s>' or beam[-1][indexes[idx]].finish), finishLen=(beam[-1][indexes[idx]].finishLen if beam[-1][indexes[idx]].finish else (beam[-1][indexes[idx]].finishLen + 1))) for idx in range(beam_size) ] beam.append(next_beam) finished = [node.finish for node in beam[-1]] if all(finished): break # output.append(next_char) all_result = [] all_score = [] for aaa in range(beam_size): ptr = aaa result = [] for idx in range(len(beam) - 1 - 1, 0, -1): word = beam[idx][ptr].content if word != '</s>': result.append(word) ptr = beam[idx][ptr].father result = result[::-1] all_result.append(' '.join(result)) all_score.append(beam[-1][aaa].acc_score) return all_result, all_score