Esempio n. 1
0
 def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
     """
     Returns a symbolic graph for the model
     :param model: gluon HybridBlock that constructs the symbolic graph
     :param num_inputs: number of inputs to the graph
     :return: symbol for the network
     """
     var_args = [
         sym.Variable('Data{}'.format(i)) for i in range(num_inputs)
     ]
     with ScopedOnnxEnable(model):
         return sym.Group(
             gluon.block._flatten(model(*var_args), "output")[0])
def test_onnx_export_multi_output():
    class MultiOutputBlock(nn.HybridBlock):
        def __init__(self):
            super(MultiOutputBlock, self).__init__()
            with self.name_scope():
                self.net = nn.HybridSequential()
                for i in range(10):
                    self.net.add(nn.Dense(100 + i * 10, activation='relu'))

        def hybrid_forward(self, F, x):
            out = tuple(block(x) for block in self.net._children.values())
            return out

    net = MultiOutputBlock()
    assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
    _check_onnx_export(net, group_outputs=True)
Esempio n. 3
0
    def get_op(self, row_cnt):
        if row_cnt in self.ops:
            return self.ops[row_cnt]
        graph = self.model.graph
        X = graph.X
        # build the graph
        h0_probs, hn_sample = graph.get_op_sample_h_given_v(X, row_cnt)
        for n in range(1, self.k + 1):
            vn_probs, vn_sample, hn_probs, hn_sample = graph.get_op_gibbs_hvh(
                hn_sample, row_cnt)

        # gradients
        grad_W = (sym.dot(sym.transpose(vn_sample), hn_probs) -
                  sym.dot(sym.transpose(X), h0_probs))
        grad_v_bias = sym.sum(vn_sample - X, axis=0)
        grad_h_bias = sym.sum(hn_probs - h0_probs, axis=0)

        # loss
        loss = sym.mean(sym.square(X - vn_probs))
        op = sym.Group([loss, grad_W, grad_v_bias, grad_h_bias])
        self.ops[row_cnt] = op
        return op
def _optional_group(symbols, group=False):
    if group:
        return sym.Group(symbols)
    else:
        return symbols