def __init__(self, in_channels, out_channels, kernel_size, prm, use_bias=False, stride=1, padding=0, dilation=1): from pudb import set_trace #set_trace() super(StochasticConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_bias = use_bias self.stride = stride self.padding = padding self.dilation = dilation kernel_size = make_pair(kernel_size) # k --> (k,k) self.kernel_size = kernel_size weights_size = (out_channels, in_channels, kernel_size[0], kernel_size[1]) if use_bias: bias_size = (out_channels) else: bias_size = None self.create_stochastic_layer(weights_size, bias_size, prm) init_stochastic_conv2d(self, prm.log_var_init) #self.eps_std = 1.0 self.eps_std = 0
def init_module(m, log_var_init): # Conv2d standard if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels stdv = 1. / math.sqrt(n) m.weight.data.uniform_(-stdv, stdv) if m.bias is not None: m.bias.data.uniform_(-stdv, +stdv) # Linear standard elif isinstance(m, nn.Linear): n = m.weight.size(1) stdv = 1. / math.sqrt(n) m.weight.data.uniform_(-stdv, stdv) if m.bias is not None: m.bias.data.uniform_(-stdv, +stdv) # BatchNorm2d elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() # Conv2d stochastic elif isinstance(m, StochasticConv2d): init_stochastic_conv2d(m, log_var_init) # Linear stochastic elif isinstance(m, StochasticLinear): init_stochastic_linear(m, log_var_init)