def _build(self, inputs, memory, treat_input_as_matrix=False): """Adds relational memory to the TensorFlow graph. Args: inputs: Tensor input. memory: Memory output from the previous time step. treat_input_as_matrix: Optional, whether to treat `input` as a sequence of matrices. Defaulta to False, in which case the input is flattened into a vector. Returns: output: This time step's output. next_memory: The next version of memory to use. """ if treat_input_as_matrix: inputs = basic.BatchFlatten(preserve_dims=2)(inputs) inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size), n_dims=2)(inputs) else: inputs = basic.BatchFlatten()(inputs) inputs = basic.Linear(self._mem_size)(inputs) inputs_reshape = tf.expand_dims(inputs, 1) memory_plus_input = tf.concat([memory, inputs_reshape], axis=1) next_memory = self._attend_over_memory(memory_plus_input) n = inputs_reshape.get_shape().as_list()[1] next_memory = next_memory[:, :-n, :] if self._gate_style == 'unit' or self._gate_style == 'memory': self._input_gate, self._forget_gate = self._create_gates( inputs_reshape, memory) next_memory = self._input_gate * tf.tanh(next_memory) next_memory += self._forget_gate * memory output = basic.BatchFlatten()(next_memory) return output, next_memory
def _create_gates(self, inputs, memory): """Create input and forget gates for this step using `inputs` and `memory`. Args: inputs: Tensor input. memory: The current state of memory. Returns: input_gate: A LSTM-like insert gate. forget_gate: A LSTM-like forget gate. """ # We'll create the input and forget gates at once. Hence, calculate double # the gate size. num_gates = 2 * self._calculate_gate_size() memory = tf.tanh(memory) inputs = basic.BatchFlatten()(inputs) gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs) gate_inputs = tf.expand_dims(gate_inputs, axis=1) gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory) gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2) input_gate, forget_gate = gates input_gate = tf.sigmoid(input_gate + self._input_bias) forget_gate = tf.sigmoid(forget_gate + self._forget_bias) return input_gate, forget_gate
def _multihead_attention(self, memory): """Perform multi-head attention from 'Attention is All You Need'. Implementation of the attention mechanism from https://arxiv.org/abs/1706.03762. Args: memory: Memory tensor to perform attention on. Returns: new_memory: New memory tensor. """ key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads # Denote as F. qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] # Denoted as N. # [B, N, F] -> [B, N, H, F/H] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) # [B, N, H, F/H] -> [B, H, N, F/H] qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) # [B, H, N, N] weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) # [B, H, N, V] # [B, H, N, V] -> [B, N, H, V] output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, N, H, V] -> [B, N, H * V] new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory
def _affine_grid_warper_inverse(inputs): """Assembles network to compute inverse affine transformation. Each `inputs` row potentailly contains [a, b, tx, c, d, ty] corresponding to an affine matrix: A = [a, b, tx], [c, d, ty] We want to generate a tensor containing the coefficients of the corresponding inverse affine transformation in a constraints-aware fashion. Calling M: M = [a, b] [c, d] the affine matrix for the inverse transform is: A_in = [M^(-1), M^-1 * [-tx, -tx]^T] where M^(-1) = (ad - bc)^(-1) * [ d, -b] [-c, a] Args: inputs: Tensor containing a batch of transformation parameters. Returns: A tensorflow graph performing the inverse affine transformation parametrized by the input coefficients. """ batch_size = tf.expand_dims(tf.shape(inputs)[0], 0) constant_shape = tf.concat([batch_size, tf.convert_to_tensor((1,))], 0) index = iter(range(6)) def get_variable(constraint): if constraint is None: i = index.next() return inputs[:, i:i+1] else: return tf.fill(constant_shape, tf.constant(constraint, dtype=inputs.dtype)) constraints = chain.from_iterable(self.constraints) a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints) det = a * d - b * c a_inv = d / det b_inv = -b / det c_inv = -c / det d_inv = a / det m_inv = basic.BatchReshape( [2, 2])(tf.concat([a_inv, b_inv, c_inv, d_inv], 1)) txy = tf.expand_dims(tf.concat([tx, ty], 1), 2) txy_inv = basic.BatchFlatten()(tf.matmul(m_inv, txy)) tx_inv = txy_inv[:, 0:1] ty_inv = txy_inv[:, 1:2] inverse_gw_inputs = tf.concat( [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1) agw = AffineGridWarper(self.output_shape, self.source_shape) return agw(inverse_gw_inputs) # pylint: disable=not-callable
def _build(self, inputs, keep_prob=None, is_training=True, test_local_stats=True): """Connects the AlexNet module into the graph. Args: inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images. keep_prob: A scalar Tensor representing the dropout keep probability. is_training: Boolean to indicate to `snt.BatchNorm` if we are currently training. By default `True`. test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch normalization should use local batch statistics at test time. By default `True`. Returns: A Tensor of size [batch_size, output_size], where `output_size` depends on the mode the network was constructed in. Raises: base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode. """ input_shape = inputs.get_shape().as_list() if input_shape[1] < self._min_size or input_shape[2] < self._min_size: raise base.IncompatibleShapeError( "Image shape too small: ({:d}, {:d}) < {:d}".format( input_shape[1], input_shape[2], self._min_size)) net = inputs for i, params in enumerate(self._conv_layers): output_channels, conv_params, max_pooling = params kernel_size, stride = conv_params conv_mod = conv.Conv2D(name="conv_{}".format(i), output_channels=output_channels, kernel_shape=kernel_size, stride=stride, padding=conv.VALID, initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers) if not self.is_connected: self._conv_modules.append(conv_mod) net = conv_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if max_pooling is not None: pooling_kernel_size, pooling_stride = max_pooling net = tf.nn.max_pool( net, ksize=[1, pooling_kernel_size, pooling_kernel_size, 1], strides=[1, pooling_stride, pooling_stride, 1], padding=conv.VALID) net = basic.BatchFlatten(name="flatten")(net) for i, output_size in enumerate(self._fc_layers): linear_mod = basic.Linear(name="fc_{}".format(i), output_size=output_size, partitioners=self._partitioners) if not self.is_connected: self._linear_modules.append(linear_mod) net = linear_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if keep_prob is not None: net = tf.nn.dropout(net, keep_prob=keep_prob) return net
def _build(self, inputs, keep_prob=None, is_training=None, test_local_stats=True): """Connects the AlexNet module into the graph. The is_training flag only controls the batch norm settings, if `False` it does not force no dropout by overriding any input `keep_prob`. To avoid any confusion this may cause, if `is_training=False` and `keep_prob` would cause dropout to be applied, an error is thrown. Args: inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images. keep_prob: A scalar Tensor representing the dropout keep probability. When `is_training=False` this must be None or 1 to give no dropout. is_training: Boolean to indicate if we are currently training. Must be specified if batch normalization or dropout is used. test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch normalization should use local batch statistics at test time. By default `True`. Returns: A Tensor of size [batch_size, output_size], where `output_size` depends on the mode the network was constructed in. Raises: base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode. ValueError: If `keep_prob` is not None or 1 when `is_training=False`. ValueError: If `is_training` is not explicitly specified when using batch normalization. """ # Check input shape if (self._use_batch_norm or keep_prob is not None) and is_training is None: raise ValueError( "Boolean is_training flag must be explicitly specified " "when using batch normalization or dropout.") input_shape = inputs.get_shape().as_list() if input_shape[1] < self._min_size or input_shape[2] < self._min_size: raise base.IncompatibleShapeError( "Image shape too small: ({:d}, {:d}) < {:d}".format( input_shape[1], input_shape[2], self._min_size)) net = inputs # Check keep prob if keep_prob is not None: valid_inputs = tf.logical_or(is_training, tf.equal(keep_prob, 1.)) keep_prob_check = tf.assert_equal( valid_inputs, True, message= "Input `keep_prob` must be None or 1 if `is_training=False`.") with tf.control_dependencies([keep_prob_check]): net = tf.identity(net) for i, params in enumerate(self._conv_layers): output_channels, conv_params, max_pooling = params kernel_size, stride = conv_params conv_mod = conv.Conv2D(name="conv_{}".format(i), output_channels=output_channels, kernel_shape=kernel_size, stride=stride, padding=conv.VALID, initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers) if not self.is_connected: self._conv_modules.append(conv_mod) net = conv_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if max_pooling is not None: pooling_kernel_size, pooling_stride = max_pooling net = tf.nn.max_pool( net, ksize=[1, pooling_kernel_size, pooling_kernel_size, 1], strides=[1, pooling_stride, pooling_stride, 1], padding=conv.VALID) net = basic.BatchFlatten(name="flatten")(net) for i, output_size in enumerate(self._fc_layers): linear_mod = basic.Linear(name="fc_{}".format(i), output_size=output_size, initializers=self._initializers, partitioners=self._partitioners) if not self.is_connected: self._linear_modules.append(linear_mod) net = linear_mod(net) if self._use_batch_norm and self._bn_on_fc_layers: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if keep_prob is not None: net = tf.nn.dropout(net, keep_prob=keep_prob) return net