def build_cnf():
     graphode_diffeq = layers.ODEGatedGraphnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         ifgate=args.if_graph_gate,
         node_autoencode=args.node_autoencode,
         num_func=args.num_func,
         num_layers=args.num_layers)
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
 def build_cnf():
     graphode_diffeq = layers.ODEGraphnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         ifgate=args.if_graph_gate,
         num_func=args.num_func,
         embed_config=[
             args.embed_dim > 0, args.num_func, vocab_size,
             args.embed_dim
         ])
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Example #3
0
    def __init__(
        self,
        initial_size,
        idims=(32, ),
        nonlinearity="softplus",
        squeeze=True,
        init_layer=None,
        n_blocks=1,
        cnf_kwargs={},
    ):
        strides = tuple([1] + [1 for _ in idims])
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        def _make_odefunc(size):
            net = ODEnet(idims,
                         size,
                         strides,
                         True,
                         layer_type="concat",
                         nonlinearity=nonlinearity)
            f = layers.ODEfunc(net)
            return f

        if squeeze:
            c, h, w = initial_size
            after_squeeze_size = c * 4, h // 2, w // 2
            pre = [
                layers.CNF(_make_odefunc(initial_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]
            post = [
                layers.CNF(_make_odefunc(after_squeeze_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]
            chain += pre + [layers.SqueezeLayer(2)] + post
        else:
            chain += [
                layers.CNF(_make_odefunc(initial_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]

        super(StackedCNFLayers, self).__init__(chain)
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Example #5
0
 def build_cnf():
     graphode_diffeq = layers.ODEGraphnetGraphGen(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         num_squeeze=0,
         ifgate=False,
     )
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Example #6
0
    def __init__(
        self,
        initial_size,
        idims=(32, ),
        strides=None,
        nonlinearity="softplus",
        squeeze=True,
        init_layer=None,
        n_blocks=1,
        unit_type="conv",
        mp_type="generic",
        num_graph_layers=1,
        conv=True,
        ifgate=False,
        cnf_kwargs={},
        network_choice=None,
        prior_config=None,
        conv_embed_config=None,
    ):
        self.unit_type = unit_type
        self.ifconv = conv
        self.prior_config = prior_config
        self.conv_embed_config = conv_embed_config
        self.num_graph_layers = num_graph_layers
        self.mp_type = mp_type
        self.ifgate = ifgate
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        def _make_odefunc(size, network_choice):
            if self.unit_type == "conv" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     ifgate=self.ifgate,
                                     num_func=5,
                                     conv_embed_config=conv_embed_config)
                f = layers.ODEfunc(net)

            elif self.unit_type == "linear" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     num_func=5,
                                     reshape=True)
                f = layers.ODEfunc(net)

            elif self.unit_type == "ae" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     num_graph_layers=self.num_graph_layers)
                f = layers.GraphAutoencoderODEfunc(net)

            elif self.unit_type == "conv" and self.mp_type == "affine":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     reshape=True)
                f = layers.ODEAffineGraphfunc(net)

            return f

        if squeeze:
            c, h, w = initial_size
            after_squeeze_size = c * 4, h // 2, w // 2
            pre = [
                layers.CNF(_make_odefunc(initial_size, network_choice),
                           **cnf_kwargs) for _ in range(n_blocks)
            ]
            post = [
                layers.CNF(_make_odefunc(after_squeeze_size, network_choice),
                           **cnf_kwargs) for _ in range(n_blocks)
            ]
            chain += pre + [layers.SqueezeLayer(2)] + post
        else:
            chain += [
                layers.CNF(_make_odefunc(initial_size, network_choice),
                           **cnf_kwargs) for _ in range(n_blocks)
            ]

        super(GraphStackedCNFLayers, self).__init__(chain, prior_config)