コード例 #1
0
ファイル: cnf.py プロジェクト: klauscc/MixMatch-pytorch
def create_cnf_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    def build_cnf():
        diffeq = layers.ODEnet(
            hidden_dims=hidden_dims,
            input_shape=data_shape,
            strides=strides,
            conv=args.conv,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
        odefunc = layers.ODEfunc(
            diffeq=diffeq,
            divergence_fn=args.divergence_fn,
            residual=args.residual,
            rademacher=args.rademacher,
        )
        cnf = layers.CNF(
            odefunc=odefunc,
            T=args.time_length,
            train_T=args.train_T,
            regularization_fns=regularization_fns,
            solver=args.solver,
        )
        return cnf

    chain = [layers.LogitTransform(alpha=args.cnf_alpha)
             ] if args.cnf_alpha > 0 else [layers.ZeroMeanTransform()]
    chain = chain + [build_cnf() for _ in range(args.num_blocks)]
    if args.batch_norm:
        chain.append(layers.MovingBatchNorm2d(data_shape[0]))
    model = layers.SequentialFlow(chain)
    return model
コード例 #2
0
 def _build_net(self, input_size):
     _, c, h, w = input_size
     transforms = []
     transforms.append(
         ParallelCNFLayers(
             initial_size=(c, h, w),
             idims=self.intermediate_dims,
             init_layer=(layers.LogitTransform(self.alpha)
                         if self.alpha > 0 else layers.ZeroMeanTransform()),
             n_blocks=self.n_blocks,
             time_length=self.time_length))
     return nn.ModuleList(transforms)
コード例 #3
0
ファイル: odenvp.py プロジェクト: plai-group/daphne
 def _build_net(self, input_size):
     _, c, h, w = input_size
     transforms = []
     for i in range(self.n_scale):
         transforms.append(
             StackedCNFLayers(
                 initial_size=(c, h, w),
                 idims=self.intermediate_dims,
                 squeeze=(i < self.n_scale - 1),  # don't squeeze last layer
                 init_layer=(layers.LogitTransform(self.alpha) if self.alpha > 0 else layers.ZeroMeanTransform())
                 if self.squash_input and i == 0 else None,
                 n_blocks=self.n_blocks,
                 cnf_kwargs=self.cnf_kwargs,
                 nonlinearity=self.nonlinearity,
             )
         )
         c, h, w = c * 2, h // 2, w // 2
     return nn.ModuleList(transforms)
コード例 #4
0
        # Remove the logit transform.
        init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    else:
        transform_train = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            add_noise,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.ToTensor(),
            add_noise,
        ])
        init_layer = layers.LogitTransform(0.05)
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(args.dataroot, train=True, transform=transform_train),
        batch_size=args.batchsize,
        shuffle=True,
        num_workers=args.nworkers,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(args.dataroot, train=False, transform=transform_test),
        batch_size=args.val_batchsize,
        shuffle=False,
        num_workers=args.nworkers,
    )
elif args.data == 'mnist':
    im_dim = 1
    init_layer = layers.LogitTransform(1e-6)
コード例 #5
0
def create_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    if args.multiscale:
        model = odenvp.ODENVP(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            nonlinearity=args.nonlinearity,
            alpha=args.alpha,
            cnf_kwargs={
                "T": args.time_length,
                "train_T": args.train_T,
                "regularization_fns": regularization_fns
            },
        )
    elif args.parallel:
        model = multiscale_parallel.MultiscaleParallelCNF(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            alpha=args.alpha,
            time_length=args.time_length,
        )
    else:
        if args.autoencode:

            def build_cnf():
                autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.AutoencoderODEfunc(
                    autoencoder_diffeq=autoencoder_diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf
        else:

            def build_cnf():
                diffeq = layers.ODEnet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.ODEfunc(
                    diffeq=diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    train_T=args.train_T,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf

        chain = [layers.LogitTransform(alpha=args.alpha)
                 ] if args.alpha > 0 else [layers.ZeroMeanTransform()]
        chain = chain + [build_cnf() for _ in range(args.num_blocks)]
        if args.batch_norm:
            chain.append(layers.MovingBatchNorm2d(data_shape[0]))
        model = layers.SequentialFlow(chain)
    return model
コード例 #6
0
def build_model(args, state_dict):
    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    # neural net that parameterizes the velocity field
    if args.autoencode:

        def build_cnf():
            autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.AutoencoderODEfunc(
                autoencoder_diffeq=autoencoder_diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf
    else:

        def build_cnf():
            diffeq = layers.ODEnet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.ODEfunc(
                diffeq=diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf

    chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()]
    if args.batch_norm:
        chain.append(layers.MovingBatchNorm2d(data_shape[0]))
    model = layers.SequentialFlow(chain)

    if args.spectral_norm:
        add_spectral_norm(model)

    model.load_state_dict(state_dict)

    return model, test_loader.dataset