示例#1
0
    def _RunModel(self, input_1_val, input_2_val, dropout, custom_grad):
        self._ClearCachedSession()
        tf.reset_default_graph()
        with self.session() as sess:
            tf.random.set_seed(321)
            input_1 = tf.placeholder(tf.float32)
            input_2 = tf.placeholder(tf.float32)

            revnet_params = self._SimpleRevNetParams('revnet', dropout,
                                                     custom_grad)
            revnet_params.params_init = py_utils.WeightInit.Xavier(scale=1.0,
                                                                   seed=0)
            revnet = revnet_params.Instantiate()

            h = revnet.FPropDefaultTheta(
                py_utils.NestedMap(split1=input_1, split2=input_2))

            dfx = tf.gradients(h.Flatten(), [input_1, input_2])
            dfw = tf.gradients(
                h.Flatten(),
                revnet.theta.Flatten(),
                unconnected_gradients=tf.UnconnectedGradients.ZERO)

            tf.global_variables_initializer().run()
            dfx_val, dfw_val, h_val = sess.run([dfx, dfw, h],
                                               feed_dict={
                                                   input_1: input_1_val,
                                                   input_2: input_2_val,
                                               })
            return h_val, dfx_val, dfw_val
示例#2
0
    def setUp(self):
        """Generates test minibatches, the expected output and the layer config."""

        super().setUp()

        tf.reset_default_graph()

        # Generate the common list of minibatches
        self.input_batch_list = self._GenerateListOfMinibatches()

        # Get the expected test results to compare with
        self.expected_outputs = self._GetExpectedTestOutputs()

        # Set up the basic cumulative statistics layer config
        input_dim = self.input_batch_list[0].features.shape[2]
        p = cumulative_statistics_layer.CumulativeStatisticsLayer.Params()
        p.stats_type = 'PASS_THRU'
        p.use_weighted_frames = False
        p.input_dim = input_dim
        p.features_name = 'features'
        p.paddings_name = 'paddings'
        p.frame_weight_ffn.activation = ['SIGMOID']
        p.frame_weight_ffn.has_bias = [True]
        p.frame_weight_ffn.hidden_layer_dims = [1]
        p.frame_weight_ffn.input_dim = input_dim
        self.params = p
    def testRespectsInfeedBatchSize(self):
        p = ToyInputGenerator.Params()
        p.batch_size = 3
        self._tmpdir, p.input_files = _CreateFakeTFRecordFiles()
        p.dataset_type = tf.data.TFRecordDataset

        ig = p.Instantiate()
        batch = ig.GetPreprocessedInputBatch()
        self.assertEqual(batch.audio.shape[0], p.batch_size)
        self.assertEqual(p.batch_size, ig.InfeedBatchSize())

        tf.reset_default_graph()
        ig = p.Instantiate()
        with mock.patch.object(ig, 'InfeedBatchSize',
                               return_value=42) as mock_method:
            batch = ig.GetPreprocessedInputBatch()
            self.assertEqual(batch.audio.shape[0], 42)
        mock_method.assert_called()
示例#4
0
    def setUp(self):
        """Generates test examples and the attentive scoring configuration."""

        super().setUp()

        tf.reset_default_graph()

        # Generate the example to test the attentive scoring
        (self.test_data, self.enroll_data,
         self.data_info) = self._GenerateExample()

        # Set up the basic attentive scoring layer config
        p = attentive_scoring_layer.AttentiveScoringLayer.Params()
        p.num_keys = self.data_info['num_keys']
        p.key_dim = self.data_info['key_dim']
        p.value_dim = self.data_info['value_dim']
        p.scale_factor = 1.0
        p.use_trainable_scale_factor = False
        p.apply_l2_norm_to_keys = False
        p.apply_l2_norm_to_values = True
        p.apply_global_l2_norm_to_concat_form = False

        self.params = p
 def setUp(self):
     super(BaseExampleInputGeneratorTest, self).setUp()
     tf.reset_default_graph()
示例#6
0
 def setUp(self):
   tf.reset_default_graph()
示例#7
0
 def setUp(self):
     super().setUp()
     tf.reset_default_graph()