class GraphConvolution(Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.detach().uniform_(-stdv, stdv) if self.bias is not None: self.bias.detach().uniform_(-stdv, stdv) def forward(self, input, adj): support = torch.mm(input, self.weight) output = torch.spmm(adj, support) if self.bias is not None: return output + self.bias else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class NoisyLinear(nn.Module): def __init__(self, in_features, out_features, bias=True): super(NoisyLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.Tensor(out_features, in_features) self.weight_epsilon = torch.Tensor(out_features, in_features) self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_sigma = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = torch.Tensor(out_features) self.bias_epsilon = torch.Tensor(out_features) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_sigma = Parameter(torch.Tensor(out_features)) else: self.bias = None self.bias_epsilon = None self.register_parameter('bias_mu', None) self.register_parameter('bias_sigma', None) self.reset_parameters() self.sampled = False def sample(self): if self.training: self.weight_epsilon.normal_() self.weight = self.weight_epsilon.mul(self.weight_sigma).add_( self.weight_mu) if self.bias is not None: self.bias_epsilon.normal_() self.bias = self.bias_epsilon.mul(self.bias_sigma).add_( self.bias_mu) else: self.weight = self.weight_mu.detach() if self.bias is not None: self.bias = self.bias_mu.detach() self.sampled = True def reset_parameters(self): stdv = math.sqrt(3.0 / self.weight.size(1)) self.weight_mu.uniform_(-stdv, stdv) self.weight_sigma.fill_(0.017) if self.bias is not None: self.bias_mu.uniform_(-stdv, stdv) self.bias_sigma.fill_(0.017) def forward(self, input): if not self.sampled: self.sample() return F.linear(input, self.weight, self.bias) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class MF(nn.Module): def __init__(self, num_users, num_items, embedding_size): super(MF, self).__init__() self.users = Parameter(torch.FloatTensor(num_users, embedding_size)) self.items = Parameter(torch.FloatTensor(num_items, embedding_size)) self.init_params() def init_params(self): stdv = 1. / math.sqrt(self.users.size(1)) self.users.data.uniform_(-stdv, stdv) self.items.data.uniform_(-stdv, stdv) def pair_forward(self, user, item_p, item_n): user = self.users[user] item_p = self.items[item_p] item_n = self.items[item_n] p_score = torch.sum(user * item_p, 2) n_score = torch.sum(user * item_n, 2) return p_score, n_score def point_forward(self, user, item): user = self.users[user] item = self.items[item] score = torch.sum(user * item, 2) return score def get_item_embeddings(self): return self.items.detach().cpu().numpy().astype('float32') def get_user_embeddings(self): return self.users.detach().cpu().numpy().astype('float32')
class CapsuleShare(ModuleParall): def __init__(self, in_feature, out_feature, routings=3, bias=True, retain_grad=False): super(CapsuleShare, self).__init__() self.in_feature = in_feature self.out_feature = out_feature self.routings = routings if routings < 1: raise ValueError('Routing should be at least 1!') self.retain_grad = retain_grad b = Variable(torch.zeros(1, 1, self.out_feature)).type(FloatTensor) self.c = F.softmax(b, 2) self.weight = Parameter(torch.Tensor(self.in_feature, 1, self.out_feature)) if bias: self.bias = Parameter(torch.Tensor(self.out_feature)) else: self.bias = None self.forward = self.no_bias self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(0)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, u): c = self.c u_hat = u.unsqueeze(-1) * self.weight u_hat = u_hat.view(-1,self.in_feature * u.size(-1),self.out_feature) u, bias = (u_hat,self.bias) if self.retain_grad else (u_hat.detach(),self.bias.detach()) b = 0. for _ in range(self.routings - 1): v = squash(torch.sum(u * c, dim=1) + bias, dim=1) b = b + u * v.view(-1, 1, self.out_feature) c = F.softmax(b, 2) v = squash(torch.sum(u_hat * c, dim=1) + self.bias, dim=1) return v def no_bias(self, u): c = self.c u_hat = u.unsqueeze(-1) * self.weight u_hat = u_hat.view(-1,self.in_feature * u.size(-1),self.out_feature) u = u_hat if self.retain_grad else u_hat.detach() b = 0. for _ in range(self.routings - 1): v = squash(torch.sum(u * c, dim=1), dim=1) b = b + u * v.view(-1, 1, self.out_feature) c = F.softmax(b, 2) v = squash(torch.sum(u_hat * c, dim=1), dim=1) return v def extra_repr(self): s = ('{in_feature}, {out_feature}, routings={routings}' ', bias='+str(self.bias is not None)+', retain_grad={retain_grad}') return s.format(**self.__dict__)
class Conv2d_filtermap_1x1_compression(_ConvNd_filtermap_1x1_compression): def __init__(self, in_channels, out_channels, channel_compression_1x1, kernel_size, binary_filtermap=False, stride=1, padding=0, dilation=1, groups=1, bias=True): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(Conv2d_filtermap_1x1_compression, self).__init__(in_channels, out_channels, channel_compression_1x1, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias) self.binary_filtermap = binary_filtermap self.reset_parameters1() #pdb.set_trace() #self.conv_weight = Parameter(torch.Tensor(out_channels, self.in_channels // self.groups, \ # self.kernel_size[0], self.kernel_size[1])) #self.conv_weight.requires_grad = False #self.conv_weight.cuda() def reset_parameters1(self): fm_size = self.filtermap.size() fm_width = fm_size[2] fm_height = fm_size[1] fm_depth = fm_size[0] # not for 1x1 conv, do the padding on the spatial if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: self.fm_pad_width = fm_width + 1 self.fm_pad_height = fm_height + 1 #for 1x1 conv no padding on the spatial else: self.fm_pad_width = fm_width self.fm_pad_height = fm_height self.fm_pad_depth = fm_depth * 2 #set the ids for extracting filters from filtermap out_channels = self.out_channels in_channels = self.in_channels // self.groups k_h = self.kernel_size[0] k_w = self.kernel_size[1] sample_y = self.sample_y sample_x = self.sample_x sample_c = self.sample_c stride_y = self.stride_y stride_x = self.stride_x stride_c = self.stride_c fm_depth = self.fm_pad_depth fm_height = self.fm_pad_height fm_width = self.fm_pad_width ids = (torch.Tensor(range(0, k_h * k_w))) tmp_count = 0 for y in range(0, k_h): for x in range(0, k_w): ids[tmp_count] = y * fm_width + x tmp_count = tmp_count + 1 ids0 = ids #pdb.set_trace() for c in range(1, in_channels): ids_c = ids0 + c * fm_height * fm_width ids = torch.cat((ids, ids_c), 0) #ids0 = ids #for x in range(1, out_channels): # ids = torch.cat((ids,ids0),0) #pdb.set_trace() ids0 = ids for y in range(0, sample_y): for x in range(0, sample_x): if y == 0 and x == 0: continue ss = y * stride_y * fm_width + x * stride_x ids_ss = ids0 + ss ids = torch.cat((ids, ids_ss), 0) #pdb.set_trace() ids0 = ids for c in range(1, sample_c): ids_c = ids0 + c * stride_c * fm_height * fm_width ids = torch.cat((ids, ids_c), 0) #pdb.set_trace() #ids = ids.long() #ids = ids.detach() #pdb.set_trace() ids = ids.long() self.ids = Parameter(ids) self.ids.requires_grad = False #self.register_parameter() #if torch.max(ids) >= fm_depth*fm_height*fm_width or torch.min(ids) < 0: #print(torch.max(ids)) #ids = Variable(ids) def extract_filters(self): #pdb.set_trace() out_channels = self.out_channels in_channels = self.in_channels // self.groups k_h = self.kernel_size[0] k_w = self.kernel_size[1] #for compressing the channel by 2 times #if in_channels != 3: # filtermap_pad_tmp = torch.cat((self.filtermap,self.filtermap),0) #else: # filtermap_pad_tmp = self.filtermap #filtermap_pad = torch.cat((filtermap_pad_tmp,filtermap_pad_tmp),0) #for not compressing the channel filtermap_pad = torch.cat((self.filtermap, self.filtermap), 0) # not for 1x1 conv, do the padding on the spatial if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: filtermap_pad_s1 = filtermap_pad[:, 1, :] filtermap_pad_s1 = filtermap_pad_s1[:, None, :] filtermap_pad = torch.cat((filtermap_pad, filtermap_pad_s1), 1) filtermap_pad_s2 = filtermap_pad[:, :, 1] filtermap_pad_s2 = filtermap_pad_s2[:, :, None] filtermap_pad = torch.cat((filtermap_pad, filtermap_pad_s2), 2) #pdb.set_trace() ids = self.ids.detach() conv_weight = filtermap_pad.view(-1, 1).index_select(0, ids) conv_weight = conv_weight.view(out_channels, in_channels, k_h, k_w) if self.binary_filtermap: binary_conv_weight = conv_weight.clone() for nf in range(0, out_channels): float_filter = conv_weight[nf, :, :, :] L1_norm = torch.norm(float_filter.view(-1, 1), 1) sign_filter = torch.sign(float_filter) binary_filter = sign_filter * L1_norm binary_conv_weight[nf, :, :, :] = binary_filter return binary_conv_weight else: return conv_weight #pdb.set_trace() #for c in range(0,sample_c): # for y in range(0,sample_y): # for x in range(0,sample_x): # filter_count = c*sample_y*sample_x + y*sample_x + x # conv_weight_clone[filter_count,:,:,:] = filtermap_pad[c*stride_c:c*stride_c+in_channels, \ # y*stride_y:y*stride_y+k_h, x*stride_x:x*stride_x+k_w] #return conv_weight def forward(self, input): #return F.conv2d(input, self.weight, self.bias, self.stride, # self.padding, self.dilation, self.groups) #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat)) #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat)) #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat)) #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat)) #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \ # self.kernel_size[1], self.out_channels); #conv_weight = conv_weight.permute(3, 0, 1, 2) #conv_weight = conv_weight.contiguous() #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2) #pdb.set_trace() #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \ # self.kernel_size[0], self.kernel_size[1])) #conv_weight.cuda() conv_weight = self.extract_filters() #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \ # 0:0+self.kernel_size[0], 0:0+self.kernel_size[1]] out = F.conv2d(input, conv_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return out #return F.conv2d(input, conv_weight, self.bias, self.stride, # self.padding, self.dilation, self.groups) def fit_filtermap1(self, conv): conv_weight = conv.weight.data conv_weight_size = conv_weight.size() out_channels = conv_weight_size[0] in_channels = conv_weight_size[1] k_h = conv_weight_size[2] k_w = conv_weight_size[3] sample_y = self.sample_y sample_x = self.sample_x sample_c = self.sample_c stride_y = self.stride_y stride_x = self.stride_x stride_c = self.stride_c fm_ext = torch.Tensor(sample_c * in_channels, sample_y * k_h, sample_x * k_w) for c in range(0, sample_c): for y in range(0, sample_y): for x in range(0, sample_x): filter_count = c * sample_y * sample_x + y * sample_x + x fm_ext[c * in_channels:(c + 1) * in_channels, y * k_h:(y + 1) * k_h, x * k_w:(x + 1) * k_w] = conv_weight[filter_count, :, :, :] #pdb.set_trace() for oc in range(0, sample_c): for c in range(1, sample_c): idx = sample_c - c + oc if idx > sample_c - 1: idx = 0 fm_ext[oc*stride_c:(oc+1)*stride_c,:,:] += \ fm_ext[idx*stride_c:(idx+1)*stride_c,:,:] fm_ext = fm_ext[0:in_channels, :, :] / sample_c fm_ext_h = fm_ext.size()[1] fm_ext_w = fm_ext.size()[2] for y in range(0, sample_y): fm_ext[:, y * k_h, :] += fm_ext[:, (y * k_h - 1 + fm_ext_h) % fm_ext_h, :] fm_ext[:, y * k_h, :] = fm_ext[:, y * k_h, :] / 2 for x in range(0, sample_x): fm_ext[:, :, x * k_w] += fm_ext[:, :, (x * k_w - 1 + fm_ext_w) % fm_ext_w] fm_ext[:, :, x * k_w] = fm_ext[:, :, x * k_w] / 2 y_ids = torch.Tensor([0, 1]) y_ids0 = y_ids for y in range(1, sample_y): y_ids = torch.cat((y_ids, y_ids0 + y * k_h), 0) x_ids = torch.Tensor([0, 1]) x_ids0 = x_ids for x in range(1, sample_x): x_ids = torch.cat((x_ids, x_ids0 + x * k_w), 0) fm = torch.index_select(fm_ext, 1, y_ids.long()) fm = torch.index_select(fm, 2, x_ids.long()) self.filtermap.data = fm
class _LearnableFakeQuantize(nn.Module): r""" This is an extension of the FakeQuantize module in fake_quantize.py, which supports more generalized lower-bit quantization and support learning of the scale and zero point parameters through backpropagation. For literature references, please see the class _LearnableFakeQuantizePerTensorOp. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. * :attr: `channel_len` defines the length of the channel when initializing scale and zero point for the per channel case. * :attr: `use_grad_scaling` defines the flag for whether the gradients for scale and zero point are normalized by the constant, which is proportional to the square root of the number of elements in the tensor. The related literature justifying the use of this particular constant can be found here: https://openreview.net/pdf?id=rkgO66VKDS. * :attr: `fake_quant_enabled` defines the flag for enabling fake quantization on the output. * :attr: `static_enabled` defines the flag for using observer's static estimation for scale and zero point. * attr: `learning_enabled` defines the flag for enabling backpropagation for scale and zero point. """ def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1, use_grad_scaling=False, **observer_kwargs): super(_LearnableFakeQuantize, self).__init__() assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' self.quant_min = quant_min self.quant_max = quant_max self.use_grad_scaling = use_grad_scaling if channel_len == -1: self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer." self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ 'quant_max out of bound' self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) @torch.jit.export def enable_param_learning(self): r"""Enables learning of quantization parameters and disables static observer estimates. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=True) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=False) return self @torch.jit.export def enable_static_estimate(self): r"""Enables static observer estimates and disbales learning of quantization parameters. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=True) @torch.jit.export def enable_static_observation(self): r"""Enables static observer accumulating data from input but doesn't update the quantization parameters. Forward path returns the original X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=False) \ .toggle_observer_update(enabled=True) @torch.jit.export def toggle_observer_update(self, enabled=True): self.static_enabled[0] = int(enabled) return self @torch.jit.export def toggle_qparam_learning(self, enabled=True): self.learning_enabled[0] = int(enabled) self.scale.requires_grad = enabled self.zero_point.requires_grad = enabled return self @torch.jit.export def toggle_fake_quant(self, enabled=True): self.fake_quant_enabled[0] = int(enabled) return self @torch.jit.export def observe_quant_params(self): print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) print('_LearnableFakeQuantize Zero Point: {}'.format(self.zero_point.detach())) @torch.jit.export def calculate_qparams(self): return self.activation_post_process.calculate_qparams() def forward(self, X): self.activation_post_process(X.detach()) _scale, _zero_point = self.calculate_qparams() _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) if self.static_enabled[0] == 1: self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) if self.fake_quant_enabled[0] == 1: if self.learning_enabled[0] == 1: if self.use_grad_scaling: grad_factor = 1.0 / (self.weight.numel() * self.quant_max) ** 0.5 else: grad_factor = 1.0 if self.qscheme in ( torch.per_channel_symmetric, torch.per_channel_affine): X = _LearnableFakeQuantizePerChannelOp.apply( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) else: X = _LearnableFakeQuantizePerTensorOp.apply( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor) else: if self.qscheme == torch.per_channel_symmetric or \ self.qscheme == torch.per_channel_affine: X = torch.fake_quantize_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( X, float(self.scale.item()), int(self.zero_point.item()), self.quant_min, self.quant_max) return X def _save_to_state_dict(self, destination, prefix, keep_vars): # We will be saving the static state of scale (instead of as a dynamic param). super(_LearnableFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'scale'] = self.scale.data destination[prefix + 'zero_point'] = self.zero_point def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): local_state = ['scale', 'zero_point'] for name in local_state: key = prefix + name if key in state_dict: val = state_dict[key] if name == 'scale': self.scale.data.copy_(val) else: setattr(self, name, val) elif strict: missing_keys.append(key) super(_LearnableFakeQuantize, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) with_args = classmethod(_with_args)
class LabelFts(nn.Module): """ Class for encoding label features Arguments ---------- input_size: int embeddings dimension for the label representation label_features: int number of token in the label text padding_idx: int (default=None) device: string (default="cuda:0") sparse: bool (default=False) fixed_features: bool (default=False) use_external_weights: bool (default=False) transform: nn.Module transformation over label features Returns: nn.Module Network block to encode label features """ def __init__(self, input_size, label_features, device="cuda:0", sparse=False, fixed_features=False, use_external_weights=False, transform=None): super(LabelFts, self).__init__() self.device = device # Useful in case of multiple GPUs self.input_size = input_size self.label_features = label_features self.use_external_weights = use_external_weights self.fixed_features = fixed_features if not self.use_external_weights: self.weight = Parameter( torch.Tensor(self.label_features, self.input_size)) self.sparse = True else: self.sparse = sparse self.Rpp = transform self.reset_parameters() def _get_clf(self, labels, features_shortlist=None, weights=None): if features_shortlist is not None: weights = F.embedding(features_shortlist, weights, sparse=self.sparse, padding_idx=None).squeeze() lbl_clf = labels.to(weights.device).mm(weights) return lbl_clf def forward(self, labels, features_shortlist=None, weight=None): if self.fixed_features: return self.Rpp(labels) else: if not self.use_external_weights: weight = self.weight return self.Rpp(self._get_clf(labels, features_shortlist, weight)) def to(self, device=None): if device is None: super().to(self.device) else: super().to(device) def reset_parameters(self): stdv = 1. / math.sqrt(self.input_size) if not self.use_external_weights: self.weight.data.uniform_(-stdv, stdv) def __repr__(self): s = "{name}({input_size}, {label_features}, {sparse}, {device}, {fixed_features}" if not self.use_external_weights: s += ", weight={}".format(self.weight.detach().cpu().numpy().shape) s += "\n%s" % (self.Rpp.__repr__()) s += ")" return s.format(name=self.__class__.__name__, **self.__dict__) def _init_(self, state_dict): """ Initilizes model parameters """ if not self.use_external_weights: keys = list(state_dict.keys()) key = [x for x in keys if x.split(".")[-1] in ['weight']][0] weight = state_dict[key] self.weight.data.copy_(weight)
class Linear(torch.nn.Module): r"""Applies a linear tranformation to the incoming data .. math:: \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b} similar to :class:`torch.nn.Linear`. It supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int): Size of each input sample. Will be initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) weight_initializer (str, optional): The initializer for the weight matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"` or :obj:`None`). If set to :obj:`None`, will match default weight initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) bias_initializer (str, optional): The initializer for the bias vector (:obj:`"zeros"` or :obj:`None`). If set to :obj:`None`, will match default bias initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) Shapes: - **input:** features :math:`(*, F_{in})` - **output:** features :math:`(*, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer if in_channels > 0: self.weight = Parameter(torch.Tensor(out_channels, in_channels)) else: self.weight = nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self._load_hook = self._register_load_state_dict_pre_hook( self._lazy_load_hook) self.reset_parameters() def __deepcopy__(self, memo): out = Linear(self.in_channels, self.out_channels, self.bias is not None, self.weight_initializer, self.bias_initializer) if self.in_channels > 0: out.weight = copy.deepcopy(self.weight, memo) if self.bias is not None: out.bias = copy.deepcopy(self.bias, memo) return out def reset_parameters(self): if self.in_channels <= 0: pass elif self.weight_initializer == 'glorot': inits.glorot(self.weight) elif self.weight_initializer == 'uniform': bound = 1.0 / math.sqrt(self.weight.size(-1)) torch.nn.init.uniform_(self.weight.data, -bound, bound) elif self.weight_initializer == 'kaiming_uniform': inits.kaiming_uniform(self.weight, fan=self.in_channels, a=math.sqrt(5)) elif self.weight_initializer is None: inits.kaiming_uniform(self.weight, fan=self.in_channels, a=math.sqrt(5)) else: raise RuntimeError(f"Linear layer weight initializer " f"'{self.weight_initializer}' is not supported") if self.bias is None or self.in_channels <= 0: pass elif self.bias_initializer == 'zeros': inits.zeros(self.bias) elif self.bias_initializer is None: inits.uniform(self.in_channels, self.bias) else: raise RuntimeError(f"Linear layer bias initializer " f"'{self.bias_initializer}' is not supported") def forward(self, x: Tensor) -> Tensor: r""" Args: x (Tensor): The features. """ return F.linear(x, self.weight, self.bias) @torch.no_grad() def initialize_parameters(self, module, input): if is_uninitialized_parameter(self.weight): self.in_channels = input[0].size(-1) self.weight.materialize((self.out_channels, self.in_channels)) self.reset_parameters() self._hook.remove() delattr(self, '_hook') def _save_to_state_dict(self, destination, prefix, keep_vars): if is_uninitialized_parameter(self.weight): destination[prefix + 'weight'] = self.weight else: destination[prefix + 'weight'] = self.weight.detach() if self.bias is not None: destination[prefix + 'bias'] = self.bias.detach() def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight = state_dict[prefix + 'weight'] if is_uninitialized_parameter(weight): self.in_channels = -1 self.weight = nn.parameter.UninitializedParameter() if not hasattr(self, '_hook'): self._hook = self.register_forward_pre_hook( self.initialize_parameters) elif is_uninitialized_parameter(self.weight): self.in_channels = weight.size(-1) self.weight.materialize((self.out_channels, self.in_channels)) if hasattr(self, '_hook'): self._hook.remove() delattr(self, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, bias={self.bias is not None})')
class Model: def __init__(self, config, args): self.saver = save.Saver(config.NPOP, config.MODELDIR, 'models', 'bests', resetTol=256) self.config, self.args = config, args self.init() if self.config.LOAD or self.config.BEST: self.load(self.config.BEST) def init(self): print('Initializing new model...') if self.config.SHAREINIT: self.shared(self.config.NPOP) else: self.unshared(self.config.NPOP) self.params = Parameter(torch.Tensor(np.array(self.models))) self.opt = None if not self.config.TEST: self.opt = ManualAdam([self.params], lr=0.001, weight_decay=0.00001) # Initialize a new network def initModel(self): return getParameters(trinity.ANN(self.config)) def shared(self, n): model = self.initModel() self.models = [model for _ in range(n)] def unshared(self, n): self.models = [self.initModel() for _ in range(n)] # Grads and clip def stepOpt(self, gradDicts): grads = defaultdict(list) keysets = [grads.keys() for grads in gradDicts] for gradDict in gradDicts: for worker, grad in gradDict.items(): grads[worker].append(grad) for worker, gradList in grads.items(): grad = np.array(gradList) grad = np.mean(grad, 0) grad = np.clip(grad, -5, 5) grads[worker] = grad gradAry = torch.zeros_like(self.params) for worker, grad in grads.items(): gradAry[worker] = torch.Tensor(grad) self.opt.step(gradAry) def checkpoint(self, reward): if self.config.TEST: return self.saver.checkpoint(self.params, self.opt, reward) def load(self, best=False): print('Loading model...') epoch = self.saver.load(self.opt, self.params, best) @property def nParams(self): nParams = sum([len(e) for e in self.model]) print('#Params: ', str(nParams / 1000), 'K') @property def model(self): return self.params.detach().numpy()
class MimoLinearDynamicalOperator(torch.nn.Module): r"""Applies a multi-input-multi-output linear dynamical filtering operation. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels n_b (int): Number of learnable coefficients of the transfer function numerator n_a (int): Number of learnable coefficients of the transfer function denominator n_k (int, optional): Number of input delays in the numerator. Default: 0 Shape: - Input: (batch_size, seq_len, in_channels) - Output: (batch_size, seq_len, out_channels) Attributes: b_coeff (Tensor): The learnable coefficients of the transfer function numerator a_coeff (Tensor): The learnable coefficients of the transfer function denominator Examples:: >>> in_channels, out_channels = 2, 4 >>> n_b, n_a, n_k = 2, 2, 1 >>> G = MimoLinearDynamicalOperator(in_channels, out_channels, n_b, n_a, n_k) >>> batch_size, seq_len = 32, 100 >>> u_in = torch.ones((batch_size, seq_len, in_channels)) >>> y_out = G(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) """ def __init__(self, in_channels, out_channels, n_b, n_a, n_k=0): super(MimoLinearDynamicalOperator, self).__init__() self.b_coeff = Parameter(torch.zeros(out_channels, in_channels, n_b)) self.a_coeff = Parameter(torch.zeros(out_channels, in_channels, n_a)) self.out_channels = out_channels self.in_channels = in_channels self.n_a = n_a self.n_b = n_b self.n_k = n_k with torch.no_grad(): init_range = 0.01 self.a_coeff[:] = (torch.rand(self.a_coeff.shape) - 0.5) * 2 * init_range self.b_coeff[:] = (torch.rand(self.b_coeff.shape) - 0.5) * 2 * init_range def forward(self, u_in, y_0=None, u_0=None): if self.n_k != 0: #u_d = u_in.roll(self.n_k, dims=-2) # roll on the time axis #u_d[..., 0:self.n_k, :] = 0.0 # input sequence with delay u_d = torch.empty_like(u_in) u_d[..., self.n_k:, :] = u_in[:, :-self.n_k, :] u_d[..., 0:self.n_k, :] = 0.0 else: u_d = u_in return MimoLinearDynamicalOperatorFun.apply(self.b_coeff, self.a_coeff, u_d, y_0, u_0) def get_filtdata(self): r"""Returns the numerator and denominator coefficients of the transfer function :math:`q^{-1}`-polynomials. The polynomials are function of the variable :math:`q^{-1}`. The polynomial coefficients b and a have length m and n, respectively and are sorted in descending power order. For a certain input channel :math:`i` and output channel :math:`o`, the corresponding transfer function :math:`G_{i\rightarrow o}(z)` is: .. math:: G_{i\rightarrow o}(z) = q^{-n_k}\frac{b[o, i, 0] + b[o, i, 1]q^{-1} + \dots + b[o, i, n]q^{-m+1}} {a[o, i, 0] + a[o, i, 1]q^{-1} + \dots + a[o, i, n]q^{-n+1}} Returns: np.array(in_channels, out_channels, m), np.array(in_channels, out_channels, n): numerator :math:`\beta` and denominator :math:`\alpha` polynomial coefficients of the transfer function. Examples:: >>> num, den = G.get_tfdata() >>> G_tf = control.TransferFunction(G2_num, G2_den, ts=1.0) """ return self.__get_filtdata__() def get_tfdata(self): r"""Returns the numerator and denominator coefficients of the transfer function :math:`z`-polynomials. The polynomials are function of the variable Z-transform variable :math:`z`. The polynomial coefficients :math::`\beta` and :math:`\alpha` have equal length p and are sorted in descending power order. For a certain input channel :math:`i` and output channel :math:`o`, the corresponding transfer function :math:`G_{i\rightarrow o}(z)` is: .. math:: G_{i\rightarrow o}(z) = \frac{\beta[o, i, 0]z^{n-1} + \beta[o, i, 1]z^{n-1} + \dots + \beta[o, i, p]}{\alpha[o, i, 0]z^{n-1} + \alpha[o, i, 1]z^{n-2} + \dots + \alpha[o, i, p]} Returns: np.array(in_channels, out_channels, p), np.array(in_channels, out_channels, p): numerator :math:`\beta` and denominator :math:`\alpha` polynomial coefficients of the transfer function. Examples:: >>> num, den = G.get_tfdata() >>> G_tf = control.TransferFunction(G2_num, G2_den, ts=1.0) """ return self.__get_tfdata__() def __get_filtdata__(self): # returns the coefficients of the polynomials b and a as function of q^{-1} b_coeff_np, a_coeff_np = self.__get_ba_coeff__() b_seq = np.zeros_like(b_coeff_np, shape=(self.out_channels, self.in_channels, self.n_b + self.n_k)) #b_coeff_np b_seq[:, :, self.n_k:] = b_coeff_np[:, :, :] a_seq = np.empty_like(a_coeff_np, shape=(self.out_channels, self.in_channels, self.n_a + 1)) a_seq[:, :, 0] = 1 a_seq[:, :, 1:] = a_coeff_np[:, :, :] return b_seq, a_seq def __get_tfdata__(self): b_seq, a_seq = self.__get_filtdata__() M = self.n_b + self.n_k # number of numerator coefficients of the q^{-1} polynomial N = self.n_a + 1 # number of denominator coefficients of the q^{-1} polynomial if M > N: num = b_seq den = np.c_[a_seq, np.zeros((self.out_channels, self.in_channels, M - N))] elif N > M: num = np.c_[b_seq, np.zeros((self.out_channels, self.in_channels, N - M))] den = a_seq else: # N == M num = b_seq den = a_seq return num, den def __get_ba_coeff__(self): return self.b_coeff.detach().numpy(), self.a_coeff.detach().numpy()
class NESConv2d(nn.Conv2d): def __init__(self, in_planes, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, multiplier=1.0, spatial_multiplier=1., rep_dim=1, repeat_weight=True, use_coeff=False): if rep_dim == 0: # this is repeat along the channel dim (dimension 0 of the weights tensor) super(NESConv2d, self).__init__(int(in_planes), int(np.ceil(out_channels / multiplier)), int(kernel_size * spatial_multiplier), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.rep_time = int( np.ceil(1. * out_channels / self.weight.shape[0])) elif rep_dim == 1: # this is to repeat along the filter dim(dimension 1 of the weights tensor) super(NESConv2d, self).__init__(int(np.ceil(in_planes / multiplier)), int(out_channels), int(kernel_size * spatial_multiplier), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1])) self.in_planes = in_planes self.out_channels_ori = out_channels self.groups = groups self.multiplier = multiplier self.spatial_multiplier = spatial_multiplier # specify the range for the w and h direction self.kernel_size = kernel_size self.w_wange = kernel_size * (spatial_multiplier - 1) self.h_wange = kernel_size * (spatial_multiplier - 1) self.rep_dim = rep_dim self.repeat_weight = repeat_weight self.use_coeff = use_coeff # print(self.weight.shape) # import pdb; pdb.set_trace() if spatial_multiplier > 1: out_num = int(self.rep_time * 3) self.conv1_stride_lr_1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=0, bias=False) self.bn1 = nn.BatchNorm2d(in_planes) self.relu = nn.ReLU6(inplace=True) self.conv1_stride_lr_2 = nn.Conv2d(in_planes, out_num, kernel_size=1, stride=1, padding=0, bias=False) self.bn2 = nn.BatchNorm2d(out_num) self.coefficient = Parameter(torch.Tensor(out_num), requires_grad=False) else: out_num = int(self.rep_time) self.conv1_stride_lr_1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=0, bias=False) self.bn1 = nn.BatchNorm2d(in_planes) self.relu = nn.ReLU6(inplace=True) self.conv1_stride_lr_2 = nn.Conv2d(in_planes, out_num, kernel_size=1, stride=1, padding=0, bias=False) self.bn2 = nn.BatchNorm2d(out_num) self.coefficient = Parameter(torch.Tensor(out_num), requires_grad=False) self.reuse = False self.coeff_grad = None def generate_share_weight(self, base_weight, rep_num, coeff, nchannel, dim=0): ''' sample weights from base weight''' # pdb.set_trace() if rep_num == 1: return base_weight new_weight = [] for i in range(rep_num): w_idx = coeff(i * 3 + 1) * self.w_wange start_idx_w = int(w_idx) end_idx_w = start_idx_w + self.kernel_size w_frac = w_idx - start_idx_w h_idx = coeff(i * 3 + 2) * self.h_wange start_idx_h = int(h_idx) end_idx_h = start_idx_h + self.kernel_size h_frac = h_idx - start_idx_h new_weight_temp = torch.cat( [base_weight[:, :, :, :], base_weight[:, :, :, :]], dim=2)[:, :, start_idx_w:end_idx_w, :] * ( 1 - w_frac) + base_weight * w_frac new_weight_temp = torch.cat( [base_weight[:, :, :, :], base_weight[:, :, :, :]], dim=3)[:, :, :, start_idx_h:end_idx_h] * ( 1 - h_frac) + base_weight * h_frac if dim == 0: new_weight_temp = torch.cat( [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]], dim=0) * (1 - coeff[int(i * 3)]) else: new_weight_temp = torch.cat( [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]], dim=1) * (1 - coeff[i]) new_weight.append(base_weight * coeff[int(i * 3)] + new_weight_temp) out = torch.cat(new_weight, dim=dim) if dim == 0: return out[:nchannel, :, :, :] else: return out[:, :nchannel, :, :] def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max( (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max( (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 ]) if self.use_coeff: if self.training: # set reuse to True for coefficient sharing if not self.reuse: lr_conv1 = self.relu(self.bn1(self.conv1_stride_lr_1(x))) # pdb.set_trace() lr_conv1 = self.bn2(self.conv1_stride_lr_2(lr_conv1)) lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0, 0] self.coefficient.set_( F.normalize(torch.mean(lr_conv1, 0), dim=0).clone().detach()) # pdb.set_trace() self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0), dim=0) if self.rep_dim == 0: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight(self.weight, self.rep_time, self.coeff_grad, self.out_channels_ori, dim=0)) else: out_tmp = F.conv2d(x, self.weight) out = self.generate_share_feature( self.rep_time, out_tmp, F.normalize(torch.mean(lr_conv1, 0), dim=0), self.out_channels_ori) # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0))) else: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight(self.weight, self.rep_time, self.coeff_grad, x.shape[1], dim=1)) else: out_tmp = self.feature_wrapper( self.rep_time, x, F.normalize(torch.mean(lr_conv1, 0), dim=0), self.out_channels_ori) out = F.conv2d(out_tmp, self.weight) else: if self.rep_dim == 0: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight( self.weight, self.rep_time, self.coefficient.detach(), self.out_channels_ori, dim=0)) else: out_tmp = F.conv2d(x, self.weight) out = self.generate_share_feature( self.rep_time, out_tmp, self.coefficient.detach(), self.out_channels_ori) # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach())) else: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight( self.weight, self.rep_time, self.coefficient.detach(), x.shape[1], dim=1)) else: out_tmp = self.feature_wrapper( self.rep_time, out_tmp, self.coefficient.detach(), self.out_channels_ori) out = F.conv2d(out_tmp, self.weight) else: # print("use_coeff == False") if self.rep_dim == 0: out = F.conv2d( x, self.weight.repeat([self.rep_time, 1, 1, 1])[:self.out_channels_ori, :, :, :], None, 1) else: out = F.conv2d( x, self.weight.repeat([1, self.rep_time, 1, 1])[:, :x.shape[1], :, :], None, 1) return out
class Embedding(torch.nn.Module): """ General way to handle embeddings * Support for sequential models * Memory efficient way to compute weighted EmbeddingBag Arguments: ---------- num_embeddings: int vocalubary size embedding_dim: int dimension for embeddings padding_idx: 0 or None, optional (default=None) index for <PAD>; embedding is not updated max_norm: None or float, optional (default=None) maintain norm of embeddings norm_type: int, optional (default=2) norm for max_norm scale_grad_by_freq: boolean, optional (default=False) Scale gradients by token frequency sparse: boolean, optional (default=False) sparse or dense gradients * the optimizer will infer from this parameters reduction: str or None, optional (default=None) * None: don't reduce * sum: sum over tokens * mean: mean over tokens pretrained_weights: torch.Tensor or None, optional (default=None) Initialize with these weights * first token is treated as a padding index * dim=1 should be one less than the num_embeddings device: str, optional (default="cuda:0") Keep embeddings on this device """ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, reduction=True, pretrained_weights=None, device="cuda:0"): super(Embedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.sparse = sparse self.reduce = self._construct_reduce(reduction) self.reduction = reduction self.device = torch.device(device) self.reset_parameters() if pretrained_weights is not None: self.from_pretrained(pretrained_weights) def _construct_reduce(self, reduction): if reduction is None: return self._reduce elif reduction == 'sum': return self._reduce_sum elif reduction == 'mean': return self._reduce_mean else: return NotImplementedError(f"Unknown reduction: {reduction}") def reset_parameters(self): """ Reset weights """ torch.nn.init.xavier_uniform_( self.weight.data, gain=torch.nn.init.calculate_gain('relu')) if self.padding_idx is not None: self.weight.data[self.padding_idx].fill_(0) def to(self): super().to(self.device) def _reduce_sum(self, x, w): if w is None: return torch.sum(x, dim=1) else: return torch.sum(x * w.unsqueeze(2), dim=1) def _reduce_mean(self, x, w): if w is None: return torch.mean(x, dim=1) else: return torch.mean(x * w.unsqueeze(2), dim=1) def _reduce(self, x, *args): return x def forward(self, x, w=None): """ Forward pass for embedding layer Arguments: --------- x: torch.LongTensor indices of tokens in a batch (batch_size, max_features_in_a_batch) w: torch.Tensor or None, optional (default=None) weights of tokens in a batch (batch_size, max_features_in_a_batch) Returns: -------- out: torch.Tensor embedding for each sample Shape: (batch_size, seq_len, embedding_dims), if reduction is None Shape: (batch_size, embedding_dims), otherwise """ x = F.embedding(x, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) return self.reduce(x, w) def from_pretrained(self, embeddings): # first index is treated as padding index assert embeddings.shape[0] == self.num_embeddings-1, \ "Shapes doesn't match for pre-trained embeddings" self.weight.data[1:, :] = torch.from_numpy(embeddings) def get_weights(self): return self.weight.detach().cpu().numpy()[1:, :] def __repr__(self): s = '{name}({num_embeddings}, {embedding_dim}, {device}' s += ', reduction={reduction}' if self.padding_idx is not None: s += ', padding_idx={padding_idx}' if self.max_norm is not None: s += ', max_norm={max_norm}' if self.norm_type != 2: s += ', norm_type={norm_type}' if self.scale_grad_by_freq is not False: s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: s += ', sparse=True' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class NeuroRNN(nn.Module): def __init__(self, input_size, output_dim, alpha_r, alpha_s, nonlinearity, hidden_size, bias, ratio): super(NeuroRNN, self).__init__() self.input_size = input_size self.nonlinearity = nonlinearity self.hidden_size = hidden_size self.alpha_r = alpha_r self.alpha_s = alpha_s self.bias = bias self.output_dim = output_dim self.Win = Parameter(torch.Tensor(hidden_size, input_size)) self.Win.requires_grad = True self.Wrec = Parameter(torch.Tensor(hidden_size, hidden_size)) self.Wrec.requires_grad = True self.Wout = Parameter(torch.Tensor(output_dim, hidden_size)) self.Wout.requires_grad = True self.ratio = ratio if bias: self.bin = Parameter(torch.Tensor(hidden_size, hidden_size)) self.brec = Parameter(torch.Tensor(hidden_size)) # init weights nn.init.orthogonal_(self.Win) nn.init.orthogonal_(self.Wrec) nn.init.uniform_(self.Wout) def nonlin(self, inp): if self.nonlinearity == 'relu': nn.ReLU(inp) return inp else: inp = torch.tanh(inp) return inp def forward(self, x): # x shape: (batch_size, seq_len, input_size) I = torch.zeros(self.hidden_size, x.size(0)).to(device) # (hidden_size, batch_size) r = torch.zeros(self.hidden_size, x.size(0)).to(device) out = torch.zeros(x.size(0), x.size(1), self.output_dim).to( device) # (batch_size, seq_len, output_dim) # for pascanu regularization #r_total = torch.zeros(x.size(0), self.hidden_size, x.size(1)) #r_total.requires_grad = True for t in range(x.size(1)): I = (1 - self.alpha_s) * I + self.alpha_s * ( torch.mm(self.Wrec, r) + torch.mm(self.Win, x[:, t].T)) r = (1 - self.alpha_r) * r + self.alpha_r * (self.nonlin(I)) #r_total[:, :, t] = r.T out[:, t, :] = torch.mm(self.Wout, r).T del I del r return out def dale_weight_init(self): with torch.no_grad(): num_exc = np.int(self.ratio[0] * self.hidden_size) num_inh = np.int(self.hidden_size - num_exc) D = torch.diag_embed( torch.cat((torch.ones(num_exc), -1 * torch.ones(num_inh)))) self.Wrec = torch.nn.Parameter( torch.abs(self.Wrec.detach()).matmul(D)) def enforce_dale(self): with torch.no_grad(): num_exc = np.int(self.ratio[0] * self.hidden_size) num_inh = np.int(self.hidden_size - num_exc) self.Wrec[:num_exc, :].clamp(min=0) self.Wrec[num_exc:, :].clamp(max=0)
class Model: '''Model manager class Convenience class wrapping saving/loading, model initialization, optimization, and logging. Args: ann: Model to optimize. Used to initialize weights. config: A Config specification args: Hook for additional user arguments ''' def __init__(self, ann, config): self.saver = save.Saver(config.MODELDIR, 'models', 'bests', resetTol=256) self.config = config print('Initializing new model...') self.net = ann(config) self.parameters = Parameter( torch.Tensor(np.array(getParameters(self.net)))) #Have been experimenting with population based #training. Nothing stable yet -- advise avoiding if config.POPOPT: self.opt = PopulationOptimizer(self, config) else: self.opt = GradientOptimizer(self, config) if config.LOAD or config.BEST: self.load(self.opt, config.BEST) def step(self, recvs, blobs, log, lifetime): if self.config.TEST: return self.opt.step(recvs, blobs) perf, _ = self.checkpoint(self.opt, lifetime) return perf def load(self, opt, best=False): '''Load a model from file Args: best (bool): Whether to load the best (True) or most recent (False) checkpoint ''' print('Loading model...') epoch = self.saver.load(opt, self.parameters, best) self.syncParameters() return self def checkpoint(self, opt, lifetime): '''Save the model to checkpoint Args: reward: Mean reward of the model ''' return self.saver.checkpoint(self.parameters, opt, lifetime) def printParams(self): '''Print the number of model parameters''' nParams = len(self.weights) print('#Params: ', str(nParams / 1000), 'K') def syncParameters(self): parameters = self.parameters.detach().numpy().tolist() setParameters(self.net, parameters) @property def weights(self): '''Get model parameters Returns: a numpy array of model parameters ''' return getParameters(self.net)
class Embedding(th.nn.Module): def __init__(self): super(Embedding, self).__init__() def get_embed_list(self, sent_pair: list) -> np.ndarray: raise "not defined" def get_similarity(self, X1 : np.ndarray, X2 : np.ndarray) -> np.ndarray: similarity = X1 @ self.weight.detach().numpy() @ (X2.T) return similarity # return th.Tensor((cosine_similarity(X1, X2) + 1.0) / 2.0) def normalize(self): self.weight = Parameter(f.normalize(self.weight)) def apply_distortion(self, sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray: shape = sim_matrix.shape if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: return sim_matrix pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])]) pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])]) distortion_mask = -abs((pos_x - np.transpose(pos_y))) * ratio return nsim_matrix + distortion_mask def eta_U(self, f, e, distortion=False): data = [f,e] src_embedding, trg_embedding = self.get_embed_list(data) data_save = {} scale = min(len(src_embedding), len(trg_embedding)) # src_indexs = [int(w) for w in f] # trg_indexs = [int(w) for w in f] data_save['scale'] = scale data_save['src_embedding'] = src_embedding data_save['trg_embedding'] = trg_embedding similarity = self.get_similarity(src_embedding, trg_embedding) self.save_for_backward(data_save) similarity = similarity - self.maxvalue - np.log(self.sum_exp_sim) if distortion: similarity = self.apply_distortion(sim) return similarity def re_compute_sum_exp_norm(self): temp = self.src_words_vec @ self.weight.detach().numpy() @ (self.trg_words_vec.T) self.maxvalue = np.max(temp) exp_value = np.exp( temp - self.maxvalue ) self.sum_exp_sim = np.sum(exp_value) softmax_value = exp_value / self.sum_exp_sim self.softmax_gradient = self.src_words_vec.T @ softmax_value @ self.trg_words_vec print(self.sum_exp_sim) #self.backword_softmax = th.Tensor(self.src_words_vec).T @ (th.Tensor(self.softmax_value) @ th.Tensor(self.trg_words_vec)) def forward(self, data : list): f, e = data similarity = self.eta_U(f, e) return th.Tensor(similarity) def backward(self,g): data_save = self.saved_tensors temp = data_save['scale']*self.softmax_gradient second = data_save['src_embedding'].T @ g.numpy() @ data_save['trg_embedding'] self.weight.backward( th.Tensor(temp - second) ) def save_for_backward(self, t): self.saved_tensors = t
class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 _FLOAT_MODULE = MOD def __init__(self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None, dim=2): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def reset_parameters(self): super(_ConvBnNd, self).reset_parameters() def update_bn_stats(self): self.freeze_bn = False self.bn.training = True return self def freeze_bn_stats(self): self.freeze_bn = True self.bn.training = False return self def _forward(self, input): assert self.bn.running_var is not None running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std weight_shape = [1] * len(self.weight.shape) weight_shape[0] = -1 bias_shape = [1] * len(self.weight.shape) bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) # using zero bias here since the bias for original conv # will be added later if self.bias is not None: zero_bias = torch.zeros_like(self.bias) else: zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) conv = self._conv_forward(input, scaled_weight, zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape(bias_shape) conv = self.bn(conv_orig) return conv def extra_repr(self): # TODO(jerryzh): extend return super(_ConvBnNd, self).extra_repr() def forward(self, input): return self._forward(input) def train(self, mode=True): """ Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.freeze_bn: for module in self.children(): module.train(mode) return self # ===== Serialization version history ===== # # Version 1/None # self # |--- weight : Tensor # |--- bias : Tensor # |--- gamma : Tensor # |--- beta : Tensor # |--- running_mean : Tensor # |--- running_var : Tensor # |--- num_batches_tracked : Tensor # # Version 2 # self # |--- weight : Tensor # |--- bias : Tensor # |--- bn : Module # |--- weight : Tensor (moved from v1.self.gamma) # |--- bias : Tensor (moved from v1.self.beta) # |--- running_mean : Tensor (moved from v1.self.running_mean) # |--- running_var : Tensor (moved from v1.self.running_var) # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if version is None or version == 1: # BN related parameters and buffers were moved into the BN module for v2 v2_to_v1_names = { 'bn.weight': 'gamma', 'bn.bias': 'beta', 'bn.running_mean': 'running_mean', 'bn.running_var': 'running_var', 'bn.num_batches_tracked': 'num_batches_tracked', } for v2_name, v1_name in v2_to_v1_names.items(): if prefix + v1_name in state_dict: state_dict[prefix + v2_name] = state_dict[prefix + v1_name] state_dict.pop(prefix + v1_name) elif prefix + v2_name in state_dict: # there was a brief period where forward compatibility # for this module was broken (between # https://github.com/pytorch/pytorch/pull/38478 # and https://github.com/pytorch/pytorch/pull/38820) # and modules emitted the v2 state_dict format while # specifying that version == 1. This patches the forward # compatibility issue by allowing the v2 style entries to # be used. pass elif strict: missing_keys.append(prefix + v2_name) super(_ConvBnNd, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound # has no __name__ (code is fine though) assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) qat_convbn.weight = conv.weight qat_convbn.bias = conv.bias qat_convbn.bn.weight = bn.weight qat_convbn.bn.bias = bn.bias qat_convbn.bn.running_mean = bn.running_mean qat_convbn.bn.running_var = bn.running_var # mypy error: Cannot determine type of 'num_batches_tracked' qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type] return qat_convbn def to_float(self): cls = type(self) conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias is not None, self.padding_mode) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] # fuse bn into conv conv.weight, conv.bias = fuse_conv_bn_weights( conv.weight, conv.bias, self.bn.running_mean, self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias ) if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] modules = [] modules.append(conv) relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] modules.append(relu) conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] conv_relu.train(self.training) return conv_relu else: conv.train(self.training) return conv
class Conv2d_filtermap_1x1_compression(_ConvNd_filtermap_1x1_compression): def __init__(self, in_channels, out_channels, channel_compression_1x1, kernel_size, binary_filtermap = False, stride=1, padding=0, dilation=1, groups=1, bias=True): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(Conv2d_filtermap_1x1_compression, self).__init__( in_channels, out_channels, channel_compression_1x1, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias) self.binary_filtermap = binary_filtermap self.reset_parameters1() #pdb.set_trace() #self.conv_weight = Parameter(torch.Tensor(out_channels, self.in_channels // self.groups, \ # self.kernel_size[0], self.kernel_size[1])) #self.conv_weight.requires_grad = False #self.conv_weight.cuda() def reset_parameters1(self): fm_size = self.filtermap.size() fm_width = fm_size[2] fm_height = fm_size[1] fm_depth = fm_size[0] # not for 1x1 conv, do the padding on the spatial if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: self.fm_pad_width = fm_width + 1 self.fm_pad_height = fm_height + 1 #for 1x1 conv no padding on the spatial else: self.fm_pad_width = fm_width self.fm_pad_height = fm_height self.fm_pad_depth = fm_depth*2 #set the ids for extracting filters from filtermap out_channels = self.out_channels in_channels = self.in_channels // self.groups k_h = self.kernel_size[0] k_w = self.kernel_size[1] sample_y = self.sample_y sample_x = self.sample_x sample_c = self.sample_c stride_y = self.stride_y stride_x = self.stride_x stride_c = self.stride_c fm_depth = self.fm_pad_depth fm_height = self.fm_pad_height fm_width = self.fm_pad_width ids = (torch.Tensor(range(0,k_h*k_w))) tmp_count = 0 for y in range(0,k_h): for x in range(0,k_w): ids[tmp_count] = y*fm_width+x tmp_count = tmp_count+1 ids0 = ids #pdb.set_trace() for c in range(1,in_channels): ids_c = ids0 + c*fm_height*fm_width ids = torch.cat((ids,ids_c),0) #ids0 = ids #for x in range(1, out_channels): # ids = torch.cat((ids,ids0),0) #pdb.set_trace() ids0 = ids for y in range(0,sample_y): for x in range(0,sample_x): if y == 0 and x == 0: continue ss = y*stride_y*fm_width + x*stride_x ids_ss = ids0+ss ids = torch.cat((ids,ids_ss),0) #pdb.set_trace() ids0 = ids for c in range(1,sample_c): ids_c = ids0+c*stride_c*fm_height*fm_width ids = torch.cat((ids,ids_c),0) #pdb.set_trace() #ids = ids.long() #ids = ids.detach() #pdb.set_trace() ids = ids.long() self.ids = Parameter(ids) self.ids.requires_grad = False #self.register_parameter() #if torch.max(ids) >= fm_depth*fm_height*fm_width or torch.min(ids) < 0: #print(torch.max(ids)) #ids = Variable(ids) def extract_filters(self): #pdb.set_trace() out_channels = self.out_channels in_channels = self.in_channels // self.groups k_h = self.kernel_size[0] k_w = self.kernel_size[1] #for compressing the channel by 2 times #if in_channels != 3: # filtermap_pad_tmp = torch.cat((self.filtermap,self.filtermap),0) #else: # filtermap_pad_tmp = self.filtermap #filtermap_pad = torch.cat((filtermap_pad_tmp,filtermap_pad_tmp),0) #for not compressing the channel filtermap_pad = torch.cat((self.filtermap,self.filtermap),0) # not for 1x1 conv, do the padding on the spatial if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: filtermap_pad_s1 = filtermap_pad[:,1,:] filtermap_pad_s1 = filtermap_pad_s1[:,None,:] filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s1),1) filtermap_pad_s2 = filtermap_pad[:,:,1] filtermap_pad_s2 = filtermap_pad_s2[:,:,None] filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s2),2) #pdb.set_trace() ids = self.ids.detach() conv_weight = filtermap_pad.view(-1,1).index_select(0,ids) conv_weight = conv_weight.view(out_channels,in_channels,k_h,k_w) if self.binary_filtermap: binary_conv_weight = conv_weight.clone() for nf in range(0,out_channels): float_filter = conv_weight[nf,:,:,:]; L1_norm = torch.norm(float_filter.view(-1,1),1); sign_filter = torch.sign(float_filter); binary_filter = sign_filter*L1_norm; binary_conv_weight[nf,:,:,:] = binary_filter return binary_conv_weight else: return conv_weight #pdb.set_trace() #for c in range(0,sample_c): # for y in range(0,sample_y): # for x in range(0,sample_x): # filter_count = c*sample_y*sample_x + y*sample_x + x # conv_weight_clone[filter_count,:,:,:] = filtermap_pad[c*stride_c:c*stride_c+in_channels, \ # y*stride_y:y*stride_y+k_h, x*stride_x:x*stride_x+k_w] #return conv_weight def forward(self, input): #return F.conv2d(input, self.weight, self.bias, self.stride, # self.padding, self.dilation, self.groups) #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat)) #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat)) #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat)) #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat)) #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \ # self.kernel_size[1], self.out_channels); #conv_weight = conv_weight.permute(3, 0, 1, 2) #conv_weight = conv_weight.contiguous() #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2) #pdb.set_trace() #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \ # self.kernel_size[0], self.kernel_size[1])) #conv_weight.cuda() conv_weight = self.extract_filters() #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \ # 0:0+self.kernel_size[0], 0:0+self.kernel_size[1]] out = F.conv2d(input, conv_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return out #return F.conv2d(input, conv_weight, self.bias, self.stride, # self.padding, self.dilation, self.groups) def fit_filtermap1(self, conv): conv_weight = conv.weight.data conv_weight_size = conv_weight.size() out_channels = conv_weight_size[0] in_channels = conv_weight_size[1] k_h = conv_weight_size[2] k_w = conv_weight_size[3] sample_y = self.sample_y sample_x = self.sample_x sample_c = self.sample_c stride_y = self.stride_y stride_x = self.stride_x stride_c = self.stride_c fm_ext = torch.Tensor(sample_c*in_channels,sample_y*k_h,sample_x*k_w) for c in range(0,sample_c): for y in range(0,sample_y): for x in range(0,sample_x): filter_count = c*sample_y*sample_x + y*sample_x + x fm_ext[c*in_channels:(c+1)*in_channels,y*k_h:(y+1)*k_h,x*k_w:(x+1)*k_w] = conv_weight[filter_count,:,:,:] #pdb.set_trace() for oc in range(0,sample_c): for c in range(1,sample_c): idx = sample_c-c+oc if idx > sample_c-1: idx = 0 fm_ext[oc*stride_c:(oc+1)*stride_c,:,:] += \ fm_ext[idx*stride_c:(idx+1)*stride_c,:,:] fm_ext = fm_ext[0:in_channels,:,:]/sample_c fm_ext_h = fm_ext.size()[1] fm_ext_w = fm_ext.size()[2] for y in range(0,sample_y): fm_ext[:,y*k_h,:] += fm_ext[:,(y*k_h-1+fm_ext_h)%fm_ext_h,:] fm_ext[:,y*k_h,:] = fm_ext[:,y*k_h,:]/2 for x in range(0,sample_x): fm_ext[:,:,x*k_w] += fm_ext[:,:,(x*k_w-1+fm_ext_w)%fm_ext_w] fm_ext[:,:,x*k_w] = fm_ext[:,:,x*k_w]/2 y_ids = torch.Tensor([0,1]) y_ids0 = y_ids for y in range(1,sample_y): y_ids = torch.cat((y_ids,y_ids0+y*k_h),0) x_ids = torch.Tensor([0,1]) x_ids0 = x_ids for x in range(1,sample_x): x_ids = torch.cat((x_ids,x_ids0+x*k_w),0) fm = torch.index_select(fm_ext,1,y_ids.long()) fm = torch.index_select(fm,2,x_ids.long()) self.filtermap.data = fm
class ZMUSWrapper(nn.Module): """Zero-mean Unit-STD States see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-variance-of-two-groups-given-known-group-variances-mean """ def __init__(self, policy_net, eps=1e-6): super(ZMUSWrapper, self).__init__() self.eps = eps self.policy_net = policy_net self.policy_cfg = self.policy_net.policy_cfg # Parameters state_size = self.policy_net.agent.observation_space.shape self.state_mean = Parameter(torch.Tensor(state_size, )) self.state_variance = Parameter(torch.Tensor(state_size, )) self.state_mean.requires_grad = False self.state_variance.requires_grad = False # cash self.size = 0 self.ep_states_data = [] self.first_pass = True def _get_state_mean(self): return self.state_mean.detach() def _get_state_variance(self): return self.state_variance.detach() def forward(self, s): self.size += 1 self.ep_states_data.append(s) if not self.first_pass: s = (s - self._get_state_mean()) / \ (torch.sqrt(self._get_state_variance()) + self.eps) return self.policy_net(s) def episode_callback(self): ep_states_tensor = torch.stack(self.ep_states_data) new_data_mean = torch.mean(ep_states_tensor, dim=0) new_data_var = torch.var(ep_states_tensor, dim=0) if self.first_pass: self.state_mean.data = new_data_mean self.state_variance.data = new_data_var self.first_pass = False else: n = len(self.ep_states_data) mean = self._get_state_mean() var = self._get_state_variance() new_data_mean_sq = torch.mul(new_data_mean, new_data_mean) size = min(self.policy_cfg.FORGET_COUNT_OBS_SCALER, self.size) new_mean = ((mean * size) + (new_data_mean * n)) / (size + n) new_var = (((size * (var + torch.mul(mean, mean))) + (n * (new_data_var + new_data_mean_sq))) / (size + n) - torch.mul(new_mean, new_mean)) self.state_mean.data = new_mean self.state_variance.data = torch.clamp( new_var, 0.) # occasionally goes negative, clip self.size += n def batch_callback(self): pass
class PhaseShifter(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` This module supports :ref:`TensorFloat32<tf32_on_ampere>`. Args: in_features: size of each input sample out_features: size of each output sample Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = \text{in\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = \text{out\_features}`. Attributes: theta: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in\_features})`. The values are initialized from uniform(0,2*pi) Examples:: >>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """ __constants__ = ['in_features', 'out_features'] in_features: int out_features: int scale: float theta: Tensor def __init__(self, in_features: int, out_features: int, scale: float=1, theta = None) -> None: super(PhaseShifter, self).__init__() self.in_features = in_features self.in_dim = self.in_features//2 self.out_features = out_features self.scale = scale # self.theta = Parameter(torch.Tensor(self.out_features, self.in_dim)) self.theta = Parameter(torch.Tensor(self.in_dim, self.out_features)) self.reset_parameters(theta) def reset_parameters(self, theta = None) -> None: if theta is None: init.uniform_(self.theta, a=0, b=2*np.pi) else: assert theta.shape == (self.in_dim,self.out_features) self.theta = Parameter(theta) self.real_kernel = (1 / self.scale) * torch.cos(self.theta) # self.imag_kernel = (1 / self.scale) * torch.sin(self.theta) # def forward(self, inputs: Tensor) -> Tensor: self.real_kernel = (1 / self.scale) * torch.cos(self.theta) # self.imag_kernel = (1 / self.scale) * torch.sin(self.theta) # cat_kernels_4_real = torch.cat( (self.real_kernel, -self.imag_kernel), dim=-1 ) cat_kernels_4_imag = torch.cat( (self.imag_kernel, self.real_kernel), dim=-1 ) cat_kernels_4_complex = torch.cat( (cat_kernels_4_real, cat_kernels_4_imag), dim=0 ) # This block matrix represents the conjugate transpose of the original: # [ W_R, -W_I; W_I, W_R] # output = F.linear(inputs, cat_kernels_4_complex) output = torch.matmul(inputs, cat_kernels_4_complex) return output def extra_repr(self) -> str: return 'in_features={}, out_features={}'.format( self.in_features, self.out_features ) def get_theta(self) -> torch.Tensor: return self.theta.detach().clone() def get_weights(self) -> torch.Tensor: with torch.no_grad(): real_kernel = (1 / self.scale) * torch.cos(self.theta) # imag_kernel = (1 / self.scale) * torch.sin(self.theta) # # cat_kernels_4_real = torch.cat( # (real_kernel, -imag_kernel), # dim=-1 # ) # cat_kernels_4_imag = torch.cat( # (imag_kernel, real_kernel), # dim=-1 # ) # cat_kernels_4_complex = torch.cat( # (cat_kernels_4_real, cat_kernels_4_imag), # dim=0 # ) # This block matrix represents the conjugate transpose of the original: # # [ W_R, -W_I; W_I, W_R] beam_weights = real_kernel + 1j*imag_kernel return beam_weights
class Quantization(nn.Module): def __init__(self, network_controller: NetworkQuantizationController, is_signed: bool, alpha: float = 0.9, weights_values=None, efficient=True): """ HMQ Block :param network_controller: The network controller :param is_signed: is this tensor signed :param alpha: the thresholds I.I.R value :param weights_values: In the case of weights quantized this is the tensors values :param efficient: Boolean flag stating to use the memory efficient """ super(Quantization, self).__init__() self.weights_values = weights_values if weights_values is None: self.tensor_type = TensorType.ACTIVATION self.tensor_size = None else: self.tensor_type = TensorType.COEFFICIENT self.tensor_size = np.prod(weights_values.shape) self.network_controller = network_controller self.alpha = alpha self.is_signed_tensor = torch.Tensor([float(is_signed)]).cuda() if efficient: self.base_q = EfficientBaseQuantization() else: self.base_q = BaseQuantization() self.gumbel_softmax = GumbelSoftmax(ste=network_controller.ste) self.bits_vector = None self.mv_shifts = None self.base_thresholds = None self.nb_shifts_points_div = None self.search_matrix = None def init_quantization_coefficients(self): """ This function initlized the HMQ parameters :return: None """ init_threshold = 0 n_bits_list, thresholds_shifts = self.network_controller.quantization_config.get_thresholds_bitwidth_lists(self) if self.is_coefficient(): init_threshold = torch.pow(2.0, self.weights_values.abs().max().log2().ceil() + 1).item() if self.is_activation(): n_bits_list = [8] self._init_quantization_params(n_bits_list, thresholds_shifts, init_threshold) self._init_search_matrix(self.network_controller.p, n_bits_list, len(thresholds_shifts)) def _init_quantization_params(self, bit_list, thresholds_shifts, init_thresholds): self.update_bits_list(bit_list) self.mv_shifts = Parameter(torch.Tensor(thresholds_shifts), requires_grad=False) self.thresholds_shifts_points_div = Parameter(torch.pow(2.0, self.mv_shifts), requires_grad=False) self.base_thresholds = Parameter(torch.Tensor(1), requires_grad=False) init.constant_(self.base_thresholds, init_thresholds) def _init_search_matrix(self, p, n_bits_list, n_thresholds_options): n_channels = 1 sm = -np.random.rand(n_channels, len(n_bits_list), n_thresholds_options, 1) n = np.prod(sm.shape) sm[:, 0, 0, 0] = np.log(p * n / (1 - p)) # for single channels self.search_matrix = Parameter(torch.Tensor(sm)) def _get_quantization_probability_matrix(self, batch_size=1, noise_disable=False): return self.gumbel_softmax(self.search_matrix, self.network_controller.temperature, batch_size=batch_size, noise_disable=noise_disable) def _get_bits_probability(self, batch_size=1, noise_disable=False): p = self._get_quantization_probability_matrix(batch_size=batch_size, noise_disable=noise_disable) return p.sum(dim=4).sum(dim=3).sum(dim=1) def _update_iir(self, x): # update scale using statistics if self.is_activation(): if self.tensor_size is None: self.tensor_size = np.prod(x.shape[1:]) # Remove batch axis max_value = x.abs().max() self.base_thresholds.data.add_(self.alpha * (max_value - self.base_thresholds)) def _calculate_expected_delta(self, p, max_scale): max_scales = max_scale / (self.thresholds_shifts_points_div.reshape(1, -1)) max_scales = max_scales.reshape(1, 1, 1, -1, 1) nb_shifts = self.nb_shifts_points_div.reshape(1, 1, -1, 1, 1) * torch.pow(2.0, -self.is_signed_tensor) delta = (max_scales / nb_shifts) * p return delta.sum(dim=-1).sum(dim=-1).sum(dim=-1).sum(dim=-1) def _calculate_expected_threshold(self, p, max_threshold): p_t = p.sum(dim=4).sum(dim=2).sum(dim=1) thresholds = max_threshold / (self.thresholds_shifts_points_div.reshape(1, -1)) return (p_t * thresholds).sum(dim=-1) def _calculate_expected_q_point(self, p, max_threshold, expected_delta, param_shape): t = self._calculate_expected_threshold(p, max_threshold=max_threshold).reshape(*param_shape) return t / expected_delta def _built_param_shape(self, x): random_size = x.shape[0] if self.is_activation() else x.shape[1] # select random if len(x.shape) == 4: param_shape = [random_size, -1, 1, 1] if self.is_activation() else [-1, random_size, 1, 1] elif len(x.shape) == 2: param_shape = [random_size, -1] if self.is_activation() else [-1, random_size] else: raise NotImplemented return random_size, param_shape def forward(self, x): """ The forward function of the HMQ module :param x: Input tensor x :return: A tensor after quantization """ if self.network_controller.statistics_update: self._update_iir(x) max_threshold = torch.pow(2.0, torch.ceil(torch.log2(self.base_thresholds.detach().abs()))).detach() # read scale if self.training and self.network_controller.temperature > 0: random_size, param_shape = self._built_param_shape(x) # axis according to tensor type (activation randomization is done over the batch axis, # coeff the randomization is done over the input channel axis) p = self._get_quantization_probability_matrix(batch_size=random_size) delta = self._calculate_expected_delta(p, max_threshold).reshape(*param_shape) q_points = self._calculate_expected_q_point(p, max_threshold, delta, param_shape).reshape(*param_shape) return self.base_q(x, delta, q_points, self.is_signed_tensor) else: # negative temperature/ infernce p = self._get_quantization_probability_matrix(batch_size=1, noise_disable=True).squeeze(dim=0) bits_index = torch.argmax(self._get_bits_probability(batch_size=1, noise_disable=True).squeeze(dim=0)) max_index = torch.argmax(p[:, bits_index, :, 0], dim=-1) q_points = self.nb_shifts_points_div[bits_index] * torch.pow(2.0, -self.is_signed_tensor) max_scales = (max_threshold / self.thresholds_shifts_points_div.reshape(1, -1)).detach() delta = torch.stack( [(max_scales[i, mv] / q_points) for i, mv in enumerate(max_index)]).flatten().detach() return self.base_q(x, delta, q_points, self.is_signed_tensor) def get_bit_width(self): """ This function return the selected bit-width :return: the bit-width of the HMQ """ return self.bits_vector[torch.argmax(self._get_bits_probability(noise_disable=True).flatten())].item() def get_expected_bits(self): """ This function return the expected bit-width :return: the expected bit-width of the HMQ """ return (self.bits_vector * self._get_bits_probability(noise_disable=True)).sum() def get_float_size(self): """ This function return the size of floating point tensor in bits Note: we assume 32 bits for floating point values :return: the floating point tensor size """ return 32 * self.tensor_size def get_fxp_size(self): """ This function return the size of quantized tensor in bits :return: the quantized tensor size """ return self.get_bit_width() * self.tensor_size def is_activation(self): """ This function return the boolean stating if this module quantize activation :return: a boolean flag stating if this activation quantization """ return self.tensor_type == TensorType.ACTIVATION def is_coefficient(self): """ This function return the boolean stating if this module quantize coefficient :return: a boolean flag stating if this coefficient quantization """ return self.tensor_type == TensorType.COEFFICIENT def get_expected_tensor_size(self): """ This function return the expected size of quantized tensor in bits :return: the expected size of quantized tensor """ return torch.Tensor([self.tensor_size]).cuda() def update_bits_list(self, bits_list): """ This function update the HMQ bit-width list :param bits_list: A list of new bit-widths :return: None """ if self.bits_vector is None: self.bits_vector = Parameter(torch.Tensor(bits_list), requires_grad=False) self.nb_shifts_points_div = Parameter( torch.pow(2.0, self.bits_vector), # - int(q_node.is_signed) requires_grad=False) # move to init else: self.bits_vector.add_(torch.Tensor(bits_list).cuda() - self.bits_vector) self.nb_shifts_points_div.add_(torch.pow(2.0, self.bits_vector) - self.nb_shifts_points_div)
class Linear(nn.Module): """Linear layer Parameters: ----------- input_size: int input size of transformation output_size: int output size of transformation bias: boolean, default=True whether to use bias or not device: str, default="cuda:0" keep on this device """ def __init__(self, input_size, output_size, bias=True, device="cuda:0"): super(Linear, self).__init__() self.device = device # Useful in case of multiple GPUs self.input_size = input_size self.output_size = output_size self.weight = Parameter(torch.Tensor(self.output_size, self.input_size)) if bias: self.bias = Parameter(torch.Tensor(self.output_size, 1)) else: self.register_parameter('bias', None) self.reset_parameters() def forward(self, input): if self.bias is not None: return F.linear(input.to(self.device), self.weight, self.bias.view(-1)) else: return F.linear(input.to(self.device), self.weight) def to(self): """Transfer to device """ super().to(self.device) def reset_parameters(self): """Initialize vectors """ torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound) # stdv = 1. / math.sqrt(self.weight.size(1)) # self.weight.data.uniform_(-stdv, stdv) # if self.bias is not None: # self.bias.data.uniform_(-stdv, stdv) def get_weights(self): """Get weights as numpy array Bias is appended in the end """ _wts = self.weight.detach().cpu().numpy() if self.bias is not None: _bias = self.bias.detach().cpu().numpy() _wts = np.hstack([_wts, _bias]) return _wts def __repr__(self): s = '{name}({input_size}, {output_size}, {device}' if self.bias is not None: s += ', bias=True' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__) @property def sparse(self): return False
class Conv2dDPQ(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, qmin=1e-3, qmax=100, dmin=1e-5, dmax=10, bias=True, sign=True, wbits=4, abits=4, mode=Qmodes.layer_wise): """ :param d_init: the inital quantization stepsize (alpha) :param mode: Qmodes.layer_wise or Qmodes.kernel_wise :param xmax_init: the quantization range for whole weights """ super(Conv2dDPQ, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.qmin = qmin self.qmax = qmax self.dmin = dmin self.dmax = dmax self.q_mode = mode self.sign = sign self.nbits = wbits self.act_dpq = ActDPQ(signed=True, nbits=abits) if self.q_mode == Qmodes.kernel_wise: self.alpha = Parameter(torch.Tensor(out_channels)) else: self.alpha = Parameter(torch.Tensor(1)) self.xmax = Parameter(torch.Tensor(1)) self.weight.requires_grad_(True) if bias: self.bias.requires_grad_(True) self.register_buffer('init_state', torch.zeros(1)) def get_nbits(self): abits = self.act_dpq.get_nbits() # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data)) self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax)) self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax)) # print('after clamp, the result is :') # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data)) if self.sign: nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2) + 1).ceil() else: nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2)).ceil() # print('the nbits for weight is : {}'.format(nbits)) self.nbits = int(nbits.item()) return abits, nbits def get_quan_filters(self, filters): if self.training and self.init_state == 0: # print('initial alphas in current time !') Qp = 2 ** (self.nbits - 1) - 1 self.alpha.data.copy_(2 * filters.abs().mean() / math.sqrt(Qp)) self.xmax.data.copy_(self.alpha * Qp) # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data)) # self.xmax.data.copy_(self.weight.abs().max()) # wmean = self.weight.abs().mean() # self.alpha.data.copy_(4 * wmean * wmean / self.xmax) self.init_state.fill_(1) self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax)) self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax)) Qp = (self.xmax.detach()/self.alpha.detach()).item() g = 1.0 / math.sqrt(filters.numel() * Qp) alpha = grad_scale(self.alpha, g) xmax = grad_scale(self.xmax, g) # alpha = quantize_pow2(self.alpha) # xmax = self.xmax # alpha = self.alpha # g = 1.0 / math.sqrt(self.weight.numel() * (xmax.detach()/d.detach()).item()) # d = grad_scale(self.alpha, g) # xmax = torch.clamp(self.xmax, self.xmax_min, self.xmax_max) if self.sign: wq = round_pass((torch.clamp(filters/xmax, -1, 1) * xmax)/alpha) * alpha else: wq = round_pass((torch.clamp(filters/xmax, 0, 1) * xmax)/alpha) * alpha return wq def forward(self,x): if self.training and self.init_state == 0: # print('initial alphas in current time !') Qp = 2 ** (self.nbits - 1) - 1 self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp)) self.xmax.data.copy_(self.alpha * Qp) # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data)) # self.xmax.data.copy_(self.weight.abs().max()) # wmean = self.weight.abs().mean() # self.alpha.data.copy_(4 * wmean * wmean / self.xmax) self.init_state.fill_(1) # alpha = quantize_pow2(self.alpha) # alpha = self.alpha # g = 1.0 / math.sqrt(self.weight.numel() * Qp) # xmax = self.xmax self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax)) self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax)) Qp = (self.xmax.detach()/self.alpha.detach()).item() g = 1.0 / math.sqrt(self.weight.numel() * Qp) alpha = grad_scale(self.alpha, g) xmax = grad_scale(self.xmax, g) if self.sign: wq = round_pass((torch.clamp(self.weight/xmax, -1, 1) * xmax)/alpha) * alpha else: wq = round_pass((torch.clamp(self.weight/xmax, 0, 1) * xmax)/alpha) * alpha if self.act_dpq is not None: x = self.act_dpq(x) # print('the abits is : {} and the nbits is : {}'.format(self.act_dpq.nbits, self.nbits)) return F.conv2d(x, wq, self.bias, self.stride, self.padding, self.dilation, self.groups)
class ABAE(torch.nn.Module): """ The model described in the paper ``An Unsupervised Neural Attention Model for Aspect Extraction'' by He, Ruidan and Lee, Wee Sun and Ng, Hwee Tou and Dahlmeier, Daniel, ACL2017 https://aclweb.org/anthology/papers/P/P17/P17-1036/ """ def __init__(self, wv_dim: int = 200, asp_count: int = 30, ortho_reg: float = 0.1, maxlen: int = 201, init_aspects_matrix=None): """ Initializing the model :param wv_dim: word vector size :param asp_count: number of aspects :param ortho_reg: coefficient for tuning the ortho-regularizer's influence :param maxlen: sentence max length taken into account :param init_aspects_matrix: None or init. matrix for aspects """ super(ABAE, self).__init__() self.wv_dim = wv_dim self.asp_count = asp_count self.ortho = ortho_reg self.maxlen = maxlen self.attention = SelfAttention(wv_dim, maxlen) self.linear_transform = torch.nn.Linear(self.wv_dim, self.asp_count) self.softmax_aspects = torch.nn.Softmax() self.aspects_embeddings = Parameter( torch.empty(size=(wv_dim, asp_count))) if init_aspects_matrix is None: torch.nn.init.xavier_uniform(self.aspects_embeddings) else: self.aspects_embeddings.data = torch.from_numpy( init_aspects_matrix.T) def get_aspects_importances(self, text_embeddings): """ Takes embeddings of a sentence as input, returns attention weights """ # compute attention scores, looking at text embeddings average attention_weights = self.attention(text_embeddings) # multiplying text embeddings by attention scores -- and summing # (matmul: we sum every word embedding's coordinate with attention weights) weighted_text_emb = torch.matmul( attention_weights.unsqueeze(1), # (batch, 1, sentence) text_embeddings # (batch, sentence, wv_dim) ).squeeze() # encoding with a simple feed-forward layer (wv_dim) -> (aspects_count) raw_importances = self.linear_transform(weighted_text_emb) # computing 'aspects distribution in a sentence' aspects_importances = self.softmax_aspects(raw_importances) return attention_weights, aspects_importances, weighted_text_emb def forward(self, text_embeddings, negative_samples_texts): # negative samples are averaged averaged_negative_samples = torch.mean(negative_samples_texts, dim=2) # encoding: words embeddings -> sentence embedding, aspects importances _, aspects_importances, weighted_text_emb = self.get_aspects_importances( text_embeddings) # decoding: aspects embeddings matrix, aspects_importances -> recovered sentence embedding recovered_emb = torch.matmul( self.aspects_embeddings, aspects_importances.unsqueeze(2)).squeeze() # loss reconstruction_triplet_loss = ABAE._reconstruction_loss( weighted_text_emb, recovered_emb, averaged_negative_samples) max_margin = torch.max(reconstruction_triplet_loss, torch.zeros_like(reconstruction_triplet_loss)) return self.ortho * self._ortho_regularizer() + max_margin @staticmethod def _reconstruction_loss(text_emb, recovered_emb, averaged_negative_emb): positive_dot_products = torch.matmul( text_emb.unsqueeze(1), recovered_emb.unsqueeze(2)).squeeze() negative_dot_products = torch.matmul( averaged_negative_emb, recovered_emb.unsqueeze(2)).squeeze() reconstruction_triplet_loss = torch.sum( 1 - positive_dot_products.unsqueeze(1) + negative_dot_products, dim=1) return reconstruction_triplet_loss def _ortho_regularizer(self): return torch.norm( torch.matmul(self.aspects_embeddings.t(), self.aspects_embeddings) \ - torch.eye(self.asp_count)) def get_aspect_words(self, w2v_model, topn=15): words = [] # getting aspects embeddings aspects = self.aspects_embeddings.detach().numpy() # getting scalar products of word embeddings and aspect embeddings; # to obtain the ``probabilities'', one should also apply softmax words_scores = w2v_model.wv.syn0.dot(aspects) for row in range(aspects.shape[1]): argmax_scalar_products = np.argsort(-words_scores[:, row])[:topn] # print([w2v_model.wv.index2word[i] for i in argmax_scalar_products]) # print([w for w, dist in w2v_model.similar_by_vector(aspects.T[row])[:topn]]) words.append( [w2v_model.wv.index2word[i] for i in argmax_scalar_products]) return words
class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): r""" This is an extension of the FakeQuantize module in fake_quantize.py, which supports more generalized lower-bit quantization and support learning of the scale and zero point parameters through backpropagation. For literature references, please see the class _LearnableFakeQuantizePerTensorOp. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. * :attr:`channel_len` defines the length of the channel when initializing scale and zero point for the per channel case. * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are normalized by the constant, which is proportional to the square root of the number of elements in the tensor. The related literature justifying the use of this particular constant can be found here: https://openreview.net/pdf?id=rkgO66VKDS. * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. * :attr:`static_enabled` defines the flag for using observer's static estimation for scale and zero point. * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. """ def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1, use_grad_scaling=False, **observer_kwargs): super(_LearnableFakeQuantize, self).__init__() assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' self.quant_min = quant_min self.quant_max = quant_max # also pass quant_min and quant_max to observer observer_kwargs["quant_min"] = quant_min observer_kwargs["quant_max"] = quant_max self.use_grad_scaling = use_grad_scaling if channel_len == -1: self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: assert isinstance( channel_len, int ) and channel_len > 0, "Channel size must be a positive integer." self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter( torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ 'quant_max out of bound' self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps])) @torch.jit.export def enable_param_learning(self): r"""Enables learning of quantization parameters and disables static observer estimates. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=True) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=False) return self @torch.jit.export def enable_static_estimate(self): r"""Enables static observer estimates and disbales learning of quantization parameters. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=True) @torch.jit.export def enable_static_observation(self): r"""Enables static observer accumulating data from input but doesn't update the quantization parameters. Forward path returns the original X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=False) \ .toggle_observer_update(enabled=True) @torch.jit.export def toggle_observer_update(self, enabled=True): self.static_enabled[0] = int(enabled) # type: ignore[operator] return self @torch.jit.export def enable_observer(self, enabled=True): self.toggle_observer_update(enabled) @torch.jit.export def toggle_qparam_learning(self, enabled=True): self.learning_enabled[0] = int(enabled) # type: ignore[operator] self.scale.requires_grad = enabled self.zero_point.requires_grad = enabled return self @torch.jit.export def toggle_fake_quant(self, enabled=True): self.fake_quant_enabled[0] = int(enabled) return self @torch.jit.export def observe_quant_params(self): print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) print('_LearnableFakeQuantize Zero Point: {}'.format( self.zero_point.detach())) @torch.jit.export def calculate_qparams(self): self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] scale = self.scale.detach() zero_point = self.zero_point.detach().round().clamp( self.quant_min, self.quant_max).long() return scale, zero_point def forward(self, X): if self.static_enabled[0] == 1: # type: ignore[index] self.activation_post_process(X.detach()) _scale, _zero_point = self.activation_post_process.calculate_qparams( ) _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) else: self.scale.data.clamp_( min=self.eps.item()) # type: ignore[operator] if self.fake_quant_enabled[0] == 1: if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): self.zero_point.data.zero_() if self.use_grad_scaling: grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 else: grad_factor = 1.0 if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): X = torch._fake_quantize_learnable_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) else: X = torch._fake_quantize_learnable_per_tensor_affine( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor) return X
class CustomEmbedding(torch.nn.Module): """ Memory efficient way to compute weighted EmbeddingBag """ def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, device="cuda:0"): """ Args: num_embeddings: int: vocalubary size embedding_dim: int: dimension for embeddings padding_idx: int: index for <PAD>; embedding is not updated max_norm: norm_type: int: default: 2 scale_grad_by_freq: boolean: True/False sparse: boolean: sparse or dense gradients """ super(CustomEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = num_embeddings if self.padding_idx is not None: self.num_embeddings = num_embeddings + 1 self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.weight = Parameter( torch.Tensor(self.num_embeddings, embedding_dim)) self.sparse = sparse self.device = device self.reset_parameters() def reset_parameters(self): """ Reset weights """ self.weight.data.normal_(0, 1) if self.padding_idx is not None: self.weight.data[self.padding_idx].fill_(0) def to(self): super().to(self.device) def forward(self, batch_data): """ Forward pass for embedding layer Arguments ---------- batch_data: dict {'X': torch.LongTensor (BxN), 'X_w': torch.FloatTensor (BxN)} 'X': Feature indices 'X_w': Feature weights Returns: ---------- torch.Tensor embedding for each sample B x embedding_dims """ features = batch_data['X'].to(self.device) weights = batch_data['X_w'].to(self.device) out = F.embedding(features, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) out = weights.unsqueeze(1).bmm(out).squeeze() return out def get_weights(self): return self.weight.detach().cpu().numpy() def __repr__(self): s = '{name}({num_embeddings}, {embedding_dim}, {device}' if self.padding_idx is not None: s += ', padding_idx={padding_idx}' if self.max_norm is not None: s += ', max_norm={max_norm}' if self.norm_type != 2: s += ', norm_type={norm_type}' if self.scale_grad_by_freq is not False: s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: s += ', sparse=True' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__) def _init_(self, state_dict): keys = list(state_dict.keys()) key = [x for x in keys if x.split(".")[-1] in ['weight']][0] weight = state_dict[key] self.weight.data.copy_(weight)
class CombineUV(nn.Module): """ Combines all classifier components for DECAF Arguments ---------- input_size: int Classifier's embeddinging dimension output_size: int Number of classifiers to train bias: bool (default=True) Flag to use bias use_classifier_wts: bool (default=False) If True learns a refinement vector padding_idx: int (default=None) Label padding index device: string (default="cuda:0") GPU id to keep the parameters sparse: bool (default=False) Type of optimizer to use for training """ def __init__(self, input_size, output_size, bias=True, use_classifier_wts=False, padding_idx=None, device="cuda:0", sparse=False): super(CombineUV, self).__init__() self._device = device # Useful in case of multiple GPUs self.input_size = input_size self.output_size = output_size self.use_classifier_wts = use_classifier_wts self.sparse = sparse if self.sparse: self.device = "cpu" self.padding_idx = padding_idx if self.use_classifier_wts: if bias: if self.sparse: self.bias = Parameter(torch.Tensor(self.output_size, 1)) else: self.bias = Parameter(torch.Tensor(self.output_size)) else: self.register_parameter('bias', None) self.weight = Parameter( torch.Tensor(self.output_size, self.input_size)) self.alpha = Parameter(torch.Tensor(1, self.input_size)) self.beta = Parameter(torch.Tensor(1, self.input_size)) else: self.register_parameter('bias', None) self.register_parameter('weight', None) self.prebias = None self.prepredict = None self.reset_parameters() def _get_clf(self, weights, shortlist=None, sparse=True, padding_idx=None): if shortlist is not None and weights is not None: return F.embedding(shortlist, weights, sparse=sparse, padding_idx=padding_idx, max_norm=None, norm_type=2., scale_grad_by_freq=False) return weights def _get_rebuild(self, lbl_clf, shortlist=None): if self.prepredict is None: bias = self._get_clf(self.bias, shortlist, self.sparse, self.padding_idx) lbl_clf = lbl_clf.to(self.device) _lbl_clf = self._get_clf(self.weight, shortlist, self.sparse, self.padding_idx) if self.use_classifier_wts: lbl_clf = self.alpha.sigmoid()*_lbl_clf.squeeze() + \ self.beta.sigmoid()*lbl_clf return lbl_clf, bias else: return self.prepredict, self.prebias def forward(self, input, labels, shortlist=None, shortlist_first=None): input = input.to(self.device) if labels is not None: labels = labels.to(self.device) if shortlist_first is not None: shortlist_first = shortlist_first.to(self.device) lbl_clf, bias = self._get_rebuild(labels, shortlist_first) if shortlist is not None: shortlist = shortlist.to(self.device) input = input.to(self.device) lbl_clf = self._get_clf(lbl_clf, shortlist, sparse=False) bias = self._get_clf(bias, shortlist, sparse=False) out = torch.matmul(input.unsqueeze(1), lbl_clf.permute(0, 2, 1)) if bias is not None: out = out.squeeze() + bias.squeeze() return out.squeeze() else: return F.linear(input, lbl_clf, bias) def _setup(self, lbl_clf): self.device = "cpu" if self.use_classifier_wts: norm_u = np.mean( torch.norm(self.weight, dim=1).detach().cpu().numpy()) print("Alpha=", torch.mean(self.alpha.detach().sigmoid()).cpu().numpy(), "Beta=", torch.mean(self.beta.detach().sigmoid()).cpu().numpy()) norm_v = np.mean(torch.norm(lbl_clf, dim=1).detach().cpu().numpy()) lbl_clf, _bias = self._get_rebuild(lbl_clf) norm_w = np.mean(torch.norm(lbl_clf, dim=1).detach().cpu().numpy()) print("||u||=%0.2f, ||v||=%0.2f, ||w||=%0.2f" % (norm_u, norm_v, norm_w)) self.prebias = _bias + 0 self.prepredict = lbl_clf def _pred_to(self): self.device = self._device self.prepredict = self.prepredict.to(self.device) if self.use_classifier_wts: self.prebias = self.prebias.to(self.device) def _pred_cpu(self): self.device = "cpu" self.prepredict = self.prepredict.cpu() if self.use_classifier_wts: self.prebias = self.prebias.cpu() def _clean(self): if self.prepredict is None: print("No need to clean") return print("Cleaing for GPU") self.prepredict = self.prepredict.cpu() if self.use_classifier_wts: self.prebias = self.prebias.cpu() self.prebias, self.prepredict = None, None def to(self): """ Transfer to device """ self.device = self._device super().to(self._device) def reset_parameters(self): """ Initialize vectors """ stdv = 1. / math.sqrt(self.input_size) if self.weight is not None: self.weight.data.uniform_(-stdv, stdv) self.alpha.data.fill_(0) self.beta.data.fill_(0) if self.bias is not None: self.bias.data.fill_(0) def get_weights(self): """ Get weights as dictionary """ parameters = {} if self.bias is not None: parameters['bias'] = self.bias.detach().cpu().numpy() if self.use_classifier_wts: parameters['weight'] = self.weight.detach().cpu().numpy() parameters['alpha'] = self.alpha.detach().cpu().numpy() parameters['beta'] = self.beta.detach().cpu().numpy() return parameters def __repr__(self): s = "{name}({input_size}, {output_size}, {_device}, {use_classifier_wts}, sparse={sparse}" if self.use_classifier_wts: if self.bias is not None: s += ', bias=True' s += ", Alpha={}".format(self.alpha.detach().cpu().numpy().shape) s += ", Beta={}".format(self.beta.detach().cpu().numpy().shape) s += ", weight={}".format(self.weight.detach().cpu().numpy().shape) s += ')' return s.format(name=self.__class__.__name__, **self.__dict__) def _init_(self, lbl_fts_mat, state_dict): if self.use_classifier_wts: print("INFO::Initializing the classifier") keys = list(state_dict.keys()) key = [x for x in keys if x.split(".")[-1] in ['weight']][0] weight = state_dict[key] self.weight.data.copy_(lbl_fts_mat.mm(weight))
class Model: '''Model manager class Convenience class wrapping saving/loading, model initialization, optimization, and logging. Args: ann: Model to optimize. Used to initialize weights. config: A Config specification args: Hook for additional user arguments ''' def __init__(self, ann, config, args): self.saver = save.Saver(config.MODELDIR, 'models', 'bests', resetTol=256) self.config, self.args = config, args self.init(ann) if self.config.LOAD or self.config.BEST: self.load(self.config.BEST) def init(self, ann): print('Initializing new model...') self.initModel(ann) self.opt = None if not self.config.TEST: self.opt = ManualAdam([self.params], lr=0.001, weight_decay=0.00001) #Initialize a new network def initModel(self, ann): self.models = ann(self.config).params() self.params = Parameter(torch.Tensor(np.array(self.models))) #Grads and clip def stepOpt(self, gradList): '''Clip the provided gradients and step the optimizer Args: gradList: a list of gradients ''' grad = np.array(gradList) grad = np.mean(grad, 0) grad = np.clip(grad, -5, 5) gradAry = torch.Tensor(grad) self.opt.step(gradAry) def checkpoint(self, reward): '''Save the model to checkpoint Args: reward: Mean reward of the model ''' if self.config.TEST: return self.saver.checkpoint(self.params, self.opt, reward) def load(self, best=False): '''Load a model from file Args: best (bool): Whether to load the best (True) or most recent (False) checkpoint ''' print('Loading model...') epoch = self.saver.load(self.opt, self.params, best) @property def nParams(self): '''Print the number of model parameters''' nParams = len(self.model) print('#Params: ', str(nParams / 1000), 'K') @property def model(self): '''Get model parameters Returns: a numpy array of model parameters ''' return self.params.detach().numpy()
class WSConv2d(nn.Conv2d): def __init__(self, in_planes, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, multiplier=1.0, rep_dim=1, repeat_weight=True, use_coeff=False): if rep_dim == 0: # this is repeat along the channel dim (dimension 0 of the weights tensor) super(WSConv2d, self).__init__(int(in_planes), int(np.ceil(out_channels / multiplier)), kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.rep_time = int( np.ceil(1. * out_channels / self.weight.shape[0])) elif rep_dim == 1: # this is to repeat along the filter dim(dimension 1 of the weights tensor) super(WSConv2d, self).__init__(int(np.ceil(in_planes / multiplier)), int(out_channels), kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1])) self.in_planes = in_planes self.out_channels_ori = out_channels self.groups = groups self.multiplier = multiplier self.rep_dim = rep_dim self.repeat_weight = repeat_weight self.use_coeff = use_coeff # print(self.weight.shape) # import pdb; pdb.set_trace() self.conv1_stride_lr_1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=0, bias=False) self.conv1_stride_lr_2 = nn.Conv2d(in_planes, self.rep_time, kernel_size=1, stride=1, padding=0, bias=False) self.coefficient = Parameter(torch.Tensor(self.rep_time), requires_grad=False) self.reuse = False self.coeff_grad = None def generate_share_weight(self, base_weight, rep_num, coeff, nchannel, dim=0): ''' sample weights from base weight''' # pdb.set_trace() if rep_num == 1: return base_weight new_weight = [] for i in range(rep_num): if dim == 0: new_weight_temp = torch.cat( [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]], dim=0) * (1 - coeff[i]) else: new_weight_temp = torch.cat( [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]], dim=1) * (1 - coeff[i]) new_weight.append(base_weight * coeff[i] + new_weight_temp) out = torch.cat(new_weight, dim=dim) if dim == 0: return out[:nchannel, :, :, :] else: return out[:, :nchannel, :, :] def forward(self, x): """ same padding as efficientnet tf version """ ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max( (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max( (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 ]) if self.use_coeff: if self.training: # set reuse to True for coefficient sharing if not self.reuse: lr_conv1 = self.conv1_stride_lr_1(x) # pdb.set_trace() lr_conv1 = self.conv1_stride_lr_2(lr_conv1) lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0, 0] self.coefficient.set_( F.normalize(torch.mean(lr_conv1, 0), dim=0).clone().detach()) # pdb.set_trace() self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0), dim=0) if self.rep_dim == 0: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight(self.weight, self.rep_time, self.coeff_grad, self.out_channels_ori, dim=0)) else: out_tmp = F.conv2d(x, self.weight) out = self.generate_share_feature( self.rep_time, out_tmp, F.normalize(torch.mean(lr_conv1, 0), dim=0), self.out_channels_ori) # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0))) else: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight(self.weight, self.rep_time, self.coeff_grad, x.shape[1], dim=1)) else: out_tmp = self.feature_wrapper( self.rep_time, x, F.normalize(torch.mean(lr_conv1, 0), dim=0), self.out_channels_ori) out = F.conv2d(out_tmp, self.weight) else: if self.rep_dim == 0: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight( self.weight, self.rep_time, self.coefficient.detach(), self.out_channels_ori, dim=0)) else: out_tmp = F.conv2d(x, self.weight) out = self.generate_share_feature( self.rep_time, out_tmp, self.coefficient.detach(), self.out_channels_ori) # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach())) else: if self.repeat_weight: out = F.conv2d( x, self.generate_share_weight( self.weight, self.rep_time, self.coefficient.detach(), x.shape[1], dim=1)) else: out_tmp = self.feature_wrapper( self.rep_time, out_tmp, self.coefficient.detach(), self.out_channels_ori) out = F.conv2d(out_tmp, self.weight) else: # print("use_coeff == False") if self.rep_dim == 0: out = F.conv2d( x, self.weight.repeat([self.rep_time, 1, 1, 1])[:self.out_channels_ori, :, :, :], None, 1) else: out = F.conv2d( x, self.weight.repeat([1, self.rep_time, 1, 1])[:, :x.shape[1], :, :], None, 1) return out
class Conv2dDPQ(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, qmin=1e-3, qmax=100, dmin=1e-5, dmax=10, bias=True, sign=True, wbits=4, abits=4, mode=Qmodes.layer_wise): """ :param d_init: the inital quantization stepsize (alpha) :param mode: Qmodes.layer_wise or Qmodes.kernel_wise :param xmax_init: the quantization range for whole weights """ super(Conv2dDPQ, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.qmin = qmin self.qmax = qmax self.dmin = dmin self.dmax = dmax self.q_mode = mode self.sign = sign self.nbits = wbits self.act_dpq = ActDPQ(signed=False, nbits=abits) self.alpha = Parameter(torch.Tensor(1)) self.xmax = Parameter(torch.Tensor(1)) self.weight.requires_grad_(True) if bias: self.bias.requires_grad_(True) self.register_buffer('init_state', torch.zeros(1)) def get_nbits(self): abits = self.act_dpq.get_nbits() xmax = self.xmax.abs().item() alpha = self.alpha.abs().item() if self.sign: nbits = math.ceil(math.log(xmax / alpha + 1) / math.log(2) + 1) else: nbits = math.cell(math.log(xmax / alpha + 1) / math.log(2)) self.nbits = nbits return abits, nbits def get_quan_filters(self, filters): if self.training and self.init_state == 0: Qp = 2**(self.nbits - 1) - 1 self.xmax.data.copy_(filters.abs().max()) self.alpha.data.copy_(self.xmax / Qp) # self.alpha[self.index].data.copy_(2 * filters.abs().mean() / math.sqrt(Qp)) # self.xmax[self.index].data.copy_(self.alpha[self.index] * Qp) self.init_state.fill_(1) Qp = (self.xmax.detach() / self.alpha.detach()).abs().item() g = 1.0 / math.sqrt(filters.numel() * Qp) alpha = grad_scale(self.alpha, g) xmax = grad_scale(self.xmax, g) w = F.hardtanh(filters / xmax.abs(), -1, 1) * xmax.abs() w = w / alpha.abs() wq = round_pass(w) * alpha.abs() return wq def forward(self, x): if self.act_dpq is not None: x = self.act_dpq(x) wq = self.get_quan_filters(self.weight) return F.conv2d(x, wq, self.bias, self.stride, self.padding, self.dilation, self.groups)
class _ConvBnNd(nn.modules.conv._ConvNd): _version = 2 def __init__( self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn_stats=False, rt_spec=None, dim=2): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert rt_spec, 'Runtime spec be provided for quantized module {}'.format( self.__class__.__name__) self.bn_frozen = freeze_bn_stats if self.training else True self.dim = dim self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_quantizer = rt_spec.get_weight_quantizer('weight') self.bias_quantizer = rt_spec.get_weight_quantizer('bias') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn_stats: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() self.conv_bn_fused = False def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def batch_stats(self, x, bias=None): """Get the batch mean and variance of x and updates the BatchNorm's running mean and average. Args: x (torch.Tensor): input batch. bias (torch.Tensor): the bias that is to be applied to the batch. Returns: (mean, variance) Note: In case of `nn.Linear`, x may be of shape (N, C, L) or (N, L) where N is batch size, C is number of channels, L is the features size. The batch norm computes the stats over C in the first case or L on the second case. The batch normalization layer is (`nn.BatchNorm1d`)[https://pytorch.org/docs/stable/nn.html#batchnorm1d] In case of `nn.Conv2d`, x is of shape (N, C, H, W) where H,W are the image dimensions, and the batch norm computes the stats over C. The batch normalization layer is (`nn.BatchNorm2d`)[https://pytorch.org/docs/stable/nn.html#batchnorm2d] In case of `nn.Conv3d`, x is of shape (N, C, D, H, W) where H,W are the image dimensions, D is additional channel dimension, and the batch norm computes the stats over C. The batch normalization layer is (`nn.BatchNorm3d`)[https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html#torch.nn.BatchNorm3d] """ channel_size = self.bn.num_features self.bn.num_batches_tracked += 1 # Calculate current batch stats batch_mean = x.transpose(0, 1).contiguous().view(channel_size, -1).mean(1) # BatchNorm currently uses biased variance (without Bessel's correction) as was discussed at # https://github.com/pytorch/pytorch/issues/1410 # # also see the source code itself: # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L216 batch_var = x.transpose(0, 1).contiguous().view(channel_size, -1).var( 1, unbiased=False) # Update running stats with torch.no_grad(): biased_batch_mean = batch_mean + (bias if bias is not None else 0) # However - running_var is updated using unbiased variance! # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L223 n = x.numel() / channel_size corrected_var = batch_var * (n / float(n - 1)) momentum = self.bn.momentum if momentum is None: # momentum is None - we compute a cumulative moving average # as noted in https://pytorch.org/docs/stable/nn.html#batchnorm2d momentum = 1. / float(self.bn.num_batches_tracked) self.bn.running_mean.mul_(1 - momentum).add_(momentum * biased_batch_mean) self.bn.running_var.mul_(1 - momentum).add_(momentum * corrected_var) return batch_mean, batch_var def reset_parameters(self): super(_ConvBnNd, self).reset_parameters() def merge_bn_to_conv(self): with torch.no_grad(): # Use the same implementation in nndct_shared/optimzation/fuse_conv_bn.py # to make sure the test accruacy is same as the deployable model. gamma = self.bn.weight.detach().cpu().numpy() beta = self.bn.bias.detach().cpu().numpy() running_var = self.bn.running_var.detach().cpu().numpy() running_mean = self.bn.running_mean.detach().cpu().numpy() epsilon = self.bn.eps scale = gamma / np.sqrt(running_var + epsilon) offset = beta - running_mean * scale weight = self.weight.detach().cpu().numpy() # Conv2d if self.dim == 2 and not self.transposed: # OIHW -> IHWO -> OIHW weight = np.multiply(weight.transpose(1, 2, 3, 0), scale).transpose(3, 0, 1, 2) # ConvTranspose2d elif self.dim == 2 and self.transposed: # IOHW -> IHWO -> IOHW weight = np.multiply(weight.transpose(0, 2, 3, 1), scale).transpose(0, 3, 1, 2) # Conv3D elif self.dim == 3 and not self.transposed: weight = np.multiply(weight.transpose(1, 2, 3, 4, 0), scale).transpose(4, 0, 1, 2, 3) # ConvTranspose3d elif self.dim == 3 and self.transposed: weight = np.multiply(weight.transpose(2, 3, 4, 0, 1), scale).transpose(3, 4, 0, 1, 2) else: raise RuntimeError( 'Unsupported combinations: (dim={}, transposed={})'.format( self.dim, self.transposed)) self.weight.copy_(torch.from_numpy(weight)) bias = self.bias.detach().cpu().numpy() if self.bias is not None else 0 bias = torch.from_numpy(bias * scale + offset) if self.bias is not None: self.bias.copy_(bias) else: self.bias = nn.Parameter(bias) self.conv_bn_fused = True def update_bn_stats(self): self.bn_frozen = False def freeze_bn_stats(self): self.bn_frozen = True def clear_non_native_bias(self): if self.bias is None: print('[WARNING] No bias to unmerge') return with torch.no_grad(): gamma = self.bn.weight.detach().cpu().numpy() beta = self.bn.bias.detach().cpu().numpy() running_var = self.bn.running_var.detach().cpu().numpy() running_mean = self.bn.running_mean.detach().cpu().numpy() epsilon = self.bn.eps scale = gamma / np.sqrt(running_var + epsilon) bias = self.bias.detach().cpu().numpy() beta = torch.from_numpy(bias * scale + beta) self.bn.bias.copy_(beta) self.bias = None def broadcast_correction(self, c): """Broadcasts a correction factor to the output for elementwise operations. Two tensors are “broadcastable” if the following rules hold: - Each tensor has at least one dimension. - When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist. See https://pytorch.org/docs/stable/notes/broadcasting.html """ expected_output_dim = self.dim + 2 view_fillers_dim = expected_output_dim - c.dim() - 1 view_filler = (1,) * view_fillers_dim expected_view_shape = c.shape + view_filler return c.view(*expected_view_shape) def broadcast_correction_weight(self, c): """Broadcasts a correction factor to the weight.""" if c.dim() != 1: raise ValueError("Correction factor needs to have a single dimension") expected_weight_dim = self.dim + 2 view_fillers_dim = expected_weight_dim - c.dim() view_filler = (1,) * view_fillers_dim expected_view_shape = c.shape + view_filler return c.view(*expected_view_shape) def extra_repr(self): return super(_ConvBnNd, self).extra_repr() def forward(self, x, output_size=None): """ See https://arxiv.org/pdf/1806.08342.pdf section 3.2.2. bn(conv(x)) = (conv(x) - E(conv(x))) * gamma / std(conv(x)) + beta = (x*W + B - E(x*W + B)) * gamma / sqrt(E((x*W + B - E(x*W + B))^2)) + beta = (x*W - E(x*W)) * gamma / std(x*W) + beta """ gamma, beta = self.bn.weight, self.bn.bias if self.conv_bn_fused: quantized_weight = self.weight_quantizer(self.weight) quantized_bias = self.bias_quantizer(self.bias) return self._conv_forward(x, quantized_weight, quantized_bias, output_size) if self.training and not self.bn_frozen: batch_mean, batch_var = self.batch_stats( self._conv_forward(x, self.weight, output_size=output_size), self.bias) recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps) with torch.no_grad(): running_sigma = torch.sqrt(self.bn.running_var + self.bn.eps) w_corrected = self.weight * self.broadcast_correction_weight( gamma / running_sigma) w_quantized = self.weight_quantizer(w_corrected) recip_c = self.broadcast_correction(running_sigma * recip_sigma_batch) bias_corrected = beta - gamma * batch_mean * recip_sigma_batch bias_quantized = self.broadcast_correction( self.bias_quantizer(bias_corrected)) y = self._conv_forward(x, w_quantized, None, output_size) y.mul_(recip_c).add_(bias_quantized) else: with torch.no_grad(): recip_running_sigma = torch.rsqrt(self.bn.running_var + self.bn.eps) w_corrected = self.weight * self.broadcast_correction_weight( gamma * recip_running_sigma) w_quantized = self.weight_quantizer(w_corrected) mean_corrected = self.bn.running_mean - ( self.bias if self.bias is not None else 0) bias_corrected = beta - gamma * mean_corrected * recip_running_sigma bias_quantized = self.bias_quantizer(bias_corrected) y = self._conv_forward(x, w_quantized, bias_quantized, output_size) return y def train(self, mode=True): """Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.bn_frozen: for module in self.children(): module.train(mode) return self @property def is_quantized(self): return True @classmethod def from_float(cls, conv, bn, rt_spec): """Create a qat module from a float module.""" assert rt_spec, 'Input float module must have a valid rt_spec' convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, rt_spec) convbn.weight = conv.weight convbn.bias = conv.bias convbn.bn.weight = bn.weight convbn.bn.bias = bn.bias convbn.bn.running_mean = bn.running_mean convbn.bn.running_var = bn.running_var convbn.bn.num_batches_tracked = bn.num_batches_tracked convbn.bn.eps = bn.eps return convbn
class Model: '''Model manager class Convenience class wrapping saving/loading, model initialization, optimization, and logging. Args: ann: Model to optimize. Used to initialize weights. config: A Config specification args: Hook for additional user arguments ''' def __init__(self, ann, config): self.saver = save.Saver(config.MODELDIR, 'models', 'bests', resetTol=256) self.config = config print('Initializing new model...') self.net = ann(config) self.parameters = Parameter( torch.Tensor(np.array(getParameters(self.net)))) def load(self, opt, best=False): '''Load a model from file Args: best (bool): Whether to load the best (True) or most recent (False) checkpoint ''' print('Loading model...') epoch = self.saver.load(opt, self.parameters, best) self.syncParameters() def checkpoint(self, opt, reward): '''Save the model to checkpoint Args: reward: Mean reward of the model ''' self.saver.checkpoint(self.parameters, opt, reward) self.saver.print() @property def nParams(self): '''Print the number of model parameters''' nParams = len(self.weights) print('#Params: ', str(nParams / 1000), 'K') def syncParameters(self): parameters = self.parameters.detach().numpy().tolist() setParameters(self.net, parameters) @property def weights(self): '''Get model parameters Returns: a numpy array of model parameters ''' return getParameters(self.net)