def __init__(self, **kwargs): outnum = kwargs.pop('outnum', 1) outact = kwargs.pop('outact', None) numhid = kwargs.pop('numhid', 64) clip = kwargs.pop('clip', [-1., 1.]) use_bias = kwargs.pop('use_bias', False) # patest = dict(relu='relu', out='clip', pixel='relu', gauss='relu') patest = dict(relu='linear', out='linear', pixel='linear', gauss='linear') patest.update(kwargs.pop('patest', {})) explain = dict(relu='zplus', out='zclip', pixel='zb', gauss='wsquare') explain.update(kwargs.pop('explain', {})) super().__init__(**kwargs) with self.name_scope(): self.add(Concat()) # Same as Dense + reshape since we're coming from 1x1 self.add( Conv2DTranspose(numhid * 8, 4, strides=1, padding=0, use_bias=use_bias, explain=explain['gauss'], regimes=estimators[patest['gauss']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 4 x 4 self.add( Conv2DTranspose(numhid * 4, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 8 x 8 self.add( Conv2DTranspose(numhid * 2, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 12 x 12 self.add( Conv2DTranspose(numhid * 2, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 16 x 16 self.add( Conv2DTranspose(numhid, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 20 x 20 self.add( Conv2DTranspose(numhid, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 24 x 24 self.add( Conv2DTranspose(numhid, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) self.add(BatchNorm()) self.add(ReLU()) # _numhid x 28 x 28 self.add( Conv2DTranspose(outnum, 5, strides=1, padding=0, output_padding=0, use_bias=use_bias, explain=explain['out'], regimes=estimators[patest['out']](), noregconst=-1.)) if outact == 'relu': self.add(ReLU()) elif outact == 'clip': self.add(Clip(low=clip[0], high=clip[1])) elif outact == 'tanh': self.add(Tanh()) elif outact == 'batchnorm': self.add(BatchNorm(scale=False, center=False)) self.add(Identity()) else: self.add(Identity())
def __init__(self, **kwargs): outnum = kwargs.pop('outnum', 1) outact = kwargs.pop('outact', None) numhid = kwargs.pop('numhid', 64) clip = kwargs.pop('clip', [-1., 1.]) use_bias = kwargs.pop('use_bias', False) use_bnrm = kwargs.pop('use_bnorm', True) # patest = dict(relu='relu', out='clip', pixel='relu', gauss='relu') patest = dict(relu='linear', out='linear', pixel='linear', gauss='linear') patest.update(kwargs.pop('patest', {})) explain = dict(relu='zplus', out='zclip', pixel='zb', gauss='wsquare') explain.update(kwargs.pop('explain', {})) super().__init__(**kwargs) with self.name_scope(): self.add(Concat()) self.add( Conv2DTranspose(numhid * 8, 4, strides=1, padding=0, use_bias=use_bias, explain=explain['gauss'], regimes=estimators[patest['gauss']]())) if use_bnrm: self.add(BatchNorm()) self.add(ReLU()) # _numhid x 3 x 3 self.add( Conv2DTranspose(numhid * 4, 4, strides=1, padding=0, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) if use_bnrm: self.add(BatchNorm()) self.add(ReLU()) # _numhid x 7 x 7 self.add( Conv2DTranspose(numhid * 2, 4, strides=2, padding=1, use_bias=use_bias, explain=explain['relu'], regimes=estimators[patest['relu']]())) if use_bnrm: self.add(BatchNorm()) self.add(ReLU()) # _numhid x 14 x 14 self.add( Conv2DTranspose(outnum, 4, strides=2, padding=1, use_bias=use_bias, explain=explain['out'], regimes=estimators[patest['out']](), noregconst=-1.)) if outact == 'relu': self.add(ReLU()) elif outact == 'clip': self.add(Clip(low=clip[0], high=clip[1])) elif outact == 'tanh': self.add(Tanh()) elif outact == 'batchnorm': self.add(BatchNorm(scale=False, center=False)) self.add(Identity()) else: self.add(Identity())