def testObservesWrappedFunction(self): activation_module = base.Module(tf.nn.relu) with base.observe_connections(self._connection_observer): outputs = activation_module(self._inputs) self.assertEqual(1, len(self._connected_subgraphs)) self.assertIs(activation_module, self._connected_subgraphs[0].module) self.assertIs(self._inputs, self._connected_subgraphs[0].inputs["args"][0]) self.assertIs(self._connected_subgraphs[0].outputs, outputs)
def testSharing(self): batch_size = 3 in_size = 4 input_data = np.random.rand(batch_size, in_size) inputs1 = tf.constant(input_data) inputs2 = tf.constant(input_data) build = functools.partial(_make_model_with_params, output_size=10) model = base.Module(build) self.assertEqual(model.scope_name, "make_model_with_params") outputs1 = model(inputs1) outputs2 = model(inputs2) self.evaluate(tf.global_variables_initializer()) outputs1, outputs2 = self.evaluate([outputs1, outputs2]) self.assertAllClose(outputs1, outputs2)
def testSharing(self): batch_size = 3 in_size = 4 inputs1 = tf.placeholder(tf.float32, shape=[batch_size, in_size]) inputs2 = tf.placeholder(tf.float32, shape=[batch_size, in_size]) model = base.Module(build=partial(_make_model_with_params, output_size=10)) outputs1 = model(inputs1) outputs2 = model(inputs2) input_data = np.random.rand(batch_size, in_size) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs1, outputs2 = sess.run( [outputs1, outputs2], feed_dict={inputs1: input_data, inputs2: input_data}) self.assertAllClose(outputs1, outputs2)
def inverse(self, name=None): """Returns a `sonnet` module to compute inverse affine transforms. The function first assembles a network that given the constraints of the current AffineGridWarper and a set of input parameters, retrieves the coefficients of the corresponding inverse affine transform, then feeds its output into a new AffineGridWarper setup to correctly warp the `output` space into the `source` space. Args: name: Name of module implementing the inverse grid transformation. Returns: A `sonnet` module performing the inverse affine transform of a reference grid of points via an AffineGridWarper module. Raises: tf.errors.UnimplementedError: If the function is called on a non 2D instance of AffineGridWarper. """ if self._num_coeff != 6: raise tf.errors.UnimplementedError('AffineGridWarper currently supports' 'inversion only for the 2D case.') def _affine_grid_warper_inverse(inputs): """Assembles network to compute inverse affine transformation. Each `inputs` row potentailly contains [a, b, tx, c, d, ty] corresponding to an affine matrix: A = [a, b, tx], [c, d, ty] We want to generate a tensor containing the coefficients of the corresponding inverse affine transformation in a constraints-aware fashion. Calling M: M = [a, b] [c, d] the affine matrix for the inverse transform is: A_in = [M^(-1), M^-1 * [-tx, -tx]^T] where M^(-1) = (ad - bc)^(-1) * [ d, -b] [-c, a] Args: inputs: Tensor containing a batch of transformation parameters. Returns: A tensorflow graph performing the inverse affine transformation parametrized by the input coefficients. """ batch_size = tf.expand_dims(tf.shape(inputs)[0], 0) constant_shape = tf.concat([batch_size, tf.convert_to_tensor((1,))], 0) index = iter(range(6)) def get_variable(constraint): if constraint is None: i = index.next() return inputs[:, i:i+1] else: return tf.fill(constant_shape, tf.constant(constraint, dtype=inputs.dtype)) constraints = chain.from_iterable(self.constraints) a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints) det = a * d - b * c a_inv = d / det b_inv = -b / det c_inv = -c / det d_inv = a / det m_inv = basic.BatchReshape( [2, 2])(tf.concat([a_inv, b_inv, c_inv, d_inv], 1)) txy = tf.expand_dims(tf.concat([tx, ty], 1), 2) txy_inv = basic.BatchFlatten()(tf.matmul(m_inv, txy)) tx_inv = txy_inv[:, 0:1] ty_inv = txy_inv[:, 1:2] inverse_gw_inputs = tf.concat( [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1) agw = AffineGridWarper(self.output_shape, self.source_shape) return agw(inverse_gw_inputs) # pylint: disable=not-callable if name is None: name = self.module_name + '_inverse' return base.Module(_affine_grid_warper_inverse, name=name)
def testFunctionType(self): with self.assertRaises(TypeError) as cm: base.Module(build="not_a_function") self.assertEqual(str(cm.exception), "Input 'build' must be callable.")