def setAsVariable(*args):
    barg = []
    for arg in args:
        barg.append(Variable(arg))
    return barg


# ---- The model ---- #
# get the model definition/architecture
# get network
import DAENet

if opt.useDense:
    encoders = DAENet.Dense_Encoders_Intrinsic(opt)
    decoders = DAENet.Dense_DecodersIntegralWarper2_Intrinsic(opt)
else:
    encoders = DAENet.Encoders_Intrinsic(opt)
    decoders = DAENet.DecodersIntegralWarper2_Intrinsic(opt)

# light_transfer    = DAENet.LightingTransfer(opt)

if opt.cuda:
    encoders.cuda()
    decoders.cuda()
    # light_transfer.cuda()

if not opt.modelPath == '':
    # rewrite here
    print('Reload previous model at: ' + opt.modelPath)
    def __init__(self,
                 nc=10,
                 ndf=32,
                 ndim=128,
                 activation=nn.LeakyReLU,
                 args=[0.2, False],
                 f_activation=nn.Sigmoid,
                 f_args=[],
                 norm_layer=nn.BatchNorm2d):
        super(DenseEncoder, self).__init__()
        self.ndim = ndim
        self.main = nn.Sequential(
            # input is (nc) x 256 x 256
            nn.Conv2d(nc, ndf, 7, stride=2, padding=3),
            nn.BatchNorm2d(ndf),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # input is (ndf) x 64 x 64
            DAENet.DenseBlockEncoder(ndf, 4, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockEncoder(ndf,
                                               ndf * 2,
                                               2,
                                               activation=activation,
                                               args=args,
                                               norm_layer=norm_layer),

            # input is (ndf*2) x 32 x 32
            DAENet.DenseBlockEncoder(ndf * 2, 6, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockEncoder(ndf * 2,
                                               ndf * 4,
                                               2,
                                               activation=activation,
                                               args=args,
                                               norm_layer=norm_layer),

            # input is (ndf*4) x 16 x 16
            DAENet.DenseBlockEncoder(ndf * 4, 12, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockEncoder(ndf * 4,
                                               ndf * 8,
                                               2,
                                               activation=activation,
                                               args=args,
                                               norm_layer=norm_layer),

            # input is (ndf*8) x 8 x 8
            DAENet.DenseBlockEncoder(ndf * 8, 24, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockEncoder(ndf * 8,
                                               ndf * 8,
                                               2,
                                               activation=activation,
                                               args=args,
                                               norm_layer=norm_layer),

            # input is (ndf*8) x 4 x 4
            DAENet.DenseBlockEncoder(ndf * 8, 16, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockEncoder(ndf * 8,
                                               ndim,
                                               4,
                                               activation=activation,
                                               args=args,
                                               norm_layer=norm_layer),
            f_activation(*f_args),
        )
def load_decoder(path):
    model = DAENet.Dense_DecodersIntegralWarper2_Intrinsic(opt)
    model.load_state_dict(torch.load(path))
    model, = setCuda(model)
    model.eval()
    return model
    def __init__(self,
                 nz=128,
                 nc=10,
                 ngf=32,
                 lb=0,
                 ub=1,
                 activation=nn.ReLU,
                 args=[False],
                 f_activation=nn.Hardtanh,
                 f_args=[0, 1],
                 norm_layer=nn.BatchNorm2d):
        super(DenseDecoder, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            # input is Z, going into convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),

            # state size. (ngf*8) x 4 x 4
            DAENet.DenseBlockDecoder(ngf * 8, 16, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf * 8,
                                               ngf * 8,
                                               norm_layer=norm_layer),

            # state size. (ngf*8) x 8 x 8
            DAENet.DenseBlockDecoder(ngf * 8, 24, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf * 8,
                                               ngf * 4,
                                               norm_layer=norm_layer),

            # state size. (ngf*4) x 16 x 16
            DAENet.DenseBlockDecoder(ngf * 4, 12, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf * 4,
                                               ngf * 2,
                                               norm_layer=norm_layer),

            # state size. (ngf*2) x 32 x 32
            DAENet.DenseBlockDecoder(ngf * 2, 6, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf * 2,
                                               ngf,
                                               norm_layer=norm_layer),

            # state size. (ngf) x 64 x 64
            DAENet.DenseBlockDecoder(ngf, 4, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf, ngf,
                                               norm_layer=norm_layer),

            # state size. (ngf) x 128 x 128
            DAENet.DenseBlockDecoder(ngf, 2, norm_layer=norm_layer),
            DAENet.DenseTransitionBlockDecoder(ngf, ngf,
                                               norm_layer=norm_layer),

            # state size. (ngf) x 256 x 256
            norm_layer(ngf),
            activation(*args),
            nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
            f_activation(*f_args),
        )
def load_encoder(path):
    model = DAENet.Dense_Encoders_Intrinsic(opt)
    model.load_state_dict(torch.load(path))
    model, = setCuda(model)
    model.eval()
    return model