예제 #1
0
파일: nets.py 프로젝트: changyi7231/MEF
    def forward(self, x, reverse=False, init=False):
        if not reverse:
            out = x
            outputs = []
            log_det_sum = x.new_zeros(x.size(0))
            for i in range(self.num_levels):
                out = squeeze2d(out)
                out, log_det = self.blocks[i](out, init=init)
                log_det_sum = log_det_sum + log_det
                if i < self.num_levels - 1:
                    out1, out2 = split2d(out, out.size(1) // 2)
                    outputs.append(out2)
                    out = out1
            out = unsqueeze2d(out)
            for _ in range(self.num_levels - 1):
                out2 = outputs.pop()
                out = unsqueeze2d(unsplit2d([out, out2]), factor=2)
        else:
            out = x
            outputs = []
            log_det_sum = x.new_zeros(x.size(0))
            out = squeeze2d(out)
            for _ in range(self.num_levels - 1):
                out1, out2 = split2d(out, out.size(1) // 2)
                outputs.append(out2)
                out = squeeze2d(out1)
            for i in reversed(range(self.num_levels)):
                if i < self.num_levels - 1:
                    out2 = outputs.pop()
                    out = unsplit2d([out, out2])
                out, log_det = self.blocks[i](out, reverse=reverse)
                log_det_sum = log_det_sum + log_det
                out = unsqueeze2d(out, factor=2)

        return out, log_det_sum
예제 #2
0
    def inverse(self,
                x,
                objective,
                yy=None,
                nlf0=None,
                nlf1=None,
                iso=None,
                shutter=None):
        z = x
        squeeze_factor = self.hps.squeeze_factor

        for i in range(self.n_levels):
            z = squeeze2d(z, squeeze_factor, self.hps.squeeze_type)
            if yy is not None:
                yy = squeeze2d(yy, squeeze_factor, self.hps.squeeze_type)
            for bijector in self.model[i]:
                if type(bijector) in [
                        AffineCouplingCondY, AffineCouplingCondXY,
                        AffineCouplingFitSdnGain2, AffineCouplingCondYG,
                        AffineCouplingCamSdn, AffineCouplingCondXYG,
                        AffineCouplingSdnGain, AffineCouplingSdn,
                        AffineCouplingGain, AffineCouplingGainEx1,
                        AffineCouplingGainEx2, AffineCouplingGainEx3,
                        AffineCouplingSdnEx1, AffineCouplingSdnEx2,
                        AffineCouplingSdnEx3, AffineCouplingSdnEx4,
                        AffineCouplingGainEx4, AffineCouplingSdnEx5,
                        AffineCouplingSdnEx6
                ]:
                    try:
                        z, log_abs_det_J_inv = \
                            bijector._inverse_and_log_det_jacobian(z, yy, nlf0, nlf1, iso, shutter)
                    except Exception as e:
                        print(e)
                        z = bijector._inverse(z, yy, nlf0, nlf1, iso, shutter)
                        log_abs_det_J_inv = bijector._inverse_log_det_jacobian(
                            z, yy, nlf0, nlf1, iso, shutter)
                else:
                    try:
                        z, log_abs_det_J_inv = \
                            bijector._inverse_and_log_det_jacobian(z)
                    except Exception as e:
                        print(e)
                        z = bijector._inverse(z)
                        log_abs_det_J_inv = bijector._inverse_log_det_jacobian(
                            z)
                objective += log_abs_det_J_inv
            if i < self.n_levels - 1:
                z, objective = split2d("pool{}".format(i), z, objective)
        return z, objective
예제 #3
0
 def _forward_log_det_jacobian(self, x, yy, nlf0=None, nlf1=None, iso=None, cam=None):
     if self._last_layer:
         x = tf.reshape(x, (-1, self.i0, self.i1, self.ic))
         yy = tf.reshape(yy, (-1, self.i0, self.i1, self.ic))
     if 2 * x.shape[1] == yy.shape[1]:
         yy = squeeze2d(yy, 2)
     x0 = x[:, :, :, :self.ic // 2]
     x0yy = tf.concat([x0, yy], axis=-1)
     _, log_scale = self._shift_and_log_scale_fn(x0yy, iso)
     log_scale = self.scale * tf.tanh(log_scale)
     if log_scale is None:
         return tf.constant(0., dtype=x.dtype, name="fldj")
     return -tf.reduce_sum(log_scale, axis=[1, 2, 3])
예제 #4
0
 def _forward(self, x, yy, nlf0=None, nlf1=None, iso=None, cam=None):
     if self._last_layer:
         x = tf.reshape(x, (-1, self.i0, self.i1, self.ic))
         yy = tf.reshape(yy, (-1, self.i0, self.i1, self.ic))
     if yy.shape[1] == 2 * x.shape[1]:  # needs squeezing
         yy = squeeze2d(yy, 2)
     shift, log_scale = self._shift_and_log_scale_fn(yy)
     log_scale = self.scale * tf.tanh(log_scale)
     y = x  # x[:, :, :, self.ic // 2:]
     if shift is not None:
         y -= shift
     if log_scale is not None:
         y *= tf.exp(-log_scale)
     return y
예제 #5
0
    def _forward(self, x, yy, nlf0=None, nlf1=None, iso=None, cam=None):
        if self._last_layer:
            x = tf.reshape(x, (-1, self.i0, self.i1, self.ic))
            yy = tf.reshape(yy, (-1, self.i0, self.i1, self.ic))

        if yy.shape[1] == 2 * x.shape[1]:  # needs squeezing
            yy = squeeze2d(yy, 2)

        scale = sdn_model_params_ex1(yy, iso)
        shift = 0.0

        y = x
        if scale is not None:
            y *= scale
        if shift is not None:
            y += shift
        return y
예제 #6
0
 def _forward(self, x, yy, nlf0=None, nlf1=None, iso=None, cam=None):
     # print('_forward-------')
     # import pdb
     # pdb.set_trace()
     if self._last_layer:
         x = tf.reshape(x, (-1, self.i0, self.i1, self.ic))
         yy = tf.reshape(yy, (-1, self.i0, self.i1, self.ic))
     if 2 * x.shape[1] == yy.shape[1]:
         yy = squeeze2d(yy, 2)
     x0 = x[:, :, :, :self.ic // 2]
     x1 = x[:, :, :, self.ic // 2:]
     x0yy = tf.concat([x0, yy], axis=-1)
     shift, log_scale = self._shift_and_log_scale_fn(x0yy, iso)
     log_scale = self.scale * tf.tanh(log_scale)
     y1 = x1
     if shift is not None:
         y1 -= shift
     if log_scale is not None:
         y1 *= tf.exp(-log_scale)
     y = tf.concat([x0, y1], axis=-1)
     return y
예제 #7
0
 def _forward_and_log_det_jacobian(self, x, yy, nlf0=None, nlf1=None, iso=None, cam=None):
     if self._last_layer:
         x = tf.reshape(x, (-1, self.i0, self.i1, self.ic))
         yy = tf.reshape(yy, (-1, self.i0, self.i1, self.ic))
     if 2 * x.shape[1] == yy.shape[1]:
         yy = squeeze2d(yy, 2)
     x0 = x[:, :, :, :self.ic // 2]
     x1 = x[:, :, :, self.ic // 2:]
     x0yy = tf.concat([x0, yy], axis=-1)
     shift, log_scale = self._shift_and_log_scale_fn(x0yy, iso)
     log_scale = self.scale * tf.tanh(log_scale)
     y1 = x1
     if shift is not None:
         y1 -= shift
     if log_scale is not None:
         y1 *= tf.exp(-log_scale)
     y = tf.concat([x0, y1], axis=-1)
     if log_scale is None:
         log_abs_det_J = tf.constant(0., dtype=x.dtype, name="fldj")
     else:
         log_abs_det_J = -tf.reduce_sum(log_scale, axis=[1, 2, 3])
     return y, log_abs_det_J