コード例 #1
0
class RecurrentModel(Recurrent):

    # INITIALIZATION

    def __init__(self, input, output, initial_states=None, final_states=None, readout_input=None, teacher_force=False, decode=False, output_length=None, return_states=False, state_initializer=None, **kwargs):
        inputs = [input]
        outputs = [output]
        state_spec = None
        if initial_states is not None:
            if type(initial_states) not in [list, tuple]:
                initial_states = [initial_states]
            state_spec = [InputSpec(shape=K.int_shape(state)) for state in initial_states]
            if final_states is None:
                raise Exception('Missing argument : final_states')
            else:
                self.states = [None] * len(initial_states)
            inputs += initial_states
        else:
            self.states = []
            state_spec = []

        if final_states is not None:
            if type(final_states) not in [list, tuple]:
                final_states = [final_states]
            assert len(initial_states) == len(final_states), 'initial_states and final_states should have same number of tensors.'
            if initial_states is None:
                raise Exception('Missing argument : initial_states')
            outputs += final_states
        self.decode = decode
        self.output_length = output_length
        if decode:
            if output_length is None:
                raise Exception('output_length should be specified for decoder')
            kwargs['return_sequences'] = True
        self.return_states = return_states
        if readout_input is not None:
            self.readout = True
            state_spec += [Input(batch_shape=K.int_shape(outputs[0]))]
            self.states += [None]
            inputs += [readout_input]
        else:
            self.readout = False
        if teacher_force and not self.readout:
            raise Exception('Readout should be enabled for teacher forcing.')
        self.teacher_force = teacher_force
        self.model = Model(inputs, outputs)
        super(RecurrentModel, self).__init__(**kwargs)
        input_shape = list(K.int_shape(input))
        if not decode:
            input_shape.insert(1, None)
        self.input_spec = InputSpec(shape=tuple(input_shape))
        self.state_spec = state_spec
        self._optional_input_placeholders = {}
        if state_initializer:
            if type(state_initializer) not in [list, tuple]:
                state_initializer = [state_initializer] * self.num_states
            else:
                state_initializer += [None] * (self.num_states - len(state_initializer))
            state_initializer = [initializers.get(init) if init else initializers.get('zeros') for init in state_initializer]
        self.state_initializer = state_initializer

    def build(self, input_shape):
        if type(input_shape) is list:
            input_shape = input_shape[0]
        if not self.decode:
            input_length = input_shape[1]
            if input_length is not None:
                input_shape = list(self.input_spec.shape)
                input_shape[1] = input_length
                input_shape = tuple(input_shape)
                self.input_spec = InputSpec(shape=input_shape)
        if type(self.model.input) is list:
            model_input_shape = self.model.input_shape[0]
        else:
            model_input_shape = self.model.input_shape
        if not self.decode:
            input_shape = input_shape[:1] + input_shape[2:]
        for i, j in zip(input_shape, model_input_shape):
            if i is not None and j is not None and i != j:
                raise Exception('Model expected input with shape ' + str(model_input_shape) +
                                '. Received input with shape ' + str(input_shape))
        if self.stateful:
            self.reset_states()
        self.built = True

    # STATES

    @property
    def num_states(self):
        model_input = self.model.input
        if type(model_input) is list:
            return len(model_input[1:])
        else:
            return 0

    def get_initial_state(self, inputs):
        if type(self.model.input) is not list:
            return []
        try:
            batch_size = K.int_shape(inputs)[0]
        except:
            batch_size = None
        state_shapes = list(map(K.int_shape, self.model.input[1:]))
        states = []
        if self.readout:
            state_shapes.pop()
            # default value for initial_readout is handled in call()
        for shape in state_shapes:
            if None in shape[1:]:
                raise Exception('Only the batch dimension of a state can be left unspecified. Got state with shape ' + str(shape))
            if shape[0] is None:
                ndim = K.ndim(inputs)
                z = K.zeros_like(inputs)
                slices = [slice(None)] + [0] * (ndim - 1)
                z = z[slices]  # (batch_size,)
                state_ndim = len(shape)
                z = K.reshape(z, (-1,) + (1,) * (state_ndim - 1))
                z = K.tile(z, (1,) + tuple(shape[1:]))
                states.append(z)
            else:
                states.append(K.zeros(shape))
        state_initializer = self.state_initializer
        if state_initializer:
            # some initializers don't accept symbolic shapes
            for i in range(len(state_shapes)):
                if state_shapes[i][0] is None:
                    if hasattr(self, 'batch_size'):
                        state_shapes[i] = (self.batch_size,) + state_shapes[i][1:]
                if None in state_shapes[i]:
                    state_shapes[i] = K.shape(states[i])
            num_state_init = len(state_initializer)
            num_state = self.num_states
            assert num_state_init == num_state, 'RNN has ' + str(num_state) + ' states, but was provided ' + str(num_state_init) + ' state initializers.'
            for i in range(len(states)):
                init = state_initializer[i]
                shape = state_shapes[i]
                try:
                    if not isinstance(init, initializers.Zeros):
                        states[i] = init(shape)
                except:
                    raise Exception('Seems the initializer ' + init.__class__.__name__ + ' does not support symbolic shapes(' + str(shape) + '). Try providing the full input shape (include batch dimension) for you RecurrentModel.')
        return states

    def reset_states(self, states_value=None):
        if len(self.states) == 0:
            return
        if not self.stateful:
            raise AttributeError('Layer must be stateful.')
        if not hasattr(self, 'states') or self.states[0] is None:
            state_shapes = list(map(K.int_shape, self.model.input[1:]))
            self.states = list(map(K.zeros, state_shapes))

        if states_value is not None:
            if type(states_value) not in (list, tuple):
                states_value = [states_value] * len(self.states)
            assert len(states_value) == len(self.states), 'Your RNN has ' + str(len(self.states)) + ' states, but was provided ' + str(len(states_value)) + ' state values.'
            if 'numpy' not in type(states_value[0]):
                states_value = list(map(np.array, states_value))
            if states_value[0].shape == tuple():
                for state, val in zip(self.states, states_value):
                    K.set_value(state, K.get_value(state) * 0. + val)
            else:
                for state, val in zip(self.states, states_value):
                    K.set_value(state, val)
        else:
            if self.state_initializer:
                for state, init in zip(self.states, self.state_initializer):
                    if isinstance(init, initializers.Zeros):
                        K.set_value(state, 0 * K.get_value(state))
                    else:
                        K.set_value(state, K.eval(init(K.get_value(state).shape)))
            else:
                for state in self.states:
                    K.set_value(state, 0 * K.get_value(state))

    # EXECUTION

    def __call__(self, inputs, initial_state=None, initial_readout=None, ground_truth=None, **kwargs):
        req_num_inputs = 1 + self.num_states
        inputs = _to_list(inputs)
        inputs = inputs[:]
        if len(inputs) == 1:
            if initial_state is not None:
                if type(initial_state) is list:
                    inputs += initial_state
                else:
                    inputs.append(initial_state)
            else:
                if self.readout:
                    initial_state = self._get_optional_input_placeholder('initial_state', self.num_states - 1)
                else:
                    initial_state = self._get_optional_input_placeholder('initial_state', self.num_states)
                inputs += _to_list(initial_state)
            if self.readout:
                if initial_readout is None:
                    initial_readout = self._get_optional_input_placeholder('initial_readout')
                inputs.append(initial_readout)
            if self.teacher_force:
                req_num_inputs += 1
                if ground_truth is None:
                    ground_truth = self._get_optional_input_placeholder('ground_truth')
                inputs.append(ground_truth)
        assert len(inputs) == req_num_inputs, "Required " + str(req_num_inputs) + " inputs, received " + str(len(inputs)) + "."
        with K.name_scope(self.name):
            if not self.built:
                self.build(K.int_shape(inputs[0]))
                if self._initial_weights is not None:
                    self.set_weights(self._initial_weights)
                    del self._initial_weights
                    self._initial_weights = None
            previous_mask = _collect_previous_mask(inputs[:1])
            user_kwargs = kwargs.copy()
            if not _is_all_none(previous_mask):
                if 'mask' in inspect.getargspec(self.call).args:
                    if 'mask' not in kwargs:
                        kwargs['mask'] = previous_mask
            input_shape = _collect_input_shape(inputs)
            output = self.call(inputs, **kwargs)
            output_mask = self.compute_mask(inputs[0], previous_mask)
            output_shape = self.compute_output_shape(input_shape[0])
            self._add_inbound_node(input_tensors=inputs, output_tensors=output,
                                   input_masks=previous_mask, output_masks=output_mask,
                                   input_shapes=input_shape, output_shapes=output_shape,
                                   arguments=user_kwargs)
            if hasattr(self, 'activity_regularizer') and self.activity_regularizer is not None:
                regularization_losses = [self.activity_regularizer(x) for x in _to_list(output)]
                self.add_loss(regularization_losses, _to_list(inputs))
        return output

    def call(self, inputs, initial_state=None, initial_readout=None, ground_truth=None, mask=None, training=None):
        # input shape: `(samples, time (padded with zeros), input_dim)`
        # note that the .build() method of subclasses MUST define
        # self.input_spec and self.state_spec with complete input shapes.
        if type(mask) is list:
            mask = mask[0]
        if self.model is None:
            raise Exception('Empty RecurrentModel.')
        num_req_states = self.num_states
        if self.readout:
            num_actual_states = num_req_states - 1
        else:
            num_actual_states = num_req_states
        if type(inputs) is list:
            inputs_list = inputs[:]
            inputs = inputs_list.pop(0)
            initial_states = inputs_list[:num_actual_states]
            if len(initial_states) > 0:
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)
            inputs_list = inputs_list[num_actual_states:]
            if self.readout:
                initial_readout = inputs_list.pop(0)
                if self.teacher_force:
                    ground_truth = inputs_list.pop()
        else:
            if initial_state is not None:
                if not isinstance(initial_state, (list, tuple)):
                    initial_states = [initial_state]
                else:
                    initial_states = list(initial_state)
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)

            elif self.stateful:
                initial_states = self.states
            else:
                initial_states = self.get_initial_state(inputs)
        if self.readout:
            if initial_readout is None or self._is_optional_input_placeholder(initial_readout):
                output_shape = K.int_shape(_to_list((self.model.output))[0])
                output_ndim = len(output_shape)
                input_ndim = K.ndim(inputs)
                initial_readout = K.zeros_like(inputs)
                slices = [slice(None)] + [0] * (input_ndim - 1)
                initial_readout = initial_readout[slices]  # (batch_size,)
                initial_readout = K.reshape(initial_readout, (-1,) + (1,) * (output_ndim - 1))
                initial_readout = K.tile(initial_readout, (1,) + tuple(output_shape[1:]))
            initial_states.append(initial_readout)
            if self.teacher_force:
                if ground_truth is None or self._is_optional_input_placeholder(ground_truth):
                    raise Exception('ground_truth must be provided for RecurrentModel with teacher_force=True.')
                # counter = K.zeros((1,), dtype='int32')
                counter = K.zeros((1,))
                counter = K.cast(counter, 'int32')
                initial_states.insert(-1, counter)
                initial_states[-2]
                initial_states.insert(-1, ground_truth)
                num_req_states += 2
        if len(initial_states) != num_req_states:
            raise ValueError('Layer requires ' + str(num_req_states) +
                             ' states but was passed ' +
                             str(len(initial_states)) +
                             ' initial states.')
        input_shape = K.int_shape(inputs)
        if self.unroll and input_shape[1] is None:
            raise ValueError('Cannot unroll a RNN if the '
                             'time dimension is undefined. \n'
                             '- If using a Sequential model, '
                             'specify the time dimension by passing '
                             'an `input_shape` or `batch_input_shape` '
                             'argument to your first layer. If your '
                             'first layer is an Embedding, you can '
                             'also use the `input_length` argument.\n'
                             '- If using the functional API, specify '
                             'the time dimension by passing a `shape` '
                             'or `batch_shape` argument to your Input layer.')
        preprocessed_input = self.preprocess_input(inputs, training=None)
        constants = self.get_constants(inputs, training=None)
        if self.decode:
            initial_states.insert(0, inputs)
            preprocessed_input = K.zeros((1, self.output_length, 1))
            input_length = self.output_length
        else:
            input_length = input_shape[1]
        if self.uses_learning_phase:
            with learning_phase_scope(0):
                last_output_test, outputs_test, states_test, updates = rnn(self.step,
                                                                           preprocessed_input,
                                                                           initial_states,
                                                                           go_backwards=self.go_backwards,
                                                                           mask=mask,
                                                                           constants=constants,
                                                                           unroll=self.unroll,
                                                                           input_length=input_length)
            with learning_phase_scope(1):
                last_output_train, outputs_train, states_train, updates = rnn(self.step,
                                                                              preprocessed_input,
                                                                              initial_states,
                                                                              go_backwards=self.go_backwards,
                                                                              mask=mask,
                                                                              constants=constants,
                                                                              unroll=self.unroll,
                                                                              input_length=input_length)

            last_output = K.in_train_phase(last_output_train, last_output_test, training=training)
            outputs = K.in_train_phase(outputs_train, outputs_test, training=training)
            states = []
            for state_train, state_test in zip(states_train, states_test):
                states.append(K.in_train_phase(state_train, state_test, training=training))

        else:
            last_output, outputs, states, updates = rnn(self.step,
                                                        preprocessed_input,
                                                        initial_states,
                                                        go_backwards=self.go_backwards,
                                                        mask=mask,
                                                        constants=constants,
                                                        unroll=self.unroll,
                                                        input_length=input_length)
        states = list(states)
        if self.decode:
            states.pop(0)
        if self.readout:
            states.pop()
            if self.teacher_force:
                states.pop()
                states.pop()
        if len(updates) > 0:
            self.add_update(updates)
        if self.stateful:
            updates = []
            for i in range(len(states)):
                updates.append((self.states[i], states[i]))
            self.add_update(updates, inputs)

        # Properly set learning phase
        if 0 < self.dropout + self.recurrent_dropout:
            last_output._uses_learning_phase = True
            outputs._uses_learning_phase = True

        if self.return_sequences:
            y = outputs
        else:
            y = last_output
        if self.return_states:
            return [y] + states
        else:
            return y

    def step(self, inputs, states):
        states = list(states)
        if self.teacher_force:
            readout = states.pop()
            ground_truth = states.pop()
            assert K.ndim(ground_truth) == 3, K.ndim(ground_truth)
            counter = states.pop()
            zero = K.cast(K.zeros((1,))[0], 'int32')
            one = K.cast(K.zeros((1,))[0], 'int32')
            slices = [slice(None), counter[0] - K.switch(counter[0], one, zero)] + [slice(None)] * (K.ndim(ground_truth) - 2)
            ground_truth_slice = ground_truth[slices]
            readout = K.in_train_phase(K.switch(counter[0], ground_truth_slice, readout), readout)
            states.append(readout)
        if self.decode:
            model_input = states
        else:
            model_input = [inputs] + states
        shapes = []
        for x in model_input:
            if hasattr(x, '_keras_shape'):
                shapes.append(x._keras_shape)
                del x._keras_shape  # Else keras internals will get messed up.
        model_output = _to_list(self.model.call(model_input))
        for x, s in zip(model_input, shapes):
            setattr(x, '_keras_shape', s)
        if self.decode:
            model_output.insert(1, model_input[0])
        for tensor in model_output:
            tensor._uses_learning_phase = self.uses_learning_phase
        states = model_output[1:]
        output = model_output[0]
        if self.readout:
            states += [output]
            if self.teacher_force:
                states.insert(-1, counter + 1)
                states.insert(-1, ground_truth)
        return output, states

    # SHAPE, MASK, WEIGHTS

    def compute_output_shape(self, input_shape):
        if not self.decode:
            if type(input_shape) is list:
                input_shape[0] = self._remove_time_dim(input_shape[0])
            else:
                input_shape = self._remove_time_dim(input_shape)
        input_shape = _to_list(input_shape)
        input_shape = [input_shape[0]] + [K.int_shape(state) for state in self.model.input[1:]]
        output_shape = self.model.compute_output_shape(input_shape)
        if type(output_shape) is list:
            output_shape = output_shape[0]
        if self.return_sequences:
            if self.decode:
                output_shape = output_shape[:1] + (self.output_length,) + output_shape[1:]
            else:
                output_shape = output_shape[:1] + (self.input_spec.shape[1],) + output_shape[1:]
        if self.return_states and len(self.states) > 0:
            output_shape = [output_shape] + list(map(K.int_shape, self.model.output[1:]))
        return output_shape

    def compute_mask(self, input, input_mask=None):
        mask = input_mask[0] if type(input_mask) is list else input_mask
        mask = mask if self.return_sequences else None
        mask = [mask] + [None] * len(self.states) if self.return_states else mask
        return mask

    def set_weights(self, weights):
        self.model.set_weights(weights)

    def get_weights(self):
        return self.model.get_weights()

    # LAYER ATTRIBS

    @property
    def updates(self):
        return self.model.updates

    def add_update(self, updates, inputs=None):
        self.model.add_update(updates, inputs)

    @property
    def uses_learning_phase(self):
        return self.teacher_force or self.model.uses_learning_phase

    @property
    def _per_input_losses(self):
        if hasattr(self, 'model'):
            return getattr(self.model, '_per_input_losses', {})
        else:
            return {}

    @_per_input_losses.setter
    def _per_input_losses(self, val):
        if hasattr(self, 'model'):
            self.model._per_input_losses = val

    @property
    def losses(self):
        if hasattr(self, 'model'):
            return self.model.losses
        else:
            return []

    @losses.setter
    def losses(self, val):
        if hasattr(self, 'model'):
            self.model.losses = val

    def add_loss(self, losses, inputs=None):
        self.model.add_loss(losses, inputs)

    @property
    def constraints(self):
        return self.model.constraints

    @property
    def trainable_weights(self):
        return self.model.trainable_weights

    @property
    def non_trainable_weights(self):
        return self.model.non_trainable_weights

    def get_losses_for(self, inputs):
        return self.model.get_losses_for(inputs)

    def get_updates_for(self, inputs):
        return self.model.get_updates_for(inputs)

    def _remove_time_dim(self, shape):
        return shape[:1] + shape[2:]

    # SERIALIZATION

    def _serialize_state_initializer(self):
        si = self.state_initializer
        if si is None:
            return None
        elif type(si) is list:
            return list(map(initializers.serialize, si))
        else:
            return initializers.serialize(si)

    def get_config(self):
        config = {'model_config': self.model.get_config(),
                  'decode': self.decode,
                  'output_length': self.output_length,
                  'return_states': self.return_states,
                  'state_initializer': self._serialize_state_initializer()            
                  }
        base_config = super(RecurrentModel, self).get_config()
        config.update(base_config)
        return config

    @classmethod
    def from_config(cls, config, custom_objects={}):
        if type(custom_objects) is list:
            custom_objects = {obj.__name__: obj for obj in custom_objects}
        custom_objects.update(_get_cells())
        config = config.copy()
        model_config = config.pop('model_config')
        if model_config is None:
            model = None
        else:
            model = Model.from_config(model_config, custom_objects)
        if type(model.input) is list:
            input = model.input[0]
            initial_states = model.input[1:]
        else:
            input = model.input
            initial_states = None
        if type(model.output) is list:
            output = model.output[0]
            final_states = model.output[1:]
        else:
            output = model.output
            final_states = None
        return cls(input, output, initial_states, final_states, **config)

    def get_cell(self, **kwargs):
        return RNNCellFromModel(self.model, **kwargs)

    def _get_optional_input_placeholder(self, name=None, num=1):
        if name:
            if name not in self._optional_input_placeholders:
                if num > 1:
                    self._optional_input_placeholders[name] = [self._get_optional_input_placeholder() for _ in range(num)]
                else:
                    self._optional_input_placeholders[name] = self._get_optional_input_placeholder()
            return self._optional_input_placeholders[name]
        if num == 1:
            optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
            assert self._is_optional_input_placeholder(optional_input_placeholder)
            return optional_input_placeholder
        else:
            y = []
            for _ in range(num):
                optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
                assert self._is_optional_input_placeholder(optional_input_placeholder)
                y.append(optional_input_placeholder)
            return y

    def _is_optional_input_placeholder(self, x):
        if hasattr(x, '_keras_history'):
            if isinstance(x._keras_history[0], _OptionalInputPlaceHolder):
                return True
        return False