Exemplo n.º 1
0
 def compute_sparse_loss(self, a_in):
   with tf.compat.v1.variable_scope("unsupervised"):
     reduc_dims = tuple(range(len(a_in.get_shape().as_list()) - 1))
     avg_act = tf.reduce_mean(input_tensor=a_in, axis=reduc_dims, name="batch_avg_activity")
     p_dist = self.target_act * tf.subtract(ef.safe_log(self.target_act),
       ef.safe_log(avg_act), name="kl_p")
     q_dist = (1-self.target_act) * tf.subtract(ef.safe_log(1-self.target_act),
       ef.safe_log(1-avg_act), name="kl_q")
     kl_divergence = tf.reduce_sum(input_tensor=tf.add(p_dist, q_dist), name="kld")
     sparse_loss = tf.multiply(self.sparse_mult, kl_divergence, name="sparse_loss")
   return sparse_loss
Exemplo n.º 2
0
  def build_graph_from_input(self, input_node):
    """Build the TensorFlow graph object"""
    with tf.device(self.params.device):
      with self.graph.as_default():
        with tf.compat.v1.variable_scope("auto_placeholders") as scope:
          self.label_placeholder = tf.compat.v1.placeholder(tf.float32,
            shape=self.label_shape, name="input_labels")
          self.w_decay_mult = tf.compat.v1.placeholder(tf.float32, shape=(), name="w_decay_mult")
          self.w_norm_mult = tf.compat.v1.placeholder(tf.float32, shape=(), name="w_norm_mult")
          self.kld_mult = tf.compat.v1.placeholder(tf.float32, shape=(), name="kld_mult")
          self.train_vae = tf.compat.v1.placeholder(tf.bool, shape=(), name="train_vae")

        with tf.compat.v1.variable_scope("placeholders") as scope:
          self.mlp_dropout_keep_probs = tf.compat.v1.placeholder(tf.float32, shape=[None],
            name="mlp_dropout_keep_probs")
          self.ae_dropout_keep_probs = tf.compat.v1.placeholder(tf.float32, shape=[None],
            name="ae_dropout_keep_probs")

        self.train_vae = tf.cast(self.train_vae, tf.float32)

        self.vae_module = self.build_vae_module(input_node)
        self.trainable_variables.update(self.vae_module.trainable_variables)

        if self.params.train_on_recon:
          if self.params.mlp_layer_types[0] == "conv":
            data_shape = [tf.shape(input=input_node)[0]]+self.params.full_data_shape
            mlp_input = tf.reshape(self.vae_module.reconstruction, shape=data_shape)
          elif self.params.mlp_layer_types[0] == "fc":
            mlp_input = self.vae_module.reconstruction
          else:
            assert False, ("params.mlp_layer_types must be 'fc' or 'conv'")
        else: # train on VAE latent encoding
          assert self.params.mlp_layer_types[0] == "fc", (
            "MLP must have FC layers to train on VAE activity")
          mlp_input = self.vae_module.a
        self.mlp_module = self.build_mlp_module(mlp_input)
        self.trainable_variables.update(self.mlp_module.trainable_variables)

        with tf.compat.v1.variable_scope("loss") as scope:
          #Loss switches based on train_vae flag
          self.total_loss = self.train_vae * self.vae_module.total_loss + \
            (1-self.train_vae) * self.mlp_module.total_loss

        self.label_est = tf.identity(self.mlp_module.label_est, name="label_est")

        with tf.compat.v1.variable_scope("performance_metrics") as scope:
          #VAE metrics
          MSE = tf.reduce_mean(input_tensor=tf.square(tf.subtract(input_node, self.vae_module.reconstruction)),
            axis=[1, 0], name="mean_squared_error")
          pixel_var = tf.nn.moments(x=input_node, axes=[1])[1]
          self.pSNRdB = tf.multiply(10.0, ef.safe_log(tf.math.divide(tf.square(pixel_var), MSE)),
            name="recon_quality")
          with tf.compat.v1.variable_scope("prediction_bools"):
            self.correct_prediction = tf.equal(tf.argmax(input=self.label_est, axis=1),
              tf.argmax(input=self.label_placeholder, axis=1), name="individual_accuracy")
          with tf.compat.v1.variable_scope("accuracy"):
            self.accuracy = tf.reduce_mean(input_tensor=tf.cast(self.correct_prediction,
              tf.float32), name="avg_accuracy")
Exemplo n.º 3
0
  def build_graph_from_input(self, input_node):
    """Build the TensorFlow graph object"""
    with tf.device(self.params.device):
      with self.graph.as_default():
        with tf.compat.v1.variable_scope("auto_placeholders") as scope:
          self.sparse_mult = tf.compat.v1.placeholder(tf.float32, shape=(), name="sparse_mult")
          self.train_lca = tf.compat.v1.placeholder(tf.bool, shape=(), name="train_lca")

        self.train_lca = tf.cast(self.train_lca, tf.float32)

        self.lca_module = self.build_lca_module(input_node)

        with tf.compat.v1.variable_scope("weight_inits") as scope:
          self.w_init = tf.compat.v1.truncated_normal_initializer(stddev=0.01)
          self.s_init = init_ops.GDNGammaInitializer(diagonal_gain=0.0, off_diagonal_gain=0.001)

        with tf.compat.v1.variable_scope("weights") as scope:
          self.w = tf.compat.v1.get_variable(name="w_enc", shape=self.w_shape, dtype=tf.float32,
            initializer=self.w_init, trainable=True)
          self.s = tf.compat.v1.get_variable(name="lateral_connectivity", shape=self.s_shape,
            dtype=tf.float32, initializer=self.s_init, trainable=True)
        self.trainable_variables.update({self.w.name:self.w, self.s.name:self.s})

        with tf.compat.v1.variable_scope("inference") as scope:
          feedforward_drive = tf.matmul(input_node, self.w, name="feedforward_drive")
          self.a_list = [self.lca_module.threshold_units(feedforward_drive, name="a_init")]
          for layer_id in range(self.params.num_layers):
            self.a_list.append(self.lca_module.threshold_units(feedforward_drive
              + tf.matmul(self.a_list[layer_id], self.s)))
          self.a = self.a_list[-1]

        with tf.compat.v1.variable_scope("loss") as scope:
          reduc_dim = list(range(1, len(self.lca_module.a.shape)))
          labels = tf.stop_gradient(self.lca_module.a)
          self.lista_loss = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=tf.square(labels - self.a),
            axis=reduc_dim))
          #Loss switches based on train_lca flag
          self.total_loss = self.train_lca * self.lca_module.total_loss + \
            (1-self.train_lca) * self.lista_loss

        with tf.compat.v1.variable_scope("norm_weights") as scope:
          self.norm_lista_w = self.w.assign(tf.nn.l2_normalize(self.w, axis=0, epsilon=self.params.eps,
            name="row_l2_norm"))
          self.norm_weights = tf.group(self.norm_lista_w, name="l2_normalization")

        with tf.compat.v1.variable_scope("performance_metrics") as scope:
          #LCA metrics
          MSE = tf.reduce_mean(input_tensor=tf.square(tf.subtract(input_node, self.lca_module.reconstruction)),
            axis=[1, 0], name="mean_squared_error")
          pixel_var = tf.nn.moments(x=input_node, axes=[1])[1]
          self.pSNRdB = tf.multiply(10.0, ef.safe_log(tf.math.divide(tf.square(pixel_var),
            MSE)), name="recon_quality")
Exemplo n.º 4
0
 def add_inference_ops_to_graph(self, num_imgs=1, num_inference_steps=None):
     if num_inference_steps is None:
         num_inference_steps = self.model_params.num_steps  # this is replicated in self.inference_analysis
     with tf.device(self.model_params.device):
         with self.model.graph.as_default():
             self.lca_b = self.model.module.compute_excitatory_current()
             self.lca_g = self.model.module.compute_inhibitory_connectivity(
             )
             self.u_list = [self.model.module.u_zeros]
             self.a_list = [
                 self.model.module.threshold_units(self.u_list[0])
             ]
             self.ga_list = [tf.matmul(self.a_list[0], self.lca_g)]
             self.psnr_list = [tf.constant(0.0)]  #, dtype=tf.float32)]
             current_recon = self.model.compute_recon_from_encoding(
                 self.a_list[0])
             current_loss_list = [[
                 self.model.module.compute_recon_loss(current_recon)
             ], [self.model.module.compute_sparse_loss(self.a_list[0])]]
             self.loss_dict = dict(
                 zip(["recon_loss", "sparse_loss"], current_loss_list))
             self.loss_dict["total_loss"] = [
                 tf.add_n([item[0] for item in current_loss_list],
                          name="total_loss")
             ]
             for step in range(num_inference_steps - 1):
                 u, ga = self.model.module.step_inference(
                     self.u_list[step], self.a_list[step], self.lca_b,
                     self.lca_g, step)
                 self.u_list.append(u)
                 self.ga_list.append(ga)
                 self.a_list.append(self.model.module.threshold_units(u))
                 current_recon = self.model.compute_recon_from_encoding(
                     self.a_list[-1])
                 current_loss_list = [
                     self.model.module.compute_recon_loss(current_recon),
                     self.model.module.compute_sparse_loss(self.a_list[-1])
                 ]
                 self.loss_dict["recon_loss"].append(current_loss_list[0])
                 self.loss_dict["sparse_loss"].append(current_loss_list[1])
                 self.loss_dict["total_loss"].append(
                     tf.add_n(current_loss_list, name="total_loss"))
                 MSE = tf.reduce_mean(input_tensor=tf.square(
                     tf.subtract(self.model.input_placeholder,
                                 current_recon)))
                 pixel_var = tf.nn.moments(x=self.model.input_placeholder,
                                           axes=[1])[1]
                 current_pSNRdB = tf.multiply(
                     10.0,
                     ef.safe_log(tf.math.divide(tf.square(pixel_var), MSE)))
                 self.psnr_list.append(current_pSNRdB)
Exemplo n.º 5
0
 def add_inference_ops_to_graph(self, num_imgs, num_inference_steps=None):
     if num_inference_steps is None:
         num_inference_steps = self.model_params.num_steps
     with tf.device(self.model_params.device):
         with self.model.graph.as_default():
             self.u_list = [self.model.module.u_zeros]
             self.a_list = [
                 self.model.module.threshold_units(self.u_list[0])
             ]
             self.psnr_list = [tf.constant(0.0)]  #, dtype=tf.float32)]
             current_recon = self.model.compute_recon_from_encoding(
                 self.a_list[0])
             current_loss_list = [[
                 self.model.module.compute_recon_loss(current_recon)
             ], [self.model.module.compute_sparse_loss(self.a_list[0])]]
             self.loss_dict = dict(
                 zip(["recon_loss", "sparse_loss"], current_loss_list))
             self.loss_dict["total_loss"] = [
                 tf.add_n([item[0] for item in current_loss_list],
                          name="total_loss")
             ]
             for step in range(num_inference_steps - 1):
                 u = self.model.module.step_inference(
                     self.u_list[step], self.a_list[step], step)
                 self.u_list.append(u)
                 self.a_list.append(
                     self.model.module.threshold_units(self.u_list[step +
                                                                   1]))
                 current_recon = self.model.compute_recon_from_encoding(
                     self.a_list[-1])
                 current_loss_list = [
                     self.model.module.compute_recon_loss(current_recon),
                     self.model.module.compute_sparse_loss(self.a_list[-1])
                 ]
                 self.loss_dict["recon_loss"].append(current_loss_list[0])
                 self.loss_dict["sparse_loss"].append(current_loss_list[1])
                 self.loss_dict["total_loss"].append(
                     tf.add_n(current_loss_list, name="total_loss"))
                 MSE = tf.reduce_mean(input_tensor=tf.square(
                     tf.subtract(self.model.input_placeholder,
                                 current_recon)))
                 reduc_dim = list(
                     range(1, len(self.model.input_placeholder.shape)))
                 pixel_var = tf.nn.moments(x=self.model.input_placeholder,
                                           axes=reduc_dim)[1]
                 pSNRdB = tf.multiply(
                     10.0,
                     ef.safe_log(tf.math.divide(tf.square(pixel_var), MSE)))
                 self.psnr_list.append(pSNRdB)
Exemplo n.º 6
0
    def build_graph_from_input(self, input_node):
        """Build the TensorFlow graph object"""
        with tf.device(self.params.device):
            with self.graph.as_default():
                with tf.compat.v1.variable_scope("auto_placeholders") as scope:
                    self.sparse_mult = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name="sparse_mult")

                self.module = self.build_module(input_node)
                self.trainable_variables.update(
                    self.module.trainable_variables)

                with tf.compat.v1.variable_scope("inference") as scope:
                    self.a = tf.identity(self.get_encodings(), name="activity")

                with tf.compat.v1.variable_scope("placeholders") as sess:
                    self.latent_input = tf.compat.v1.placeholder(
                        tf.float32,
                        shape=self.a.get_shape().as_list(),
                        name="latent_input")

                with tf.compat.v1.variable_scope("norm_weights") as scope:
                    self.norm_weights = tf.group(self.module.norm_w,
                                                 name="l2_normalization")

                with tf.compat.v1.variable_scope("output") as scope:
                    self.decoder_recon = self.module.build_decoder(
                        self.latent_input, name="latent_recon")
                    self.reconstruction = tf.identity(
                        self.compute_recon_from_encoding(self.a),
                        name="reconstruction")

                with tf.compat.v1.variable_scope(
                        "performance_metrics") as scope:
                    MSE = tf.reduce_mean(input_tensor=tf.square(
                        tf.subtract(input_node, self.module.reconstruction)),
                                         name="mean_squared_error")
                    pixel_var = tf.nn.moments(x=input_node, axes=[1])[1]
                    self.pSNRdB = tf.multiply(
                        10.0,
                        ef.safe_log(tf.math.divide(tf.square(pixel_var), MSE)),
                        name="recon_quality")