Exemple #1
0
  def make_report(self, data):

    res = self.connect(data)
    exprs = AttrDict(loss=self._loss(data, res))
    exprs.update(self._report(data, res))

    for k, v in exprs.items():
      if not isinstance(v, tf.Tensor):
        exprs[k] = tf.convert_to_tensor(v)

    return exprs
Exemple #2
0
def collect_results(sess, tensors, n_batches):
    """Collects `n_batches` of tensors and aggregates the results."""

    res = AttrDict({k: [] for k in tensors})

    print('Collecting: 0/{}'.format(n_batches), end='')
    for i in range(n_batches):
        print('\rCollecting: {}/{}'.format(i + 1, n_batches), end='')

        vals = sess.run(tensors)
        for k, v in vals.items():
            res[k].append(v)

    print('')
    for k, v in res.items():
        if v[0].shape:
            res[k] = np.concatenate(v, 0)
        else:
            res[k] = np.stack(v)

    return res
def collect_results(sess, tensors, n_batches):
    """Collects `n_batches` of tensors and aggregates the results."""

    res = AttrDict({k: [] for k in tensors})

    print('Collecting: 0/{}'.format(n_batches), end='')

    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()

    for i in range(n_batches):
        print('\rCollecting: {}/{}'.format(i + 1, n_batches), end='')

        if i == 10:
            print('')
            print('herehereherehere it starts')
            vals = sess.run(tensors, options=run_options, run_metadata=run_metadata)
            tl = timeline.Timeline(run_metadata.step_stats)
            ctf = tl.generate_chrome_trace_format()
            print('herehereherehere it ends')
            with open('timeline_mnist_8.json', 'w') as f:
                f.write(ctf)
            print('')
        else:
            vals = sess.run(tensors)

        for k, v in vals.items():
            res[k].append(v)

    print('')
    for k, v in res.items():
        if v[0].shape:
            res[k] = np.concatenate(v, 0)
        else:
            res[k] = np.stack(v)

    return res
Exemple #4
0
def main(_=None):
  FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
  config = FLAGS
  FLAGS.__dict__['config'] = config

  # Build the graph
  with tf.Graph().as_default():

    model_dict = model_config.get(FLAGS)
    data_dict = data_config.get(FLAGS)

    model = model_dict.model
    trainset = data_dict.trainset
    validset = data_dict.validset

    # Optimisation target
    validset = tools.maybe_convert_dataset(validset)
    trainset = tools.maybe_convert_dataset(trainset)

    train_tensors = model(trainset)
    valid_tensors = model(validset)

    sess = tf.Session()
    saver = tf.train.Saver()
    saver.restore(sess, FLAGS.snapshot)

  valid_results = _collect_results(sess, valid_tensors, validset,
                                   10000 // FLAGS.batch_size)

  train_results = _collect_results(sess, train_tensors, trainset,
                                   60000 // FLAGS.batch_size)

  results = AttrDict(train=train_results, valid=valid_results)

  # Linear classification
  print('Linear classification accuracy:')
  for k, v in results.items():
    print('\t{}: prior={:.04f}, posterior={:.04f}'.format(
        k, v.prior_acc.mean(), v.posterior_acc.mean()))

  # Unsupervised classification via clustering
  print('Bipartite matching classification accuracy:')
  for field in 'posterior_pres prior_pres'.split():
    kmeans = sklearn.cluster.KMeans(
        n_clusters=10,
        precompute_distances=True,
        n_jobs=-1,
        max_iter=1000,
    ).fit(results.train[field])

    train_acc = cluster_classify(results.train[field], results.train.label, 10,
                                 kmeans)
    valid_acc = cluster_classify(results.valid[field], results.valid.label, 10,
                                 kmeans)

    print('\t{}: train_acc={:.04f}, valid_acc={:.04f}'.format(field, train_acc,
                                                              valid_acc))

  checkpoint_folder = osp.dirname(FLAGS.snapshot)
  figure_filename = osp.join(checkpoint_folder, FLAGS.tsne_figure_name)
  print('Savign TSNE plot at "{}"'.format(figure_filename))
  make_tsne_plot(valid_results.posterior_pres, valid_results.label,
                 figure_filename)
  def _build(self, h, x, presence=None):
    """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
    batch_size, n_input_points, _ = x.shape.as_list()
    res = AttrDict(
        dynamic_weights_l2=tf.constant(0.)
    )

    output_shapes = (
        [1],  # per-capsule presence
        [self._n_votes],  # per-vote-presence
        [self._n_votes],  # per-vote scale
        [self._n_votes, self._n_caps_dims]
    )

    splits = [np.prod(i).astype(np.int32) for i in output_shapes]
    n_outputs = sum(splits)
    batch_mlp = neural.BatchMLP([self._n_hiddens, self._n_hiddens, n_outputs],
                                use_bias=True)

    all_params = batch_mlp(h)
    all_params = tf.split(all_params, splits, -1)
    batch_shape = [batch_size, self._n_caps]
    all_params = [tf.reshape(i, batch_shape + s)
                  for (i, s) in zip(all_params, output_shapes)]

    def add_noise(tensor):
      return tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * 4.

    res.pres_logit_per_caps = add_noise(all_params[0])
    res.pres_logit_per_vote = add_noise(all_params[1])
    res.scale = tf.nn.softplus(all_params[2] + .5) + 1e-6
    res.vote_presence = (tf.nn.sigmoid(res.pres_logit_per_caps)
                         * tf.nn.sigmoid(res.pres_logit_per_vote))
    res.vote = all_params[3]

    for k, v in res.items():
      if v.shape.ndims > 0:
        res[k] = snt.MergeDims(1, 2)(v)

    likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes,
                                                          res.vote, res.scale,
                                                          res.vote_presence)
    ll_res = likelihood(x, presence)
    res.update(ll_res._asdict())

    # post processing
    mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
    prior_mixing_log_prob = tf.log(1. / n_input_points)
    mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob)
    mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

    wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

    if presence is not None:
      wins_per_caps *= tf.expand_dims(presence, -1)

    wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

    has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
    should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

    sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=should_be_active, logits=res.pres_logit_per_caps)

    sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
    sparsity_loss = tf.reduce_mean(sparsity_loss)

    caps_presence_prob = tf.reduce_max(
        tf.reshape(res.vote_presence,
                   [batch_size, self._n_caps, self._n_votes]), 2)

    res.update(dict(
        mixing_kl=mixing_kl,
        sparsity_loss=sparsity_loss,
        caps_presence_prob=caps_presence_prob,
        mean_scale=tf.reduce_mean(res.scale)
    ))
    return res