예제 #1
def conv1d_sd(input, filters, image_shape, filter_shape, border_mode='valid',
              subsample=(1,), filter_flip=True):
    using a single dot product
    if border_mode not in ('valid', 0, (0,)):
        raise RuntimeError("Unsupported border_mode for conv1d_sd: "
                           "%s" % border_mode)

    batch_size, num_input_channels, input_length = image_shape
    num_filters, num_input_channels_, filter_length = filter_shape
    stride = subsample[0]

    if filter_length % stride > 0:
        raise RuntimeError("Filter length (%d) is not a multiple of the "
                           "stride (%d)" % (filter_length, stride))

    num_steps = filter_length // stride
    output_length = (input_length - filter_length + stride) // stride

    # pad the input so all the shifted dot products fit inside.
    # shape is (b, c, l)
    padded_length = ((input_length // filter_length) * filter_length +
                     (num_steps - 1) * stride)

    # at this point, it is possible that the padded_length is SMALLER than the
    # input size. so then we have to truncate first.
    truncated_length = min(input_length, padded_length)
    input_truncated = input[:, :, :truncated_length]

    input_padded_shape = (batch_size, num_input_channels, padded_length)
    input_padded = T.zeros(input_padded_shape)
    input_padded = T.set_subtensor(input_padded[:, :, :truncated_length],

    inputs = []
    for num in range(num_steps):
        shift = num * stride
        length = (padded_length - shift) // filter_length

        r_input_shape = (batch_size, num_input_channels, length, filter_length)
        r_input = input_padded[
            :, :, shift:length * filter_length + shift].reshape(r_input_shape)


    inputs_stacked = T.stack(*inputs)  # shape is (n, b, c, w, f)
    filters_flipped = filters[:, :, ::-1] if filter_flip else filters

    r_conved = T.tensordot(inputs_stacked, filters_flipped,
                           np.asarray([[2, 4], [1, 2]]))
    # resulting shape is (n, b, w, n_filters)
    # output needs to be (b, n_filters, w * n)
    r_conved = r_conved.dimshuffle(1, 3, 2, 0)  # (b, n_filters, w, n)
    conved = r_conved.reshape((r_conved.shape[0], r_conved.shape[1],
                               r_conved.shape[2] * r_conved.shape[3]))
    # result is (b, n_f, l)

    # remove padding
    return conved[:, :, :output_length]
예제 #2
def conv1d_md(input, filters, image_shape, filter_shape, border_mode='valid',
              subsample=(1,), filter_flip=True):
    using multiple dot products
    if border_mode not in ('valid', 0, (0,)):
        raise RuntimeError("Unsupported border_mode for conv1d_md: "
                           "%s" % border_mode)

    batch_size, num_input_channels, input_length = image_shape
    num_filters, num_input_channels_, filter_length = filter_shape
    stride = subsample[0]

    if filter_length % stride > 0:
        raise RuntimeError("Filter length (%d) is not a multiple of the "
                           "stride (%d)" % (filter_length, stride))

    num_steps = filter_length // stride
    output_length = (input_length - filter_length + stride) // stride
    output_shape = (batch_size, num_filters, output_length)

    filters_flipped = filters[:, :, ::-1] if filter_flip else filters

    conved = T.zeros(output_shape)

    for num in range(num_steps):
        shift = num * stride
        length = (input_length - shift) // filter_length

        if length == 0:
            # we can safely skip this product, it doesn't contribute to the
            # final convolution.

        r_input_shape = (batch_size, num_input_channels, length, filter_length)
        r_input = input[
            :, :, shift:length * filter_length + shift].reshape(r_input_shape)

        # shape (b, l, n_filters)
        r_conved = T.tensordot(r_input, filters_flipped,
                               np.asarray([[1, 3], [1, 2]]))
        r_conved = r_conved.dimshuffle(0, 2, 1)  # shape is (b, n_filters, l)
        conved = T.set_subtensor(conved[:, :, num::num_steps], r_conved)

    return conved