示例#1
0
    def build_model(self) -> nn.Module:
        model = nn.Sequential(
            nn.Conv2d(NUM_CHANNELS, IMAGE_SIZE, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Dropout2d(pedl.get_hyperparameter("layer1_dropout")),
            nn.Conv2d(32, 64, (3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Dropout2d(pedl.get_hyperparameter("layer2_dropout")),
            Flatten(),
            nn.Linear(2304, 512),
            nn.ReLU(),
            nn.Dropout2d(pedl.get_hyperparameter("layer3_dropout")),
            nn.Linear(512, NUM_CLASSES),
            nn.Softmax(dim=0),
        )

        # If loading backbone weights, do not call reset_parameters() or
        # call before loading the backbone weights.
        reset_parameters(model)
        return model
示例#2
0
    def build_model(self) -> nn.Module:
        model = MultiNet()

        # If loading backbone weights, do not call reset_parameters() or
        # call before loading the backbone weights.
        reset_parameters(model)
        return model
示例#3
0
    def build_model(self) -> nn.Module:
        model = nn.Sequential(
            nn.Conv2d(1, pedl.get_hyperparameter("n_filters1"), kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(
                pedl.get_hyperparameter("n_filters1"),
                pedl.get_hyperparameter("n_filters2"),
                kernel_size=5,
            ),
            nn.MaxPool2d(2),
            nn.ReLU(),
            Flatten(),
            nn.Linear(16 * pedl.get_hyperparameter("n_filters2"), 50),
            nn.ReLU(),
            nn.Dropout2d(pedl.get_hyperparameter("dropout")),
            nn.Linear(50, 10),
            nn.LogSoftmax(),
        )

        # If loading backbone weights, do not call reset_parameters() or
        # call before loading the backbone weights.
        reset_parameters(model)
        return model