def build(self,
              input_values,
              input_shape_visualisaion,
              hparams,
              name='diff_plasticity'):
        """Initializes the model parameters.

    Args:
        input_values: Tensor containing input
        input_shape_visualisaion: The shape of the input, for display (internal is vectorized)
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        name: A globally unique graph name used as a prefix for all tensors and ops.
    """
        self._name = name
        self._hparams = hparams
        self._dual = DualData(self._name)
        self._summary_op = None
        self._summary_values = None
        self._get_weights_op = None
        self._input_shape_visualisation = input_shape_visualisaion
        self._input_values = input_values

        length_sparse = int(self._hparams.filters *
                            self._hparams.input_sparsity)
        self._active_bits = int(
            self._hparams.filters - length_sparse
        )  # do it this way to match the int rounding in generator

        self._build()
示例#2
0
  def build(self, hparams, name='dg_stub'):
    """Builds the DG Stub."""
    self._name = name
    self._hparams = hparams
    self._dual = DualData(self._name)

    batch_arr = self._dg_stub_batch()

    the_one_batch = tf.convert_to_tensor(batch_arr, dtype=tf.float32)
    self._dual.set_op('encoding', the_one_batch)

    # add a stub of secondary decoding also, it is expected by workflow
    self._dual.add('secondary_decoding_input', shape=the_one_batch.shape, default_value=1.0).add_pl(default=True)
示例#3
0
    def build(self,
              input_values,
              input_shape,
              hparams,
              name='deep_ae',
              input_cue_raw=None,
              encoding_shape=None):
        """Initializes the model parameters.

    Args:
        input_values: Tensor containing input
        input_shape: The shape of the input, for display (internal is vectorized)
        encoding_shape: The shape to be used to display encoded (hidden layer) structures
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        name: A globally unique graph name used as a prefix for all tensors and ops.
    """
        self._name = name
        self._hidden_name = 'hidden'
        self._hparams = hparams
        self._dual = DualData(self._name)
        self._summary_training_op = None
        self._summary_encoding_op = None
        self._summary_values = None
        self._weights = None

        self._input_shape = input_shape
        self._input_values = input_values
        self._encoding_shape = encoding_shape

        if self._encoding_shape is None:
            self._encoding_shape = self._create_encoding_shape_4d(
                input_shape)  # .as_list()

        if self._hparams.pm_raw_type != 'none' and input_cue_raw is not None:
            self._use_pm_raw = True
            self._input_cue_raw = input_cue_raw

        self._batch_type = None

        with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):
            # 1) Build the deep autoencoder
            self._build()

            # 2) build Pattern Mapping
            # ---------------------------------------------------
            if self._use_pm or self._use_pm_raw:
                self._build_pm()

            self.reset()
示例#4
0
class DualComponent(Component):
    """
  A Component that uses a DualData object to manage on/off graph data.
  It also has a unique name() property.
  """
    def __init__(self, name=None):
        super().__init__()

        self._dual = DualData(name)
        # self._name = name  # Maybe discovered after instantiation time

    @property
    def name(self):
        return self._dual.get_root_name()

    @name.setter
    def name(self, name):
        self._dual.set_root_name(name)

    def get_dual(self):
        return self._dual

    def get_op(self, key):
        return self._dual.get_op(key)

    def get_shape(self, key):
        return self._dual.get(key).get_shape()

    def get_values(self, key):
        return self._dual.get_values(key)
示例#5
0
class DeepAutoencoderComponent(AutoencoderComponent):
    """Deep Autoencoder with untied weights (and untied biases)."""
    @staticmethod
    def default_hparams():
        """Builds an HParam object with default hyperparameters."""
        return tf.contrib.training.HParams(
            learning_rate=0.005,
            loss_type='mse',
            num_layers=3,
            nonlinearity=['relu', 'relu', 'relu'],
            output_nonlinearity='sigmoid',
            batch_size=64,
            filters=[128, 64, 32],
            pm_type='none',  # Not relevant; use pm_raw_type
            pm_raw_type=
            'none',  # 'none' or 'nn': map stable PC patterns back to VC input (image space)
            pm_l1_size=100,  # hidden layer of PM path (pattern mapping)
            pm_raw_hidden_size=[100],
            pm_raw_l2_regularizer=0.0,
            pm_raw_nonlinearity='leaky_relu',
            pm_noise_type='s',  # 's' for salt, 'sp' for salt + pepper
            pm_train_with_noise=0.0,
            pm_train_with_noise_pp=0.0,
            pm_train_dropout_input_keep_prob=1.0,
            pm_train_dropout_hidden_keep_prob=[1.0],
            optimizer='adam',
            momentum=0.9,
            momentum_nesterov=False,
            summarize_level=SummarizeLevels.ALL.value,
            max_outputs=3  # Number of outputs in TensorBoard
        )

    def __init__(self):
        super().__init__()

        self._use_pm = False
        self._use_pm_raw = False

    @property
    def use_input_cue(self):
        return False

    @property
    def use_pm(self):
        return self._use_pm

    @property
    def use_pm_raw(self):
        return self._use_pm_raw

    def use_nn_in_pr_path(self):
        return True

    def get_ec_out_raw_op(self):
        return None

    def variables_networks(self, outer_scope):
        vars_nets = []

        # Selectively include/exclude optimizer parameters
        optim_ae = False
        optim_pm_raw = True

        vars_nets += self._variables_encoder(outer_scope)
        vars_nets += self._variables_decoder(outer_scope)

        if optim_ae:
            vars_nets += self._variables_ae_optimizer(outer_scope)

        if self.use_pm_raw:
            vars_nets += self._variables_pm_raw(outer_scope)
            if optim_pm_raw:
                vars_nets += self._variables_pm_raw_optimizer(outer_scope)

        return vars_nets

    def get_input(self, batch_type):
        """Unlike Hopfield, a standard component only has the one input, so return that regardless of batch type."""
        del batch_type
        return self.get_inputs()

    def get_losses_pm(self, default=0):
        loss = self._dual.get_values('pm_loss')
        loss_raw = self._dual.get_values('pm_loss_raw')

        if loss is None:
            loss = default

        if loss_raw is None:
            loss_raw = default

        return loss, loss_raw

    def build(self,
              input_values,
              input_shape,
              hparams,
              name='deep_ae',
              input_cue_raw=None,
              encoding_shape=None):
        """Initializes the model parameters.

    Args:
        input_values: Tensor containing input
        input_shape: The shape of the input, for display (internal is vectorized)
        encoding_shape: The shape to be used to display encoded (hidden layer) structures
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        name: A globally unique graph name used as a prefix for all tensors and ops.
    """
        self._name = name
        self._hidden_name = 'hidden'
        self._hparams = hparams
        self._dual = DualData(self._name)
        self._summary_training_op = None
        self._summary_encoding_op = None
        self._summary_values = None
        self._weights = None

        self._input_shape = input_shape
        self._input_values = input_values
        self._encoding_shape = encoding_shape

        if self._encoding_shape is None:
            self._encoding_shape = self._create_encoding_shape_4d(
                input_shape)  # .as_list()

        if self._hparams.pm_raw_type != 'none' and input_cue_raw is not None:
            self._use_pm_raw = True
            self._input_cue_raw = input_cue_raw

        self._batch_type = None

        with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):
            # 1) Build the deep autoencoder
            self._build()

            # 2) build Pattern Mapping
            # ---------------------------------------------------
            if self._use_pm or self._use_pm_raw:
                self._build_pm()

            self.reset()

    def _create_encoding_shape_4d(self, input_shape):  # pylint: disable=W0613
        """Put it into convolutional geometry: [batches, filter h, filter w, filters]"""
        return [self._hparams.batch_size, 1, 1, self._hparams.filters[-1]]

    def _build_optimizer(self, loss_op, training_op_name, scope=None):
        """Minimise loss using initialised a tf.train.Optimizer."""

        logging.info(
            "-----------> Adding optimiser for op {0}".format(loss_op))

        if scope is not None:
            scope = 'optimizer/' + str(scope)
        else:
            scope = 'optimizer'

        with tf.variable_scope(scope):
            optimizer = self._setup_optimizer()
            training = optimizer.minimize(
                loss_op, global_step=tf.train.get_or_create_global_step())

            self._dual.set_op(training_op_name, training)

    def _setup_optimizer(self):
        """Initialise the Optimizer class specified by a hyperparameter."""
        if self._hparams.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(self._hparams.learning_rate)
        elif self._hparams.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(
                self._hparams.learning_rate,
                self._hparams.momentum,
                use_nesterov=self._hparams.momentum_nesterov)
        elif self._hparams.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(
                self._hparams.learning_rate)
        else:
            raise NotImplementedError('Optimizer not implemented: ' +
                                      str(self._hparams.optimizer))

        return optimizer

    def _build(self):
        """Build the autoencoder network"""

        self._batch_type = tf.placeholder_with_default(input='training',
                                                       shape=[],
                                                       name='batch_type')

        self._dual.set_op('inputs', self._input_values)

        # input_shape = self._input_values.get_shape().as_list()
        # output_shape = np.prod(input_shape[1:])

        # kernel_initializer = build_kernel_initializer('xavier')

        # assert self._hparams.num_layers == len(self._hparams.filters)
        # assert self._hparams.num_layers == len(self._hparams.nonlinearity)

        # with tf.variable_scope('encoder'):
        #   encoder_output = tf.layers.flatten(self._input_values)

        #   print('input', encoder_output)
        #   for i in range(self._hparams.num_layers):
        #     encoder_output = tf.layers.dense(encoder_output, self._hparams.filters[i],
        #                                      activation=type_activation_fn(self._hparams.nonlinearity[i]),
        #                                      kernel_initializer=kernel_initializer)
        #     print('encoder', i, encoder_output)

        #   self._dual.set_op('encoding', tf.stop_gradient(encoder_output))

        # with tf.variable_scope('decoder'):
        #   decoder_output = encoder_output
        #   decoder_filters = self._hparams.filters[:-1][::-1]  # Remove last filter (bottleneck), reverse filters
        #   decoder_filters += [output_shape]

        #   decoder_nonlinearity = self._hparams.nonlinearity[:-1][::-1]
        #   decoder_nonlinearity += [self._hparams.output_nonlinearity]

        #   for i in range(self._hparams.num_layers):
        #     decoder_output = tf.layers.dense(encoder_output, decoder_filters[i],
        #                                      activation=type_activation_fn(decoder_nonlinearity[i]),
        #                                      kernel_initializer=kernel_initializer)
        #     print('decoder', i, decoder_output)

        #   output = tf.reshape(decoder_output, [-1] + input_shape[1:])
        #   print('output', output)

        #   self._dual.set_op('decoding', output)
        #   self._dual.set_op('output', tf.stop_gradient(output))

        # # Build loss and optimizer
        # loss = self._build_loss_fn(self._input_values, output)
        # self._dual.set_op('loss', loss)
        # self._build_optimizer(loss, 'training', scope='ae')

    def _build_pm(self):
        """Preprocess the inputs and build the pattern mapping components."""
        def normalize(x):
            return (x - tf.reduce_min(x)) / (tf.reduce_max(x) -
                                             tf.reduce_min(x))

        # map to input
        # x_nn = self._dual.get_op('output')  # output of Deep AE
        x_nn = self._dual.get_op('inputs')
        x_nn = tf.layers.flatten(x_nn)

        x_nn = normalize(x_nn)

        # Apply noise during training, to regularise / test generalisation
        # --------------------------------------------------------------------------
        if self._hparams.pm_noise_type == 's':  # salt noise
            x_nn = tf.cond(
                tf.equal(self._batch_type, 'training'),
                lambda: image_utils.add_image_salt_noise_flat(
                    x_nn,
                    None,
                    noise_val=self._hparams.pm_train_with_noise,
                    noise_factor=self._hparams.pm_train_with_noise_pp,
                    mode='replace'), lambda: x_nn)

        elif self._hparams.pm_noise_type == 'sp':  # salt + pepper noise
            # Inspired by denoising AE.
            # Add salt+pepper noise to mimic missing/extra bits in PC space.
            # Use a fairly high rate of noising to mitigate few training iters.
            x_nn = tf.cond(
                tf.equal(self._batch_type, 'training'),
                lambda: image_utils.add_image_salt_pepper_noise_flat(
                    x_nn,
                    None,
                    salt_val=self._hparams.pm_train_with_noise,
                    pepper_val=-self._hparams.pm_train_with_noise,
                    noise_factor=self._hparams.pm_train_with_noise_pp),
                lambda: x_nn)

        else:
            raise NotImplementedError('PM noise type not supported: ' +
                                      str(self._hparams.noise_type))

        # apply dropout during training
        keep_prob = self._hparams.pm_train_dropout_input_keep_prob
        x_nn = tf.cond(tf.equal(self._batch_type, 'training'),
                       lambda: tf.nn.dropout(x_nn, keep_prob), lambda: x_nn)

        # Build PM
        # --------------------------------------------------------------------------
        if self.use_pm:
            ec_in = self._input_cue
            output_nonlinearity = type_activation_fn('leaky_relu')
            ec_out = self._build_pm_core(x=x_nn,
                                         target=ec_in,
                                         hidden_size=self._hparams.pm_l1_size,
                                         non_linearity1=tf.nn.leaky_relu,
                                         non_linearity2=output_nonlinearity,
                                         loss_fn=tf.losses.mean_squared_error)
            self._dual.set_op('ec_out', ec_out)

        if self._use_pm_raw:
            ec_in = self._input_cue_raw
            output_nonlinearity = type_activation_fn(
                self._hparams.pm_raw_nonlinearity)
            ec_out_raw = self._build_pm_core(
                x=x_nn,
                target=ec_in,
                hidden_size=self._hparams.pm_raw_hidden_size,
                non_linearity1=tf.nn.leaky_relu,
                non_linearity2=output_nonlinearity,
                loss_fn=tf.losses.mean_squared_error,
                name_suffix="_raw")
            self._dual.set_op('ec_out_raw', ec_out_raw)

    def _build_pm_core(self,
                       x,
                       target,
                       hidden_size,
                       non_linearity1,
                       non_linearity2,
                       loss_fn,
                       name_suffix=""):
        """Build the layers of the PM network, with optional L2 regularization."""
        target_shape = target.get_shape().as_list()
        target_size = np.prod(target_shape[1:])
        l2_size = target_size

        weights = []
        scope = 'pm' + name_suffix
        with tf.variable_scope(scope):
            out = x
            keep_prob = self._hparams.pm_train_dropout_hidden_keep_prob

            # Build encoding layers
            for i, num_units in enumerate(hidden_size):
                hidden_layer = tf.layers.Dense(units=num_units,
                                               activation=non_linearity1)
                out = hidden_layer(out)

                weights.append(hidden_layer.weights[0])
                weights.append(hidden_layer.weights[1])

                # apply dropout during training
                out = tf.cond(tf.equal(self._batch_type, 'training'),
                              lambda: tf.nn.dropout(out, keep_prob[i]),
                              lambda: out)

            # Store final hidden state
            self._dual.set_op('encoding', tf.stop_gradient(out))
            self._dual.set_op('decoding', tf.stop_gradient(out))

            # Build output layer
            f_layer = tf.layers.Dense(units=l2_size, activation=non_linearity2)
            f = f_layer(out)

            weights.append(f_layer.weights[0])
            weights.append(f_layer.weights[1])

            y = tf.stop_gradient(
                f)  # ensure gradients don't leak into other nn's in PC
            self._dual.set_op('output', y)

        target_flat = tf.reshape(target, shape=[-1, target_size])
        loss = loss_fn(f, target_flat)

        self._dual.set_op('loss', loss)
        self._dual.set_op('pm_loss' + name_suffix, loss)

        if self._hparams.pm_raw_l2_regularizer > 0.0:
            all_losses = [loss]

            for weight in weights:
                weight_loss = tf.nn.l2_loss(weight)
                weight_loss_sum = tf.reduce_sum(weight_loss)
                weight_loss_scaled = weight_loss_sum * self._hparams.pm_raw_l2_regularizer
                all_losses.append(weight_loss_scaled)

            all_losses_op = tf.add_n(all_losses, name='total_pm_loss')
            self._build_optimizer(all_losses_op, 'training_pm' + name_suffix,
                                  scope)
        else:
            self._build_optimizer(loss, 'training_pm' + name_suffix, scope)

        return y

    @staticmethod
    def _variables_ae_optimizer(outer_scope):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=outer_scope + "/optimizer/ae")

    @staticmethod
    def _variables_pm_raw_optimizer(outer_scope):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=outer_scope + "/optimizer/pm_raw")

    @staticmethod
    def _variables_encoder(outer_scope):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=outer_scope + "/encoder")

    @staticmethod
    def _variables_decoder(outer_scope):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=outer_scope + "/decoder")

    @staticmethod
    def _variables_pm_raw(outer_scope):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=outer_scope + "/pm_raw")

    # OP ACCESS ------------------------------------------------------------------
    def get_encoding_op(self):
        return self._dual.get_op('encoding')

    def get_decoding_op(self):
        return self._dual.get_op('decoding')

    def get_ec_out_raw_op(self):
        return self._dual.get_op('ec_out_raw')

    # MODULAR INTERFACE ------------------------------------------------------------------
    def update_feed_dict(self, feed_dict, batch_type='training'):
        if batch_type == 'training':
            self.update_training_dict(feed_dict)
        if batch_type == 'encoding':
            self.update_encoding_dict(feed_dict)

    def add_fetches(self, fetches, batch_type='training'):
        if batch_type == 'training':
            self.add_training_fetches(fetches)
        if batch_type == 'encoding':
            self.add_encoding_fetches(fetches)

    def set_fetches(self, fetched, batch_type='training'):
        if batch_type == 'training':
            self.set_training_fetches(fetched)
        if batch_type == 'encoding':
            self.set_encoding_fetches(fetched)

    def build_summaries(self, batch_types=None, scope=None):
        """Builds all summaries."""
        if not scope:
            scope = self._name + '/summaries/'
        with tf.name_scope(scope):
            for batch_type in batch_types:
                if batch_type == 'training':
                    self.build_training_summaries()
                if batch_type == 'encoding':
                    self.build_encoding_summaries()

    def write_summaries(self, step, writer, batch_type='training'):
        """Write the summaries fetched into _summary_values"""
        if self._summary_values is not None:
            writer.add_summary(self._summary_values, step)
            writer.flush()

    # TRAINING ------------------------------------------------------------------
    def update_training_dict(self, feed_dict):
        feed_dict.update({self._batch_type: 'training'})

    def add_training_fetches(self, fetches):
        # names = ['loss', 'training', 'encoding', 'decoding', 'inputs']
        names = ['loss', 'encoding', 'decoding', 'inputs']

        if self._use_pm_raw:
            names.extend(['training_pm_raw', 'ec_out_raw'])

        self._dual.add_fetches(fetches, names)

        if self._summary_training_op is not None:
            fetches[self._name]['summaries'] = self._summary_training_op

    def set_training_fetches(self, fetched):
        self_fetched = fetched[self._name]
        self._loss = self_fetched['loss']

        names = ['encoding', 'decoding', 'inputs']

        if self._use_pm_raw:
            names.extend(['ec_out_raw'])

        self._dual.set_fetches(fetched, names)

        if self._summary_training_op is not None:
            self._summary_values = fetched[self._name]['summaries']

    # ENCODING ------------------------------------------------------------------
    def update_encoding_dict(self, feed_dict):
        feed_dict.update({self._batch_type: 'encoding'})

    def add_encoding_fetches(self, fetches):
        names = ['encoding', 'decoding', 'inputs']

        if self._use_pm_raw:
            names.extend(['pm_loss_raw', 'ec_out_raw'])

        self._dual.add_fetches(fetches, names)

        if self._summary_encoding_op is not None:
            fetches[self._name]['summaries'] = self._summary_encoding_op

    def set_encoding_fetches(self, fetched):
        names = ['encoding', 'decoding', 'inputs']

        if self._use_pm_raw:
            names.extend(['pm_loss_raw', 'ec_out_raw'])

        self._dual.set_fetches(fetched, names)

        if self._summary_encoding_op is not None:
            self._summary_values = fetched[self._name]['summaries']

    def get_inputs(self):
        return self._dual.get_values('inputs')

    def get_encoding(self):
        return self._dual.get_values('encoding')

    def get_decoding(self):
        return self._dual.get_values('decoding')

    def get_ec_out_raw(self):
        return self._dual.get_values('ec_out_raw')

    def get_batch_type(self):
        return self._batch_type

    # SUMMARIES ------------------------------------------------------------------
    def write_filters(self, session, folder=None):
        pass

    def build_training_summaries(self):
        with tf.name_scope('training'):
            summaries = self._build_summaries()
            if len(summaries) > 0:
                self._summary_training_op = tf.summary.merge(summaries)
            return self._summary_training_op

    def build_encoding_summaries(self):
        with tf.name_scope('encoding'):
            summaries = self._build_summaries()
            if len(summaries) > 0:
                self._summary_encoding_op = tf.summary.merge(summaries)
            return self._summary_encoding_op

    def _build_summarise_pm(self, summaries, max_outputs):
        if not (self.use_pm_raw or self._use_pm):
            return

        with tf.name_scope('pm'):

            # original vc for visuals
            if self.use_pm_raw:
                ec_in = self._input_cue_raw
                ec_out = self._dual.get_op('ec_out_raw')
                ec_recon = image_utils.concat_images([ec_in, ec_out],
                                                     self._hparams.batch_size)
                summaries.append(
                    tf.summary.image('ec_recon_raw',
                                     ec_recon,
                                     max_outputs=max_outputs))

                pm_loss_raw = self._dual.get_op('pm_loss_raw')
                summaries.append(tf.summary.scalar('pm_loss_raw', pm_loss_raw))

    def _build_summaries(self):
        """Build the summaries for TensorBoard."""

        # summarise_stuff = ['general', 'pm']
        summarise_stuff = ['pm']

        max_outputs = self._hparams.max_outputs
        summaries = []

        if self._hparams.summarize_level == SummarizeLevels.OFF.value:
            return summaries

        if 'general' in summarise_stuff:
            encoding_op = self.get_encoding_op()
            decoding_op = self.get_decoding_op()

            summary_input_shape = image_utils.get_image_summary_shape(
                self._input_shape)

            input_summary_reshape = tf.reshape(self._input_values,
                                               summary_input_shape)
            decoding_summary_reshape = tf.reshape(decoding_op,
                                                  summary_input_shape)

            summary_reconstruction = tf.concat(
                [input_summary_reshape, decoding_summary_reshape], axis=1)
            reconstruction_summary_op = tf.summary.image(
                'reconstruction',
                summary_reconstruction,
                max_outputs=max_outputs)
            summaries.append(reconstruction_summary_op)

            # show input on it's own
            input_alone = True
            if input_alone:
                summaries.append(
                    tf.summary.image('input',
                                     input_summary_reshape,
                                     max_outputs=max_outputs))

            summaries.append(
                self._summary_hidden(encoding_op, 'encoding', max_outputs))

            # Loss
            loss_summary = tf.summary.scalar('loss', self._dual.get_op('loss'))
            summaries.append(loss_summary)

        if 'pm' in summarise_stuff:
            self._build_summarise_pm(summaries, max_outputs)

        return summaries

    def _summary_hidden(self, hidden, name, max_outputs=3):
        """Return a summary op of a 'square as possible' image of hidden, the tensor for the hidden state"""
        hidden_shape_4d = self._hidden_image_summary_shape(
        )  # [batches, height=1, width=filters, 1]
        summary_reshape = tf.reshape(hidden, hidden_shape_4d)
        summary_op = tf.summary.image(name,
                                      summary_reshape,
                                      max_outputs=max_outputs)
        return summary_op
示例#6
0
class EpisodicComponent(CompositeComponent):
    """
  A component to implement episodic memory, inspired by the Medial Temporal Lobe.
  Currently, it consists of a Visual Cortex (Sparse Autoencoder, SAE) and
  a Pattern Completer similar to DG/CA3 (Differentiable Plasticity or SAE).
  """
    @staticmethod
    def default_hparams():
        """Builds an HParam object with default hyperparameters."""

        # create component level hparams (this will be a multi hparam, with hparams from sub components)
        batch_size = 40
        max_outputs = 3
        hparam = tf.contrib.training.HParams(
            batch_size=batch_size,
            output_features=
            'pc',  # the output of this subcomponent is used as the component's features
            pc_type=
            'sae',  # none, hl = hopfield like, sae = sparse autoencoder, dp = differentiable-plasticity
            dg_type='fc',  # 'none', 'fc', or 'conv' Dentate Gyrus
            ll_vc_type='none',  # vc label learner: 'none', 'fc'
            ll_pc_type='none',  # pc label learner: 'none', 'fc'
            use_cue_to_pc=
            False,  # use a secondary input as a cue to pc (EC perforant path to CA3)
            use_pm=False,  # pattern mapping (reconstruct inputs from PC output
            use_interest_filter=
            False,  # this replaces VC (attentional system zones in on interesting features)
            summarize_level=SummarizeLevels.ALL.
            value,  # for the top summaries (leave individual comps to decide on own)
            vc_norm_per_filter=False,
            vc_norm_per_sample=False,
            max_pool_vc_final_size=2,
            max_pool_vc_final_stride=1,
            max_outputs=max_outputs)

        # create all possible sub component hparams (must create one for every possible sub component)
        if HVC_ENABLED:
            vc = VisualCortexComponent.default_hparams()
        else:
            vc = SparseConvAutoencoderMaxPoolComponent.default_hparams()

        dg_fc = DGSAE.default_hparams()
        dg_conv = DGSCAE.default_hparams()
        dg_stub = DGStubComponent.default_hparams()
        pc_sae = SparseAutoencoderComponent.default_hparams()
        pc_dae = DeepAutoencoderComponent.default_hparams()
        pc_dp = DifferentiablePlasticityComponent.default_hparams()
        pc_hl = HopfieldlikeComponent.default_hparams()
        ifi = InterestFilter.default_hparams()
        ll_vc = LabelLearnerFC.default_hparams()
        ll_pc = LabelLearnerFC.default_hparams()

        subcomponents = [
            vc, dg_fc, dg_conv, dg_stub, pc_sae, pc_dae, pc_dp, pc_hl, ll_vc,
            ll_pc
        ]  # all possible subcomponents

        # default overrides of sub-component hparam defaults
        if not HVC_ENABLED:
            vc.set_hparam('learning_rate', 0.001)
            vc.set_hparam('sparsity', 25)
            vc.set_hparam('sparsity_output_factor', 1.5)
            vc.set_hparam('filters', 64)
            vc.set_hparam('filters_field_width', 6)
            vc.set_hparam('filters_field_height', 6)
            vc.set_hparam('filters_field_stride', 3)

            vc.set_hparam('pool_size', 2)
            vc.set_hparam('pool_strides', 2)

            # Note that DG will get the pooled->unpooled encoding
            vc.set_hparam('use_max_pool', 'none')  # none, encoding, training

        dg_fc.set_hparam('learning_rate', 0.001)
        dg_fc.set_hparam('sparsity', 20)
        dg_fc.set_hparam('sparsity_output_factor', 1.0)
        dg_fc.set_hparam('filters', 784)

        pc_hl.set_hparam('learning_rate', 0.0001)
        pc_hl.set_hparam('optimizer', 'adam')
        pc_hl.set_hparam('momentum', 0.9)
        pc_hl.set_hparam('momentum_nesterov', False)
        pc_hl.set_hparam('use_feedback', True)
        pc_hl.set_hparam('memorise_method', 'pinv')
        pc_hl.set_hparam('nonlinearity', 'none')
        pc_hl.set_hparam('update_n_neurons', -1)

        # default hparams in individual component should be consistent with component level hparams
        HParamMulti.set_hparam_in_subcomponents(subcomponents, 'batch_size',
                                                batch_size)

        # add sub components to the composite hparams
        HParamMulti.add(source=vc, multi=hparam, component='vc')
        HParamMulti.add(source=dg_fc, multi=hparam, component='dg_fc')
        HParamMulti.add(source=dg_conv, multi=hparam, component='dg_conv')
        HParamMulti.add(source=dg_stub, multi=hparam, component='dg_stub')
        HParamMulti.add(source=pc_dp, multi=hparam, component='pc_dp')
        HParamMulti.add(source=pc_sae, multi=hparam, component='pc_sae')
        HParamMulti.add(source=pc_dae, multi=hparam, component='pc_dae')
        HParamMulti.add(source=pc_hl, multi=hparam, component='pc_hl')
        HParamMulti.add(source=ifi, multi=hparam, component='ifi')
        HParamMulti.add(source=ll_vc, multi=hparam, component='ll_vc')
        HParamMulti.add(source=ll_pc, multi=hparam, component='ll_pc')

        return hparam

    def __init__(self):
        super(EpisodicComponent, self).__init__()

        self._name = None
        self._hparams = None
        self._summary_op = None
        self._summary_result = None
        self._dual = None
        self._input_shape = None
        self._input_values = None
        self._summary_values = None

        self._sub_components = {}  # map {name, component}

        self._pc_mode = PCMode.Combined

        self._pc_input = None
        self._pc_input_vis_shape = None

        self._degrade_type = 'random'  # if degrading is used, then a degrade type: vertical, horizontal, random

        self._signals = {
        }  # signals at each stage: convenient container for significant signals
        self._show_episodic_level_summary = True

        self._interest_filter = None

    def batch_size(self):
        return self._hparams.batch_size

    def get_vc_encoding(self):
        return self._dual.get_values('vc_encoding')

    def is_build_dg(self):
        return self._hparams.dg_type != 'none'

    def is_build_ll_vc(self):
        return self._hparams.ll_vc_type != 'none'

    def is_build_ll_pc(self):
        return self._hparams.ll_pc_type != 'none'

    def is_build_ll_ensemble(self):
        build_ll_ensemble = True
        return self.is_build_ll_vc() and self.is_build_ll_pc(
        ) and build_ll_ensemble

    def is_build_pc(self):
        return self._hparams.pc_type != 'none'

    def is_pc_hopfield(self):
        return isinstance(self.get_pc(), HopfieldlikeComponent)

    @staticmethod
    def is_vc_hierarchical():
        # return isinstance(self.get_vc(), VisualCortexComponent)
        return HVC_ENABLED  # need to use this before _component is instantiated

    def pc_combined(self):
        self._pc_mode = PCMode.Combined

    def pc_exclude(self):
        self._pc_mode = PCMode.Exclude

    def pc_only(self):
        self._pc_mode = PCMode.PCOnly

    @property
    def name(self):
        return self._name

    def get_interest_filter_masked_encodings(self):
        return self._dual.get_values('masked_encodings')

    def get_interest_filter_positional_encodings(self):
        return self._dual.get_values('positional_encodings')

    def set_signal(self, key, val, val_shape):
        """Set as a significant signal, that should be summarised"""
        self._signals.update({key: (val, val_shape)})

    def get_signal(self, key):
        val, val_shape = self._signals[key]
        return val, val_shape

    def get_loss(self):
        """Define loss as the loss of the subcomponent selected for output features: using _hparam.output_features"""

        if self._hparams.output_features == 'vc':
            comp = self.get_vc()
        elif self._hparams.output_features == 'dg':
            comp = self.get_dg()
        else:  # assumes     output_features == 'pc'
            comp = self.get_pc()
        return comp.get_loss()

    @staticmethod
    def degrader(degrade_step_pl,
                 degrade_type,
                 random_value_pl,
                 input_values,
                 degrade_step,
                 name=None):

        return tf.cond(
            tf.equal(degrade_step_pl, degrade_step),
            lambda: image_utils.degrade_image(input_values,
                                              degrade_type=degrade_type,
                                              random_value=random_value_pl),
            lambda: input_values,
            name=name)

    def _build_vc(self, input_values, input_shape):

        if HVC_ENABLED:
            vc = VisualCortexComponent()
            hparams_vc = VisualCortexComponent.default_hparams()
        else:
            vc = SparseConvAutoencoderMaxPoolComponent()
            hparams_vc = SparseConvAutoencoderMaxPoolComponent.default_hparams(
            )

        hparams_vc = HParamMulti.override(multi=self._hparams,
                                          target=hparams_vc,
                                          component='vc')
        vc.build(input_values, input_shape, hparams_vc, 'vc')
        self._add_sub_component(vc, 'vc')

        # Update 'next' value/shape for DG
        if HVC_ENABLED:
            # Since pooling/unpooling is applied within the VC component,
            # use the get_output() method to get the final layer of VC with the
            # appropriate pooling/unpooling setting.
            input_values_next = vc.get_output_op()
        else:
            # Otherwise, get the encoding or unpooled encoding as appropriate
            input_values_next = vc.get_encoding_op()
            if hparams_vc.use_max_pool == 'encoding':
                input_values_next = vc.get_encoding_unpooled_op()
        print('vc', 'output', input_values_next)

        # Optionally norm VC per filter (this should probably be only first layer, but only one layer for now anyway)
        for _, layer in vc.get_sub_components().items():
            layer.set_norm_filters(self._hparams.vc_norm_per_filter)

        # Add InterestFilter to mask in and blur position of interesting visual filters (and block out the rest)
        if self._hparams.use_interest_filter:
            self._interest_filter = InterestFilter()
            image_tensor, image_shape = self.get_signal('input')
            vc_tensor = input_values_next
            vc_shape = vc_tensor.get_shape().as_list()

            assert image_shape[:-1] == vc_shape[:-1], "The VC encoding must be the same height and width as the image " \
              "i.e. conv stride 1"

            hparams_ifi = InterestFilter.default_hparams()
            hparams_ifi = HParamMulti.override(multi=self._hparams,
                                               target=hparams_ifi,
                                               component='ifi')
            _, input_values_next = self._interest_filter.build(
                image_tensor, vc_tensor, hparams_ifi)
            self._dual.set_op(
                'masked_encodings',
                self._interest_filter.get_image('masked_encodings'))
            self._dual.set_op(
                'positional_encodings',
                self._interest_filter.get_image('positional_encodings'))

        # Optionally pool the final output of the VC (simply to reduce dimensionality)
        pool_size = self._hparams.max_pool_vc_final_size
        pool_stride = self._hparams.max_pool_vc_final_stride
        if pool_size > 1:
            input_values_next = tf.layers.max_pooling2d(
                input_values_next, pool_size, pool_stride, 'SAME')
            print('vc final pooled', input_values_next)

        # Optionally norm the output samples so that they are comparable to the next stage
        def normalize_min_max_4d(x):
            sample_mins = tf.reduce_min(x, axis=[1, 2, 3], keepdims=True)
            sample_maxs = tf.reduce_max(x, axis=[1, 2, 3], keepdims=True)
            return (x - sample_mins) / (sample_maxs - sample_mins)

        if self._hparams.vc_norm_per_sample:
            frobenius_norm = tf.sqrt(
                tf.reduce_sum(tf.square(input_values_next),
                              axis=[1, 2, 3],
                              keepdims=True))
            input_values_next = input_values_next / frobenius_norm
            #input_values_next = normalize_min_max_4d(input_values_next)

        # Unpack the conv cells shape
        input_volume = np.prod(input_values_next.get_shape().as_list()[1:])
        input_next_vis_shape, _ = image_utils.square_image_shape_from_1d(
            input_volume)

        return input_values_next, input_next_vis_shape

    def _build_ll_vc(self,
                     target_output,
                     train_input,
                     test_input,
                     name='ll_vc'):
        """Build the label learning component for LTM."""
        ll_vc = None

        # Don't normalize this yet
        train_input = normalize_minmax(train_input)
        test_input = normalize_minmax(test_input)

        if self._hparams.ll_vc_type == 'fc':
            ll_vc = LabelLearnerFC()
            self._add_sub_component(ll_vc, name)
            hparams_ll_vc = LabelLearnerFC.default_hparams()
            hparams_ll_vc = HParamMulti.override(multi=self._hparams,
                                                 target=hparams_ll_vc,
                                                 component='ll_vc')
            ll_vc.build(target_output, train_input, test_input, hparams_ll_vc,
                        name)

        return ll_vc

    def _build_ll_pc(self,
                     target_output,
                     train_input,
                     test_input,
                     name='ll_pc'):
        """Build the label learning component for PC."""
        ll_pc = None

        train_input = normalize_minmax(train_input)
        test_input = normalize_minmax(test_input)

        if self._hparams.ll_pc_type == 'fc':
            ll_pc = LabelLearnerFC()
            self._add_sub_component(ll_pc, name)
            hparams_ll_pc = LabelLearnerFC.default_hparams()
            hparams_ll_pc = HParamMulti.override(multi=self._hparams,
                                                 target=hparams_ll_pc,
                                                 component='ll_pc')
            ll_pc.build(target_output, train_input, test_input, hparams_ll_pc,
                        name)

        return ll_pc

    def _build_dg(self, input_next, input_next_vis_shape):
        """Builds the pattern separation component."""
        dg_type = self._hparams.dg_type

        if dg_type == 'stub':

            # create fc, so that we can use the encodings etc. without breaking other stuff
            dg = DGStubComponent()
            self._add_sub_component(dg, 'dg')
            hparams_dg = DGStubComponent.default_hparams()
            hparams_dg = HParamMulti.override(multi=self._hparams,
                                              target=hparams_dg,
                                              component='dg_stub')

            dg.build(hparams_dg, 'dg')

            # Update 'next' value/shape for PC
            input_next = dg.get_encoding_op()
            input_next_vis_shape, _ = image_utils.square_image_shape_from_1d(
                hparams_dg.filters)

            dg_sparsity = hparams_dg.sparsity
        elif dg_type == 'fc':
            dg = DGSAE()
            self._add_sub_component(dg, 'dg')
            hparams_dg = DGSAE.default_hparams()
            hparams_dg = HParamMulti.override(multi=self._hparams,
                                              target=hparams_dg,
                                              component='dg_fc')
            dg.build(input_next, input_next_vis_shape, hparams_dg, 'dg')

            # Update 'next' value/shape for PC
            input_next = dg.get_encoding_op()
            input_next_vis_shape, _ = image_utils.square_image_shape_from_1d(
                hparams_dg.filters)

            dg_sparsity = hparams_dg.sparsity
        elif dg_type == 'conv':
            input_next_vis_shape = [-1] + input_next.get_shape().as_list()[1:]

            print('dg', 'input', input_next)

            dg = DGSCAE()
            self._add_sub_component(dg, 'dg')
            hparams_dg = DGSCAE.default_hparams()
            hparams_dg = HParamMulti.override(multi=self._hparams,
                                              target=hparams_dg,
                                              component='dg_conv')
            dg.build(input_next, input_next_vis_shape, hparams_dg, 'dg')

            # Update 'next' value/shape for PC
            input_next = dg.get_encoding_op()

            print('dg', 'output', input_next)

            # Unpack the conv cells shape
            input_volume = np.prod(input_next.get_shape().as_list()[1:])
            input_next_vis_shape, _ = image_utils.square_image_shape_from_1d(
                input_volume)
            input_next = tf.reshape(input_next, [-1, input_volume])

            dg_sparsity = hparams_dg.sparsity

        else:
            raise NotImplementedError('Dentate Gyrus not implemented: ' +
                                      dg_type)

        return input_next, input_next_vis_shape, dg_sparsity

    def _build_pc(self, input_next, input_next_vis_shape, dg_sparsity):

        pc_type = self._hparams.pc_type
        use_cue_to_pc = self._hparams.use_cue_to_pc
        use_pm = self._hparams.use_pm

        if use_cue_to_pc:
            cue = self._signals['vc'][0]
        else:
            cue = None

        cue_raw = None
        if use_pm:
            cue_raw = self._signals['vc_input'][0]

        if pc_type == 'sae':
            pc = SparseAutoencoderComponent()
            self._add_sub_component(pc, 'pc')
            hparams_sae = SparseAutoencoderComponent.default_hparams()
            hparams_sae = HParamMulti.override(multi=self._hparams,
                                               target=hparams_sae,
                                               component='pc_sae')
            pc.build(input_next, input_next_vis_shape, hparams_sae, 'pc')
        elif pc_type == 'dae':
            pc = DeepAutoencoderComponent()
            self._add_sub_component(pc, 'pc')
            hparams_dae = DeepAutoencoderComponent.default_hparams()
            hparams_dae = HParamMulti.override(multi=self._hparams,
                                               target=hparams_dae,
                                               component='pc_dae')
            input_next_shape = input_next.get_shape().as_list()
            pc.build(input_next,
                     input_next_shape,
                     hparams_dae,
                     'pc',
                     input_cue_raw=cue_raw)
        elif pc_type == 'dp':
            # DP works with batches differently, so not prescriptive for input shape (used for summaries only)
            input_next_vis_shape[0] = -1

            # ensure DP receives binary values (all k winners will be 1)
            input_next = tf.greater(input_next, 0)
            input_next = tf.to_float(input_next)

            pc = DifferentiablePlasticityComponent()
            self._add_sub_component(pc, 'pc')
            hparams_pc_dp = DifferentiablePlasticityComponent.default_hparams()
            hparams_pc_dp = HParamMulti.override(multi=self._hparams,
                                                 target=hparams_pc_dp,
                                                 component='pc_dp')
            pc.build(input_next, input_next_vis_shape, hparams_pc_dp, 'pc')
        elif pc_type == 'hl':
            pc = HopfieldlikeComponent()
            self._add_sub_component(pc, 'pc')
            hparams_hl = HopfieldlikeComponent.default_hparams()
            hparams_hl = HParamMulti.override(multi=self._hparams,
                                              target=hparams_hl,
                                              component='pc_hl')

            if dg_sparsity == 0:
                raise RuntimeError(
                    "Could not establish dg per sample sparsity to pass to Hopfield."
                )

            hparams_hl.cue_nn_label_sparsity = dg_sparsity

            pc.build(input_next,
                     input_next_vis_shape,
                     hparams_hl,
                     'pc',
                     input_cue=cue,
                     input_cue_raw=cue_raw)
        else:
            raise NotImplementedError('Pattern completer not implemented: ' +
                                      pc_type)

        pc_output = pc.get_decoding_op()

        if pc_type == 'dae':
            input_volume = np.prod(pc_output.get_shape().as_list()[1:])
            pc_output_shape, _ = image_utils.square_image_shape_from_1d(
                input_volume)
        else:
            pc_output_shape = input_next_vis_shape  # output is same shape and size as input

        return pc_output, pc_output_shape

    def _build_ll_ensemble(self):
        """Builds ensemble of VC and PC classifiers."""
        distributions = []
        distribution_mass = []
        num_classes = self._label_values.get_shape().as_list()[-1]

        aha_mass = 0.495
        ltm_mass = 0.495
        uniform_mass = 0.01

        if aha_mass > 0.0:
            aha_prediction = self.get_ll_pc().get_op('preds')
            distributions.append(aha_prediction)
            distribution_mass.append(aha_mass)

        if ltm_mass > 0.0:
            ltm_prediction = self.get_ll_vc().get_op('preds')
            distributions.append(ltm_prediction)
            distribution_mass.append(ltm_mass)

        if uniform_mass > 0.0:
            uniform = np_uniform(num_classes)
            distributions.append(uniform)
            distribution_mass.append(uniform_mass)

        unseen_sum = 1
        unseen_idxs = (0, unseen_sum)

        # Build the final distribution, calculate loss
        ensemble_preds = tf_build_interpolate_distributions(
            distributions, distribution_mass, num_classes)

        ensemble_correct_preds = tf.equal(tf.argmax(ensemble_preds, 1),
                                          tf.argmax(self._label_values, 1))
        ensemble_correct_preds = tf.cast(ensemble_correct_preds, tf.float32)

        ensemble_accuracy = tf.reduce_mean(ensemble_correct_preds)
        ensemble_accuracy_unseen = tf.reduce_mean(
            ensemble_correct_preds[unseen_idxs[0]:unseen_idxs[1]])

        self._dual.set_op('ensemble_preds', ensemble_preds)
        self._dual.set_op('ensemble_accuracy', ensemble_accuracy)
        self._dual.set_op('ensemble_accuracy_unseen', ensemble_accuracy_unseen)

    def build(self,
              input_values,
              input_shape,
              hparams,
              label_values=None,
              name='episodic'):
        """Initializes the model parameters.

    Args:
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        :param input_values:
        :param input_shape:
        :param hparams:
        :param name:
    """

        self._name = name
        self._hparams = hparams
        self._summary_op = None
        self._summary_result = None
        self._dual = DualData(self._name)
        self._input_values = input_values
        self._input_shape = input_shape
        self._label_values = label_values

        input_area = np.prod(input_shape[1:])

        logging.debug('Input Shape: %s', input_shape)
        logging.debug('Input Area: %s', input_area)

        with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):

            # Replay mode
            # ------------------------------------------------------------------------
            replay_mode = 'pixel'  # pixel or encoding
            replay = self._dual.add('replay', shape=[],
                                    default_value=False).add_pl(default=True,
                                                                dtype=tf.bool)

            # Replace labels during replay
            replay_labels = self._dual.add('replay_labels',
                                           shape=label_values.shape,
                                           default_value=0.0).add_pl(
                                               default=True,
                                               dtype=label_values.dtype)

            self._label_values = tf.cond(tf.equal(replay,
                                                  True), lambda: replay_labels,
                                         lambda: self._label_values)

            # Replay pixel inputs during replay, if using 'pixel' replay mode
            if replay_mode == 'pixel':
                replay_inputs = self._dual.add('replay_inputs',
                                               shape=input_values.shape,
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=input_values.dtype)

                self._input_values = tf.cond(tf.equal(replay, True),
                                             lambda: replay_inputs,
                                             lambda: self._input_values)

            self.set_signal('input', self._input_values, self._input_shape)

            # Build the LTM
            # ------------------------------------------------------------------------

            # Optionally degrade input to VC
            degrade_step_pl = self._dual.add(
                'degrade_step',
                shape=[],  # e.g. hidden, input, none
                default_value='none').add_pl(default=True, dtype=tf.string)
            degrade_random_pl = self._dual.add('degrade_random',
                                               shape=[],
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=tf.float32)
            input_values = self.degrader(degrade_step_pl,
                                         self._degrade_type,
                                         degrade_random_pl,
                                         self._input_values,
                                         degrade_step='input',
                                         name='vc_input_values')

            print('vc', 'input', input_values)
            self.set_signal('vc_input', input_values, input_shape)

            # Build the VC
            input_next, input_next_vis_shape = self._build_vc(
                input_values, input_shape)

            vc_encoding = input_next

            # Replace the encoding during replay, if using 'encoding' replay mode
            if replay_mode == 'encoding':
                replay_inputs = self._dual.add('replay_inputs',
                                               shape=vc_encoding.shape,
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=vc_encoding.dtype)

                vc_encoding = tf.cond(tf.equal(replay,
                                               True), lambda: replay_inputs,
                                      lambda: vc_encoding)

            self.set_signal('vc', vc_encoding, input_next_vis_shape)
            self._dual.set_op('vc_encoding', vc_encoding)

            # Build the softmax classifier
            if self.is_build_ll_vc() and self._label_values is not None:
                self._build_ll_vc(self._label_values, vc_encoding, vc_encoding)

            # Build AHA
            # ------------------------------------------------------------------------

            # Build the DG
            dg_sparsity = 0
            if self.is_build_dg():
                input_next, input_next_vis_shape, dg_sparsity = self._build_dg(
                    input_next, input_next_vis_shape)
                dg_encoding = input_next
                self.set_signal('dg', dg_encoding, input_next_vis_shape)

            # Build the PC
            if self.is_build_pc():
                # Optionally degrade input to PC

                # not all degrade types are supported for embedding in graph (but may still be used directly on test set)
                if self._degrade_type != 'rect' and self._degrade_type != 'circle':
                    input_next = self.degrader(degrade_step_pl,
                                               self._degrade_type,
                                               degrade_random_pl,
                                               input_next,
                                               degrade_step='hidden',
                                               name='pc_input_values')
                print('pc_input', input_next)
                self.set_signal('pc_input', input_next, input_next_vis_shape)

                pc_output, pc_output_shape = self._build_pc(
                    input_next, input_next_vis_shape, dg_sparsity)
                self.set_signal('pc', pc_output, pc_output_shape)

                if self._hparams.use_pm:
                    ec_out_raw = self.get_pc().get_ec_out_raw_op()
                    self.set_signal('ec_out_raw', ec_out_raw, input_shape)

                if self.is_build_ll_pc() and self.is_build_dg(
                ) and self._label_values is not None:
                    self._build_ll_pc(self._label_values, dg_encoding,
                                      pc_output)

            if self.is_build_ll_ensemble():
                self._build_ll_ensemble()

        self.reset()

    def get_vc(self):
        vc = self.get_sub_component('vc')
        return vc

    def get_pc(self):
        pc = self.get_sub_component('pc')
        return pc

    def get_dg(self):
        dg = self.get_sub_component('dg')
        return dg

    def get_decoding(self):
        return self.get_pc().get_decoding()

    def get_ll_vc(self):
        return self.get_sub_component('ll_vc')

    def get_ll_pc(self):
        return self.get_sub_component('ll_pc')

    def get_batch_type(self, name=None):
        """
    Return dic of batch types for each component (key is component)
    If component does not have a persistent batch type, then don't include in dictionary,
    assumption is that in that case, it doesn't have any effect.
    """
        if name is None:
            batch_types = dict.fromkeys(self._sub_components.keys(), None)
            for c in self._sub_components:
                if hasattr(self._sub_components[c], 'get_batch_type'):
                    batch_types[c] = self._sub_components[c].get_batch_type()
                else:
                    batch_types.pop(c)
            return batch_types

        return self._sub_components[name].get_batch_type()

    def get_features(self, batch_type='training'):
        """
    The output of the component is taken from one of the subcomponents, depending on hparams.
    If not vc or dg, the fallback is to take from pc regardless of value of the hparam
    """
        del batch_type
        if self._hparams.output_features == 'vc':
            features = self.get_vc().get_features()
        elif self._hparams.output_features == 'dg':
            features = self.get_dg().get_features()
        else:  # self._hparams.output_features == 'pc':
            features = self.get_pc().get_features()
        return features

    def _is_skip_for_pc(self, name):
        if self._pc_mode == PCMode.PCOnly:  # only pc
            if name != 'pc':
                return True
        elif self._pc_mode == PCMode.Exclude:  # skip pc
            if name == 'pc':
                return True
        return False

    def update_feed_dict_input_gain_pl(self, feed_dict, gain):
        """
    This is relevant for the PC, and is only be called when it is being run recursively.
    """
        if self.get_pc() is not None:
            self.get_pc().update_feed_dict_input_gain_pl(feed_dict, gain)

    def update_feed_dict(self, feed_dict, batch_type='training'):
        for name, comp in self._sub_components.items():
            if self._is_skip_for_pc(name):
                continue
            comp.update_feed_dict(feed_dict,
                                  self._select_batch_type(batch_type, name))

    def add_fetches(self, fetches, batch_type='training'):
        # each component adds its own
        for name, comp in self._sub_components.items():
            if self._is_skip_for_pc(name):
                continue
            comp.add_fetches(fetches,
                             self._select_batch_type(batch_type, name))

        # Episodic Component specific
        # ------------------------------
        # Interest Filter and other
        names = []

        if self._hparams.use_interest_filter:
            names.extend(['masked_encodings', 'positional_encodings'])

        if self.is_build_ll_ensemble():
            names.extend([
                'ensemble_preds', 'ensemble_accuracy',
                'ensemble_accuracy_unseen'
            ])

        # Other
        names.extend(['vc_encoding'])

        if len(names) > 0:
            self._dual.add_fetches(fetches, names)

        # Episodic Component specific - summaries
        bt = self._select_batch_type(batch_type, self._name)
        summary_op = self._dual.get_op(generic_utils.summary_name(bt))
        if summary_op is not None:
            fetches[self._name]['summaries'] = summary_op

    def set_fetches(self, fetched, batch_type='training'):
        # each component adds its own
        for name, comp in self._sub_components.items():
            if self._is_skip_for_pc(name):
                continue
            comp.set_fetches(fetched,
                             self._select_batch_type(batch_type, name))

        # Episodic Component specific
        # ----------------------------
        # Interest Filter
        names = []

        if self._hparams.use_interest_filter:
            names.extend(['masked_encodings', 'positional_encodings'])

        if self.is_build_ll_ensemble():
            names.extend([
                'ensemble_preds', 'ensemble_accuracy',
                'ensemble_accuracy_unseen'
            ])

        # other
        names.extend(['vc_encoding'])

        if len(names) > 0:
            self._dual.set_fetches(fetched, names)

        # Episodic Component specific - summaries
        bt = self._select_batch_type(batch_type, self._name)
        summary_op = self._dual.get_op(generic_utils.summary_name(bt))
        if summary_op is not None:
            self._summary_values = fetched[self._name]['summaries']

    def build_summaries(self, batch_types=None):
        if batch_types is None:
            batch_types = []

        components = self._sub_components.copy()

        consolidate_graph_view = False

        if self._show_episodic_level_summary:
            components.update({self._name: self})

        for name, comp in components.items():
            scope = name + '-summaries'  # this is best for visualising images in summaries
            if consolidate_graph_view:
                scope = self._name + '/' + name + '/summaries/'

            bt = self._select_batch_type(batch_types, name, as_list=True)

            if name == self._name:
                comp.build_summaries_episodic(bt, scope=scope)
            else:
                comp.build_summaries(bt, scope=scope)

    def write_summaries(self, step, writer, batch_type='training'):
        # the episodic component itself
        if self._summary_values is not None:
            writer.add_summary(
                self._summary_values,
                step)  # Write the summaries fetched into _summary_values
            writer.flush()

        super().write_summaries(step, writer, batch_type)

    def write_recursive_summaries(self, step, writer, batch_type=None):
        for name, comp in self._sub_components.items():
            if hasattr(comp, 'write_recursive_summaries'):
                comp.write_recursive_summaries(step, writer, batch_type)

    def build_summaries_episodic(self, batch_types=None, scope=None):
        """Builds all summaries."""

        if not scope:
            scope = self._name + '/summaries/'
        with tf.name_scope(scope):
            for batch_type in batch_types:

                # build 'batch_type' summary subgraph
                with tf.name_scope(batch_type):
                    summaries = self._build_summaries(batch_type)
                    if len(summaries) > 0:
                        self._dual.set_op(
                            generic_utils.summary_name(batch_type),
                            tf.summary.merge(summaries))

    def _build_summaries(self, batch_type='training'):
        """Assumes appropriate name scope has been set."""
        max_outputs = self._hparams.max_outputs
        summaries = []

        if self._hparams.summarize_level != SummarizeLevels.OFF.value:
            for key, pair in self._signals.items():
                val = pair[0]
                val_shape = pair[1]

                summary_shape = image_utils.get_image_summary_shape(val_shape)
                reshaped = tf.reshape(val, summary_shape)
                summaries.append(
                    tf.summary.image(key, reshaped, max_outputs=max_outputs))

        if self._hparams.use_interest_filter and self._interest_filter.summarize_level(
        ) != SummarizeLevels.OFF.value:
            with tf.name_scope('interest_filter'):
                self._interest_filter.add_summaries(summaries)

        return summaries
示例#7
0
    def build(self,
              input_values,
              input_shape,
              hparams,
              label_values=None,
              name='episodic'):
        """Initializes the model parameters.

    Args:
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        :param input_values:
        :param input_shape:
        :param hparams:
        :param name:
    """

        self._name = name
        self._hparams = hparams
        self._summary_op = None
        self._summary_result = None
        self._dual = DualData(self._name)
        self._input_values = input_values
        self._input_shape = input_shape
        self._label_values = label_values

        input_area = np.prod(input_shape[1:])

        logging.debug('Input Shape: %s', input_shape)
        logging.debug('Input Area: %s', input_area)

        with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):

            # Replay mode
            # ------------------------------------------------------------------------
            replay_mode = 'pixel'  # pixel or encoding
            replay = self._dual.add('replay', shape=[],
                                    default_value=False).add_pl(default=True,
                                                                dtype=tf.bool)

            # Replace labels during replay
            replay_labels = self._dual.add('replay_labels',
                                           shape=label_values.shape,
                                           default_value=0.0).add_pl(
                                               default=True,
                                               dtype=label_values.dtype)

            self._label_values = tf.cond(tf.equal(replay,
                                                  True), lambda: replay_labels,
                                         lambda: self._label_values)

            # Replay pixel inputs during replay, if using 'pixel' replay mode
            if replay_mode == 'pixel':
                replay_inputs = self._dual.add('replay_inputs',
                                               shape=input_values.shape,
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=input_values.dtype)

                self._input_values = tf.cond(tf.equal(replay, True),
                                             lambda: replay_inputs,
                                             lambda: self._input_values)

            self.set_signal('input', self._input_values, self._input_shape)

            # Build the LTM
            # ------------------------------------------------------------------------

            # Optionally degrade input to VC
            degrade_step_pl = self._dual.add(
                'degrade_step',
                shape=[],  # e.g. hidden, input, none
                default_value='none').add_pl(default=True, dtype=tf.string)
            degrade_random_pl = self._dual.add('degrade_random',
                                               shape=[],
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=tf.float32)
            input_values = self.degrader(degrade_step_pl,
                                         self._degrade_type,
                                         degrade_random_pl,
                                         self._input_values,
                                         degrade_step='input',
                                         name='vc_input_values')

            print('vc', 'input', input_values)
            self.set_signal('vc_input', input_values, input_shape)

            # Build the VC
            input_next, input_next_vis_shape = self._build_vc(
                input_values, input_shape)

            vc_encoding = input_next

            # Replace the encoding during replay, if using 'encoding' replay mode
            if replay_mode == 'encoding':
                replay_inputs = self._dual.add('replay_inputs',
                                               shape=vc_encoding.shape,
                                               default_value=0.0).add_pl(
                                                   default=True,
                                                   dtype=vc_encoding.dtype)

                vc_encoding = tf.cond(tf.equal(replay,
                                               True), lambda: replay_inputs,
                                      lambda: vc_encoding)

            self.set_signal('vc', vc_encoding, input_next_vis_shape)
            self._dual.set_op('vc_encoding', vc_encoding)

            # Build the softmax classifier
            if self.is_build_ll_vc() and self._label_values is not None:
                self._build_ll_vc(self._label_values, vc_encoding, vc_encoding)

            # Build AHA
            # ------------------------------------------------------------------------

            # Build the DG
            dg_sparsity = 0
            if self.is_build_dg():
                input_next, input_next_vis_shape, dg_sparsity = self._build_dg(
                    input_next, input_next_vis_shape)
                dg_encoding = input_next
                self.set_signal('dg', dg_encoding, input_next_vis_shape)

            # Build the PC
            if self.is_build_pc():
                # Optionally degrade input to PC

                # not all degrade types are supported for embedding in graph (but may still be used directly on test set)
                if self._degrade_type != 'rect' and self._degrade_type != 'circle':
                    input_next = self.degrader(degrade_step_pl,
                                               self._degrade_type,
                                               degrade_random_pl,
                                               input_next,
                                               degrade_step='hidden',
                                               name='pc_input_values')
                print('pc_input', input_next)
                self.set_signal('pc_input', input_next, input_next_vis_shape)

                pc_output, pc_output_shape = self._build_pc(
                    input_next, input_next_vis_shape, dg_sparsity)
                self.set_signal('pc', pc_output, pc_output_shape)

                if self._hparams.use_pm:
                    ec_out_raw = self.get_pc().get_ec_out_raw_op()
                    self.set_signal('ec_out_raw', ec_out_raw, input_shape)

                if self.is_build_ll_pc() and self.is_build_dg(
                ) and self._label_values is not None:
                    self._build_ll_pc(self._label_values, dg_encoding,
                                      pc_output)

            if self.is_build_ll_ensemble():
                self._build_ll_ensemble()

        self.reset()
示例#8
0
    def __init__(self, name=None):
        super().__init__()

        self._dual = DualData(name)
class DifferentiablePlasticityComponent(SummaryComponent):
    """
  Differentiable Plasticity algorithm from Uber, Miconi

  WARNINGS:
    - currently number of neurons (filters) must be equal to input vector size.
    - currently only returns correct dimensions if input shape = [batch size, input sample area]

  encoding == no equivalent here
  decoding == the output of the layer

  batch_types:
    to be consistent with ae, using 'encoding' in place of 'testing'
    So:
      training  - use BPIT to update weights
      encoding  - just use for inference

  """
    @staticmethod
    def default_hparams():
        """Builds an HParam object with default hyperparameters (use values closest to Miconi as default)."""
        return tf.contrib.training.HParams(
            batch_size=5,
            learning_rate=0.0003,
            loss_type='sse',  # sum square error (only option is sse)
            nonlinearity='tanh',  # tanh or sigmoid
            filters=1024,
            bias_neurons=0,  # add this many 'active' bias neurons
            bias=False,  # include a bias value (to be trained)
            use_batch_transformer=True,  #
            bt_presentation_repeat=
            2,  # number of times the total sequence of repeats with blanks, is repeated
            bt_sample_repeat=
            6,  # number of repeats of each original sample (1 = identical to input)
            bt_blank_repeat=
            4,  # number of zero samples between each original sample
            bt_amplify_factor=20,  # amplify input by this amount
            bt_degrade=
            True,  # randomly select a sample from batch, degrade and append it & non-degraded sample
            bt_degrade_repeat=6,
            bt_degrade_value=0.0,  # when degrading, set pixel to this value
            bt_degrade_factor=0.5,  # what proportion of bits to knockout
            bt_degrade_type='random',  # options: 'random' = randomly degrade,
            # 'vertical' = degrade a random half along vertical symmetry,
            # 'horizontal' = same but horizontal symmetry
            input_sparsity=0.5,
            max_outputs=3)

    def __init__(self):
        self._name = None
        self._hparams = None
        self._dual = None
        self._summary_op = None
        self._summary_values = None
        self._get_weights_op = None
        self._input_shape_visualisation = None
        self._input_values = None
        self._loss = None
        self._active_bits = None
        self._blank_indices = []

    def build(self,
              input_values,
              input_shape_visualisaion,
              hparams,
              name='diff_plasticity'):
        """Initializes the model parameters.

    Args:
        input_values: Tensor containing input
        input_shape_visualisaion: The shape of the input, for display (internal is vectorized)
        hparams: The hyperparameters for the model as tf.contrib.training.HParams.
        name: A globally unique graph name used as a prefix for all tensors and ops.
    """
        self._name = name
        self._hparams = hparams
        self._dual = DualData(self._name)
        self._summary_op = None
        self._summary_values = None
        self._get_weights_op = None
        self._input_shape_visualisation = input_shape_visualisaion
        self._input_values = input_values

        length_sparse = int(self._hparams.filters *
                            self._hparams.input_sparsity)
        self._active_bits = int(
            self._hparams.filters - length_sparse
        )  # do it this way to match the int rounding in generator

        self._build()

    @property
    def name(self):
        return self._name

    def get_loss(self):
        return self._loss

    def get_dual(self):
        return self._dual

    def _build(self):
        """Build the component"""

        with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):
            input_tensor = self._input_values

            if self._hparams.use_batch_transformer:
                input_tensor = self._build_batch_transformer(
                    self._input_values)

            # add bias neurons - important that it is after batch_transformer, and therefore the degradations
            num_bias_neurons = self._hparams.bias_neurons
            if num_bias_neurons > 0:
                num_batches = input_tensor.shape.as_list()[0]
                input_shape_with_bias = [num_batches, num_bias_neurons]
                bias_neurons = tf.ones(shape=input_shape_with_bias,
                                       dtype=tf.float32)
                input_tensor = tf.concat([input_tensor, bias_neurons], 1)

                # add bias to the raw input also, so that visualisations align (in terms of expectations and sizes)
                num_batches = self._input_values.shape.as_list()[0]
                input_shape_with_bias = [num_batches, num_bias_neurons]
                bias_neurons = tf.ones(shape=input_shape_with_bias,
                                       dtype=tf.float32)
                self._input_values = tf.concat(
                    [self._input_values, bias_neurons], 1)

            _, testing = self._build_rnn(input_tensor)

            # output fork - ths path doesn't accumulate gradients
            # -----------------------------------------------------------------
            stop_gradient = tf.stop_gradient(testing)
            self._dual.set_op('output', stop_gradient)

    def _build_batch_transformer(self, input_tensor):
        """
    input_tensor = input batch, shape = [input batch size, sample tensor shape]
    The output is transformed, shape = [output batch size, sample tensor shape]
    Adds a stop gradient at the end because nothing in here is trainable

    - control number of:
      - repeats of each sample
      - repeats of blanks between groups of repeated samples
      - number of presentations of samples (with repeats)
    - randomly select and degraded a sample, append at the end, together with the un-corrupted copy

    Shuffle batch order for each presentation.
    """

        self._dual.set_op('bt_input', input_tensor)

        # get hyper params
        sample_repeat = self._hparams.bt_sample_repeat
        blank_repeat = self._hparams.bt_blank_repeat
        presentation_repeat = self._hparams.bt_presentation_repeat
        is_degrade = self._hparams.bt_degrade
        degrade_repeat = self._hparams.bt_degrade_repeat
        degrade_type = self._hparams.bt_degrade_type
        degrade_value = self._hparams.bt_degrade_value
        degrade_factor = self._hparams.bt_degrade_factor

        # note: x_batch = input batch, y_batch = transformed batch

        with tf.variable_scope('batch_transformer'):

            # convert input to shape [batch, samples]
            input_shape = input_tensor.get_shape().as_list()
            input_area = np.prod(input_shape[1:])
            batch_input_shape = (-1, input_area)

            input_vector = tf.reshape(input_tensor,
                                      batch_input_shape,
                                      name='input_vector')
            logging.debug(input_vector)

            x_batch_length = input_shape[0]
            y_batch_length = presentation_repeat * x_batch_length * (sample_repeat + blank_repeat) + \
                             (1 + degrade_repeat if is_degrade else 0)

            self._blank_indices = []
            for p in range(presentation_repeat):
                start_pres = p * x_batch_length * (sample_repeat +
                                                   blank_repeat)
                for s in range(x_batch_length):
                    blank_start = start_pres + s * (
                        sample_repeat + blank_repeat) + sample_repeat
                    self._blank_indices.append(
                        [blank_start, blank_start + blank_repeat - 1])

            # start with all blanks, in this case zero tensors
            y_batch = tf.get_variable(
                initializer=tf.zeros(shape=[y_batch_length, input_area]),
                trainable=False,
                name='blanks')

            # use scatter updates to fill with repeats, can not do 1-to-many, so need to do it
            # `pres_repeat * sample_repeat` times
            presentation_length = x_batch_length * (sample_repeat +
                                                    blank_repeat)
            for p in range(presentation_repeat):

                input_vector = tf.random_shuffle(input_vector)

                for i in range(sample_repeat):
                    x2y = []  # the idx itself is the x_idx, val = y_idx
                    for x_idx in range(x_batch_length):
                        y_idx = (p * presentation_length
                                 ) + x_idx * (sample_repeat + blank_repeat) + i
                        x2y.append(y_idx)
                    xy_scatter_map = tf.constant(value=x2y,
                                                 name='x_y_scatter_' + str(i))
                    y_batch = tf.scatter_update(y_batch,
                                                xy_scatter_map,
                                                input_vector,
                                                name="sample_repeat")

            # append degraded and non-degraded samples
            if is_degrade:
                # randomly choose one of the input vectors
                input_shuffled = tf.random_shuffle(input_vector)
                target = input_shuffled[0]

                if degrade_type == 'horizontal':
                    degraded = image_utils.degrade_image(
                        input_shuffled,
                        label=None,
                        degrade_type='horizontal',
                        degrade_value=degrade_value)[0]
                elif degrade_type == 'vertical':
                    raise NotImplementedError(
                        'vertical degradation not implemented')
                elif degrade_type == 'random':

                    # This next commented out line, caused major malfunction (result was passed to degrade in place of the whole batch - but i have no idea why)
                    # degraded_samples = tf.reshape(target, batch_input_shape)  # for efficiency, only degrade one image

                    min_value_0 = True
                    if min_value_0 is not True:
                        degraded = image_utils.degrade_image(
                            image=input_shuffled,
                            label=None,
                            degrade_type='random',
                            degrade_value=degrade_value,
                            degrade_factor=degrade_factor)[0]
                    else:
                        # degrade the high bits (not the bits that are already zero)
                        eps = 0.01
                        degrade_mask = tf.greater(target, 1.0 - eps)
                        degrade_mask = tf.to_float(degrade_mask)
                        degrade_mask = tf_print(degrade_mask,
                                                "degrade_mask",
                                                mute=True)

                        degraded = tf_utils.degrade_by_mask(
                            input_tensor=input_shuffled,
                            num_active=self._active_bits,
                            degrade_mask=degrade_mask,
                            degrade_factor=degrade_factor,
                            degrade_value=degrade_value)[0]

                else:
                    raise NotImplementedError('Unknown degradation type.')

                degraded_repeated = tf.ones([degrade_repeat, input_area])
                degraded_repeated = degraded_repeated * degraded

                target = tf.reshape(target, [1, input_area])
                degraded_and_target = tf.concat([degraded_repeated, target], 0)

                index_map = []
                for i in range(degrade_repeat + 1):
                    index_map.insert(0, y_batch_length - 1 - i)

                y_batch = tf.scatter_update(y_batch,
                                            tf.constant(index_map),
                                            degraded_and_target,
                                            name="degradetarget")
                y_batch = tf.stop_gradient(y_batch)

        self._dual.set_op('bt_output', y_batch)
        return y_batch

    def _build_rnn(self, input_tensor):
        """
    Build the encoder network

    input_tensor = 1 batch = 1 episode (batch size, #neurons)
    Assumes second last item is degraded input, last is target
    """

        w_trainable = False
        x_shift_trainable = False
        eta_trainable = True

        input_shape = input_tensor.get_shape().as_list()
        input_area = np.prod(input_shape[1:])
        batch_input_shape = (-1, input_area)

        filters = self._hparams.filters + self._hparams.bias_neurons
        hidden_size = [filters]
        weights_shape = [filters, filters]

        with tf.variable_scope("rnn"):
            init_state_pl = self._dual.add('init_pl',
                                           shape=hidden_size,
                                           default_value=0).add_pl()
            init_hebb_pl = self._dual.add('hebb_init_pl',
                                          shape=weights_shape,
                                          default_value=0).add_pl()

            # ensure init placeholders are being reset every iteration
            init_hebb_pl = tf_print(init_hebb_pl,
                                    "Init Hebb:",
                                    summarize=100,
                                    mute=True)

            # Input reshape: Ensure flat (vector) x batch size input (batches, inputs)
            # -----------------------------------------------------------------
            input_vector = tf.reshape(input_tensor,
                                      batch_input_shape,
                                      name='input_vector')

            # unroll input into a series so that we can iterate over it easily
            x_series = tf.unstack(
                input_vector, axis=0,
                name="ep-series")  # batch_size of hidden_size

            # get the target and degraded samples
            target = input_vector[-1]
            target = tf_print(target, "TARGET\n", mute=True)
            degraded_extracted = input_vector[-2]
            degraded_extracted = tf_print(degraded_extracted,
                                          "DEGRADED-extracted\n",
                                          mute=True)
            self._dual.set_op('target', target)
            self._dual.set_op('degraded_raw', degraded_extracted)

            y_current = tf.reshape(init_state_pl, [1, filters],
                                   name="init-curr-state")
            hebb = init_hebb_pl

            with tf.variable_scope("slow-weights"):
                w_default = 0.01
                alpha_default = 0.1
                eta_default = 0.1
                x_shift_default = 0.01
                bias_default = 1.0 * w_default  # To emulate the Miconi method of having an additional input at 20 i.e.
                # it creates an output of 1.0, and this is multiplied by the weight (here we have straight bias, no weight)

                if w_trainable:
                    w = tf.get_variable(
                        name="w",
                        initializer=(w_default *
                                     tf.random_uniform(weights_shape)))
                else:
                    w = tf.zeros(weights_shape)

                alpha = tf.get_variable(
                    name="alpha",
                    initializer=(alpha_default *
                                 tf.random_uniform(weights_shape)))

                if eta_trainable:
                    eta = tf.get_variable(name="eta",
                                          initializer=(eta_default *
                                                       tf.ones(shape=[1])))
                else:
                    eta = eta_default * tf.ones([1])

                if x_shift_trainable:
                    x_shift = tf.get_variable(name="x_shift",
                                              initializer=(x_shift_default *
                                                           tf.ones(shape=[1])))
                else:
                    x_shift = 0

                self._dual.set_op('w', w)
                self._dual.set_op('alpha', alpha)
                self._dual.set_op('eta', eta)
                self._dual.set_op('x_shift', x_shift)

                if self._hparams.bias:
                    bias = tf.get_variable(name="bias",
                                           initializer=(bias_default *
                                                        tf.ones(filters)))
                    self._dual.set_op('bias', bias)
                    bias = tf_print(bias,
                                    "*** bias ***",
                                    mute=MUTE_DEBUG_GRAPH)

            with tf.variable_scope("layers"):
                hebb = tf_print(hebb,
                                "*** initial hebb ***",
                                mute=MUTE_DEBUG_GRAPH)
                y_current = tf_print(y_current, "*** initial state ***")
                w = tf_print(w, "*** w ***", mute=MUTE_DEBUG_GRAPH)
                alpha = tf_print(alpha, "*** alpha ***", mute=MUTE_DEBUG_GRAPH)

                i = 0
                last_x = None
                outer_first = None
                outer_last = None
                for x in x_series:
                    # last sample is target, so don't process it again
                    if i == len(x_series) - 1:  # [0:x, 1:d, 2:t], l=3
                        break
                    layer_name = "layer-" + str(i)
                    with tf.variable_scope(layer_name):
                        x = self._hparams.bt_amplify_factor * x
                        x = tf_print(x,
                                     str(i) + ": x_input",
                                     mute=MUTE_DEBUG_GRAPH)
                        y_current = tf_print(y_current,
                                             str(i) + ": y(t-1)",
                                             mute=MUTE_DEBUG_GRAPH)

                        # neurons latch on as they have bidirectional connections
                        # attempt to remove this issue by knocking out lateral connections
                        remove = 'random'
                        if remove == 'circular':
                            diagonal_mask = tf.convert_to_tensor(
                                np.tril(
                                    np.ones(weights_shape, dtype=np.float32),
                                    0))
                            alpha = tf.multiply(alpha, diagonal_mask)
                        elif remove == 'random':
                            size = np.prod(weights_shape[:])
                            knockout_mask = np.ones(size)
                            knockout_mask[:int(size / 2)] = 0
                            np.random.shuffle(knockout_mask)
                            knockout_mask = np.reshape(knockout_mask,
                                                       weights_shape)
                            alpha = tf.multiply(alpha, knockout_mask)

                        # ---------- Calculate next output of the RNN
                        weighted_sum = tf.add(
                            tf.matmul(y_current - x_shift,
                                      tf.add(w,
                                             tf.multiply(alpha,
                                                         hebb,
                                                         name='lyr-mul'),
                                             name="lyr-add_w_ah"),
                                      name='lyr-mul-add-matmul'), x,
                            "weighted_sum")

                        if self._hparams.bias:
                            weighted_sum = tf.add(
                                weighted_sum, bias)  # weighted sum with bias

                        y_next, _ = activation_fn(weighted_sum,
                                                  self._hparams.nonlinearity)

                        with tf.variable_scope("fast_weights"):
                            # ---------- Update Hebbian fast weights
                            # outer product of (yin * yout) = (current_state * next_state)
                            outer = tf.matmul(tf.reshape(y_current,
                                                         shape=[filters, 1]),
                                              tf.reshape(y_next,
                                                         shape=[1, filters]),
                                              name="outer-product")
                            outer = tf_print(outer,
                                             str(i) +
                                             ": *** outer = y(t-1) * y(t) ***",
                                             mute=MUTE_DEBUG_GRAPH)

                            if i == 1:  # first outer is zero
                                outer_first = outer
                            outer_last = outer

                            hebb = (1.0 - eta) * hebb + eta * outer
                            hebb = tf_print(hebb,
                                            str(i) + ": *** hebb ***",
                                            mute=MUTE_DEBUG_GRAPH)

                        # record for visualisation the output when presented with the last blank
                        idx_blank_first = self._blank_indices[-1][0]
                        idx_blank_last = self._blank_indices[-1][1]

                        if i == idx_blank_first:
                            blank_output_first = y_next
                            self._dual.set_op('blank_output_first',
                                              blank_output_first)

                        if i == idx_blank_last:
                            blank_output_last = y_next
                            self._dual.set_op('blank_output_last',
                                              blank_output_last)

                        y_current = y_next
                        last_x = x
                        i = i + 1

            self._dual.set_op('hebb', hebb)
            self._dual.set_op('outer_first', outer_first)
            self._dual.set_op('outer_last', outer_last)

            last_x = tf_print(last_x, str(i) + ": LAST-X", mute=True)
            self._dual.set_op('degraded', last_x)

            output_pre_masked = tf.squeeze(y_current)
            self._dual.set_op('output_pre_masked',
                              output_pre_masked)  # pre-masked output

        # External masking
        # -----------------------------------------------------------------
        with tf.variable_scope("masking"):
            mask_pl = self._dual.add('mask',
                                     shape=hidden_size,
                                     default_value=1.0).add_pl()
            y_masked = tf.multiply(y_current, mask_pl, name='y_masked')

        # Setup the training operations
        # -----------------------------------------------------------------
        with tf.variable_scope("optimizer"):
            loss_op = self._build_loss_op(y_masked, target)
            self._dual.set_op('loss', loss_op)

            self._optimizer = tf.train.AdamOptimizer(
                self._hparams.learning_rate)
            training_op = self._optimizer.minimize(
                loss_op,
                global_step=tf.train.get_or_create_global_step(),
                name='training_op')
            self._dual.set_op('training', training_op)

        return y_masked, y_masked

    def _build_loss_op(self, output, target):
        if self._hparams.loss_type == 'sse':
            losses = tf.subtract(output, target, name="losses")
            self._dual.set_op("losses", losses)
            return tf.reduce_sum(tf.square(losses))
        else:
            raise NotImplementedError('Loss function not implemented: ' +
                                      str(self._hparams.loss_type))

    # OP ACCESS ------------------------------------------------------------------
    def get_decoding_op(self):
        """The equivalent of decoding is simply the output"""
        return self._dual.get_op('output')

    def get_encoding_op(self):
        raise NotImplementedError(
            'The is no equivalent to encoding for this component')

    def get_loss_op(self):
        return self._dual.get_op('loss')

    # VALUE ACCESS ------------------------------------------------------------------
    def get_decoding(self):
        return self._dual.get_values('output')

    def get_encoding(self):
        raise NotImplementedError(
            'The is no equivalent to encoding for this component')

    # MODULAR INTERFACE ------------------------------------------------------------------
    def reset(self):
        self._loss = 0.0
        logging.warning(
            'reset() not yet properly implemented for this component.')

    def update_feed_dict(self, feed_dict, batch_type='training'):
        if batch_type == 'training':
            self.update_training_dict(feed_dict)
        if batch_type == 'encoding':
            self.update_testing_dict(feed_dict)

    def add_fetches(self, fetches, batch_type='training'):
        if batch_type == 'training':
            self.add_training_fetches(fetches)
        if batch_type == 'encoding':
            self.add_testing_fetches(fetches)

    def set_fetches(self, fetched, batch_type='training'):
        if batch_type == 'training':
            self.set_training_fetches(fetched)
        if batch_type == 'encoding':
            self.set_testing_fetches(fetched)

    def get_features(self, batch_type='training'):
        return self._dual.get_values('encoding')

    def build_summaries(self, batch_types=None, scope=None):
        """Builds all summaries."""
        if not scope:
            scope = self._name + '/summaries/'
        with tf.name_scope(scope):
            for batch_type in batch_types:
                self._build_batchtype_summaries(
                    batch_type=batch_type
                )  # same for all batch_types.... for now

    def write_summaries(self, step, writer, batch_type='training'):

        if batch_type == 'training':
            # debugging batch_transformer
            if self._hparams.use_batch_transformer:
                bt_input = self._dual.get_values('bt_input')
                bt_output = self._dual.get_values('bt_output')
                logging.debug("BT Input", bt_input[:4][:2])
                logging.debug("BT Output", bt_output[:4][:2])

        if self._summary_values is not None:
            writer.add_summary(self._summary_values, step)
            writer.flush()

    def _update_with_dual(self, feed_dict, name):
        """Convenience function to update a feed dict with a dual, using its placeholder and values."""
        dual = self._dual.get(name)
        dual_pl = dual.get_pl()
        dual_values = dual.get_values()
        feed_dict.update({
            dual_pl: dual_values,
        })

    # TRAINING ------------------------------------------------------------------
    def update_training_dict(self, feed_dict):
        """Update the feed_dict for training mode (batch_types)"""
        self._update_with_dual(
            feed_dict,
            'hebb_init_pl')  # re-initialize hebbian at the start of episode
        self._update_with_dual(
            feed_dict,
            'init_pl')  # re-initialize rnn state at the start of episode
        self._update_with_dual(feed_dict, 'mask')  # mask

    def add_training_fetches(self, fetches):
        """Add the fetches (ops to be evaluated) for training mode (batch_types)"""
        fetches[self._name] = {
            'loss': self._dual.get_op('loss'),  # the calculation of loss
            'training': self._dual.get_op('training'),  # the optimisation
            'output': self._dual.get_op('output'),  # the output value
            # debugging
            'target': self._dual.get_op('target'),
            'degraded': self._dual.get_op('degraded')
        }

        if self._hparams.use_batch_transformer:
            fetches[self._name]['bt_input'] = self._dual.get_op('bt_input')
            fetches[self._name]['bt_output'] = self._dual.get_op('bt_output')

        if self._summary_op is not None:
            fetches[self._name]['summaries'] = self._summary_op

    def set_training_fetches(self, fetched):
        """Add the fetches for training - the ops whose values will be available"""
        self_fetched = fetched[self._name]
        self._loss = self_fetched['loss']

        names = ['loss', 'output', 'target', 'degraded']

        if self._hparams.use_batch_transformer:
            names = names + ['bt_input', 'bt_output']

        self._dual.set_fetches(fetched, names)

        if self._summary_op is not None:
            self._summary_values = fetched[self._name]['summaries']

    # TESTING ------------------------------------------------------------------
    def update_testing_dict(self, feed_dict):
        self.update_training_dict(
            feed_dict)  # do the same for testing and training

    def add_testing_fetches(self, fetches):
        fetches[self._name] = {
            'loss': self._dual.get_op('loss'),  # the calculation of loss
            'output': self._dual.get_op('output')  # the output value
        }

        if self._summary_op is not None:
            fetches[self._name]['summaries'] = self._summary_op

    def set_testing_fetches(self, fetched):
        self_fetched = fetched[self._name]
        self._loss = self_fetched['loss']

        names = ['loss', 'output']
        self._dual.set_fetches(fetched, names)

        if self._summary_op is not None:
            self._summary_values = fetched[self._name]['summaries']

    # SUMMARIES ------------------------------------------------------------------
    def write_filters(self, session):
        """Write the learned filters to disk."""

        w = self._dual.get_op('w')
        weights_values = session.run(w)
        weights_transpose = np.transpose(weights_values)

        filter_height = self._input_shape_visualisation[1]
        filter_width = self._input_shape_visualisation[2]
        np_write_filters(weights_transpose, [filter_height, filter_width])

    def _build_batchtype_summaries(self, batch_type):
        """Same for all batch types right now. Other components have separate method for each batch_type."""
        with tf.name_scope(batch_type):
            summaries = self._build_summaries()
            self._summary_op = tf.summary.merge(summaries)
            return self._summary_op

    def _build_summaries(self):
        """
    Build the summaries for TensorBoard.
    Note that the outer scope has already been set by `_build_summaries()`.
      i.e. 'component_name' / 'summary name'
    """
        max_outputs = 3
        summaries = []

        # images
        # ------------------------------------------------
        summary_input_shape = image_utils.get_image_summary_shape(
            self._input_shape_visualisation)

        # input images
        input_summary_reshape = tf.reshape(self._input_values,
                                           summary_input_shape,
                                           name='input_summary_reshape')
        input_summary_op = tf.summary.image('input_images',
                                            input_summary_reshape,
                                            max_outputs=max_outputs)
        summaries.append(input_summary_op)

        # degraded, target and completed images, and histograms where relevant
        target = self._dual.get_op('target')
        degraded = self._dual.get_op('degraded')
        decoding_op = self.get_decoding_op()

        output_hist = tf.summary.histogram("output", decoding_op)
        summaries.append(output_hist)

        input_hist = tf.summary.histogram("input", self._input_values)
        summaries.append(input_hist)

        # network output when presented with blank
        blank_output_first = self._dual.get_op('blank_output_first')
        blank_first = tf.summary.image(
            'blank_first', tf.reshape(blank_output_first, summary_input_shape))
        summaries.append(blank_first)

        blank_output_last = self._dual.get_op('blank_output_last')
        blank_last = tf.summary.image(
            'blank_last', tf.reshape(blank_output_last, summary_input_shape))
        summaries.append(blank_last)

        with tf.name_scope('optimize'):
            completed_summary_reshape = tf.reshape(
                decoding_op, summary_input_shape, 'completed_summary_reshape')
            summaries.append(
                tf.summary.image('b_completed', completed_summary_reshape))

            if self._hparams.bt_degrade:
                degraded_summary_reshape = tf.reshape(
                    degraded, summary_input_shape, 'degraded_summary_reshape')
                summaries.append(
                    tf.summary.image('a_degraded', degraded_summary_reshape))

                target_summary_reshape = tf.reshape(target,
                                                    summary_input_shape,
                                                    'target_summary_reshape')
                summaries.append(
                    tf.summary.image('c_target', target_summary_reshape))

        # display slow weights as images and distributions
        with tf.name_scope('slow-weights'):
            w = self._dual.get_op('w')
            add_square_as_square(summaries, w, 'w')

            w_hist = tf.summary.histogram("w", w)
            summaries.append(w_hist)

            alpha = self._dual.get_op('alpha')
            add_square_as_square(summaries, alpha, 'alpha')

            alpha_hist = tf.summary.histogram("alpha", alpha)
            summaries.append(alpha_hist)

            if self._hparams.bias:
                bias = self._dual.get_op('bias')
                bias_image_shape, _ = image_utils.square_image_shape_from_1d(
                    self._hparams.filters)
                bias_image = tf.reshape(bias,
                                        bias_image_shape,
                                        name='bias_summary_reshape')
                summaries.append(tf.summary.image('bias', bias_image))

                bias_hist = tf.summary.histogram("bias", bias)
                summaries.append(bias_hist)

            # eta
            eta_op = self._dual.get_op('eta')
            eta_scalar = tf.reduce_sum(eta_op)
            eta_summary = tf.summary.scalar('eta', eta_scalar)
            summaries.append(eta_summary)

            # x_shift
            x_shift_op = self._dual.get_op('x_shift')
            xs_scalar = tf.reduce_sum(x_shift_op)
            xs_summary = tf.summary.scalar('x_shift', xs_scalar)
            summaries.append(xs_summary)

        # display fast weights (eta and hebbian), as image, scalars and histogram
        with tf.name_scope('fast-weights'):

            # as images
            hebb = self._dual.get_op('hebb')
            add_square_as_square(summaries, hebb, 'hebb')

            # as scalars
            hebb_summary = tf_build_stats_summaries_short(hebb, 'hebb')
            summaries.append(hebb_summary)

            # as histograms
            hebb_hist = tf.summary.histogram("hebb", hebb)
            summaries.append(hebb_hist)

            hebb_per_neuron = tf.reduce_sum(tf.abs(hebb), 0)
            hebb_per_neuron = tf.summary.histogram('hebb_pn', hebb_per_neuron)
            summaries.append(hebb_per_neuron)

            # outer products
            outer_first = self._dual.get_op('outer_first')
            outer_last = self._dual.get_op('outer_last')
            add_square_as_square(summaries, outer_first, 'outer_first')
            add_square_as_square(summaries, outer_last, 'outer_last')

        # optimization related quantities
        with tf.name_scope('optimize'):
            # loss
            loss_op = self.get_loss_op()
            loss_summary = tf.summary.scalar('loss', loss_op)
            summaries.append(loss_summary)

            # losses as an image
            losses = self._dual.get_op("losses")
            shape = losses.get_shape().as_list()
            volume = np.prod(shape[1:])
            losses_image_shape, _ = image_utils.square_image_shape_from_1d(
                volume)
            losses_image = tf.reshape(losses, losses_image_shape)
            summaries.append(tf.summary.image('losses', losses_image))

        input_stats_summary = tf_build_stats_summaries_short(
            self._input_values, 'input-stats')
        summaries.append(input_stats_summary)

        return summaries
示例#10
0
class DGStubComponent(SummaryComponent):
  """Dentate Gyrus (DG) Stub Component to use in the Episodic Component."""

  @staticmethod
  def default_hparams():
    return tf.contrib.training.HParams(
        batch_size=20,
        filters=50,
        sparsity=5,
        summarize_levels=SummarizeLevels.ALL.value,   # does nothing
        max_outputs=3                                 # does nothing
    )

  def __init__(self):
    self._name = None
    self._hidden_name = None
    self._hparams = None
    self._dual = None

  @property
  def name(self):
    return self._name

  def reset(self):
    pass

  def update_feed_dict(self, feed_dict, batch_type='training'):
    """Add items to the feed dict to run a batch."""

  def build_summaries(self, batch_types=None, max_outputs=3, scope=None):
    """Build any summaries needed to debug the module."""

  def _build_summaries(self, batch_type, max_outputs=3):
    """Build summaries for this batch type. Can be same for all batch types."""

  def write_summaries(self, step, writer, batch_type='training'):
    """Write any summaries needed to debug the module."""

  def add_fetches(self, fetches, batch_type='training'):
    """Add graph ops to the fetches dict so they are evaluated."""
    names = ['encoding']
    self._dual.add_fetches(fetches, names)

  def set_fetches(self, fetched, batch_type='training'):
    """Store results of graph ops in the fetched dict so they are available as needed."""
    names = ['encoding']
    self._dual.set_fetches(fetched, names)

  def _dg_stub_batch(self):
    """
    Return batch of non overlapping n-hot samples in range [0,1]
    All off-graph
    """

    batch_size = self._hparams.batch_size
    sample_size = self._hparams.filters
    n = self._hparams.sparsity

    assert ((batch_size * n - 1) + n) < sample_size, "Can't produce batch_size {0} non-overlapping samples, " \
           "reduce n {1} or increase sample_size {2}".format(batch_size, n, sample_size)

    batch = np.zeros(shape=(batch_size, sample_size))

    # return the sample at given idx
    for idx in range(batch_size):
      start_idx = idx * n
      end_idx = start_idx + n
      batch[idx][start_idx:end_idx] = 1

    return batch

  def get_encoding_op(self):
    return self._dual.get_op('encoding')

  def get_encoding(self):
    return self._dual.get_values('encoding')

  def get_decoding(self):
    return self._dual.get_values('encoding')    # return encoding, same thing for both

  def get_loss(self):
    return 0.0

  def build(self, hparams, name='dg_stub'):
    """Builds the DG Stub."""
    self._name = name
    self._hparams = hparams
    self._dual = DualData(self._name)

    batch_arr = self._dg_stub_batch()

    the_one_batch = tf.convert_to_tensor(batch_arr, dtype=tf.float32)
    self._dual.set_op('encoding', the_one_batch)

    # add a stub of secondary decoding also, it is expected by workflow
    self._dual.add('secondary_decoding_input', shape=the_one_batch.shape, default_value=1.0).add_pl(default=True)