def create_networks(self):
     self.r_net = MLP(
         name="encoder_r",
         d_in=self.d_x + self.d_y,
         d_out=self.d_lo,
         mlp_layers=self.mlp_layers,
         f_act=self.f_act,
         seed=self.seed,
     )
 def create_networks(self):
     if self.arch == "separate_networks_combined_input":
         self.mu_y_net = MLP(
             name="decoder_mu",
             d_in=self.d_x + self.d_z * 2,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_mu_y,
             f_act=self.f_act,
             seed=self.seed,
         )
         self.std_y_net = MLP(
             name="decoder_cov",
             d_in=self.d_x + self.d_z * 2,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_std_y,
             f_act=self.f_act,
             seed=self.seed,
         )
     elif self.arch == "separate_networks_separate_input":
         self.mu_y_net = MLP(
             name="decoder_mu",
             d_in=self.d_x + self.d_z,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_mu_y,
             f_act=self.f_act,
             seed=self.seed,
         )
         self.std_y_net = MLP(
             name="decoder_cov",
             d_in=self.d_x + self.d_z,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_std_y,
             f_act=self.f_act,
             seed=self.seed,
         )
     elif self.arch == "two_heads":
         self.mu_y_std_y_net = MLP(
             name="decoder_mu_cov",
             d_in=self.d_x + self.d_z * 2,
             d_out=2 * self.d_y,
             mlp_layers=self.
             mlp_layers_mu_y,  # ignore self.mlp_layers_std_y
             f_act=self.f_act,
             seed=self.seed,
         )
     else:
         raise ValueError("Unknown network type: {}!".format(self.arch))
Exemple #3
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.d_z = kwargs["d_z"]
     self.mean_z_mlp = MLP(
         name="agg2mean_z",
         d_in=kwargs["d_lo"],
         d_out=kwargs["d_z"],
         mlp_layers=kwargs["mlp_layers"],
         f_act=kwargs["f_act"],
         f_out=(None, {}),
         seed=kwargs["seed"],
     )
     self.cov_z_mlp = MLP(
         name="agg2cov_z",
         d_in=kwargs["d_lo"],
         d_out=kwargs["d_z"],
         mlp_layers=kwargs["mlp_layers"],
         f_act=kwargs["f_act"],
         f_out=("exp", {}),
         seed=kwargs["seed"],
     )
 def create_networks(self):
     if self.arch == "separate_networks":
         self.mu_y_net = MLP(
             name="decoder_mu",
             d_in=self.d_x + self.d_z,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_mu_y,
             f_act=self.f_act,
             seed=self.seed,
         )
         self.std_y_net = MLP(
             name="decoder_cov",
             d_in=self.d_x + self.d_z,
             d_out=self.d_y,
             mlp_layers=self.mlp_layers_std_y,
             f_act=self.f_act,
             seed=self.seed,
         )
     elif self.arch == "two_heads":
         self.mu_y_std_y_net = MLP(
             name="decoder_mu_cov",
             d_in=self.d_x + self.d_z,
             d_out=2 * self.d_y,
             mlp_layers=self.mlp_layers_mu_y,  # ignore mlp_layers_std_y
             f_act=self.f_act,
             seed=self.seed,
         )
     else:
         raise ValueError("Unknown encoder network type!")
 def create_networks(self):
     if self.arch == "separate_networks":
         self.r_net = MLP(
             name="encoder_mu",
             d_in=self.d_x + self.d_y,
             d_out=self.d_lo,
             mlp_layers=self.mlp_layers_r,
             f_act=self.f_act,
             seed=self.seed,
         )
         self.cov_r_net = MLP(
             name="encoder_cov",
             d_in=self.d_x + self.d_y,
             d_out=self.d_lo,
             mlp_layers=self.mlp_layers_cov_r,
             f_act=self.f_act,
             seed=self.seed,
         )
     elif self.arch == "two_heads":
         self.r_cov_r_net = MLP(
             name="encoder_mu_cov",
             d_in=self.d_x + self.d_y,
             d_out=2 * self.d_lo,
             mlp_layers=self.mlp_layers_r,  # ignore self.mlp_layers_cov_r
             f_act=self.f_act,
             seed=self.seed,
         )
     else:
         raise ValueError("Unknown encoder network type!")
class EncoderNetworkMA:
    def __init__(self, **kwargs):
        self.d_x = kwargs["d_x"]
        self.d_y = kwargs["d_y"]
        self.d_lo = kwargs["d_lo"]
        self.mlp_layers = kwargs["mlp_layers"]
        self.f_act = kwargs["f_act"]
        self.seed = kwargs["seed"]

        self.r_net = None
        self.create_networks()

    def set_device(self, device):
        self.r_net = self.r_net.to(device)

    def init_weights(self):
        pass

    def save_weights(self, logpath, epoch):
        self.r_net.save(path=logpath, epoch=epoch)

    def load_weights(self, logpath, epoch):
        self.r_net.load_weights(path=logpath, epoch=epoch)

    def create_networks(self):
        self.r_net = MLP(
            name="encoder_r",
            d_in=self.d_x + self.d_y,
            d_out=self.d_lo,
            mlp_layers=self.mlp_layers,
            f_act=self.f_act,
            seed=self.seed,
        )

    def encode(self, x, y):
        assert x.ndim == y.ndim == 3

        # prepare input to encoder network
        encoder_input = torch.cat((x, y), dim=2)

        # encode
        r = self.r_net(encoder_input)

        return (r, )

    @property
    def parameters(self):
        return list(self.r_net.parameters())
class DecoderNetworkSamples:
    def __init__(self, **kwargs):
        self.d_x = kwargs["d_x"]
        self.d_y = kwargs["d_y"]
        self.d_z = kwargs["d_z"]
        self.arch = kwargs["arch"]

        # process network shapes
        self.mlp_layers_mu_y = kwargs["mlp_layers_mu_y"]
        self.mlp_layers_std_y = kwargs["mlp_layers_std_y"]

        self.f_act = kwargs["f_act"]
        self.seed = kwargs["seed"]

        self.mu_y_net = self.std_y_net = self.mu_y_std_y_net = None
        self.create_networks()

    def set_device(self, device):
        if self.arch == "separate_networks":
            self.mu_y_net = self.mu_y_net.to(device)
            self.std_y_net = self.std_y_net.to(device)
        elif self.arch == "two_heads":
            self.mu_y_std_y_net = self.mu_y_std_y_net.to(device)

    def init_weights(self):
        pass  # pytorch takes care of this

    def save_weights(self, logpath, epoch):
        if self.arch == "separate_networks":
            self.mu_y_net.save(path=logpath, epoch=epoch)
            self.std_y_net.save(path=logpath, epoch=epoch)
        elif self.arch == "two_heads":
            self.mu_y_std_y_net.save(path=logpath, epoch=epoch)

    def load_weights(self, logpath, epoch):
        if self.arch == "separate_networks":
            self.mu_y_net.load_weights(path=logpath, epoch=epoch)
            self.std_y_net.load_weights(path=logpath, epoch=epoch)
        elif self.arch == "two_heads":
            self.mu_y_std_y_net.load_weights(path=logpath, epoch=epoch)

    def create_networks(self):
        if self.arch == "separate_networks":
            self.mu_y_net = MLP(
                name="decoder_mu",
                d_in=self.d_x + self.d_z,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_mu_y,
                f_act=self.f_act,
                seed=self.seed,
            )
            self.std_y_net = MLP(
                name="decoder_cov",
                d_in=self.d_x + self.d_z,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_std_y,
                f_act=self.f_act,
                seed=self.seed,
            )
        elif self.arch == "two_heads":
            self.mu_y_std_y_net = MLP(
                name="decoder_mu_cov",
                d_in=self.d_x + self.d_z,
                d_out=2 * self.d_y,
                mlp_layers=self.mlp_layers_mu_y,  # ignore mlp_layers_std_y
                f_act=self.f_act,
                seed=self.seed,
            )
        else:
            raise ValueError("Unknown encoder network type!")

    def decode(self, x, z):
        assert x.ndim == 3  # (n_tsk, n_tst, d_x)
        assert z.ndim == 4  # (n_tsk, n_ls, n_marg, d_z)

        n_tsk = z.shape[0]
        n_ls = z.shape[1]
        n_marg = z.shape[2]
        n_tst = x.shape[1]

        # add latent-state-wise batch dimension to x
        x = x[:, None, None, :, :]
        x = x.expand(n_tsk, n_ls, n_marg, n_tst, self.d_x)

        # add dataset-wise batch dimension to sample
        z = z[:, :, :, None, :]
        z = z.expand((n_tsk, n_ls, n_marg, n_tst, self.d_z))

        # prepare input to decoder network
        input = torch.cat((x, z), dim=4)

        # decode
        if self.arch == "separate_networks":
            mu_y = self.mu_y_net(input)
            std_y = self.std_y_net(input)
        elif self.arch == "two_heads":
            mu_cov_y = self.mu_y_std_y_net(input)
            mu_y = mu_cov_y[:, :self.d_y]
            std_y = mu_cov_y[:, self.d_y:]

        # deparametrize
        std_y = torch.exp(std_y)

        return mu_y, std_y

    @property
    def parameters(self):
        params = []
        if self.arch == "separate_networks":
            params += list(self.mu_y_net.parameters())
            params += list(self.std_y_net.parameters())
        elif self.arch == "two_heads":
            params += list(self.mu_y_std_y_net.parameters())
        return params
class DecoderNetworkPB:
    def __init__(self, **kwargs):
        self.d_x = kwargs["d_x"]
        self.d_y = kwargs["d_y"]
        self.d_z = kwargs["d_z"]
        self.arch = kwargs["arch"]

        # process network shapes
        self.mlp_layers_mu_y = kwargs["mlp_layers_mu_y"]
        self.mlp_layers_std_y = kwargs["mlp_layers_std_y"]

        self.f_act = kwargs["f_act"]
        self.safe_log = kwargs["safe_log"]
        self.seed = kwargs["seed"]

        self.mu_y_net = self.std_y_net = self.mu_y_std_y_net = None
        self.create_networks()

    def set_device(self, device):
        if (self.arch == "separate_networks_combined_input"
                or self.arch == "separate_networks_separate_input"):
            self.mu_y_net = self.mu_y_net.to(device)
            self.std_y_net = self.std_y_net.to(device)
        elif self.arch == "two_heads":
            self.mu_y_std_y_net = self.mu_y_std_y_net.to(device)

    def init_weights(self):
        pass  # pytorch takes care of this

    def save_weights(self, logpath, epoch):
        if (self.arch == "separate_networks_combined_input"
                or self.arch == "separate_networks_separate_input"):
            self.mu_y_net.save(path=logpath, epoch=epoch)
            self.std_y_net.save(path=logpath, epoch=epoch)
        elif self.arch == "two_heads":
            self.mu_y_std_y_net.save(path=logpath, epoch=epoch)

    def load_weights(self, logpath, epoch):
        if (self.arch == "separate_networks_combined_input"
                or self.arch == "separate_networks_separate_input"):
            self.mu_y_net.load_weights(path=logpath, epoch=epoch)
            self.std_y_net.load_weights(path=logpath, epoch=epoch)

    def create_networks(self):
        if self.arch == "separate_networks_combined_input":
            self.mu_y_net = MLP(
                name="decoder_mu",
                d_in=self.d_x + self.d_z * 2,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_mu_y,
                f_act=self.f_act,
                seed=self.seed,
            )
            self.std_y_net = MLP(
                name="decoder_cov",
                d_in=self.d_x + self.d_z * 2,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_std_y,
                f_act=self.f_act,
                seed=self.seed,
            )
        elif self.arch == "separate_networks_separate_input":
            self.mu_y_net = MLP(
                name="decoder_mu",
                d_in=self.d_x + self.d_z,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_mu_y,
                f_act=self.f_act,
                seed=self.seed,
            )
            self.std_y_net = MLP(
                name="decoder_cov",
                d_in=self.d_x + self.d_z,
                d_out=self.d_y,
                mlp_layers=self.mlp_layers_std_y,
                f_act=self.f_act,
                seed=self.seed,
            )
        elif self.arch == "two_heads":
            self.mu_y_std_y_net = MLP(
                name="decoder_mu_cov",
                d_in=self.d_x + self.d_z * 2,
                d_out=2 * self.d_y,
                mlp_layers=self.
                mlp_layers_mu_y,  # ignore self.mlp_layers_std_y
                f_act=self.f_act,
                seed=self.seed,
            )
        else:
            raise ValueError("Unknown network type: {}!".format(self.arch))

    def decode(self, x, mu_z, cov_z):
        assert x.ndim == 3
        assert mu_z.ndim == cov_z.ndim == 3

        n_tsk = mu_z.shape[0]
        n_ls = mu_z.shape[1]
        n_tst = x.shape[1]

        # covariance parametrization
        cov_z = self.parametrize_latent_cov(cov_z)

        # add latent-state-wise batch dimension to X
        x = x[:, None, :, :]
        x = x.expand(n_tsk, n_ls, n_tst, self.d_x)

        # add dataset-wise batch dimension to latent states
        mu_z = mu_z[:, :, None, :]
        cov_z = cov_z[:, :, None, :]

        # prepare input to decoder network
        if self.arch == "separate_networks_combined_input":
            mu_z_cov_z = torch.cat((mu_z, cov_z), dim=3)
            mu_z_cov_z = mu_z_cov_z.expand((n_tsk, n_ls, n_tst, self.d_z * 2))
            input_mu = input_std = torch.cat((x, mu_z_cov_z), dim=3)
        elif self.arch == "separate_networks_separate_input":
            mu_z = mu_z.expand((n_tsk, n_ls, n_tst, self.d_z))
            cov_z = cov_z.expand((n_tsk, n_ls, n_tst, self.d_z))
            input_mu = torch.cat((x, mu_z), dim=3)
            input_std = torch.cat((x, cov_z), dim=3)
        elif self.arch == "two_heads":
            mu_z_cov_z = torch.cat((mu_z, cov_z), dim=3)
            mu_z_cov_z = mu_z_cov_z.expand((n_tsk, n_ls, n_tst, self.d_z * 2))
            input_two_head = torch.cat((x, mu_z_cov_z), dim=3)

        # decode
        if (self.arch == "separate_networks_combined_input"
                or self.arch == "separate_networks_separate_input"):
            mu_y = self.mu_y_net(input_mu)
            std_y = self.std_y_net(input_std)
        elif self.arch == "two_heads":
            mu_y_std_y = self.mu_y_std_y_net(input_two_head)
            mu_y = mu_y_std_y[:, :self.d_y]
            std_y = mu_y_std_y[:, self.d_y:]

        # deparametrize
        std_y = torch.exp(std_y)

        return mu_y, std_y

    def parametrize_latent_cov(self, cov):
        cov = cov + self.safe_log
        parametrized_cov = torch.log(cov)

        return parametrized_cov

    @property
    def parameters(self):
        params = []
        if (self.arch == "separate_networks_combined_input"
                or self.arch == "separate_networks_separate_input"):
            params += list(self.mu_y_net.parameters()) + list(
                self.std_y_net.parameters())
        elif self.arch == "two_heads":
            params += list(self.mu_y_std_y_net.parameters())
        return params
Exemple #9
0
class MeanAggregatorRtoZ(MeanAggregator):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.d_z = kwargs["d_z"]
        self.mean_z_mlp = MLP(
            name="agg2mean_z",
            d_in=kwargs["d_lo"],
            d_out=kwargs["d_z"],
            mlp_layers=kwargs["mlp_layers"],
            f_act=kwargs["f_act"],
            f_out=(None, {}),
            seed=kwargs["seed"],
        )
        self.cov_z_mlp = MLP(
            name="agg2cov_z",
            d_in=kwargs["d_lo"],
            d_out=kwargs["d_z"],
            mlp_layers=kwargs["mlp_layers"],
            f_act=kwargs["f_act"],
            f_out=("exp", {}),
            seed=kwargs["seed"],
        )

    def agg2latent(self, agg_state):
        return self.mean_z_mlp.forward(agg_state[0]), self.cov_z_mlp(
            agg_state[0])

    def set_device(self, device):
        self.device = device
        self.mean_z_mlp.to(device)
        self.cov_z_mlp.to(device)

    @property
    def parameters(self):
        return list(self.mean_z_mlp.parameters()) + list(
            self.cov_z_mlp.parameters())

    def init_weights(self):
        pass

    def save_weights(self, logpath, epoch):
        self.mean_z_mlp.save(path=logpath, epoch=epoch)
        self.cov_z_mlp.save(path=logpath, epoch=epoch)

    def load_weights(self, logpath, epoch):
        self.mean_z_mlp.load_weights(path=logpath, epoch=epoch)
        self.cov_z_mlp.load_weights(path=logpath, epoch=epoch)

    def delete_all_weight_files(self, logpath):
        self.mean_z_mlp.delete_all_weight_files(path=logpath)
        self.cov_z_mlp.delete_all_weight_files(path=logpath)
class EncoderNetworkBA:
    def __init__(self, **kwargs):
        self.d_x = kwargs["d_x"]
        self.d_y = kwargs["d_y"]
        self.d_lo = kwargs["d_lo"]
        self.arch = kwargs["arch"]
        self.f_act = kwargs["f_act"]
        self.seed = kwargs["seed"]

        # process network shapes
        self.mlp_layers_r = kwargs["mlp_layers_r"]
        self.mlp_layers_cov_r = kwargs["mlp_layers_cov_r"]

        self.r_net = self.cov_r_net = self.r_cov_r_net = None
        self.create_networks()

    def set_device(self, device):
        if self.arch == "separate_networks":
            self.r_net = self.r_net.to(device)
            self.cov_r_net = self.cov_r_net.to(device)
        elif self.arch == "two_heads":
            self.r_cov_r_net = self.r_cov_r_net.to(device)

    def init_weights(self):
        pass

    def save_weights(self, logpath, epoch):
        if self.arch == "separate_networks":
            self.r_net.save(path=logpath, epoch=epoch)
            self.cov_r_net.save(path=logpath, epoch=epoch)
        elif self.arch == "two_heads":
            self.r_cov_r_net.save(path=logpath, epoch=epoch)

    def load_weights(self, logpath, epoch):
        if self.arch == "separate_networks":
            self.r_net.load_weights(path=logpath, epoch=epoch)
            self.cov_r_net.load_weights(path=logpath, epoch=epoch)
        elif self.arch == "two_heads":
            self.r_cov_r_net.load_weights(path=logpath, epoch=epoch)

    def create_networks(self):
        if self.arch == "separate_networks":
            self.r_net = MLP(
                name="encoder_mu",
                d_in=self.d_x + self.d_y,
                d_out=self.d_lo,
                mlp_layers=self.mlp_layers_r,
                f_act=self.f_act,
                seed=self.seed,
            )
            self.cov_r_net = MLP(
                name="encoder_cov",
                d_in=self.d_x + self.d_y,
                d_out=self.d_lo,
                mlp_layers=self.mlp_layers_cov_r,
                f_act=self.f_act,
                seed=self.seed,
            )
        elif self.arch == "two_heads":
            self.r_cov_r_net = MLP(
                name="encoder_mu_cov",
                d_in=self.d_x + self.d_y,
                d_out=2 * self.d_lo,
                mlp_layers=self.mlp_layers_r,  # ignore self.mlp_layers_cov_r
                f_act=self.f_act,
                seed=self.seed,
            )
        else:
            raise ValueError("Unknown encoder network type!")

    def encode(self, x, y):
        assert x.ndim == y.ndim == 3

        # prepare input to encoder network
        encoder_input = torch.cat((x, y), dim=2)

        # encode
        if self.arch == "separate_networks":
            r, cov_r = self.r_net(encoder_input), self.cov_r_net(encoder_input)
        elif self.arch == "two_heads":
            mu_r_cov_r = self.r_cov_r_net(encoder_input)
            r = mu_r_cov_r[:, :, :self.d_lo]
            cov_r = mu_r_cov_r[:, :, self.d_lo:]

        cov_r = torch.exp(cov_r)

        return r, cov_r

    @property
    def parameters(self):
        if self.arch == "separate_networks":
            return list(self.r_net.parameters()) + list(
                self.cov_r_net.parameters())
        elif self.arch == "two_heads":
            return list(self.r_cov_r_net.parameters())