Example #1
0
    def test_milstm_params(self):
        model = ModelHelper(name="milstm_params_test")

        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
            output, _, _, _ = rnn_cell.MILSTM(
                model=model,
                input_blob="input",
                seq_lengths="seqlengths",
                initial_states=None,
                dim_in=20,
                dim_out=[40, 20],
                scope="test",
                drop_states=True,
                return_last_layer_only=True,
            )
        for param in model.GetParams():
            self.assertNotEqual(model.get_param_info(param), None)