Exemplo n.º 1
0
    def loss(self, N, round_cl=1):
        """Loss function for training

        Parameters
        ----------
        N : int
            Number of training samples
        """
        loss = self.network.get_loss()

        # adding nodes to dict s.t. they can be monitored during training
        self.observables['loss.lprobs'] = self.network.lprobs
        self.observables['loss.iws'] = self.network.iws
        self.observables['loss.raw_loss'] = loss

        if self.svi:
            if self.round <= round_cl:
                # weights close to zero-centered prior in the first round
                if self.reg_lambda > 0:
                    kl, imvs = svi_kl_zero(self.network.mps, self.network.sps,
                                           self.reg_lambda)
                else:
                    kl, imvs = 0, {}
            else:
                # weights close to those of previous round
                kl, imvs = svi_kl_init(self.network.mps, self.network.sps)

            loss = loss + 1 / N * kl

            # adding nodes to dict s.t. they can be monitored
            self.observables['loss.kl'] = kl
            self.observables.update(imvs)

        return loss
Exemplo n.º 2
0
    def loss(self, N):
        """Loss function for training

        Parameters
        ----------
        N : int
            Number of training samples
        """
        loss = -tt.mean(self.network.iws * self.network.lprobs)

        # adding nodes to dict s.t. they can be monitored during training
        self.observables['loss.iws'] = self.network.iws

        if self.svi:
            if self.round == 1 or self.retain_data:
                # weights close to zero-centered prior in the first round
                kl, imvs = svi_kl_zero_diag_gauss(self.network.mps_wp,
                                                  self.network.sps_wp,
                                                  self.network.mps_bp,
                                                  self.network.sps_bp)
            else:
                # weights close to those of previous round
                kl, imvs = svi_kl_init(self.network.mps, self.network.sps)

            loss = loss + 1 / N * kl

            # adding nodes to dict s.t. they can be monitored
            self.observables['loss.kl'] = kl
            self.observables.update(imvs)

        return loss
Exemplo n.º 3
0
    def loss(self, N, round_cl=1):
        """Loss function for training

        Parameters
        ----------
        N : int
            Number of training samples
        """
        loss = -tt.mean(self.network.lprobs)

        if self.svi:

            if self.round <= round_cl:
                # weights close to zero-centered prior in the first round
                if self.reg_lambda > 0:
                    kl, imvs = svi_kl_zero(self.network.mps, self.network.sps,
                                           self.reg_lambda)
                else:
                    kl, imvs = 0, {}
            else:
                # weights close to those of previous round
                kl, imvs = svi_kl_init(self.network.mps, self.network.sps)

            loss = loss + 1 / N * kl

        return loss
Exemplo n.º 4
0
    def define_loss(self,
                    n,
                    round_cl=1,
                    proposal='gaussian',
                    combined_loss=False):
        """Loss function for training

        Parameters
        ----------
        n : int
            Number of training samples
        round_cl : int
            Round after which to start continual learning
        proposal : str
            Specifier for type of proposal used: continuous ('gaussian', 'mog')
            or 'atomic' proposals are implemented.
        combined_loss : bool
            Whether to include prior likelihood terms in addition to atomic
        """
        if proposal == 'prior':  # using prior as proposal
            loss, trn_inputs = snpe_loss_prior_as_proposal(self.network,
                                                           svi=self.svi)
        elif proposal == 'gaussian':
            assert isinstance(self.generator.proposal, dd.Gaussian)
            loss, trn_inputs = apt_loss_gaussian_proposal(self.network,
                                                          self.generator.prior,
                                                          svi=self.svi)
        elif proposal.lower() == 'mog':
            loss, trn_inputs = apt_loss_MoG_proposal(self.network,
                                                     self.generator.prior,
                                                     svi=self.svi)
        elif proposal == 'atomic':
            loss, trn_inputs = \
                apt_loss_atomic_proposal(self.network, svi=self.svi,
                                         combined_loss=combined_loss)
        else:
            raise NotImplemented()

        # adding nodes to dict s.t. they can be monitored during training
        self.observables['loss.lprobs'] = self.network.lprobs
        self.observables['loss.raw_loss'] = loss

        if self.svi:
            if self.round <= round_cl:
                # weights close to zero-centered prior in the first round
                if self.reg_lambda > 0:
                    kl, imvs = svi_kl_zero(self.network.mps, self.network.sps,
                                           self.reg_lambda)
                else:
                    kl, imvs = 0, {}
            else:
                # weights close to those of previous round
                kl, imvs = svi_kl_init(self.network.mps, self.network.sps)

            loss = loss + 1 / n * kl

            # adding nodes to dict s.t. they can be monitored
            self.observables['loss.kl'] = kl
            self.observables.update(imvs)

        return loss, trn_inputs