def get_model(args): if args.multiScale: model = multiscale_iResNet(in_shape, args.nBlocks, args.nStrides, args.nChannels, args.init_ds == 2, args.inj_pad, args.coeff, args.densityEstimation, args.nClasses, args.numTraceSamples, args.numSeriesTerms, args.powerIterSpectralNorm, actnorm=(not args.noActnorm), learn_prior=(not args.fixedPrior), nonlin=args.nonlin) else: model = iResNet(nBlocks=args.nBlocks, nStrides=args.nStrides, nChannels=args.nChannels, nClasses=args.nClasses, init_ds=args.init_ds, inj_pad=args.inj_pad, in_shape=in_shape, coeff=args.coeff, numTraceSamples=args.numTraceSamples, numSeriesTerms=args.numSeriesTerms, n_power_iter = args.powerIterSpectralNorm, density_estimation=args.densityEstimation, actnorm=(not args.noActnorm), learn_prior=(not args.fixedPrior), nonlin=args.nonlin) return model
def get_model(args): if args.multiScale: model = multiscale_iResNet(in_shape, args.nBlocks, args.nStrides, args.nChannels, args.doAttention, args.init_ds == 2, args.coeff, args.nClasses, args.numTraceSamples, args.numSeriesTerms, args.powerIterSpectralNorm, actnorm=(not args.noActnorm), nonlin=args.nonlin, use_label=args.use_label) else: # model = iResNet(nBlocks=args.nBlocks, nStrides=args.nStrides, # nChannels=args.nChannels, nClasses=args.nClasses, # init_ds=args.init_ds, # inj_pad=args.inj_pad, # in_shape=in_shape, # coeff=args.coeff, # numTraceSamples=args.numTraceSamples, # numSeriesTerms=args.numSeriesTerms, # n_power_iter = args.powerIterSpectralNorm, # density_estimation=args.densityEstimation, # actnorm=(not args.noActnorm), # learn_prior=(not args.fixedPrior), # nonlin=args.nonlin) print("Only multiscale model supported.") exit() return model