def multi_attention(state, hidden_states, encoders, encoder_input_length, pos=None, aggregation_method='sum', prev_weights=None, **kwargs): attns = [] weights = [] context_vector = None for i, (hidden, encoder, input_length) in enumerate(zip(hidden_states, encoders, encoder_input_length)): pos_ = pos[i] if pos is not None else None prev_weights_ = prev_weights[i] if prev_weights is not None else None hidden = beam_search.resize_like(hidden, state) input_length = beam_search.resize_like(input_length, state) context_vector, weights_ = attention(state=state, hidden_states=hidden, encoder=encoder, encoder_input_length=input_length, pos=pos_, context=context_vector, prev_weights=prev_weights_, **kwargs) attns.append(context_vector) weights.append(weights_) if aggregation_method == 'sum': context_vector = tf.reduce_sum(tf.stack(attns, axis=2), axis=2) else: context_vector = tf.concat(attns, axis=1) return context_vector, weights
def update_pos(pos, symbol, max_pos=None): if not decoder.pred_edits: return pos is_keep = tf.equal(symbol, utils.KEEP_ID) is_del = tf.equal(symbol, utils.DEL_ID) is_not_ins = tf.logical_or(is_keep, is_del) pos = beam_search.resize_like(pos, symbol) max_pos = beam_search.resize_like(max_pos, symbol) pos += tf.to_float(is_not_ins) if max_pos is not None: pos = tf.minimum(pos, tf.to_float(max_pos)) return pos
def multi_attention(state, hidden_states, encoders, encoder_input_length, pos=None, aggregation_method='sum', prev_weights=None, sim_score=0.0, **kwargs): attns = [] weights = [] attn_code = None attn_exemplar = None attn_exemplar_code = None attn_ast = None context_vector = None for i, (hidden, encoder, input_length) in enumerate(zip(hidden_states, encoders, encoder_input_length)): pos_ = pos[i] if pos is not None else None prev_weights_ = prev_weights[i] if prev_weights is not None else None if hidden is None: continue hidden = beam_search.resize_like(hidden, state) input_length = beam_search.resize_like(input_length, state) context_vector, weights_ = attention(state=state, hidden_states=hidden, encoder=encoder, encoder_input_length=input_length, pos=pos_, context=context_vector, prev_weights=prev_weights_, **kwargs) if encoder.name == 'code': attn_code = context_vector elif encoder.name == 'exemplar': attn_exemplar = context_vector elif encoder.name == 'exemplar_code': attn_exemplar_code = context_vector elif encoder.name == 'ast': attn_ast = context_vector attns.append(context_vector) weights.append(weights_) score_shape = tf.shape(sim_score)[0] batch_size = tf.shape(attn_code)[0] sim_score = tf.tile(tf.expand_dims(sim_score, axis=1), [1, batch_size//score_shape, 1]) sim_score = tf.reshape(sim_score, tf.stack([batch_size, 1])) # attn_fused = dense(tf.concat([attn_code, attn_ast], axis=1), 2*encoder.cell_size, activation=None, name='fuse1', use_bias=False) context_vector = attn_code * (1 - sim_score) + attn_exemplar * sim_score return context_vector, weights