示例#1
0
 def get_switched():
     x_ = x
     shape = tf.shape(x)
     n_batch = tf.shape(x)[data.batch_dim_axis]
     n_time = tf.shape(x)[data.time_dim_axis]
     take_rnd_mask = tf.less(
         tf.random.uniform(shape=shape, minval=0., maxval=1.), 0.05)
     take_blank_mask = tf.less(
         tf.random.uniform(shape=shape, minval=0., maxval=1.), 0.5)
     rnd_label = tf.random.uniform(shape=shape,
                                   minval=0,
                                   maxval=eval("target_num_labels"),
                                   dtype=tf.int32)
     rnd_label = where_bc(take_blank_mask, eval("targetb_blank_idx"),
                          rnd_label)
     x_ = where_bc(take_rnd_mask, rnd_label, x_)
     x_ = eval("random_mask")(x_,
                              batch_axis=data.batch_dim_axis,
                              axis=data.time_dim_axis,
                              min_num=0,
                              max_num=tf.maximum(
                                  tf.shape(x)[data.time_dim_axis] //
                                  (50 // time_factor), 1),
                              max_dims=20 // time_factor,
                              mask_value=eval("targetb_blank_idx"))
     # x_ = tf.Print(x_, ["switch", x[0], "to", x_[0]], summarize=100)
     return x_
def _mask(x, batch_axis, axis, pos, max_amount):
    """
    :param tf.Tensor x: (batch,time,feature)
    :param int batch_axis:
    :param int axis:
    :param tf.Tensor pos: (batch,)
    :param int|tf.Tensor max_amount: inclusive
    """
    from returnn.tf.compat import v1 as tf
    ndim = x.get_shape().ndims
    n_batch = tf.shape(x)[batch_axis]
    dim = tf.shape(x)[axis]
    amount = tf.random_uniform(shape=(n_batch, ),
                               minval=1,
                               maxval=max_amount + 1,
                               dtype=tf.int32)
    pos2 = tf.minimum(pos + amount, dim)
    idxs = tf.expand_dims(tf.range(0, dim), 0)  # (1,dim)
    pos_bc = tf.expand_dims(pos, 1)  # (batch,1)
    pos2_bc = tf.expand_dims(pos2, 1)  # (batch,1)
    cond = tf.logical_and(tf.greater_equal(idxs, pos_bc),
                          tf.less(idxs, pos2_bc))  # (batch,dim)
    if batch_axis > axis:
        cond = tf.transpose(cond)  # (dim,batch)
    cond = tf.reshape(cond, [
        tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim)
    ])
    from TFUtil import where_bc
    x = where_bc(cond, 0.0, x)
    return x
示例#3
0
def targetb_search_or_fallback(source, **kwargs):
    import tensorflow as tf
    from TFUtil import where_bc
    ts_linear = source(0)  # (B,T)
    ts_search = source(1)  # (B,T)
    l = source(2, auto_convert=False)  # (B,)
    return where_bc(tf.less(l[:, None], 0.01), ts_search, ts_linear)
def _mask(x, axis, pos, max_amount):
    from returnn.tf.compat import v1 as tf
    ndim = x.get_shape().ndims
    cond = _get_mask(x, axis, pos, max_amount)
    cond = tf.reshape(
        cond, [tf.shape(x)[i] if i in (0, axis) else 1 for i in range(ndim)])
    from TFUtil import where_bc
    x = where_bc(cond, 0.0, x)
    return x
示例#5
0
    def __init__(self, sources, seed=None, use_time_mask=None, \
                 infer_threshold=None, first_reset_value=1., exp_energy_cumsum=False, sigmoid_energy_cumprod=False,
                 **kwargs):
        assert sources
        super(GatedRecurrentContextLayer, self).__init__(sources=sources,
                                                         **kwargs)
        from TFUtil import where_bc
        energy_data = concat_sources([self.sources[0]
                                      ])  # (enc-T,B,H), not (B,enc-T,H)
        assert energy_data.dtype.startswith("float")
        energy = energy_data.placeholder  #(enc-T,B,H)
        axis = 0  #energy_data.batch_ndim - 1

        orig_time_axis = self._get_axis_to_reduce(input_data=energy_data,
                                                  axis="T",
                                                  exception_prefix=self)
        if orig_time_axis == 1:  # this case, transpose placeholder (B,enc-T,H)->(enc-T,B,H)
            print(
                "original time axis of energy was 1, and is changed with 0-th axis! (to make (enc-T,B,H))"
            )
            energy = tf.transpose(energy, perm=(1, 0, 2))

        energy_shape = tf.shape(energy)  #shape is (enc-T,B,H)
        #
        from TFUtil import check_shape_equal
        assert energy_data.have_time_axis()
        # if the time-axis is static, we can skip the masking
        if use_time_mask is None:
            use_time_mask = energy_data.is_axis_dynamic(orig_time_axis)
        if use_time_mask:
            assert energy_data.is_axis_dynamic(
                orig_time_axis
            ), "%s: use_time_mask True, dyn time axis expected" % self
            energy_mask = energy_data.get_sequence_mask_broadcast(
                axis=orig_time_axis)
            if orig_time_axis == 1:  # (B,enc-T,H)->(enc-T,B,H)
                energy_mask = tf.transpose(energy_mask, perm=(1, 0, 2))
            energy = where_bc(energy_mask,
                              energy,
                              float("-inf"),
                              name="energy_masked")
        ### Attention
        # Add sigmoid noise only at training
        network = self.sources[0].network

        #(enc-T,B,H)
        def safe_cumprod(x, *args, **kwargs):
            with tf.name_scope(None, "SafeCumprod", [x]):
                x = tf.convert_to_tensor(x, name="x")
                import numpy as np
                tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
                return tf.exp(
                    tf.cumsum(tf.log(tf.clip_by_value(x, tiny, 1)), *args,
                              **kwargs))

        if exp_energy_cumsum and sigmoid_energy_cumprod:
            assert False, "Use only 1 among exp_energy_cumsum or sigmoid_energy_cumprod."
        elif sigmoid_energy_cumprod:
            sigmoid_energy = tf.sigmoid(energy)  #(enc-T,B,H)
            reset = safe_cumprod(sigmoid_energy)  #(enc-T,B,H)
        elif exp_energy_cumsum:
            exp_energy = tf.exp(-energy)
            exp_energy_accum = tf.cumsum(
                exp_energy, axis=0)  #(enc-T,B,H), increasing as time goes
            reset = 1. / (
                1. + exp_energy_accum
            )  #(enc-T,B,H), 0~1, decreasing from 1 to 0 as time goes
        else:
            reset = tf.sigmoid(energy)  #(enc-T,B,H), 0~1
        #
        def substitute(x, time_pads=[0, 0], value=0.):
            T = tf.shape(x)[0]
            x_left = tf.fill([time_pads[0], energy_shape[1], energy_shape[2]],
                             value)
            x_middle = x[time_pads[0]:T - time_pads[1], :, :]
            x_right = tf.fill([time_pads[1], energy_shape[1], energy_shape[2]],
                              value)
            return tf.concat([x_left, x_middle, x_right], axis=axis)

        if first_reset_value is not None:
            reset = substitute(reset,
                               time_pads=[1, 0],
                               value=first_reset_value)
        #
        if ((network.train_flag is None or network.train_flag is False)
                and infer_threshold is not None):
            print("----------------------------------------------------------")
            print("--------------------INFER_simple_thresholding-------------")
            print("----------------------------------------------------------")
            low_threshold_point = get_endpoint_compare_to(
                reset, infer_threshold, "l", "first")
            before_low_threshold = tf.cumsum(low_threshold_point,
                                             axis=0,
                                             reverse=True)
            reset = before_low_threshold * reset
        # safe_cumprod computes cumprod in logspace with numeric checks
        cumprod_1mreset = safe_cumprod(1 - reset,
                                       axis=axis,
                                       exclusive=True,
                                       reverse=True)  #(enc-T,B,H) #(45,1,132)
        # Compute recurrence relation solution
        weights = reset * cumprod_1mreset
        ###
        if orig_time_axis == 1:  # this case, transpose placeholder (enc-T,B,H)->(B,enc-T,H)
            weights = tf.transpose(weights, perm=(1, 0, 2))
        weights = tf.reshape(weights,
                             [tf.shape(weights)[0],
                              tf.shape(weights)[1], 1])
        self.output.placeholder = weights  #(enc-T,B,H)
示例#6
0
    def __init__(self, sources, energy_factor=None, sigmoid_noise=1.0, seed=None,\
                 chunk_size=1, test_same_as_train=False,
                 use_time_mask=None, train_cumcalc_mode="recursive", **kwargs):
        assert sources
        super(MonotonicHardAttention2Layer, self).__init__(sources=sources,
                                                           **kwargs)
        from TFUtil import where_bc
        energy_data = concat_sources([self.sources[0]
                                      ])  # (enc-T,B,H), not (B,enc-T,H)
        assert energy_data.dtype.startswith("float")
        previous_attention_data = concat_sources([self.sources[1]
                                                  ])  #(enc-T,B,H)
        axis = 0  #energy_data.batch_ndim - 1
        energy = energy_data.placeholder  #(enc-T,B,H)

        chunk_energy = None
        if chunk_size is not None and isinstance(
                chunk_size, (int, float)) and chunk_size > 1:
            chunk_size = int(chunk_size)
            chunk_energy_data = concat_sources(
                [self.sources[2]])  #if chunk_size > 1 else None #(enc-T,B,H)
            chunk_energy = chunk_energy_data.placeholder
        orig_time_axis = self._get_axis_to_reduce(input_data=energy_data,
                                                  axis="T",
                                                  exception_prefix=self)
        if orig_time_axis == 1:  # this case, transpose placeholder (B,enc-T,H)->(enc-T,B,H)
            print(
                "original time axis of energy was 1, and is changed with 0-th axis! (to make (enc-T,B,H))"
            )
            energy = tf.transpose(energy, perm=(1, 0, 2))
            if chunk_energy is not None:
                print("  =>(did same for chunk_energy)")
                chunk_energy = tf.transpose(chunk_energy, perm=(1, 0, 2))

        previous_attention = previous_attention_data.placeholder  #(enc-T,B,H)
        orig_time_axis_prevatt = self._get_axis_to_reduce(
            input_data=previous_attention_data,
            axis="T",
            exception_prefix=self)
        if orig_time_axis_prevatt == 1:  # this case, transpose placeholder (B,enc-T,H)->(enc-T,B,H)
            print(
                "original time axis of previous_attention was 1, and is changed with 0-th axis! (to make (enc-T,B,H))"
            )
            previous_attention = tf.transpose(previous_attention,
                                              perm=(1, 0, 2))

        energy_shape = tf.shape(energy)  #shape is (enc-T,B,H)
        init_ones = tf.ones([1, energy_shape[1], energy_shape[2]],
                            dtype=energy.dtype)  #(1,B,H)
        init_zeros = tf.zeros(
            [energy_shape[0] - 1, energy_shape[1], energy_shape[2]],
            dtype=energy.dtype)  #(enc-T - 1,B,H)
        init_attention = tf.concat([init_ones, init_zeros], axis=axis)
        previous_attention = tf.cond(
            tf.equal(tf.reduce_sum(tf.abs(previous_attention)),
                     tf.constant(0., dtype=previous_attention.dtype)),
            true_fn=lambda: init_attention,
            false_fn=lambda: previous_attention,
        )

        from TFUtil import check_shape_equal
        assert energy_data.have_time_axis()
        assert previous_attention_data.have_time_axis()
        # if the time-axis is static, we can skip the masking
        if use_time_mask is None:
            use_time_mask = energy_data.is_axis_dynamic(orig_time_axis)
        if use_time_mask:
            assert energy_data.is_axis_dynamic(
                orig_time_axis
            ), "%s: use_time_mask True, dyn time axis expected" % self
            energy_mask = energy_data.get_sequence_mask_broadcast(
                axis=orig_time_axis)
            if orig_time_axis == 1:  # (B,enc-T,H)->(enc-T,B,H)
                energy_mask = tf.transpose(energy_mask, perm=(1, 0, 2))
            energy = where_bc(energy_mask,
                              energy,
                              float("-inf"),
                              name="energy_masked")
            if chunk_energy is not None:
                chunk_energy = where_bc(energy_mask,
                                        chunk_energy,
                                        float("-inf"),
                                        name="chunk_energy_masked")
        if energy_factor:
            energy = tf.multiply(energy, energy_factor, name="energy_scaled")
            if chunk_energy is not None:
                chunk_energy = tf.multiply(chunk_energy,
                                           energy_factor,
                                           name="chunk_energy_scaled")

        ### main part (https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/parts/rnns/attention_wrapper.py)
        # Add sigmoid noise only at training
        network = self.sources[0].network
        score = energy  #(enc-T,B,H)
        if network.train_flag is not False:
            print("----------------------------------------------------------")
            print("---------- NOW TRAIN TIME !!!!!!(sigmoide noise add)------")
            print("----------------------------------------------------------")
            if sigmoid_noise > 0:
                noise = tf.random.normal(tf.shape(score),
                                         dtype=score.dtype,
                                         seed=seed)
            score += sigmoid_noise * noise
        # Calculate p_choose_i
        if (network.train_flag is not False or test_same_as_train):
            print("----------------------------------------------------------")
            print("---------- NOW TRAIN TIME !!!!!!(p=sigmoid(score))--------")
            print("----------------------------------------------------------")
            p_choose_i = tf.sigmoid(score)  #(enc-T,B,H)
        else:
            print("----------------------------------------------------------")
            print("---------- NOW TEST TIME !!!!!!(p=1(score>0))------------")
            print("----------------------------------------------------------")
            if True:
                p_choose_i = tf.cast(score > 0, score.dtype)  #(enc-T,B,H)
            else:  #sampling (not to be used)
                p_choose_i = tf.sigmoid(1. * score)  #(enc-T,B,H)
                z = tf.random.uniform(tf.shape(score),
                                      dtype=score.dtype,
                                      seed=seed)  #(enc-T,B,H)
                p_choose_i = tf.cast(p_choose_i > z, score.dtype)
        # Calculate weights
        if (network.train_flag is not False
                or test_same_as_train) and train_cumcalc_mode == "recursive":
            assert False, "Recursive mode is not implemented yet."
            print("----------------------------------------------------------")
            print("---------------- NOW TRAIN TIME !!!!!!(recursive)---------")
            print("----------------------------------------------------------")
            # Use .shape[0].value when it's not None, or fall back on symbolic shape
            batch_size = p_choose_i.shape[1].value or tf.shape(p_choose_i)[1]
            num_heads = p_choose_i.shape[2].value or tf.shape(p_choose_i)[2]
            # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]]
            shifted_1mp_choose_i = tf.concat([
                tf.ones((1, batch_size, num_heads)), 1 - p_choose_i[:-1, :, :]
            ],
                                             axis=0)  #(B,H,enc-T)
            # Compute attention distribution recursively as
            # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i]
            # attention[i] = p_choose_i[i]*q[i]
            weights = p_choose_i * tf.transpose(
                tf.scan(
                    # Need to use reshape to remind TF of the shape between loop
                    # iterations
                    lambda x, yz: tf.reshape(yz[0] * x + yz[1],
                                             (batch_size, num_heads)),
                    # Loop variables yz[0] and yz[1]
                    [
                        # (enc-T,B,H)
                        tf.transpose(shifted_1mp_choose_i, perm=(0, 1, 2)),
                        tf.transpose(previous_attention, perm=(0, 1, 2))
                    ],
                    # Initial value of x is just zeros
                    tf.zeros((batch_size, num_heads)),  #(B,H)
                    swap_memory=True,
                    parallel_iterations=1,
                ),
                # (enc-T,B,H)
                perm=(0, 1, 2))
        elif (network.train_flag is not False
              or test_same_as_train) and train_cumcalc_mode == "parallel":
            print("----------------------------------------------------------")
            print("---------------- NOW TRAIN TIME !!!!!!(parallel)----------")
            print("----------------------------------------------------------")

            def safe_cumprod(x, *args, **kwargs):
                with tf.name_scope(None, "SafeCumprod", [x]):
                    x = tf.convert_to_tensor(x, name="x")
                    import numpy as np
                    tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
                    return tf.exp(
                        tf.cumsum(tf.log(tf.clip_by_value(x, tiny, 1)), *args,
                                  **kwargs))

            # safe_cumprod computes cumprod in logspace with numeric checks
            cumprod_1mp_choose_i = safe_cumprod(
                1 - p_choose_i, axis=axis,
                exclusive=True)  #(enc-T,B,H) #(45,1,132)
            # Compute recurrence relation solution
            weights = p_choose_i * cumprod_1mp_choose_i * tf.cumsum(
                previous_attention /
                # Clip cumprod_1mp to avoid divide-by-zero
                tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.),
                axis=axis)  #(enc-T,B,H)
        elif (network.train_flag is not False or test_same_as_train):
            assert False, "train_cumcalc_mode must be in [\"recuresive\",\"parallel\"]"
        else:
            print("----------------------------------------------------------")
            print("---------------- NOW TEST TIME !!!!!!!--------------------")
            print("----------------------------------------------------------")
            ####### ORIG(openseq2seq) ##########
            # p_choose_i          : [0,0,1,1,1,1,0,0]
            # tf.cumsum(prev_att) : [0,0,1,1,1,1,1,1]
            # 1-p_choose_i        : [1,1,0,0,0,0,1,1]
            # tf.cumprod('')      : [1,1,1,0,0,0,0,0]
            #weights = p_choose_i * tf.cumsum(previous_attention, axis=axis) *\
            #          tf.cumprod(1 - p_choose_i, axis=axis, exclusive=True)  #(enc-T,B,H)
            ######## ADDED ########
            prev_att_existance = tf.cast(
                previous_attention > 0., dtype=tf.float32
            )  #e.g. [0, 0.1, 0.2, 0.8, 0] => [0, 1, 1, 1, 0]
            reverse_filter = tf.cumsum(
                prev_att_existance, axis=axis, exclusive=True,
                reverse=True)  #e.g. [0, 1, 1, 1, 0] => [3, 2, 1, 0, 0]
            reverse_filter_existance = tf.cast(
                reverse_filter > 0.,
                dtype=tf.float32)  #e.g. [3, 2, 1, 0, 0] => [1, 1, 1, 0, 0]
            filter_existance = 1 - reverse_filter_existance  #e.g. [1, 1, 1, 0, 0] => [0, 0, 0, 1, 1]
            previous_hard_attention = prev_att_existance * filter_existance  #e.g. [0, 1, 1, 1, 0] * [0, 0, 0, 1, 1] = [0, 0, 0, 1, 0]
            # p_choose_i          : [1,0,0,0,1,0,1,0]
            # prev_attention      : [0,0,1,0,0,0,0,0]
            # p_c *tf.cumsum('')  : [0,0,0,0,1,0,1,0]
            # tf.cumprod(1-(''))  : [1,1,1,1,1,0,0,0]
            # weights             : [0,0,0,0,1,0,0,0]
            cs_hard_attention = p_choose_i * tf.cumsum(previous_hard_attention,
                                                       axis=axis)  #(enc-T,B,H)
            weights = cs_hard_attention * tf.cumprod(
                1 - cs_hard_attention, axis=axis, exclusive=True)  #(enc-T,B,H)
            ##############################
        if isinstance(chunk_size, (int, float)) and chunk_size > 1:
            alpha = weights  #(enc-T,B,H), only one t_i among enc-T is 1, others are 0.
            alpha_shape = tf.shape(alpha)
            if (network.train_flag is not False or test_same_as_train):

                def moving_sum(x, b, f):  #x:(T,B,C)
                    x_shape = tf.shape(x)
                    x = tf.transpose(x, perm=(1, 0, 2))  #(T,B,C)->(B,T,C)
                    I = tf.expand_dims(
                        tf.eye(x_shape[2]), axis=0
                    )  #(1,C,C), no operation applied on head dimension
                    filt_half = max(b, f) - 1  #assume b,f are not tensors
                    filt = tf.concat(
                        [
                            tf.zeros([filt_half - (b - 1), 1, 1],
                                     dtype=x.dtype),
                            tf.ones([(b - 1) + (1) + (f - 1), 1, 1],
                                    dtype=x.dtype),
                            tf.zeros([filt_half - (f - 1), 1, 1],
                                     dtype=x.dtype)
                        ],
                        axis=0,
                    )
                    W = I * filt  #(2*max(b,f)-1, C, C)
                    return tf.transpose(
                        tf.nn.conv1d(x, W, stride=1, padding="SAME"),  #(B,T,C)
                        perm=(1, 0, 2)  #(B,T,C)->(T,B,C)
                    )  #zero-padding is enough, assuming that exp(u) comes in as input. (exp(-inf)==0)

                exp_u = tf.exp(chunk_energy)  #(enc-T,B,H)
                beta = exp_u * moving_sum(
                    alpha / (moving_sum(exp_u, chunk_size, 1) + 1e-6), 1,
                    chunk_size)  #(enc-T,B,H)
            else:
                t = tf.argmax(alpha, axis=0)  #(B,H)
                chunk_energy_mask = tf.logical_or(
                    tf.sequence_mask(
                        t + 1 - chunk_size,
                        maxlen=alpha_shape[0],
                        name='chunk_energy_mask_pre'),  #(B,H,enc-T), bool  
                    tf.logical_not(
                        tf.sequence_mask(t + 1,
                                         maxlen=alpha_shape[0],
                                         name='chunk_energy_mask_post')
                    )  #(B,H,enc-T), bool  
                )
                chunk_energy_mask = tf.where(
                    tf.transpose(chunk_energy_mask,
                                 perm=(2, 0, 1)),  #(B,H,enc-T) => (enc-T,B,H) 
                    x=tf.ones(alpha_shape, dtype=tf.float32) * float('-inf'),
                    y=tf.zeros(alpha_shape, dtype=tf.float32),
                )
                # softmax over (t_i-chunk_size+1,t_i)
                chunk_energy += chunk_energy_mask  #(enc-T,B,H)
                beta = tf.where(
                    tf.ones_like(alpha) *
                    tf.reduce_sum(tf.abs(alpha), axis=0, keepdims=True) >
                    0.,  #(enc-T,B,H)
                    x=tf.nn.softmax(chunk_energy, axis=0),
                    y=tf.zeros_like(chunk_energy))
            weights = beta
        ############################################
        if orig_time_axis == 1:  # this case, transpose placeholder (enc-T,B,H)->(B,enc-T,H)
            weights = tf.transpose(weights, perm=(1, 0, 2))
        weights = tf.reshape(weights,
                             [tf.shape(weights)[0],
                              tf.shape(weights)[1], 1])
        self.output.placeholder = weights  #(enc-T,B,H)