Exemple #1
0
    def test_adaptive_stage_using_state_update_tensors(self):
        """Tests adaptive encoding stage with state update tensors."""
        encoder = core_encoder.EncoderComposer(
            test_utils.AdaptiveNormalizeEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(), P1_VALS).make()
        x = tf.constant(1.0)
        state = encoder.initial_state()

        for _ in range(1, 5):
            initial_state = state
            encode_params, decode_params = encoder.get_params(state)
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, encode_params)
            decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
            state = encoder.update_state(initial_state, state_update_tensors)
            data = self.evaluate(
                test_utils.TestData(x, encoded_x, decoded_x, initial_state,
                                    state_update_tensors, state))

            self.assertAllClose(data.x, data.decoded_x)
            self.assertLessEqual(
                data.initial_state[CHILDREN][P1_VALS][STATE][AN_FACTOR_STATE],
                1.0)
            self.assertEqual(
                data.state_update_tensors[CHILDREN][P1_VALS][TENSORS]
                [AN_NORM_UPDATE], 2.0)
            self.assertLessEqual(data.encoded_x[P1_VALS][AN_VALS], 2.0)
Exemple #2
0
    def test_correct_structure(self):
        """Tests that structured objects look like what they should.

    This test creates the following encoding tree:
    SignIntFloatEncodingStage
        [SIF_SIGNS] -> TimesTwoEncodingStage
        [SIF_INTS] -> PlusOneEncodingStage
        [SIF_FLOATS] -> PlusOneOverNEncodingStage
            [PN_VALS] -> AdaptiveNormalizeEncodingStage
    And verifies that the structured objects created by the methods of `Encoder`
    are of the expected structure.
    """
        sif_stage = test_utils.SignIntFloatEncodingStage()
        times_two_stage = test_utils.TimesTwoEncodingStage()
        plus_one_stage = test_utils.PlusOneEncodingStage()
        plus_n_squared_stage = test_utils.PlusOneOverNEncodingStage()
        adaptive_normalize_stage = test_utils.AdaptiveNormalizeEncodingStage()

        encoder = core_encoder.EncoderComposer(sif_stage)
        encoder.add_child(times_two_stage, SIF_SIGNS)
        encoder.add_child(plus_one_stage, SIF_INTS)
        encoder.add_child(plus_n_squared_stage,
                          SIF_FLOATS).add_child(adaptive_normalize_stage,
                                                PN_VALS)
        encoder = encoder.make()

        # Create all intermediary objects.
        x = tf.constant(1.0)
        initial_state = encoder.initial_state()
        encode_params, decode_params = encoder.get_params(initial_state)
        encoded_x, state_update_tensors, input_shapes = encoder.encode(
            x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        updated_state = encoder.update_state(initial_state,
                                             state_update_tensors)
        commuting_structure = encoder.commuting_structure

        # Verify the structure and naming of those objects is as expected.
        for state in [initial_state, updated_state]:
            tf.nest.assert_same_structure(
                {
                    STATE: {},
                    CHILDREN: {
                        SIF_INTS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            STATE: {
                                PN_ITER_STATE: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    STATE: {
                                        AN_FACTOR_STATE: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, state)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX,
            initial_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX,
            updated_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)

        for params in [encode_params, decode_params]:
            tf.nest.assert_same_structure(
                {
                    PARAMS: {},
                    CHILDREN: {
                        SIF_INTS: {
                            PARAMS: {
                                P1_ADD_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            PARAMS: {
                                T2_FACTOR_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            PARAMS: {
                                PN_ADD_PARAM: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    PARAMS: {
                                        AN_FACTOR_PARAM: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, params)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_INTS + '/' +
                plus_one_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_INTS][PARAMS][P1_ADD_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_SIGNS +
                '/' + times_two_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_SIGNS][PARAMS][T2_FACTOR_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_FLOATS +
                '/' + plus_n_squared_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_FLOATS][PARAMS][PN_ADD_PARAM].name)
            # Note: we do not check the value of
            # params[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS][PARAMS][AN_FACTOR_PARAM]
            # because the get_params method of adaptive_normalize_stage does not
            # modify the graph, only passes through the provided state tensor.

        tf.nest.assert_same_structure(
            {
                SIF_INTS: {
                    P1_VALS: None
                },
                SIF_SIGNS: {
                    T2_VALS: None
                },
                SIF_FLOATS: {
                    PN_VALS: {
                        AN_VALS: None
                    }
                }
            }, encoded_x)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_INTS + '/' +
            plus_one_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_INTS][P1_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_SIGNS + '/' +
            times_two_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_SIGNS][T2_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_FLOATS][PN_VALS][AN_VALS].name)

        tf.nest.assert_same_structure(
            {
                TENSORS: {},
                CHILDREN: {
                    SIF_INTS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        TENSORS: {},
                        CHILDREN: {
                            PN_VALS: {
                                TENSORS: {
                                    AN_NORM_UPDATE: None
                                },
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, state_update_tensors)
        tf.nest.assert_same_structure(state_update_tensors,
                                      encoder.state_update_aggregation_modes)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            state_update_tensors[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS]
            [TENSORS][AN_NORM_UPDATE].name)

        tf.nest.assert_same_structure(
            {
                SHAPE: None,
                CHILDREN: {
                    SIF_INTS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        SHAPE: None,
                        CHILDREN: {
                            PN_VALS: {
                                SHAPE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, input_shapes)
        self.assertTrue(tf.is_tensor(decoded_x))
        self.assertIn('encoder_decode/', decoded_x.name)

        tf.nest.assert_same_structure(
            {
                COMMUTE: None,
                CHILDREN: {
                    SIF_INTS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        COMMUTE: None,
                        CHILDREN: {
                            PN_VALS: {
                                COMMUTE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, commuting_structure)
        for item in tf.nest.flatten(commuting_structure):
            self.assertEqual(False, item)
 def default_encoding_stage(self):
   """See base class."""
   return test_utils.AdaptiveNormalizeEncodingStage()