def create_networks(self): self.r_net = MLP( name="encoder_r", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers, f_act=self.f_act, seed=self.seed, )
def create_networks(self): if self.arch == "separate_networks_combined_input": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z * 2, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z * 2, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "separate_networks_separate_input": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.mu_y_std_y_net = MLP( name="decoder_mu_cov", d_in=self.d_x + self.d_z * 2, d_out=2 * self.d_y, mlp_layers=self. mlp_layers_mu_y, # ignore self.mlp_layers_std_y f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown network type: {}!".format(self.arch))
def __init__(self, **kwargs): super().__init__(**kwargs) self.d_z = kwargs["d_z"] self.mean_z_mlp = MLP( name="agg2mean_z", d_in=kwargs["d_lo"], d_out=kwargs["d_z"], mlp_layers=kwargs["mlp_layers"], f_act=kwargs["f_act"], f_out=(None, {}), seed=kwargs["seed"], ) self.cov_z_mlp = MLP( name="agg2cov_z", d_in=kwargs["d_lo"], d_out=kwargs["d_z"], mlp_layers=kwargs["mlp_layers"], f_act=kwargs["f_act"], f_out=("exp", {}), seed=kwargs["seed"], )
def create_networks(self): if self.arch == "separate_networks": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.mu_y_std_y_net = MLP( name="decoder_mu_cov", d_in=self.d_x + self.d_z, d_out=2 * self.d_y, mlp_layers=self.mlp_layers_mu_y, # ignore mlp_layers_std_y f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown encoder network type!")
def create_networks(self): if self.arch == "separate_networks": self.r_net = MLP( name="encoder_mu", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers_r, f_act=self.f_act, seed=self.seed, ) self.cov_r_net = MLP( name="encoder_cov", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers_cov_r, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.r_cov_r_net = MLP( name="encoder_mu_cov", d_in=self.d_x + self.d_y, d_out=2 * self.d_lo, mlp_layers=self.mlp_layers_r, # ignore self.mlp_layers_cov_r f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown encoder network type!")
class EncoderNetworkMA: def __init__(self, **kwargs): self.d_x = kwargs["d_x"] self.d_y = kwargs["d_y"] self.d_lo = kwargs["d_lo"] self.mlp_layers = kwargs["mlp_layers"] self.f_act = kwargs["f_act"] self.seed = kwargs["seed"] self.r_net = None self.create_networks() def set_device(self, device): self.r_net = self.r_net.to(device) def init_weights(self): pass def save_weights(self, logpath, epoch): self.r_net.save(path=logpath, epoch=epoch) def load_weights(self, logpath, epoch): self.r_net.load_weights(path=logpath, epoch=epoch) def create_networks(self): self.r_net = MLP( name="encoder_r", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers, f_act=self.f_act, seed=self.seed, ) def encode(self, x, y): assert x.ndim == y.ndim == 3 # prepare input to encoder network encoder_input = torch.cat((x, y), dim=2) # encode r = self.r_net(encoder_input) return (r, ) @property def parameters(self): return list(self.r_net.parameters())
class DecoderNetworkSamples: def __init__(self, **kwargs): self.d_x = kwargs["d_x"] self.d_y = kwargs["d_y"] self.d_z = kwargs["d_z"] self.arch = kwargs["arch"] # process network shapes self.mlp_layers_mu_y = kwargs["mlp_layers_mu_y"] self.mlp_layers_std_y = kwargs["mlp_layers_std_y"] self.f_act = kwargs["f_act"] self.seed = kwargs["seed"] self.mu_y_net = self.std_y_net = self.mu_y_std_y_net = None self.create_networks() def set_device(self, device): if self.arch == "separate_networks": self.mu_y_net = self.mu_y_net.to(device) self.std_y_net = self.std_y_net.to(device) elif self.arch == "two_heads": self.mu_y_std_y_net = self.mu_y_std_y_net.to(device) def init_weights(self): pass # pytorch takes care of this def save_weights(self, logpath, epoch): if self.arch == "separate_networks": self.mu_y_net.save(path=logpath, epoch=epoch) self.std_y_net.save(path=logpath, epoch=epoch) elif self.arch == "two_heads": self.mu_y_std_y_net.save(path=logpath, epoch=epoch) def load_weights(self, logpath, epoch): if self.arch == "separate_networks": self.mu_y_net.load_weights(path=logpath, epoch=epoch) self.std_y_net.load_weights(path=logpath, epoch=epoch) elif self.arch == "two_heads": self.mu_y_std_y_net.load_weights(path=logpath, epoch=epoch) def create_networks(self): if self.arch == "separate_networks": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.mu_y_std_y_net = MLP( name="decoder_mu_cov", d_in=self.d_x + self.d_z, d_out=2 * self.d_y, mlp_layers=self.mlp_layers_mu_y, # ignore mlp_layers_std_y f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown encoder network type!") def decode(self, x, z): assert x.ndim == 3 # (n_tsk, n_tst, d_x) assert z.ndim == 4 # (n_tsk, n_ls, n_marg, d_z) n_tsk = z.shape[0] n_ls = z.shape[1] n_marg = z.shape[2] n_tst = x.shape[1] # add latent-state-wise batch dimension to x x = x[:, None, None, :, :] x = x.expand(n_tsk, n_ls, n_marg, n_tst, self.d_x) # add dataset-wise batch dimension to sample z = z[:, :, :, None, :] z = z.expand((n_tsk, n_ls, n_marg, n_tst, self.d_z)) # prepare input to decoder network input = torch.cat((x, z), dim=4) # decode if self.arch == "separate_networks": mu_y = self.mu_y_net(input) std_y = self.std_y_net(input) elif self.arch == "two_heads": mu_cov_y = self.mu_y_std_y_net(input) mu_y = mu_cov_y[:, :self.d_y] std_y = mu_cov_y[:, self.d_y:] # deparametrize std_y = torch.exp(std_y) return mu_y, std_y @property def parameters(self): params = [] if self.arch == "separate_networks": params += list(self.mu_y_net.parameters()) params += list(self.std_y_net.parameters()) elif self.arch == "two_heads": params += list(self.mu_y_std_y_net.parameters()) return params
class DecoderNetworkPB: def __init__(self, **kwargs): self.d_x = kwargs["d_x"] self.d_y = kwargs["d_y"] self.d_z = kwargs["d_z"] self.arch = kwargs["arch"] # process network shapes self.mlp_layers_mu_y = kwargs["mlp_layers_mu_y"] self.mlp_layers_std_y = kwargs["mlp_layers_std_y"] self.f_act = kwargs["f_act"] self.safe_log = kwargs["safe_log"] self.seed = kwargs["seed"] self.mu_y_net = self.std_y_net = self.mu_y_std_y_net = None self.create_networks() def set_device(self, device): if (self.arch == "separate_networks_combined_input" or self.arch == "separate_networks_separate_input"): self.mu_y_net = self.mu_y_net.to(device) self.std_y_net = self.std_y_net.to(device) elif self.arch == "two_heads": self.mu_y_std_y_net = self.mu_y_std_y_net.to(device) def init_weights(self): pass # pytorch takes care of this def save_weights(self, logpath, epoch): if (self.arch == "separate_networks_combined_input" or self.arch == "separate_networks_separate_input"): self.mu_y_net.save(path=logpath, epoch=epoch) self.std_y_net.save(path=logpath, epoch=epoch) elif self.arch == "two_heads": self.mu_y_std_y_net.save(path=logpath, epoch=epoch) def load_weights(self, logpath, epoch): if (self.arch == "separate_networks_combined_input" or self.arch == "separate_networks_separate_input"): self.mu_y_net.load_weights(path=logpath, epoch=epoch) self.std_y_net.load_weights(path=logpath, epoch=epoch) def create_networks(self): if self.arch == "separate_networks_combined_input": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z * 2, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z * 2, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "separate_networks_separate_input": self.mu_y_net = MLP( name="decoder_mu", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_mu_y, f_act=self.f_act, seed=self.seed, ) self.std_y_net = MLP( name="decoder_cov", d_in=self.d_x + self.d_z, d_out=self.d_y, mlp_layers=self.mlp_layers_std_y, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.mu_y_std_y_net = MLP( name="decoder_mu_cov", d_in=self.d_x + self.d_z * 2, d_out=2 * self.d_y, mlp_layers=self. mlp_layers_mu_y, # ignore self.mlp_layers_std_y f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown network type: {}!".format(self.arch)) def decode(self, x, mu_z, cov_z): assert x.ndim == 3 assert mu_z.ndim == cov_z.ndim == 3 n_tsk = mu_z.shape[0] n_ls = mu_z.shape[1] n_tst = x.shape[1] # covariance parametrization cov_z = self.parametrize_latent_cov(cov_z) # add latent-state-wise batch dimension to X x = x[:, None, :, :] x = x.expand(n_tsk, n_ls, n_tst, self.d_x) # add dataset-wise batch dimension to latent states mu_z = mu_z[:, :, None, :] cov_z = cov_z[:, :, None, :] # prepare input to decoder network if self.arch == "separate_networks_combined_input": mu_z_cov_z = torch.cat((mu_z, cov_z), dim=3) mu_z_cov_z = mu_z_cov_z.expand((n_tsk, n_ls, n_tst, self.d_z * 2)) input_mu = input_std = torch.cat((x, mu_z_cov_z), dim=3) elif self.arch == "separate_networks_separate_input": mu_z = mu_z.expand((n_tsk, n_ls, n_tst, self.d_z)) cov_z = cov_z.expand((n_tsk, n_ls, n_tst, self.d_z)) input_mu = torch.cat((x, mu_z), dim=3) input_std = torch.cat((x, cov_z), dim=3) elif self.arch == "two_heads": mu_z_cov_z = torch.cat((mu_z, cov_z), dim=3) mu_z_cov_z = mu_z_cov_z.expand((n_tsk, n_ls, n_tst, self.d_z * 2)) input_two_head = torch.cat((x, mu_z_cov_z), dim=3) # decode if (self.arch == "separate_networks_combined_input" or self.arch == "separate_networks_separate_input"): mu_y = self.mu_y_net(input_mu) std_y = self.std_y_net(input_std) elif self.arch == "two_heads": mu_y_std_y = self.mu_y_std_y_net(input_two_head) mu_y = mu_y_std_y[:, :self.d_y] std_y = mu_y_std_y[:, self.d_y:] # deparametrize std_y = torch.exp(std_y) return mu_y, std_y def parametrize_latent_cov(self, cov): cov = cov + self.safe_log parametrized_cov = torch.log(cov) return parametrized_cov @property def parameters(self): params = [] if (self.arch == "separate_networks_combined_input" or self.arch == "separate_networks_separate_input"): params += list(self.mu_y_net.parameters()) + list( self.std_y_net.parameters()) elif self.arch == "two_heads": params += list(self.mu_y_std_y_net.parameters()) return params
class MeanAggregatorRtoZ(MeanAggregator): def __init__(self, **kwargs): super().__init__(**kwargs) self.d_z = kwargs["d_z"] self.mean_z_mlp = MLP( name="agg2mean_z", d_in=kwargs["d_lo"], d_out=kwargs["d_z"], mlp_layers=kwargs["mlp_layers"], f_act=kwargs["f_act"], f_out=(None, {}), seed=kwargs["seed"], ) self.cov_z_mlp = MLP( name="agg2cov_z", d_in=kwargs["d_lo"], d_out=kwargs["d_z"], mlp_layers=kwargs["mlp_layers"], f_act=kwargs["f_act"], f_out=("exp", {}), seed=kwargs["seed"], ) def agg2latent(self, agg_state): return self.mean_z_mlp.forward(agg_state[0]), self.cov_z_mlp( agg_state[0]) def set_device(self, device): self.device = device self.mean_z_mlp.to(device) self.cov_z_mlp.to(device) @property def parameters(self): return list(self.mean_z_mlp.parameters()) + list( self.cov_z_mlp.parameters()) def init_weights(self): pass def save_weights(self, logpath, epoch): self.mean_z_mlp.save(path=logpath, epoch=epoch) self.cov_z_mlp.save(path=logpath, epoch=epoch) def load_weights(self, logpath, epoch): self.mean_z_mlp.load_weights(path=logpath, epoch=epoch) self.cov_z_mlp.load_weights(path=logpath, epoch=epoch) def delete_all_weight_files(self, logpath): self.mean_z_mlp.delete_all_weight_files(path=logpath) self.cov_z_mlp.delete_all_weight_files(path=logpath)
class EncoderNetworkBA: def __init__(self, **kwargs): self.d_x = kwargs["d_x"] self.d_y = kwargs["d_y"] self.d_lo = kwargs["d_lo"] self.arch = kwargs["arch"] self.f_act = kwargs["f_act"] self.seed = kwargs["seed"] # process network shapes self.mlp_layers_r = kwargs["mlp_layers_r"] self.mlp_layers_cov_r = kwargs["mlp_layers_cov_r"] self.r_net = self.cov_r_net = self.r_cov_r_net = None self.create_networks() def set_device(self, device): if self.arch == "separate_networks": self.r_net = self.r_net.to(device) self.cov_r_net = self.cov_r_net.to(device) elif self.arch == "two_heads": self.r_cov_r_net = self.r_cov_r_net.to(device) def init_weights(self): pass def save_weights(self, logpath, epoch): if self.arch == "separate_networks": self.r_net.save(path=logpath, epoch=epoch) self.cov_r_net.save(path=logpath, epoch=epoch) elif self.arch == "two_heads": self.r_cov_r_net.save(path=logpath, epoch=epoch) def load_weights(self, logpath, epoch): if self.arch == "separate_networks": self.r_net.load_weights(path=logpath, epoch=epoch) self.cov_r_net.load_weights(path=logpath, epoch=epoch) elif self.arch == "two_heads": self.r_cov_r_net.load_weights(path=logpath, epoch=epoch) def create_networks(self): if self.arch == "separate_networks": self.r_net = MLP( name="encoder_mu", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers_r, f_act=self.f_act, seed=self.seed, ) self.cov_r_net = MLP( name="encoder_cov", d_in=self.d_x + self.d_y, d_out=self.d_lo, mlp_layers=self.mlp_layers_cov_r, f_act=self.f_act, seed=self.seed, ) elif self.arch == "two_heads": self.r_cov_r_net = MLP( name="encoder_mu_cov", d_in=self.d_x + self.d_y, d_out=2 * self.d_lo, mlp_layers=self.mlp_layers_r, # ignore self.mlp_layers_cov_r f_act=self.f_act, seed=self.seed, ) else: raise ValueError("Unknown encoder network type!") def encode(self, x, y): assert x.ndim == y.ndim == 3 # prepare input to encoder network encoder_input = torch.cat((x, y), dim=2) # encode if self.arch == "separate_networks": r, cov_r = self.r_net(encoder_input), self.cov_r_net(encoder_input) elif self.arch == "two_heads": mu_r_cov_r = self.r_cov_r_net(encoder_input) r = mu_r_cov_r[:, :, :self.d_lo] cov_r = mu_r_cov_r[:, :, self.d_lo:] cov_r = torch.exp(cov_r) return r, cov_r @property def parameters(self): if self.arch == "separate_networks": return list(self.r_net.parameters()) + list( self.cov_r_net.parameters()) elif self.arch == "two_heads": return list(self.r_cov_r_net.parameters())