예제 #1
0
  def __init__(self,
               update_ops=None,
               collections=_tools.GraphKeys.CUSTOM_UPDATE_OPS,
               save_secs=None,
               save_steps=None):
    """Builds the object.

    Args:
      update_ops: an op or a list of ops.
      collections: tensorflow collection key or a list of collections keys.
      save_secs: number of seconds to wait until saving.
      save_steps: number of steps to wait until saving.
    """

    logging.info('Create UpdateOpsHook.')

    update_ops = nest.flatten(update_ops) if update_ops else []
    if collections:
      for collection in nest.flatten(collections):
        update_ops.extend(tf.get_collection(collection))

    self._update_op = tf.group(update_ops)
    self._timer = tf.train.SecondOrStepTimer(every_secs=save_secs,
                                             every_steps=save_steps)

    self._steps_per_run = 1
    self._last_run = -1
  def _build(self, data):

    x = data[self._input_key]
    presence = data[self._presence_key] if self._presence_key else None

    inputs = nest.flatten(x)
    if presence is not None:
      inputs.append(presence)

    h = self._encoder(*inputs)
    res = self._decoder(h, *inputs)

    n_points = int(res.posterior_mixing_probs.shape[1])
    mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)

    (res.posterior_within_sparsity_loss,
     res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._posterior_sparsity_loss_type,
         mass_explained_by_capsule / n_points,
         num_classes=self._n_classes)

    (res.prior_within_sparsity_loss,
     res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._prior_sparsity_loss_type,
         res.caps_presence_prob,
         num_classes=self._n_classes,
         within_example_constant=self._prior_within_example_constant)

    return res
예제 #3
0
 def __init__(self, tfrecord_files, img_shape, labeled=False):
   super(Dataset, self).__init__()
   self._tfrecords = nest.flatten(tfrecord_files)
   self._img_shape = img_shape
   tf.logging.info(tfrecord_files)
   tf.logging.info(labeled)
   if labeled:
     self._feature_description['labeled'] = tf.FixedLenFeature([], tf.int64)
예제 #4
0
def ensure_length(x, length):
    """Enusres that the input is an array of a given length.
  Args:
    x: tensor or a list of tensors.
    length: int.
  Returns:
    list of tensors of a given length.
  """
    x = nest.flatten(x)
    if len(x) == 1:
        x *= length

    return x
예제 #5
0
    def __init__(self,
                 n_caps,
                 n_caps_dims,
                 n_votes,
                 n_caps_params=None,
                 n_hiddens=128,
                 learn_vote_scale=False,
                 deformations=True,
                 noise_type=None,
                 noise_scale=0.,
                 similarity_transform=True,
                 caps_dropout_rate=0.0):
        """Builds the module.

    Args:
      n_caps: int, number of capsules.
      n_caps_dims: int, number of capsule coordinates.
      n_votes: int, number of votes generated by each capsule.
      n_caps_params: int, number of capsule parameters or None. If it is None,
        then the module uses encoder features directly.
      n_hiddens: int or sequence of ints, number of hidden units for an MLP
        which predicts capsule params from the input encoding.
      learn_vote_scale: bool, learns input-dependent scale for each
        capsules' votes.
      deformations: bool, allows input-dependent deformations of capsule-part
        relationships.
      noise_type: 'normal', 'logistic' or None; noise type injected into
        presence logits.
      noise_scale: float >= 0. scale parameters for the noise.
      similarity_transform: boolean; uses similarity transforms if True.
      caps_dropout_rate: float in [0, 1].
    """
        super(CapsuleLayer, self).__init__()
        self._n_caps = n_caps
        self._n_caps_dims = n_caps_dims
        self._n_caps_params = n_caps_params

        self._n_votes = n_votes
        self._n_hiddens = nest.flatten(n_hiddens)
        self._learn_vote_scale = learn_vote_scale
        self._deformations = deformations
        self._noise_type = noise_type
        self._noise_scale = noise_scale

        self._similarity_transform = similarity_transform
        self._caps_dropout_rate = caps_dropout_rate

        assert n_caps_dims == 2, (
            'This is the only value implemented now due to '
            'the restriction of similarity transform.')
예제 #6
0
  def __init__(self, n_hiddens,
               activation=tf.nn.relu,
               activate_final=False,
               initializers=None,
               use_bias=True,
               tile_dims=(0,)):

    super(BatchMLP, self).__init__()
    self._n_hiddens = nest.flatten(n_hiddens)
    self._activation = activation
    self._activate_final = activate_final
    self._initializers = initializers
    self._use_bias = use_bias
    self._tile_dims = tile_dims
def create(which,
           batch_size,
           subset=None,
           n_replicas=1,
           transforms=None,
           **kwargs):
    """Creates data loaders according to the dataset name `which`."""

    func = globals().get('_create_{}'.format(which), None)
    if func is None:
        raise ValueError('Dataset "{}" not supported. Only {} are'
                         ' supported.'.format(which, SUPPORTED_DATSETS))

    dataset = func(subset, batch_size, **kwargs)

    if transforms is not None:
        if not isinstance(transforms, dict):
            transforms = {'image': transforms}

        for k, v in transforms.items():
            transforms[k] = snt.Sequential(nest.flatten(v))

    if transforms is not None or n_replicas > 1:

        def map_func(data):
            """Replicates data if necessary."""
            data = dict(data)

            if n_replicas > 1:
                tile_by_batch = snt.TileByDim([0], [n_replicas])
                data = {k: tile_by_batch(v) for k, v in data.items()}

            if transforms is not None:
                img = data['image']

                for k, transform in transforms.items():
                    data[k] = transform(img)

            return data

        dataset = dataset.map(map_func)

    iter_data = dataset.make_one_shot_iterator()
    input_batch = iter_data.get_next()
    for _, v in input_batch.items():
        v.set_shape([batch_size * n_replicas] + v.shape[1:].as_list())

    return input_batch
예제 #8
0
  def __init__(self, mixing_logits, component_stats, component_class,
               presence=None):
    """Builds the module.
    Args:
      mixing_logits: tensor [B, k, ...] with k the number of components.
      component_stats: list of tensors of shape [B, k, ...] or broadcastable
        to these shapes; they are argument to the chosen distribution class.
      component_class: callable; returns a distribution object.
      presence: [B, k] tensor of floats in [0, 1] or None.
    """
    super(MixtureDistribution, self).__init__()
    if presence is not None:
      mixing_logits += make_brodcastable(safe_log(presence), mixing_logits)

    self._mixing_logits = mixing_logits

    component_stats = nest.flatten(component_stats)
    self._distributions = component_class(*component_stats)
    self._presence = presence