def __init__(self, update_ops=None, collections=_tools.GraphKeys.CUSTOM_UPDATE_OPS, save_secs=None, save_steps=None): """Builds the object. Args: update_ops: an op or a list of ops. collections: tensorflow collection key or a list of collections keys. save_secs: number of seconds to wait until saving. save_steps: number of steps to wait until saving. """ logging.info('Create UpdateOpsHook.') update_ops = nest.flatten(update_ops) if update_ops else [] if collections: for collection in nest.flatten(collections): update_ops.extend(tf.get_collection(collection)) self._update_op = tf.group(update_ops) self._timer = tf.train.SecondOrStepTimer(every_secs=save_secs, every_steps=save_steps) self._steps_per_run = 1 self._last_run = -1
def _build(self, data): x = data[self._input_key] presence = data[self._presence_key] if self._presence_key else None inputs = nest.flatten(x) if presence is not None: inputs.append(presence) h = self._encoder(*inputs) res = self._decoder(h, *inputs) n_points = int(res.posterior_mixing_probs.shape[1]) mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1) (res.posterior_within_sparsity_loss, res.posterior_between_sparsity_loss) = _capsule.sparsity_loss( self._posterior_sparsity_loss_type, mass_explained_by_capsule / n_points, num_classes=self._n_classes) (res.prior_within_sparsity_loss, res.prior_between_sparsity_loss) = _capsule.sparsity_loss( self._prior_sparsity_loss_type, res.caps_presence_prob, num_classes=self._n_classes, within_example_constant=self._prior_within_example_constant) return res
def __init__(self, tfrecord_files, img_shape, labeled=False): super(Dataset, self).__init__() self._tfrecords = nest.flatten(tfrecord_files) self._img_shape = img_shape tf.logging.info(tfrecord_files) tf.logging.info(labeled) if labeled: self._feature_description['labeled'] = tf.FixedLenFeature([], tf.int64)
def ensure_length(x, length): """Enusres that the input is an array of a given length. Args: x: tensor or a list of tensors. length: int. Returns: list of tensors of a given length. """ x = nest.flatten(x) if len(x) == 1: x *= length return x
def __init__(self, n_caps, n_caps_dims, n_votes, n_caps_params=None, n_hiddens=128, learn_vote_scale=False, deformations=True, noise_type=None, noise_scale=0., similarity_transform=True, caps_dropout_rate=0.0): """Builds the module. Args: n_caps: int, number of capsules. n_caps_dims: int, number of capsule coordinates. n_votes: int, number of votes generated by each capsule. n_caps_params: int, number of capsule parameters or None. If it is None, then the module uses encoder features directly. n_hiddens: int or sequence of ints, number of hidden units for an MLP which predicts capsule params from the input encoding. learn_vote_scale: bool, learns input-dependent scale for each capsules' votes. deformations: bool, allows input-dependent deformations of capsule-part relationships. noise_type: 'normal', 'logistic' or None; noise type injected into presence logits. noise_scale: float >= 0. scale parameters for the noise. similarity_transform: boolean; uses similarity transforms if True. caps_dropout_rate: float in [0, 1]. """ super(CapsuleLayer, self).__init__() self._n_caps = n_caps self._n_caps_dims = n_caps_dims self._n_caps_params = n_caps_params self._n_votes = n_votes self._n_hiddens = nest.flatten(n_hiddens) self._learn_vote_scale = learn_vote_scale self._deformations = deformations self._noise_type = noise_type self._noise_scale = noise_scale self._similarity_transform = similarity_transform self._caps_dropout_rate = caps_dropout_rate assert n_caps_dims == 2, ( 'This is the only value implemented now due to ' 'the restriction of similarity transform.')
def __init__(self, n_hiddens, activation=tf.nn.relu, activate_final=False, initializers=None, use_bias=True, tile_dims=(0,)): super(BatchMLP, self).__init__() self._n_hiddens = nest.flatten(n_hiddens) self._activation = activation self._activate_final = activate_final self._initializers = initializers self._use_bias = use_bias self._tile_dims = tile_dims
def create(which, batch_size, subset=None, n_replicas=1, transforms=None, **kwargs): """Creates data loaders according to the dataset name `which`.""" func = globals().get('_create_{}'.format(which), None) if func is None: raise ValueError('Dataset "{}" not supported. Only {} are' ' supported.'.format(which, SUPPORTED_DATSETS)) dataset = func(subset, batch_size, **kwargs) if transforms is not None: if not isinstance(transforms, dict): transforms = {'image': transforms} for k, v in transforms.items(): transforms[k] = snt.Sequential(nest.flatten(v)) if transforms is not None or n_replicas > 1: def map_func(data): """Replicates data if necessary.""" data = dict(data) if n_replicas > 1: tile_by_batch = snt.TileByDim([0], [n_replicas]) data = {k: tile_by_batch(v) for k, v in data.items()} if transforms is not None: img = data['image'] for k, transform in transforms.items(): data[k] = transform(img) return data dataset = dataset.map(map_func) iter_data = dataset.make_one_shot_iterator() input_batch = iter_data.get_next() for _, v in input_batch.items(): v.set_shape([batch_size * n_replicas] + v.shape[1:].as_list()) return input_batch
def __init__(self, mixing_logits, component_stats, component_class, presence=None): """Builds the module. Args: mixing_logits: tensor [B, k, ...] with k the number of components. component_stats: list of tensors of shape [B, k, ...] or broadcastable to these shapes; they are argument to the chosen distribution class. component_class: callable; returns a distribution object. presence: [B, k] tensor of floats in [0, 1] or None. """ super(MixtureDistribution, self).__init__() if presence is not None: mixing_logits += make_brodcastable(safe_log(presence), mixing_logits) self._mixing_logits = mixing_logits component_stats = nest.flatten(component_stats) self._distributions = component_class(*component_stats) self._presence = presence