Ejemplo n.º 1
0
    def test_extract(self, T, n, d):
        model = ModelHelper(name='external')
        workspace.ResetWorkspace()

        input_blob, initial_input_blob = model.net.AddExternalInputs(
            'input', 'initial_input')

        step = ModelHelper(name='step', param_model=model)
        input_t, output_t_prev = step.net.AddExternalInput(
            'input_t', 'output_t_prev')
        output_t = step.net.Mul([input_t, output_t_prev])
        step.net.AddExternalOutput(output_t)

        inputs = np.random.randn(T, n, d).astype(np.float32)
        initial_input = np.random.randn(1, n, d).astype(np.float32)
        recurrent.recurrent_net(
            net=model.net,
            cell_net=step.net,
            inputs=[(input_t, input_blob)],
            initial_cell_inputs=[(output_t_prev, initial_input_blob)],
            links={output_t_prev: output_t},
            scope="test_rnn_sum_mull",
        )

        workspace.blobs[input_blob] = inputs
        workspace.blobs[initial_input_blob] = initial_input

        workspace.RunNetOnce(model.param_init_net)
        workspace.CreateNet(model.net)

        prefix = "extractTest"

        workspace.RunNet(model.net.Proto().name, T)
        retrieved_blobs = recurrent.retrieve_step_blobs(
            model.net, prefix
        )

        # needed for python3.6, which returns bytearrays instead of str
        retrieved_blobs = [x.decode() for x in retrieved_blobs]

        for i in range(T):
            blob_name = prefix + "_" + "input_t" + str(i)
            self.assertTrue(
                blob_name in retrieved_blobs,
                "blob extraction failed on timestep {}\
                    . \n\n Extracted Blobs: {} \n\n Looking for {}\
                    .".format(i, retrieved_blobs, blob_name)
            )
Ejemplo n.º 2
0
    def test_extract(self, T, n, d):
        model = ModelHelper(name='external')
        workspace.ResetWorkspace()

        input_blob, initial_input_blob = model.net.AddExternalInputs(
            'input', 'initial_input')

        step = ModelHelper(name='step', param_model=model)
        input_t, output_t_prev = step.net.AddExternalInput(
            'input_t', 'output_t_prev')
        output_t = step.net.Mul([input_t, output_t_prev])
        step.net.AddExternalOutput(output_t)

        inputs = np.random.randn(T, n, d).astype(np.float32)
        initial_input = np.random.randn(1, n, d).astype(np.float32)
        recurrent.recurrent_net(
            net=model.net,
            cell_net=step.net,
            inputs=[(input_t, input_blob)],
            initial_cell_inputs=[(output_t_prev, initial_input_blob)],
            links={output_t_prev: output_t},
            scope="test_rnn_sum_mull",
        )

        workspace.blobs[input_blob] = inputs
        workspace.blobs[initial_input_blob] = initial_input

        workspace.RunNetOnce(model.param_init_net)
        workspace.CreateNet(model.net)

        prefix = "extractTest"

        workspace.RunNet(model.net.Proto().name, T)
        retrieved_blobs = recurrent.retrieve_step_blobs(
            model.net, prefix
        )

        # needed for python3.6, which returns bytearrays instead of str
        retrieved_blobs = [x.decode() for x in retrieved_blobs]

        for i in range(T):
            blob_name = prefix + "_" + "input_t" + str(i)
            self.assertTrue(
                blob_name in retrieved_blobs,
                "blob extraction failed on timestep {}\
                    . \n\n Extracted Blobs: {} \n\n Looking for {}\
                    .".format(i, retrieved_blobs, blob_name)
            )