def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_len = None if sequence_length is None else tf.cast( flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major # TODO : to be implemented later! flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_len = None if sequence_length is None else tf.cast( flatten(sequence_length, 0), 'int64') flat_outputs, final_state = tf.nn.dynamic_rnn( cell, flat_inputs, sequence_length=flat_len, initial_state=initial_state, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) outputs = reconstruct(flat_outputs, inputs, 2) return outputs, final_state
def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0, is_train=None): if args is None or (isinstance(args, (tuple, list)) and not args): raise ValueError("`args` must be specified") if not isinstance(args, (tuple, list)): args = [args] flat_args = [flatten(arg, 1) for arg in args] # for dense layer [(-1, d)] if input_keep_prob < 1.0: assert is_train is not None flat_args = [ tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg) # for dense layer [(-1, d)] for arg in flat_args ] flat_out = _linear(flat_args, output_size, bias, bias_start=bias_start, scope=scope) # dense out = reconstruct(flat_out, args[0], 1) # () if squeeze: out = tf.squeeze(out, [len(args[0].get_shape().as_list()) - 1]) if wd: add_reg_without_bias() return out