def create_network(self): args = self.args ninchannels = 1 noutchannels = 1 temp_opt = args.nl_temp n3block_opt = dict(k=args.nl_k, patchsize=args.nl_patchsize, stride=args.nl_stride, temp_opt=temp_opt, embedcnn_opt=args.embedcnn) dncnn_opt = args.dncnn dncnn_opt["residual"] = True net = n3net.N3Net(ninchannels, noutchannels, args.nfeatures_interm, nblocks=args.ndncnn, block_opt=dncnn_opt, nl_opt=n3block_opt, residual=False) return net
def create_network(self): args = self.args noutchannels = 1 if not args.bayer else 4 ninchannels = noutchannels if args.inputnoisemap: ninchannels += 2 temp_opt = args.nl_temp n3block_opt = dict(k=args.nl_k, patchsize=args.nl_patchsize, stride=args.nl_stride, temp_opt=temp_opt, embedcnn_opt=args.embedcnn) dncnn_opt = args.dncnn dncnn_opt["residual"] = True net = n3net.N3Net(ninchannels, noutchannels, args.nfeatures_interm, nblocks=args.ndncnn, block_opt=dncnn_opt, nl_opt=n3block_opt, residual=False) net.blocks[0].nplanes_residual = noutchannels return net