예제 #1
0
파일: prednet.py 프로젝트: ecreagar/Prednet
    def get_initial_state(self, x):
        input_shape = self.input_spec[0].shape
        init_nb_row = input_shape[self.row_axis]
        init_nb_col = input_shape[self.column_axis]

        base_initial_state = K.zeros_like(
            x)  # (samples, timesteps) + image_shape
        non_channel_axis = -1 if self.data_format == 'channels_first' else -2
        for _ in range(2):
            base_initial_state = K.sum(base_initial_state,
                                       axis=non_channel_axis)
        base_initial_state = K.sum(base_initial_state,
                                   axis=1)  # (samples, nb_channels)

        initial_states = []
        states_to_pass = ['r', 'c', 'e']
        nlayers_to_pass = {u: self.nb_layers for u in states_to_pass}
        if self.extrap_start_time is not None:
            states_to_pass.append(
                'ahat'
            )  # pass prediction in states so can use as actual for t+1 when extrapolating
            nlayers_to_pass['ahat'] = 1
        for u in states_to_pass:
            for l in range(nlayers_to_pass[u]):
                ds_factor = 2**l
                nb_row = init_nb_row // ds_factor
                nb_col = init_nb_col // ds_factor
                if u in ['r', 'c']:
                    stack_size = self.R_stack_sizes[l]
                elif u == 'e':
                    stack_size = 2 * self.stack_sizes[l]
                elif u == 'ahat':
                    stack_size = self.stack_sizes[l]
                output_size = stack_size * nb_row * nb_col  # flattened size

                reducer = K.zeros((input_shape[self.channel_axis],
                                   output_size))  # (nb_channels, output_size)
                initial_state = K.dot(base_initial_state,
                                      reducer)  # (samples, output_size)
                if self.data_format == 'channels_first':
                    output_shp = (-1, stack_size, nb_row, nb_col)
                else:
                    output_shp = (-1, nb_row, nb_col, stack_size)
                initial_state = K.reshape(initial_state, output_shp)
                initial_states += [initial_state]

        if K._BACKEND == 'theano':
            from theano import tensor as T
            # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension.
            # In our case, this is a problem when training on grayscale images, and the below line fixes it.
            initial_states = [
                T.unbroadcast(init_state, 0, 1)
                for init_state in initial_states
            ]

        if self.extrap_start_time is not None:
            initial_states += [
                K.variable(0, int if K.backend() != 'tensorflow' else 'int32')
            ]  # the last state will correspond to the current timestep
        return initial_states
예제 #2
0
def dot_product(x, kernel):
    """
    Wrapper for dot product operation, in order to be compatible with both
    Theano and Tensorflow
    Args:
        x (): input
        kernel (): weights
    Returns:
    """
    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)
예제 #3
0
 def step(self, inputs, states):
     states = list(states)
     if self.teacher_force:
         readout = states.pop()
         ground_truth = states.pop()
         assert K.ndim(ground_truth) == 3, K.ndim(ground_truth)
         counter = states.pop()
         if K.backend() == 'tensorflow':
             with tf.control_dependencies(None):
                 zero = K.cast(K.zeros((1, ))[0], 'int32')
                 one = K.cast(K.zeros((1, ))[0], 'int32')
         else:
             zero = K.cast(K.zeros((1, ))[0], 'int32')
             one = K.cast(K.zeros((1, ))[0], 'int32')
         slices = [
             slice(None), counter[0] - K.switch(counter[0], one, zero)
         ] + [slice(None)] * (K.ndim(ground_truth) - 2)
         ground_truth_slice = ground_truth[slices]
         readout = K.in_train_phase(
             K.switch(counter[0], ground_truth_slice, readout), readout)
         states.append(readout)
     if self.decode:
         model_input = states
     else:
         model_input = [inputs] + states
     shapes = []
     for x in model_input:
         if hasattr(x, '_keras_shape'):
             shapes.append(x._keras_shape)
             del x._keras_shape  # Else keras internals will get messed up.
     model_output = _to_list(self.model.call(model_input))
     for x, s in zip(model_input, shapes):
         setattr(x, '_keras_shape', s)
     if self.decode:
         model_output.insert(1, model_input[0])
     for tensor in model_output:
         tensor._uses_learning_phase = self.uses_learning_phase
     states = model_output[1:]
     output = model_output[0]
     if self.readout:
         states += [output]
         if self.teacher_force:
             states.insert(-1, counter + 1)
             states.insert(-1, ground_truth)
     return output, states
예제 #4
0
    def call(self,
             inputs,
             initial_state=None,
             initial_readout=None,
             ground_truth=None,
             mask=None,
             training=None):
        # input shape: `(samples, time (padded with zeros), input_dim)`
        # note that the .build() method of subclasses MUST define
        # self.input_spec and self.state_spec with complete input shapes.
        if type(mask) is list:
            mask = mask[0]
        if self.model is None:
            raise Exception('Empty RecurrentModel.')
        num_req_states = self.num_states
        if self.readout:
            num_actual_states = num_req_states - 1
        else:
            num_actual_states = num_req_states
        if type(inputs) is list:
            inputs_list = inputs[:]
            inputs = inputs_list.pop(0)
            initial_states = inputs_list[:num_actual_states]
            if len(initial_states) > 0:
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)
            inputs_list = inputs_list[num_actual_states:]
            if self.readout:
                initial_readout = inputs_list.pop(0)
                if self.teacher_force:
                    ground_truth = inputs_list.pop()
        else:
            if initial_state is not None:
                if not isinstance(initial_state, (list, tuple)):
                    initial_states = [initial_state]
                else:
                    initial_states = list(initial_state)
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)

            elif self.stateful:
                initial_states = self.states
            else:
                initial_states = self.get_initial_state(inputs)
        if self.readout:
            if initial_readout is None or self._is_optional_input_placeholder(
                    initial_readout):
                output_shape = K.int_shape(_to_list((self.model.output))[0])
                output_ndim = len(output_shape)
                input_ndim = K.ndim(inputs)
                initial_readout = K.zeros_like(inputs)
                slices = [slice(None)] + [0] * (input_ndim - 1)
                initial_readout = initial_readout[slices]  # (batch_size,)
                initial_readout = K.reshape(initial_readout,
                                            (-1, ) + (1, ) * (output_ndim - 1))
                initial_readout = K.tile(initial_readout,
                                         (1, ) + tuple(output_shape[1:]))
            initial_states.append(initial_readout)
            if self.teacher_force:
                if ground_truth is None or self._is_optional_input_placeholder(
                        ground_truth):
                    raise Exception(
                        'ground_truth must be provided for RecurrentModel with teacher_force=True.'
                    )
                if K.backend() == 'tensorflow':
                    with tf.control_dependencies(None):
                        counter = K.zeros((1, ))
                else:
                    counter = K.zeros((1, ))
                counter = K.cast(counter, 'int32')
                initial_states.insert(-1, counter)
                initial_states[-2]
                initial_states.insert(-1, ground_truth)
                num_req_states += 2
        if len(initial_states) != num_req_states:
            raise ValueError('Layer requires ' + str(num_req_states) +
                             ' states but was passed ' +
                             str(len(initial_states)) + ' initial states.')
        input_shape = K.int_shape(inputs)
        if self.unroll and input_shape[1] is None:
            raise ValueError('Cannot unroll a RNN if the '
                             'time dimension is undefined. \n'
                             '- If using a Sequential model, '
                             'specify the time dimension by passing '
                             'an `input_shape` or `batch_input_shape` '
                             'argument to your first layer. If your '
                             'first layer is an Embedding, you can '
                             'also use the `input_length` argument.\n'
                             '- If using the functional API, specify '
                             'the time dimension by passing a `shape` '
                             'or `batch_shape` argument to your Input layer.')
        preprocessed_input = self.preprocess_input(inputs, training=None)
        constants = self.get_constants(inputs, training=None)
        if self.decode:
            initial_states.insert(0, inputs)
            preprocessed_input = K.zeros((1, self.output_length, 1))
            input_length = self.output_length
        else:
            input_length = input_shape[1]
        if self.uses_learning_phase:
            with learning_phase_scope(0):
                last_output_test, outputs_test, states_test, updates = rnn(
                    self.step,
                    preprocessed_input,
                    initial_states,
                    go_backwards=self.go_backwards,
                    mask=mask,
                    constants=constants,
                    unroll=self.unroll,
                    input_length=input_length)
            with learning_phase_scope(1):
                last_output_train, outputs_train, states_train, updates = rnn(
                    self.step,
                    preprocessed_input,
                    initial_states,
                    go_backwards=self.go_backwards,
                    mask=mask,
                    constants=constants,
                    unroll=self.unroll,
                    input_length=input_length)

            last_output = K.in_train_phase(last_output_train,
                                           last_output_test,
                                           training=training)
            outputs = K.in_train_phase(outputs_train,
                                       outputs_test,
                                       training=training)
            states = []
            for state_train, state_test in zip(states_train, states_test):
                states.append(
                    K.in_train_phase(state_train,
                                     state_test,
                                     training=training))

        else:
            last_output, outputs, states, updates = rnn(
                self.step,
                preprocessed_input,
                initial_states,
                go_backwards=self.go_backwards,
                mask=mask,
                constants=constants,
                unroll=self.unroll,
                input_length=input_length)
        states = list(states)
        if self.decode:
            states.pop(0)
        if self.readout:
            states.pop()
            if self.teacher_force:
                states.pop()
                states.pop()
        if len(updates) > 0:
            self.add_update(updates)
        if self.stateful:
            updates = []
            for i in range(len(states)):
                updates.append((self.states[i], states[i]))
            self.add_update(updates, inputs)

        # Properly set learning phase
        if 0 < self.dropout + self.recurrent_dropout:
            last_output._uses_learning_phase = True
            outputs._uses_learning_phase = True

        if self.return_sequences:
            y = outputs
        else:
            y = last_output
        if self.return_states:
            return [y] + states
        else:
            return y
예제 #5
0
from keras.layers import *
from keras.models import Model
from keras import initializers
from keras.backend import tensorflow_backend as K
from K import rnn, learning_phase_scope
from .generic_utils import serialize_function, deserialize_function
from keras.engine.base_layer import Node, _collect_previous_mask, _collect_input_shape
import inspect
import tensorflow as tf

if K.backend() == 'tensorflow':
    import tensorflow as tf


def _to_list(x):
    if type(x) is not list:
        x = [x]
    return x


class _OptionalInputPlaceHolder(Layer):
    def __init__(self, name=None, **kwargs):
        if not name:
            prefix = 'optional_input_placeholder'
            name = prefix + '_' + str(K.get_uid(prefix))
        kwargs['batch_input_shape'] = (2, )
        super(_OptionalInputPlaceHolder, self).__init__(**kwargs)
        self.tensor = K.zeros(shape=(2, ))
        self.tensor._keras_shape = (2, )
        self.tensor._uses_learning_phase = False
        self.tensor._keras_history = (self, 0, 0)
예제 #6
0
파일: prednet.py 프로젝트: ecreagar/Prednet
    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']}

        for l in range(self.nb_layers):
            for c in ['i', 'f', 'c', 'o']:
                act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation
                self.conv_layers[c].append(
                    Conv2D(self.R_stack_sizes[l],
                           self.R_filt_sizes[l],
                           padding='same',
                           activation=act,
                           data_format=self.data_format))

            act = 'relu' if l == 0 else self.A_activation
            self.conv_layers['ahat'].append(
                Conv2D(self.stack_sizes[l],
                       self.Ahat_filt_sizes[l],
                       padding='same',
                       activation=act,
                       data_format=self.data_format))

            if l < self.nb_layers - 1:
                self.conv_layers['a'].append(
                    Conv2D(self.stack_sizes[l + 1],
                           self.A_filt_sizes[l],
                           padding='same',
                           activation=self.A_activation,
                           data_format=self.data_format))

        self.upsample = UpSampling2D(data_format=self.data_format)
        self.pool = MaxPooling2D(data_format=self.data_format)

        self.trainable_weights = []
        nb_row, nb_col = (
            input_shape[-2],
            input_shape[-1]) if self.data_format == 'channels_first' else (
                input_shape[-3], input_shape[-2])
        for c in sorted(self.conv_layers.keys()):
            for l in range(len(self.conv_layers[c])):
                ds_factor = 2**l
                if c == 'ahat':
                    nb_channels = self.R_stack_sizes[l]
                elif c == 'a':
                    nb_channels = 2 * self.stack_sizes[l]
                else:
                    nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[
                        l]
                    if l < self.nb_layers - 1:
                        nb_channels += self.R_stack_sizes[l + 1]
                in_shape = (input_shape[0], nb_channels, nb_row // ds_factor,
                            nb_col // ds_factor)
                if self.data_format == 'channels_last':
                    in_shape = (in_shape[0], in_shape[2], in_shape[3],
                                in_shape[1])
                with K.name_scope('layer_' + c + '_' + str(l)):
                    self.conv_layers[c][l].build(in_shape)
                self.trainable_weights += self.conv_layers[c][
                    l].trainable_weights

        self.states = [None] * self.nb_layers * 3

        if self.extrap_start_time is not None:
            self.t_extrap = K.variable(
                self.extrap_start_time,
                int if K.backend() != 'tensorflow' else 'int32')
            self.states += [None] * 2  # [previous frame prediction, timestep]