def build_cnf():
     graphode_diffeq = layers.ODEGatedGraphnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         ifgate=args.if_graph_gate,
         node_autoencode=args.node_autoencode,
         num_func=args.num_func,
         num_layers=args.num_layers)
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
 def build_cnf():
     graphode_diffeq = layers.ODEGraphnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         ifgate=args.if_graph_gate,
         num_func=args.num_func,
         embed_config=[
             args.embed_dim > 0, args.num_func, vocab_size,
             args.embed_dim
         ])
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Exemplo n.º 3
0
        def _make_odefunc(size, network_choice):
            if self.unit_type == "conv" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     ifgate=self.ifgate,
                                     num_func=5,
                                     conv_embed_config=conv_embed_config)
                f = layers.ODEfunc(net)

            elif self.unit_type == "linear" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     num_func=5,
                                     reshape=True)
                f = layers.ODEfunc(net)

            elif self.unit_type == "ae" and self.mp_type == "generic":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     num_graph_layers=self.num_graph_layers)
                f = layers.GraphAutoencoderODEfunc(net)

            elif self.unit_type == "conv" and self.mp_type == "affine":
                net = network_choice(idims,
                                     size,
                                     strides,
                                     self.ifconv,
                                     layer_type="concat",
                                     nonlinearity=nonlinearity,
                                     reshape=True)
                f = layers.ODEAffineGraphfunc(net)

            return f
Exemplo n.º 4
0
 def _make_odefunc(size):
     net = ODEnet(idims,
                  size,
                  strides,
                  True,
                  layer_type="concat",
                  nonlinearity=nonlinearity)
     f = layers.ODEfunc(net)
     return f
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Exemplo n.º 6
0
 def build_cnf():
     graphode_diffeq = layers.ODEGraphnetGraphGen(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         num_squeeze=0,
         ifgate=False,
     )
     odefunc = layers.ODEfunc(
         diffeq=graphode_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf