def __init__(self, config_ms): super(RGBHead, self).__init__() assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID' self.head = nn.Sequential( edsr.MeanShift(0, (0., 0., 0.), (128., 128., 128.)), Head(config_ms, Cin=3)) self._repr = 'MeanShift//Head(C=3)'
def __init__(self, config_ms): super(MultiscaleNetwork, self).__init__() # Set for the RGB baselines self._rgb = config_ms.rgb_bicubic_baseline # if set, make sure no backprob through sub_mean # True for L3C and RGB, not for RGB Shared self._fuse_feat = config_ms.dec.skip self._show_input = global_config.get('showinp', False) # For the first scale, where input is RGB with C=3 rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_rgb_mean = edsr.MeanShift(255., rgb_mean, rgb_std) # to interval -128, 128 self.scales = config_ms.num_scales self.config_ms = config_ms # NOTES about naming: See README if not config_ms.rgb_bicubic_baseline: # Heads are used to make the code work for L3C as well as the RBG baselines. # For RGB, each encoder gets a bicubically downsampled RGB image as input, with 3 channels. # Otherwise, the encoder gets the final feature before the quantizer, with Cf channels. # The Heads map either of these to Cf channels, such that encoders always get a feature map with Cf # channels. heads = ([RGBHead(config_ms)] + [ Head(config_ms, Cin=self.get_Cin_for_scale(scale)) for scale in range(self.scales - 1) ]) nets = [Net(config_ms, scale) for scale in range(self.scales)] prob_clfs = ([AtrousProbabilityClassifier(config_ms, C=3)] + [ AtrousProbabilityClassifier(config_ms, config_ms.q.C) for _ in range(self.scales - 1) ]) else: print('*** Multiscale RGB Pyramid') # For RGB Baselines, we feed subsampled version of RGB directly to the next subsampler # (see Fig A2, A3 in appendix of paper). Thus, the heads are just identity. heads = [ pe.LambdaModule(lambda x: x, name='ID') for _ in range(self.scales) ] nets = [Net(config_ms, scale) for scale in range(self.scales)] prob_clfs = [ AtrousProbabilityClassifier(config_ms, C=3) for _ in range(self.scales) ] self.heads = nn.ModuleList(heads) self.nets = nn.ModuleList(nets) self.prob_clfs = nn.ModuleList(prob_clfs) # len == #scales self.extra_repr_str = 'scales={} / {} nets / {} ps'.format( self.scales, len(self.nets), len(self.prob_clfs))