Beispiel #1
0
    def __init__(self, config={}, **kwargs):
        TorchNNRepresentation.__init__(
            self, config=config,
            **kwargs)  # calls all constructors up to BaseDNN (MRO)

        self.output_keys_list = self.network.encoder.output_keys_list + [
            "recon_x"
        ]
Beispiel #2
0
 def set_network(self, network_name, network_parameters):
     TorchNNRepresentation.set_network(self, network_name,
                                       network_parameters)
     # add a decoder to the network for the BiGAN
     decoder_class = decoders.get_decoder(network_name)
     self.network.decoder = decoder_class(config=network_parameters)
     # add a discriminator to the network for the BiGAN
     discriminator_class = discriminators.get_discriminator(network_name)
     self.network.discriminator = discriminator_class(
         config=network_parameters)
Beispiel #3
0
 def set_logger(self, logger_config):
     TorchNNRepresentation.set_logger(self, logger_config)
     # Save the graph in the logger
     if self.logger is not None:
         dummy_size = (
             1,
             self.config.network.parameters.n_channels,
         ) + self.config.network.parameters.input_size
         dummy_input = torch.Tensor(size=dummy_size).uniform_(0, 1).type(
             self.config.dtype).to(self.config.device)
         self.eval()
         with torch.no_grad():
             self.logger.add_graph(self, dummy_input, verbose=False)
Beispiel #4
0
    def default_config():
        default_config = TorchNNRepresentation.default_config()

        # network parameters
        default_config.network = Dict()
        default_config.network.name = "Burgess"
        default_config.network.parameters = Dict()
        default_config.network.parameters.n_channels = 1
        default_config.network.parameters.input_size = (64, 64)
        default_config.network.parameters.n_latents = 10
        default_config.network.parameters.n_conv_layers = 4
        default_config.network.parameters.feature_layer = 2
        default_config.network.parameters.encoder_conditional_type = "gaussian"

        # initialization parameters
        default_config.network.weights_init = Dict()
        default_config.network.weights_init.name = "pytorch"
        default_config.network.weights_init.parameters = Dict()

        # loss parameters
        default_config.loss = Dict()
        default_config.loss.name = "SimCLR"
        default_config.loss.parameters = Dict()
        default_config.loss.parameters.temperature = 0.5
        default_config.loss.parameters.distance = 'cosine'

        # optimizer parameters
        default_config.optimizer = Dict()
        default_config.optimizer.name = "Adam"
        default_config.optimizer.parameters = Dict()
        default_config.optimizer.parameters.lr = 1e-3
        default_config.optimizer.parameters.weight_decay = 1e-5
        return default_config
Beispiel #5
0
    def default_config():
        default_config = TorchNNRepresentation.default_config()

        # network parameters
        default_config.network = Dict()
        default_config.network.name = "Burgess"
        default_config.network.parameters = Dict()
        default_config.network.parameters.n_channels = 1
        default_config.network.parameters.input_size = (64, 64)
        default_config.network.parameters.n_latents = 10
        default_config.network.parameters.n_conv_layers = 4
        default_config.network.parameters.feature_layer = 2
        default_config.network.parameters.encoder_conditional_type = "gaussian"

        # initialization parameters
        default_config.network.initialization = Dict()
        default_config.network.initialization.name = "pytorch"
        default_config.network.initialization.parameters = Dict()

        # loss parameters
        default_config.loss = Dict()
        default_config.loss.name = "Triplet"
        default_config.loss.parameters = Dict()
        default_config.loss.parameters.distance = "squared_euclidean"
        default_config.loss.parameters.margin = 1.0
        default_config.loss.parameters.use_attention = True

        # optimizer parameters
        default_config.optimizer = Dict()
        default_config.optimizer.name = "Adam"
        default_config.optimizer.parameters = Dict()
        default_config.optimizer.parameters.lr = 1e-3
        default_config.optimizer.parameters.weight_decay = 1e-5
        return default_config
Beispiel #6
0
    def default_config():
        default_config = TorchNNRepresentation.default_config()

        # network parameters
        default_config.network = Dict()
        default_config.network.name = "Dumoulin"
        default_config.network.parameters = Dict()
        default_config.network.parameters.n_channels = 1
        default_config.network.parameters.input_size = (64, 64)
        default_config.network.parameters.n_latents = 10
        default_config.network.parameters.n_conv_layers = 4
        default_config.network.parameters.feature_layer = 2
        default_config.network.parameters.encoder_conditional_type = "gaussian"

        # initialization parameters
        default_config.network.initialization = Dict()
        default_config.network.initialization.name = "pytorch"
        default_config.network.initialization.parameters = Dict()

        # loss parameters
        default_config.loss = Dict()
        default_config.loss.name = "BiGAN"
        default_config.loss.parameters = Dict()

        # optimizer parameters
        default_config.optimizer = Dict()
        default_config.optimizer.name = "Adam"
        default_config.optimizer.parameters = Dict()
        default_config.optimizer.parameters.lr = 1e-3
        default_config.optimizer.parameters.weight_decay = 1e-5
        return default_config
Beispiel #7
0
    def default_config():
        default_config = TorchNNRepresentation.default_config()

        # network parameters
        default_config.network = Dict()
        default_config.network.name = "Burgess"
        default_config.network.parameters = Dict()
        default_config.network.parameters.n_channels = 1
        default_config.network.parameters.input_size = (64, 64)
        default_config.network.parameters.n_latents = 10
        default_config.network.parameters.n_conv_layers = 4
        default_config.network.parameters.feature_layer = 2
        default_config.network.parameters.encoder_conditional_type = "gaussian"

        # weights_init parameters
        default_config.network.weights_init = Dict()
        default_config.network.weights_init.name = "pytorch"
        default_config.network.weights_init.parameters = Dict()

        # loss parameters
        default_config.loss = Dict()
        default_config.loss.name = "VAE"
        default_config.loss.parameters = Dict()
        default_config.loss.parameters.reconstruction_dist = "bernoulli"

        # optimizer parameters
        default_config.optimizer = Dict()
        default_config.optimizer.name = "Adam"
        default_config.optimizer.parameters = Dict()
        default_config.optimizer.parameters.lr = 1e-3
        default_config.optimizer.parameters.weight_decay = 1e-5
        return default_config
Beispiel #8
0
        def __init__(self, config={}, **kwargs):
            base_class.__init__(self, config=config, **kwargs)

            if (config.load_pretrained_model) and os.path.exists(config.pretrained_model_filepath):
                model = TorchNNRepresentation.load(config.pretrained_model_filepath, map_location=self.config.device)
                if hasattr(model, "config"):
                    self.config = model.config
                    self.config.update(self.config)
                    self.config.update(kwargs)
                if self.network.encoder.config.use_attention:
                    self.network.encoder.lf.load_state_dict(model.network.encoder.lf.state_dict())
                    self.network.encoder.gf.load_state_dict(model.network.encoder.gf.state_dict())
                    self.network.encoder.ef.load_state_dict(model.network.encoder.ef.state_dict())
                    self.network.decoder.load_state_dict(model.network.decoder.state_dict())
                else:
                    self.network.load_state_dict(model.network.state_dict())
Beispiel #9
0
 def set_network(self, network_name, network_parameters):
     TorchNNRepresentation.set_network(self, network_name,
                                       network_parameters)
     # add a decoder to the network for the SimCLR
     self.network.projection_head = ProjectionHead(
         config=network_parameters)
Beispiel #10
0
 def set_network(self, network_name, network_parameters):
     TorchNNRepresentation.set_network(self, network_name,
                                       network_parameters)
Beispiel #11
0
 def set_network(self, network_name, network_parameters):
     TorchNNRepresentation.set_network(self, network_name, network_parameters)
     # add attention head
     if self.network.encoder.config.use_attention:
         self.network.fc_cast = nn.Linear(self.config.network.parameters.n_latents * 4,
                                          self.config.network.parameters.n_latents)
Beispiel #12
0
 def __init__(self, config={}, **kwargs):
     TorchNNRepresentation.__init__(self, config=config, **kwargs)
Beispiel #13
0
 def set_network(self, network_name, network_parameters):
     TorchNNRepresentation.set_network(self, network_name,
                                       network_parameters)
     # add a decoder to the network for the VAE
     decoder_class = decoders.get_decoder(network_name)
     self.network.decoder = decoder_class(config=network_parameters)