Beispiel #1
0
    def __init__(self, dim_in, dim_feat = None, dim_chanl = None):
        # if dim_feat is not None or dim_chanl is not None:
        #     warn(f"`dim_feat` {dim_feat} and `dim_chanl` {dim_chanl} ignored")
        super(CNN_DCGANpretr_224, self).__init__()
        self._param_groups = []
        self.nn_pre = nn.Sequential(
                nn.Linear(dim_in, dim_feat), nn.Tanh(),
                nn.Linear(dim_feat, 120), nn.Tanh()
            )
        self.nn_pre.apply(weights_init)
        self._param_groups += [{"params": self.nn_pre.parameters(), "lr_ratio": 1.}]
        self.f_pre = self.nn_pre

        self.nn_backbone = tc.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', # force_reload=True,
                pretrained=True, useGPU=False, model_name='cifar10').getOriginalG()
        self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}]
        self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=1)

        self.nn_post = nn.Sequential(
                nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(3, 3, 4, 4, 16, bias=False)
            )
        self.nn_post.apply(weights_init)
        self._param_groups += [{"params": self.nn_post.parameters(), "lr_ratio": 1.}]
        self.f_post = wrap4_multi_batchdims(self.nn_post, ndim_vars=3)
Beispiel #2
0
 def __init__(self, dim_in, dim_feat, dim_chanl = 3):
     super(CNN_DCGANvar_224, self).__init__()
     self.nn_main = nn.Sequential(
             # l_out = stride*(l_in - 1) + l_kernel - 2*padding. (*, *, l_kernel, stride, padding)
             # input is Z, going into a convolution
             nn.ConvTranspose2d( dim_in, dim_feat * 8, 7, 1, 0, bias=False),
             nn.BatchNorm2d(dim_feat * 8),
             nn.ReLU(True),
             # state size. (dim_feat*8) x 7 x 7
             nn.ConvTranspose2d(dim_feat * 8, dim_feat * 4, 4, 4, 0, bias=False),
             nn.BatchNorm2d(dim_feat * 4),
             nn.ReLU(True),
             # state size. (dim_feat*4) x 28 x 28
             nn.ConvTranspose2d( dim_feat * 4, dim_feat * 2, 4, 2, 1, bias=False),
             nn.BatchNorm2d(dim_feat * 2),
             nn.ReLU(True),
             # state size. (dim_feat*2) x 56 x 56
             nn.ConvTranspose2d( dim_feat * 2, dim_feat, 4, 2, 1, bias=False),
             nn.BatchNorm2d(dim_feat),
             nn.ReLU(True),
             # state size. (dim_feat) x 112 x 112
             nn.ConvTranspose2d( dim_feat, dim_chanl, 4, 2, 1, bias=True), # False),
             # nn.Tanh()
             # state size. (dim_chanl) x 224 x 224
         )
     self.apply(weights_init)
     self._param_groups = [{"params": self.nn_main.parameters(), "lr_ratio": 1.}]
     self.f_main = wrap4_multi_batchdims(self.nn_main, ndim_vars=3)
Beispiel #3
0
    def __init__(self, dim_in, dim_feat = None, dim_chanl = None):
        # if dim_feat is not None or dim_chanl is not None:
        #     warn(f"`dim_feat` {dim_feat} and `dim_chanl` {dim_chanl} ignored")
        super(CNN_PGANpretr_224, self).__init__()
        self._param_groups = []
        self.nn_pre = nn.Sequential(
                nn.Linear(dim_in, dim_feat), nn.Tanh(),
                nn.Linear(dim_feat, 512), nn.Tanh()
            )
        self.nn_pre.apply(weights_init)
        self._param_groups += [{"params": self.nn_pre.parameters(), "lr_ratio": 1.}]
        self.f_pre = self.nn_pre

        self.nn_backbone = tc.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'PGAN', # force_reload=True,
                pretrained=True, useGPU=False, model_name='celebAHQ-256').getOriginalG() # 'cifar10' unavailable for PGAN
        self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}]
        self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=1)

        self.f_post = tv.transforms.CenterCrop(224)
Beispiel #4
0
    def __init__(self, backbone_stru: str, dim_bottleneck: int, dim_s: int, dim_y: int, dim_v: int,
            std_v1x_val: float, std_s1vx_val: float, # if <= 0, then learn the std.
            dims_bb2bn: list=None, dims_bn2s: list=None, dims_s2y: list=None,
            vbranch: bool=False, dims_bn2v: list=None):
        """ Based on MDD from <https://github.com/thuml/MDD>
        if not vbranch:
              (bb)   (bn)        (med)   (cls)
                       /->   v   -\
            x ====>  -|            |-> s ----> y
                       \-> parav -/
        else:
              (bb)   (bn)        (med)   (cls)
            x ====>  ---->       ----> s ----> y
                               \ (vbr)
                                \----> v
        """
        if not vbranch: assert dim_v <= dim_bottleneck
        super(CNNsvy1x, self).__init__()
        self.dim_s = dim_s; self.dim_v = dim_v; self.dim_y = dim_y
        self.shape_s = (dim_s,); self.shape_v = (dim_v,)
        self.vbranch = vbranch
        self.std_v1x_val = std_v1x_val; self.std_s1vx_val = std_s1vx_val
        self.learn_std_v1x = std_v1x_val <= 0 if type(std_v1x_val) is float else (std_v1x_val <= 0).any()
        self.learn_std_s1vx = std_s1vx_val <= 0 if type(std_s1vx_val) is float else (std_s1vx_val <= 0).any()

        self._x_cache_bb = self._bb_cache = None
        self._x_cache_bn = self._bn_cache = None
        self._param_groups = []

        if 'domainbed' in globals() and backbone_stru.startswith("DB"):
            self.nn_backbone = Featurizer((3,224,224),
                    {'resnet18': backbone_stru[2:]=='resnet18', 'resnet_dropout': 0.})
            self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 1.0}]
            self.nn_backbone.output_num = lambda: self.nn_backbone.n_outputs
            if dim_bottleneck is None: dim_bottleneck = self.nn_backbone.output_num() // 2
            if dim_s is None: dim_s = self.nn_backbone.output_num() // 4
        else:
            self.nn_backbone = backbone.network_dict[backbone_stru]()
            self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}]
        self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=3)

        if dims_bb2bn is None: dims_bb2bn = []
        self.nn_bottleneck = mlp.mlp_constructor(
                [self.nn_backbone.output_num()] + dims_bb2bn + [dim_bottleneck],
                nn.ReLU, lastactv = False
            )
        init_linear(self.nn_bottleneck, 0., 5e-3, 0.1)
        self._param_groups += [{"params": self.nn_bottleneck.parameters(), "lr_ratio": 1.}]
        self.f_bottleneck = self.nn_bottleneck

        if dims_bn2s is None: dims_bn2s = []
        self.nn_mediate = nn.Sequential(
                *([] if backbone_stru.startswith("DB") else [nn.BatchNorm1d(dim_bottleneck)]),
                nn.ReLU(),
                # nn.Dropout(0.5),
                mlp.mlp_constructor(
                    [dim_bottleneck] + dims_bn2s + [dim_s],
                    nn.ReLU, lastactv = False)
            )
        init_linear(self.nn_mediate, 0., 1e-2, 0.)
        self._param_groups += [{"params": self.nn_mediate.parameters(), "lr_ratio": 1.}]
        self.f_mediate = wrap4_multi_batchdims(self.nn_mediate, ndim_vars=1) # required by `BatchNorm1d`

        if dims_s2y is None: dims_s2y = []
        self.nn_classifier = nn.Sequential(
                nn.ReLU(),
                # nn.Dropout(0.5),
                mlp.mlp_constructor(
                    [dim_s] + dims_s2y + [dim_y],
                    nn.ReLU, lastactv = False)
            )
        init_linear(self.nn_classifier, 0., 1e-2, 0.)
        self._param_groups += [{"params": self.nn_classifier.parameters(), "lr_ratio": 1.}]
        self.f_classifier = self.nn_classifier

        if vbranch:
            if dims_bn2v is None: dims_bn2v = []
            self.nn_vbranch = nn.Sequential(
                    nn.BatchNorm1d(dim_bottleneck),
                    nn.ReLU(),
                    # nn.Dropout(0.5),
                    mlp.mlp_constructor(
                        [dim_bottleneck] + dims_bn2v + [dim_v],
                        nn.ReLU, lastactv = False)
                )
            init_linear(self.nn_vbranch, 0., 1e-2, 0.)
            self._param_groups += [{"params": self.nn_vbranch.parameters(), "lr_ratio": 1.}]
            self.f_vbranch = wrap4_multi_batchdims(self.nn_vbranch, ndim_vars=1)

        ## std models
        if self.learn_std_v1x:
            if not vbranch:
                self.nn_std_v = nn.Sequential(
                        mlp.mlp_constructor(
                            [self.nn_backbone.output_num()] + dims_bb2bn + [dim_v],
                            nn.ReLU, lastactv = False),
                        nn.Softplus()
                    )
            else:
                self.nn_std_v = nn.Sequential(
                        mlp.mlp_constructor(
                            [dim_bottleneck] + dims_bn2v + [dim_v],
                            nn.ReLU, lastactv = False),
                        nn.Softplus()
                    )
            init_linear(self.nn_std_v, 0., 1e-2, 0.)
            self._param_groups += [{"params": self.nn_std_v.parameters(), "lr_ratio": 1.}]
            self.f_std_v = self.nn_std_v

        if self.learn_std_s1vx:
            self.nn_std_s = nn.Sequential(
                    nn.BatchNorm1d(dim_bottleneck),
                    nn.ReLU(),
                    # nn.Dropout(0.5),
                    mlp.mlp_constructor(
                        [dim_bottleneck] + dims_bn2s + [dim_s],
                        nn.ReLU, lastactv = False),
                    nn.Softplus()
                )
            init_linear(self.nn_std_s, 0., 1e-2, 0.)
            self._param_groups += [{"params": self.nn_std_s.parameters(), "lr_ratio": 1.}]
            self.f_std_s = wrap4_multi_batchdims(self.nn_std_s, ndim_vars=1)
Beispiel #5
0
    def __init__(self,
                 dim_x,
                 dims_postx2prev,
                 dim_v,
                 dim_parav,
                 dims_postv2s,
                 dims_posts2prey,
                 dim_y,
                 actv="Sigmoid",
                 std_v1x_val: float = -1.,
                 std_s1vx_val: float = -1.,
                 after_actv: bool = True):  # if <= 0, then learn the std.
        """
                       /->   v   -\
        x ====> prev -|            |==> s ==> y
                       \-> parav -/
        """
        super(MLPsvy1x, self).__init__()
        if type(actv) is str: actv = getattr(nn, actv)
        self.dim_x, self.dim_v, self.dim_y = dim_x, dim_v, dim_y
        dim_prev, dim_s = dims_postx2prev[-1], dims_postv2s[-1]
        self.dim_prev, self.dim_s = dim_prev, dim_s
        self.shape_x, self.shape_v, self.shape_s = (dim_x, ), (dim_v, ), (
            dim_s, )
        self.dims_postx2prev, self.dim_parav, self.dims_postv2s, self.dims_posts2prey, self.actv \
                = dims_postx2prev, dim_parav, dims_postv2s, dims_posts2prey, actv
        self.f_x2prev = mlp_constructor([dim_x] + dims_postx2prev, actv)
        if after_actv:
            self.f_prev2v = nn.Sequential(nn.Linear(dim_prev, dim_v), actv())
            self.f_prev2parav = nn.Sequential(nn.Linear(dim_prev, dim_parav),
                                              actv())
            self.f_vparav2s = mlp_constructor([dim_v + dim_parav] +
                                              dims_postv2s, actv)
            self.f_s2y = mlp_constructor([dim_s] + dims_posts2prey + [dim_y],
                                         actv,
                                         lastactv=False)
        else:
            self.f_prev2v = nn.Linear(dim_prev, dim_v)
            self.f_prev2parav = nn.Linear(dim_prev, dim_parav)
            self.f_vparav2s = nn.Sequential(
                actv(),
                mlp_constructor([dim_v + dim_parav] + dims_postv2s,
                                actv,
                                lastactv=False))
            self.f_s2y = nn.Sequential(
                actv(),
                mlp_constructor([dim_s] + dims_posts2prey + [dim_y],
                                actv,
                                lastactv=False))

        self.std_v1x_val = std_v1x_val
        self.std_s1vx_val = std_s1vx_val
        self.learn_std_v1x = std_v1x_val <= 0 if type(
            std_v1x_val) is float else (std_v1x_val <= 0).any()
        self.learn_std_s1vx = std_s1vx_val <= 0 if type(
            std_s1vx_val) is float else (std_s1vx_val <= 0).any()

        self._prev_cache = self._x_cache_prev = None
        self._v_cache = self._x_cache_v = None
        self._parav_cache = self._x_cache_parav = None

        ## std models
        if self.learn_std_v1x:
            self.nn_std_v = nn.Sequential(
                mlp_constructor([dim_prev, dim_v], nn.ReLU, lastactv=False),
                nn.Softplus())
            init_linear(self.nn_std_v, 0., 1e-2, 0.)
            self.f_std_v = self.nn_std_v

        if self.learn_std_s1vx:
            self.nn_std_s = nn.Sequential(
                nn.BatchNorm1d(dim_v + dim_parav),
                nn.ReLU(),
                # nn.Dropout(0.5),
                mlp_constructor([dim_v + dim_parav] + dims_postv2s,
                                nn.ReLU,
                                lastactv=False),
                nn.Softplus())
            init_linear(self.nn_std_s, 0., 1e-2, 0.)
            self.f_std_s = wrap4_multi_batchdims(self.nn_std_s, ndim_vars=1)