def create_cnf_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.cnf_alpha) ] if args.cnf_alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model
def _build_net(self, input_size): _, c, h, w = input_size transforms = [] transforms.append( ParallelCNFLayers( initial_size=(c, h, w), idims=self.intermediate_dims, init_layer=(layers.LogitTransform(self.alpha) if self.alpha > 0 else layers.ZeroMeanTransform()), n_blocks=self.n_blocks, time_length=self.time_length)) return nn.ModuleList(transforms)
def _build_net(self, input_size): _, c, h, w = input_size transforms = [] for i in range(self.n_scale): transforms.append( StackedCNFLayers( initial_size=(c, h, w), idims=self.intermediate_dims, squeeze=(i < self.n_scale - 1), # don't squeeze last layer init_layer=(layers.LogitTransform(self.alpha) if self.alpha > 0 else layers.ZeroMeanTransform()) if self.squash_input and i == 0 else None, n_blocks=self.n_blocks, cnf_kwargs=self.cnf_kwargs, nonlinearity=self.nonlinearity, )) c, h, w = c * 2, h // 2, w // 2 return nn.ModuleList(transforms)
def create_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) if args.multiscale: model = odenvp.ODENVP( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns }, ) elif args.parallel: model = multiscale_parallel.MultiscaleParallelCNF( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, alpha=args.alpha, time_length=args.time_length, ) else: if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha) ] if args.alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model