def lenet5(input_shape, num_classes): """Builds LeNet5.""" inputs = tf.keras.layers.Input(shape=input_shape) conv1 = tf.keras.layers.Conv2D(6, kernel_size=5, padding='SAME', activation='relu')(inputs) pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME')(conv1) conv2 = tf.keras.layers.Conv2D(16, kernel_size=5, padding='SAME', activation='relu')(pool1) pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME')(conv2) conv3 = tf.keras.layers.Conv2D(120, kernel_size=5, padding='SAME', activation=tf.nn.relu)(pool2) flatten = tf.keras.layers.Flatten()(conv3) dense1 = tf.keras.layers.Dense(84, activation=tf.nn.relu)(flatten) logits = tf.keras.layers.Dense(num_classes)(dense1) outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))( logits) return tf.keras.Model(inputs=inputs, outputs=outputs)
def next_state(self, old_state, inputs, parameters=None): """Samples a state transition conditioned on a previous state and input. Args: old_state: a Value whose `state` key represents the previous state. inputs: a Value whose `input` key represents the inputs. parameters: optionally a `Value` with fields corresponding to the tensor- valued entity parameters to be set at simulation time. Returns: A `Value` containing the sampled state as well as any additional random variables sampled during state generation. Raises: RuntimeError: if `parameters` has neither been provided here nor at construction. """ if parameters is None: parameters = self._get_static_parameters_or_die() kernel_params = parameters.get('transition_parameters') action_cond_kernel_params = tf.gather(kernel_params, inputs.get('input'), batch_dims=self._batch_dims) state_cond_kernel_params = tf.gather(action_cond_kernel_params, old_state.get('state'), batch_dims=self._batch_dims) return Value(state=ed.Categorical(logits=state_cond_kernel_params))
def initial_state(self): """The initial state value.""" # 70% topics are trashy, rest are nutritious. num_trashy_topics = int(self._num_topics * 0.7) num_nutritious_topics = self._num_topics - num_trashy_topics trashy = tf.linspace(self._topic_min_utility, 0., num_trashy_topics) nutritious = tf.linspace(0., self._topic_max_utility, num_nutritious_topics) topic_quality_means = tf.concat([trashy, nutritious], axis=0) # Equal probability of each topic. doc_topic = ed.Categorical( logits=tf.zeros((self._num_docs, self._num_topics)), dtype=tf.int32) # Fixed variance for doc quality. doc_quality_var = 0.1 doc_quality = ed.Normal( loc=tf.gather(topic_quality_means, doc_topic), scale=doc_quality_var) # 1-hot doc features. doc_features = ed.Normal( loc=tf.one_hot(doc_topic, depth=self._num_topics), scale=0.7) # All videos have same length. video_length = ed.Deterministic( loc=tf.ones((self._num_docs,)) * self._video_length) return Value( # doc_id=0 is reserved for "null" doc. doc_id=ed.Deterministic( loc=tf.range(start=1, limit=self._num_docs + 1, dtype=tf.int32)), doc_topic=doc_topic, doc_quality=doc_quality, doc_features=doc_features, doc_length=video_length)
def resnet_v1(input_shape, depth, num_classes, batch_norm): """Builds ResNet v1. Args: input_shape: tf.Tensor. depth: ResNet depth. num_classes: Number of output classes. batch_norm: Whether to apply batch normalization. Returns: tf.keras.Model. """ num_res_blocks = (depth - 2) // 6 filters = 16 if (depth - 2) % 6 != 0: raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).') logging.info('Starting ResNet build.') inputs = tf.keras.layers.Input(shape=input_shape) x = resnet_layer(inputs, filters=filters, activation='relu') for stack in range(3): for res_block in range(num_res_blocks): logging.info('Starting ResNet stack #%d block #%d.', stack, res_block) strides = 1 if stack > 0 and res_block == 0: # first layer but not first stack strides = 2 # downsample y = resnet_layer(x, filters=filters, strides=strides, activation='relu', batch_norm=batch_norm) y = resnet_layer(y, filters=filters, activation=None, batch_norm=batch_norm) if stack > 0 and res_block == 0: # first layer but not first stack # linear projection residual shortcut connection to match changed dims x = resnet_layer(x, filters=filters, kernel_size=1, strides=strides, activation=None, batch_norm=False) x = tf.keras.layers.add([x, y]) x = tf.keras.layers.Activation('relu')(x) filters *= 2 # v1 does not use BN after last shortcut connection-ReLU x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = tf.keras.layers.Dense(num_classes, kernel_initializer='he_normal')(x) outputs = tf.keras.layers.Lambda( lambda inputs: ed.Categorical(logits=inputs))(x) return tf.keras.models.Model(inputs=inputs, outputs=outputs)
def choice(self, slate_document_logits): """Samples a choice from a set of items. Args: slate_document_logits: a tensor with shape [b1, ..., bk, slate_size] representing the logits of each item in the slate. Returns: A `Value` containing choice random variables with shape [b1, ..., bk]. """ n = tf.shape(slate_document_logits)[-1] positional_bias = tf.expand_dims( tf.linspace(0., self._positional_bias * tf.cast(n - 1, tf.float32), n), 0) slate_document_logits0 = tf.concat( (slate_document_logits + positional_bias, self._nochoice_logits), axis=-1) return Value(choice=ed.Categorical(logits=slate_document_logits0, name='choice_Categorical'))
def initial_state(self, parameters=None): """Samples a state tensor for a batch of actors. Args: parameters: optionally a `Value` with fields corresponding to the tensor- valued entity parameters to be set at simulation time. Returns: A `Value` containing the sampled state as well as any additional random variables sampled during state generation. Raises: RuntimeError: if `parameters` has neither been provided here nor at construction. """ if parameters is None: parameters = self._get_static_parameters_or_die() return Value(state=ed.Categorical( logits=parameters.get('initial_dist_logits')))
def res_net(n_examples, input_shape, num_classes, batchnorm=False, variational='full'): """Wrapper for build_resnet_v1. Args: n_examples (int): number of training points. input_shape (list): input shape. num_classes (int): number of classes (CIFAR10 has 10). batchnorm (bool): use of batchnorm layers. variational (str): 'none', 'hybrid', 'full'. whether to use variational inference for zero, some, or all layers. Returns: model (Model): Keras model instance whose output is a tfp.distributions.Categorical distribution. """ inputs = tf.keras.layers.Input(shape=input_shape) x = build_resnet_v1(inputs, depth=20, variational=variational, batchnorm=batchnorm, n_examples=n_examples) p_fn, q_fn = mean_field_fn(empirical_bayes=True) def normalized_kl_fn(q, p, _): return tfp.distributions.kl_divergence(q, p) / tf.to_float(n_examples) logits = tfp.layers.DenseLocalReparameterization( num_classes, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(x) outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))( logits) return tf.keras.models.Model(inputs=inputs, outputs=outputs)
def lenet5(n_examples, input_shape, num_classes): """Builds Bayesian LeNet5.""" p_fn, q_fn = mean_field_fn(empirical_bayes=True) def normalized_kl_fn(q, p, _): return q.kl_divergence(p) / tf.cast(n_examples, tf.float32) inputs = tf.keras.layers.Input(shape=input_shape) conv1 = tfp.layers.Convolution2DFlipout( 6, kernel_size=5, padding='SAME', activation=tf.nn.relu, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(inputs) pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME')(conv1) conv2 = tfp.layers.Convolution2DFlipout( 16, kernel_size=5, padding='SAME', activation=tf.nn.relu, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(pool1) pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME')(conv2) conv3 = tfp.layers.Convolution2DFlipout( 120, kernel_size=5, padding='SAME', activation=tf.nn.relu, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(pool2) flatten = tf.keras.layers.Flatten()(conv3) dense1 = tfp.layers.DenseLocalReparameterization( 84, activation=tf.nn.relu, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(flatten) dense2 = tfp.layers.DenseLocalReparameterization( num_classes, kernel_prior_fn=p_fn, kernel_posterior_fn=q_fn, bias_prior_fn=p_fn, bias_posterior_fn=q_fn, kernel_divergence_fn=normalized_kl_fn, bias_divergence_fn=normalized_kl_fn)(dense1) outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))( dense2) return tf.keras.models.Model(inputs=inputs, outputs=outputs)
def resnet_v1(input_shape, depth, num_classes, batch_norm, prior_stddev, dataset_size): """Builds ResNet v1. Args: input_shape: tf.Tensor. depth: ResNet depth. num_classes: Number of output classes. batch_norm: Whether to apply batch normalization. prior_stddev: Standard deviation of weight priors. dataset_size: Total number of examples in an epoch. Returns: tf.keras.Model. """ num_res_blocks = (depth - 2) // 6 filters = 16 if (depth - 2) % 6 != 0: raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).') layer = functools.partial(resnet_layer, depth=depth, dataset_size=dataset_size, prior_stddev=prior_stddev) logging.info('Starting ResNet build.') inputs = tf.keras.layers.Input(shape=input_shape) x = layer(inputs, filters=filters, activation='selu') for stack in range(3): for res_block in range(num_res_blocks): logging.info('Starting ResNet stack #%d block #%d.', stack, res_block) strides = 1 if stack > 0 and res_block == 0: # first layer but not first stack strides = 2 # downsample y = layer(x, filters=filters, strides=strides, activation='selu', batch_norm=batch_norm) y = layer(y, filters=filters, activation=None, batch_norm=batch_norm, bayesian=True) if stack > 0 and res_block == 0: # first layer but not first stack # linear projection residual shortcut connection to match changed dims x = layer(x, filters=filters, kernel_size=1, strides=strides, activation=None, batch_norm=False) x = tf.keras.layers.add([x, y]) x = tf.keras.layers.Activation('selu')(x) filters *= 2 # v1 does not use BN after last shortcut connection-ReLU x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = ed.layers.DenseVariationalDropout( num_classes, kernel_initializer='trainable_he_normal', kernel_regularizer=NormalKLDivergenceWithTiedMean(stddev=prior_stddev, scale_factor=1. / dataset_size))(x) outputs = tf.keras.layers.Lambda( lambda inputs: ed.Categorical(logits=inputs))(x) return tf.keras.models.Model(inputs=inputs, outputs=outputs)
def next_state(self, old_state, inputs, parameters=None): det_n_value = super().next_state(old_state, inputs, parameters) return Value( state=det_n_value.get('state'), state2=ed.Categorical(logits=det_n_value.get('state')))
def initial_state(self, parameters=None): det_i_value = super().initial_state(parameters=parameters) return Value( state=det_i_value.get('state'), state2=ed.Categorical(logits=det_i_value.get('state')))
def initial_state(self, parameters=None): del parameters return Value(state=ed.Categorical( logits=tf.ones(tf.range(1, self._batch_ndims + 2))))