def __init__(self, config={}, **kwargs): TorchNNRepresentation.__init__( self, config=config, **kwargs) # calls all constructors up to BaseDNN (MRO) self.output_keys_list = self.network.encoder.output_keys_list + [ "recon_x" ]
def set_network(self, network_name, network_parameters): TorchNNRepresentation.set_network(self, network_name, network_parameters) # add a decoder to the network for the BiGAN decoder_class = decoders.get_decoder(network_name) self.network.decoder = decoder_class(config=network_parameters) # add a discriminator to the network for the BiGAN discriminator_class = discriminators.get_discriminator(network_name) self.network.discriminator = discriminator_class( config=network_parameters)
def set_logger(self, logger_config): TorchNNRepresentation.set_logger(self, logger_config) # Save the graph in the logger if self.logger is not None: dummy_size = ( 1, self.config.network.parameters.n_channels, ) + self.config.network.parameters.input_size dummy_input = torch.Tensor(size=dummy_size).uniform_(0, 1).type( self.config.dtype).to(self.config.device) self.eval() with torch.no_grad(): self.logger.add_graph(self, dummy_input, verbose=False)
def default_config(): default_config = TorchNNRepresentation.default_config() # network parameters default_config.network = Dict() default_config.network.name = "Burgess" default_config.network.parameters = Dict() default_config.network.parameters.n_channels = 1 default_config.network.parameters.input_size = (64, 64) default_config.network.parameters.n_latents = 10 default_config.network.parameters.n_conv_layers = 4 default_config.network.parameters.feature_layer = 2 default_config.network.parameters.encoder_conditional_type = "gaussian" # initialization parameters default_config.network.weights_init = Dict() default_config.network.weights_init.name = "pytorch" default_config.network.weights_init.parameters = Dict() # loss parameters default_config.loss = Dict() default_config.loss.name = "SimCLR" default_config.loss.parameters = Dict() default_config.loss.parameters.temperature = 0.5 default_config.loss.parameters.distance = 'cosine' # optimizer parameters default_config.optimizer = Dict() default_config.optimizer.name = "Adam" default_config.optimizer.parameters = Dict() default_config.optimizer.parameters.lr = 1e-3 default_config.optimizer.parameters.weight_decay = 1e-5 return default_config
def default_config(): default_config = TorchNNRepresentation.default_config() # network parameters default_config.network = Dict() default_config.network.name = "Burgess" default_config.network.parameters = Dict() default_config.network.parameters.n_channels = 1 default_config.network.parameters.input_size = (64, 64) default_config.network.parameters.n_latents = 10 default_config.network.parameters.n_conv_layers = 4 default_config.network.parameters.feature_layer = 2 default_config.network.parameters.encoder_conditional_type = "gaussian" # initialization parameters default_config.network.initialization = Dict() default_config.network.initialization.name = "pytorch" default_config.network.initialization.parameters = Dict() # loss parameters default_config.loss = Dict() default_config.loss.name = "Triplet" default_config.loss.parameters = Dict() default_config.loss.parameters.distance = "squared_euclidean" default_config.loss.parameters.margin = 1.0 default_config.loss.parameters.use_attention = True # optimizer parameters default_config.optimizer = Dict() default_config.optimizer.name = "Adam" default_config.optimizer.parameters = Dict() default_config.optimizer.parameters.lr = 1e-3 default_config.optimizer.parameters.weight_decay = 1e-5 return default_config
def default_config(): default_config = TorchNNRepresentation.default_config() # network parameters default_config.network = Dict() default_config.network.name = "Dumoulin" default_config.network.parameters = Dict() default_config.network.parameters.n_channels = 1 default_config.network.parameters.input_size = (64, 64) default_config.network.parameters.n_latents = 10 default_config.network.parameters.n_conv_layers = 4 default_config.network.parameters.feature_layer = 2 default_config.network.parameters.encoder_conditional_type = "gaussian" # initialization parameters default_config.network.initialization = Dict() default_config.network.initialization.name = "pytorch" default_config.network.initialization.parameters = Dict() # loss parameters default_config.loss = Dict() default_config.loss.name = "BiGAN" default_config.loss.parameters = Dict() # optimizer parameters default_config.optimizer = Dict() default_config.optimizer.name = "Adam" default_config.optimizer.parameters = Dict() default_config.optimizer.parameters.lr = 1e-3 default_config.optimizer.parameters.weight_decay = 1e-5 return default_config
def default_config(): default_config = TorchNNRepresentation.default_config() # network parameters default_config.network = Dict() default_config.network.name = "Burgess" default_config.network.parameters = Dict() default_config.network.parameters.n_channels = 1 default_config.network.parameters.input_size = (64, 64) default_config.network.parameters.n_latents = 10 default_config.network.parameters.n_conv_layers = 4 default_config.network.parameters.feature_layer = 2 default_config.network.parameters.encoder_conditional_type = "gaussian" # weights_init parameters default_config.network.weights_init = Dict() default_config.network.weights_init.name = "pytorch" default_config.network.weights_init.parameters = Dict() # loss parameters default_config.loss = Dict() default_config.loss.name = "VAE" default_config.loss.parameters = Dict() default_config.loss.parameters.reconstruction_dist = "bernoulli" # optimizer parameters default_config.optimizer = Dict() default_config.optimizer.name = "Adam" default_config.optimizer.parameters = Dict() default_config.optimizer.parameters.lr = 1e-3 default_config.optimizer.parameters.weight_decay = 1e-5 return default_config
def __init__(self, config={}, **kwargs): base_class.__init__(self, config=config, **kwargs) if (config.load_pretrained_model) and os.path.exists(config.pretrained_model_filepath): model = TorchNNRepresentation.load(config.pretrained_model_filepath, map_location=self.config.device) if hasattr(model, "config"): self.config = model.config self.config.update(self.config) self.config.update(kwargs) if self.network.encoder.config.use_attention: self.network.encoder.lf.load_state_dict(model.network.encoder.lf.state_dict()) self.network.encoder.gf.load_state_dict(model.network.encoder.gf.state_dict()) self.network.encoder.ef.load_state_dict(model.network.encoder.ef.state_dict()) self.network.decoder.load_state_dict(model.network.decoder.state_dict()) else: self.network.load_state_dict(model.network.state_dict())
def set_network(self, network_name, network_parameters): TorchNNRepresentation.set_network(self, network_name, network_parameters) # add a decoder to the network for the SimCLR self.network.projection_head = ProjectionHead( config=network_parameters)
def set_network(self, network_name, network_parameters): TorchNNRepresentation.set_network(self, network_name, network_parameters)
def set_network(self, network_name, network_parameters): TorchNNRepresentation.set_network(self, network_name, network_parameters) # add attention head if self.network.encoder.config.use_attention: self.network.fc_cast = nn.Linear(self.config.network.parameters.n_latents * 4, self.config.network.parameters.n_latents)
def __init__(self, config={}, **kwargs): TorchNNRepresentation.__init__(self, config=config, **kwargs)
def set_network(self, network_name, network_parameters): TorchNNRepresentation.set_network(self, network_name, network_parameters) # add a decoder to the network for the VAE decoder_class = decoders.get_decoder(network_name) self.network.decoder = decoder_class(config=network_parameters)