class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.detach().uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.detach().uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(NoisyLinear, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.Tensor(out_features, in_features)
        self.weight_epsilon = torch.Tensor(out_features, in_features)
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = torch.Tensor(out_features)
            self.bias_epsilon = torch.Tensor(out_features)
            self.bias_mu = Parameter(torch.Tensor(out_features))
            self.bias_sigma = Parameter(torch.Tensor(out_features))
        else:
            self.bias = None
            self.bias_epsilon = None
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_sigma', None)
        self.reset_parameters()
        self.sampled = False

    def sample(self):
        if self.training:
            self.weight_epsilon.normal_()
            self.weight = self.weight_epsilon.mul(self.weight_sigma).add_(
                self.weight_mu)
            if self.bias is not None:
                self.bias_epsilon.normal_()
                self.bias = self.bias_epsilon.mul(self.bias_sigma).add_(
                    self.bias_mu)
        else:
            self.weight = self.weight_mu.detach()
            if self.bias is not None:
                self.bias = self.bias_mu.detach()
        self.sampled = True

    def reset_parameters(self):
        stdv = math.sqrt(3.0 / self.weight.size(1))
        self.weight_mu.uniform_(-stdv, stdv)
        self.weight_sigma.fill_(0.017)
        if self.bias is not None:
            self.bias_mu.uniform_(-stdv, stdv)
            self.bias_sigma.fill_(0.017)

    def forward(self, input):
        if not self.sampled:
            self.sample()
        return F.linear(input, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #3
0
class MF(nn.Module):
    def __init__(self, num_users, num_items, embedding_size):

        super(MF, self).__init__()

        self.users = Parameter(torch.FloatTensor(num_users, embedding_size))
        self.items = Parameter(torch.FloatTensor(num_items, embedding_size))

        self.init_params()

    def init_params(self):

        stdv = 1. / math.sqrt(self.users.size(1))
        self.users.data.uniform_(-stdv, stdv)
        self.items.data.uniform_(-stdv, stdv)

    def pair_forward(self, user, item_p, item_n):

        user = self.users[user]
        item_p = self.items[item_p]
        item_n = self.items[item_n]

        p_score = torch.sum(user * item_p, 2)
        n_score = torch.sum(user * item_n, 2)

        return p_score, n_score

    def point_forward(self, user, item):

        user = self.users[user]
        item = self.items[item]

        score = torch.sum(user * item, 2)

        return score

    def get_item_embeddings(self):

        return self.items.detach().cpu().numpy().astype('float32')

    def get_user_embeddings(self):

        return self.users.detach().cpu().numpy().astype('float32')
Beispiel #4
0
class CapsuleShare(ModuleParall):
    def __init__(self, in_feature, out_feature, routings=3, bias=True, retain_grad=False):
        super(CapsuleShare, self).__init__()
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.routings = routings
        if routings < 1:
            raise ValueError('Routing should be at least 1!')
        self.retain_grad = retain_grad
        b = Variable(torch.zeros(1, 1, self.out_feature)).type(FloatTensor)
        self.c = F.softmax(b, 2)
        self.weight = Parameter(torch.Tensor(self.in_feature, 1, self.out_feature))
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_feature))
        else:
            self.bias = None
            self.forward = self.no_bias
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, u):
        c = self.c
        u_hat = u.unsqueeze(-1) * self.weight
        u_hat = u_hat.view(-1,self.in_feature * u.size(-1),self.out_feature)
        u, bias = (u_hat,self.bias) if self.retain_grad else (u_hat.detach(),self.bias.detach())
        b = 0.
        for _ in range(self.routings - 1):
            v = squash(torch.sum(u * c, dim=1) + bias, dim=1)
            b = b + u * v.view(-1, 1, self.out_feature)
            c = F.softmax(b, 2)
        v = squash(torch.sum(u_hat * c, dim=1) + self.bias, dim=1)
        return v

    def no_bias(self, u):
        c = self.c
        u_hat = u.unsqueeze(-1) * self.weight
        u_hat = u_hat.view(-1,self.in_feature * u.size(-1),self.out_feature)
        u = u_hat if self.retain_grad else u_hat.detach()
        b = 0.
        for _ in range(self.routings - 1):
            v = squash(torch.sum(u * c, dim=1), dim=1)
            b = b + u * v.view(-1, 1, self.out_feature)
            c = F.softmax(b, 2)
        v = squash(torch.sum(u_hat * c, dim=1), dim=1)
        return v

    def extra_repr(self):
        s = ('{in_feature}, {out_feature}, routings={routings}'
             ', bias='+str(self.bias is not None)+', retain_grad={retain_grad}')
        return s.format(**self.__dict__)
Beispiel #5
0
class Conv2d_filtermap_1x1_compression(_ConvNd_filtermap_1x1_compression):
    def __init__(self,
                 in_channels,
                 out_channels,
                 channel_compression_1x1,
                 kernel_size,
                 binary_filtermap=False,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        super(Conv2d_filtermap_1x1_compression,
              self).__init__(in_channels, out_channels,
                             channel_compression_1x1, kernel_size, stride,
                             padding, dilation, False, _pair(0), groups, bias)

        self.binary_filtermap = binary_filtermap
        self.reset_parameters1()
        #pdb.set_trace()
        #self.conv_weight = Parameter(torch.Tensor(out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #self.conv_weight.requires_grad = False
        #self.conv_weight.cuda()
    def reset_parameters1(self):
        fm_size = self.filtermap.size()
        fm_width = fm_size[2]
        fm_height = fm_size[1]
        fm_depth = fm_size[0]
        # not for 1x1 conv, do the padding on the spatial
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1:
            self.fm_pad_width = fm_width + 1
            self.fm_pad_height = fm_height + 1
        #for 1x1 conv no padding on the spatial
        else:
            self.fm_pad_width = fm_width
            self.fm_pad_height = fm_height

        self.fm_pad_depth = fm_depth * 2
        #set the ids for extracting filters from filtermap
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c

        fm_depth = self.fm_pad_depth
        fm_height = self.fm_pad_height
        fm_width = self.fm_pad_width

        ids = (torch.Tensor(range(0, k_h * k_w)))
        tmp_count = 0
        for y in range(0, k_h):
            for x in range(0, k_w):
                ids[tmp_count] = y * fm_width + x
                tmp_count = tmp_count + 1

        ids0 = ids

        #pdb.set_trace()
        for c in range(1, in_channels):
            ids_c = ids0 + c * fm_height * fm_width
            ids = torch.cat((ids, ids_c), 0)

        #ids0 = ids
        #for x in range(1, out_channels):
        #    ids = torch.cat((ids,ids0),0)
        #pdb.set_trace()
        ids0 = ids
        for y in range(0, sample_y):
            for x in range(0, sample_x):
                if y == 0 and x == 0:
                    continue
                ss = y * stride_y * fm_width + x * stride_x
                ids_ss = ids0 + ss
                ids = torch.cat((ids, ids_ss), 0)

        #pdb.set_trace()
        ids0 = ids
        for c in range(1, sample_c):
            ids_c = ids0 + c * stride_c * fm_height * fm_width
            ids = torch.cat((ids, ids_c), 0)

        #pdb.set_trace()
        #ids = ids.long()
        #ids = ids.detach()

        #pdb.set_trace()
        ids = ids.long()
        self.ids = Parameter(ids)
        self.ids.requires_grad = False
        #self.register_parameter()
        #if torch.max(ids) >= fm_depth*fm_height*fm_width or torch.min(ids) < 0:
        #print(torch.max(ids))
        #ids = Variable(ids)
    def extract_filters(self):
        #pdb.set_trace()

        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        #for compressing the channel by 2 times
        #if in_channels != 3:
        #   filtermap_pad_tmp = torch.cat((self.filtermap,self.filtermap),0)
        #else:
        #   filtermap_pad_tmp = self.filtermap
        #filtermap_pad = torch.cat((filtermap_pad_tmp,filtermap_pad_tmp),0)

        #for not compressing the channel
        filtermap_pad = torch.cat((self.filtermap, self.filtermap), 0)
        # not for 1x1 conv, do the padding on the spatial
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1:
            filtermap_pad_s1 = filtermap_pad[:, 1, :]
            filtermap_pad_s1 = filtermap_pad_s1[:, None, :]
            filtermap_pad = torch.cat((filtermap_pad, filtermap_pad_s1), 1)
            filtermap_pad_s2 = filtermap_pad[:, :, 1]
            filtermap_pad_s2 = filtermap_pad_s2[:, :, None]
            filtermap_pad = torch.cat((filtermap_pad, filtermap_pad_s2), 2)

        #pdb.set_trace()
        ids = self.ids.detach()
        conv_weight = filtermap_pad.view(-1, 1).index_select(0, ids)
        conv_weight = conv_weight.view(out_channels, in_channels, k_h, k_w)
        if self.binary_filtermap:
            binary_conv_weight = conv_weight.clone()
            for nf in range(0, out_channels):
                float_filter = conv_weight[nf, :, :, :]
                L1_norm = torch.norm(float_filter.view(-1, 1), 1)
                sign_filter = torch.sign(float_filter)
                binary_filter = sign_filter * L1_norm
                binary_conv_weight[nf, :, :, :] = binary_filter
            return binary_conv_weight
        else:
            return conv_weight
        #pdb.set_trace()
        #for c in range(0,sample_c):
        #   for y in range(0,sample_y):
        #      for x in range(0,sample_x):
        #          filter_count = c*sample_y*sample_x + y*sample_x + x
        #          conv_weight_clone[filter_count,:,:,:] = filtermap_pad[c*stride_c:c*stride_c+in_channels, \
        #                                      y*stride_y:y*stride_y+k_h, x*stride_x:x*stride_x+k_w]
        #return conv_weight
    def forward(self, input):
        #return F.conv2d(input, self.weight, self.bias, self.stride,
        #                self.padding, self.dilation, self.groups)

        #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat))

        #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat))

        #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \
        #        self.kernel_size[1], self.out_channels);
        #conv_weight = conv_weight.permute(3, 0, 1, 2)
        #conv_weight = conv_weight.contiguous()

        #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2)

        #pdb.set_trace()
        #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #conv_weight.cuda()
        conv_weight = self.extract_filters()
        #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \
        #                                      0:0+self.kernel_size[0], 0:0+self.kernel_size[1]]
        out = F.conv2d(input, conv_weight, self.bias, self.stride,
                       self.padding, self.dilation, self.groups)

        return out
        #return F.conv2d(input, conv_weight, self.bias, self.stride,
        #                 self.padding, self.dilation, self.groups)
    def fit_filtermap1(self, conv):
        conv_weight = conv.weight.data
        conv_weight_size = conv_weight.size()
        out_channels = conv_weight_size[0]
        in_channels = conv_weight_size[1]
        k_h = conv_weight_size[2]
        k_w = conv_weight_size[3]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c

        fm_ext = torch.Tensor(sample_c * in_channels, sample_y * k_h,
                              sample_x * k_w)
        for c in range(0, sample_c):
            for y in range(0, sample_y):
                for x in range(0, sample_x):
                    filter_count = c * sample_y * sample_x + y * sample_x + x
                    fm_ext[c * in_channels:(c + 1) * in_channels,
                           y * k_h:(y + 1) * k_h, x * k_w:(x + 1) *
                           k_w] = conv_weight[filter_count, :, :, :]

        #pdb.set_trace()
        for oc in range(0, sample_c):
            for c in range(1, sample_c):
                idx = sample_c - c + oc
                if idx > sample_c - 1:
                    idx = 0
                fm_ext[oc*stride_c:(oc+1)*stride_c,:,:] += \
                   fm_ext[idx*stride_c:(idx+1)*stride_c,:,:]

        fm_ext = fm_ext[0:in_channels, :, :] / sample_c

        fm_ext_h = fm_ext.size()[1]
        fm_ext_w = fm_ext.size()[2]

        for y in range(0, sample_y):
            fm_ext[:, y * k_h, :] += fm_ext[:, (y * k_h - 1 + fm_ext_h) %
                                            fm_ext_h, :]
            fm_ext[:, y * k_h, :] = fm_ext[:, y * k_h, :] / 2

        for x in range(0, sample_x):
            fm_ext[:, :,
                   x * k_w] += fm_ext[:, :,
                                      (x * k_w - 1 + fm_ext_w) % fm_ext_w]
            fm_ext[:, :, x * k_w] = fm_ext[:, :, x * k_w] / 2

        y_ids = torch.Tensor([0, 1])
        y_ids0 = y_ids
        for y in range(1, sample_y):
            y_ids = torch.cat((y_ids, y_ids0 + y * k_h), 0)

        x_ids = torch.Tensor([0, 1])
        x_ids0 = x_ids
        for x in range(1, sample_x):
            x_ids = torch.cat((x_ids, x_ids0 + x * k_w), 0)

        fm = torch.index_select(fm_ext, 1, y_ids.long())
        fm = torch.index_select(fm, 2, x_ids.long())

        self.filtermap.data = fm
class _LearnableFakeQuantize(nn.Module):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.

    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.

    * :attr: `channel_len` defines the length of the channel when initializing scale and zero point
             for the per channel case.

    * :attr: `use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
              normalized by the constant, which is proportional to the square root of the number of
              elements in the tensor. The related literature justifying the use of this particular constant
              can be found here: https://openreview.net/pdf?id=rkgO66VKDS.

    * :attr: `fake_quant_enabled` defines the flag for enabling fake quantization on the output.

    * :attr: `static_enabled` defines the flag for using observer's static estimation for
             scale and zero point.

    * attr: `learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
    """
    def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
                 use_grad_scaling=False, **observer_kwargs):
        super(_LearnableFakeQuantize, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max
        self.use_grad_scaling = use_grad_scaling

        if channel_len == -1:
            self.scale = Parameter(torch.tensor([scale]))
            self.zero_point = Parameter(torch.tensor([zero_point]))
        else:
            assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
            self.scale = Parameter(torch.tensor([scale] * channel_len))
            self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))

        self.activation_post_process = observer(**observer_kwargs)
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
               'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
               'quant_max out of bound'
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())

    @torch.jit.export
    def enable_param_learning(self):
        r"""Enables learning of quantization parameters and
        disables static observer estimates. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=True) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=False)
        return self

    @torch.jit.export
    def enable_static_estimate(self):
        r"""Enables static observer estimates and disbales learning of
        quantization parameters. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def enable_static_observation(self):
        r"""Enables static observer accumulating data from input but doesn't
        update the quantization parameters. Forward path returns the original X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=False) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def toggle_observer_update(self, enabled=True):
        self.static_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def toggle_qparam_learning(self, enabled=True):
        self.learning_enabled[0] = int(enabled)
        self.scale.requires_grad = enabled
        self.zero_point.requires_grad = enabled
        return self

    @torch.jit.export
    def toggle_fake_quant(self, enabled=True):
        self.fake_quant_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def observe_quant_params(self):
        print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach()))
        print('_LearnableFakeQuantize Zero Point: {}'.format(self.zero_point.detach()))

    @torch.jit.export
    def calculate_qparams(self):
        return self.activation_post_process.calculate_qparams()

    def forward(self, X):
        self.activation_post_process(X.detach())
        _scale, _zero_point = self.calculate_qparams()
        _scale = _scale.to(self.scale.device)
        _zero_point = _zero_point.to(self.zero_point.device)

        if self.static_enabled[0] == 1:
            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point)

        if self.fake_quant_enabled[0] == 1:
            if self.learning_enabled[0] == 1:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (self.weight.numel() * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                if self.qscheme in (
                        torch.per_channel_symmetric, torch.per_channel_affine):
                    X = _LearnableFakeQuantizePerChannelOp.apply(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
                else:
                    X = _LearnableFakeQuantizePerTensorOp.apply(
                        X, self.scale, self.zero_point,
                        self.quant_min, self.quant_max, grad_factor)
            else:
                if self.qscheme == torch.per_channel_symmetric or \
                        self.qscheme == torch.per_channel_affine:
                    X = torch.fake_quantize_per_channel_affine(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max)
                else:
                    X = torch.fake_quantize_per_tensor_affine(
                        X, float(self.scale.item()), int(self.zero_point.item()),
                        self.quant_min, self.quant_max)

        return X

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        # We will be saving the static state of scale (instead of as a dynamic param).
        super(_LearnableFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + 'scale'] = self.scale.data
        destination[prefix + 'zero_point'] = self.zero_point

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        local_state = ['scale', 'zero_point']
        for name in local_state:
            key = prefix + name
            if key in state_dict:
                val = state_dict[key]
                if name == 'scale':
                    self.scale.data.copy_(val)
                else:
                    setattr(self, name, val)
            elif strict:
                missing_keys.append(key)
        super(_LearnableFakeQuantize, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys,
            unexpected_keys, error_msgs)

    with_args = classmethod(_with_args)
Beispiel #7
0
class LabelFts(nn.Module):
    """
        Class for encoding label features
        Arguments
            ----------
            input_size: int
                embeddings dimension for the label representation
            label_features: int
                number of token in the label text
            padding_idx: int (default=None)
            device: string (default="cuda:0")
            sparse: bool (default=False)
            fixed_features: bool (default=False)
            use_external_weights: bool (default=False)
            transform: nn.Module
                transformation over label features
        Returns:
            nn.Module
                Network block to encode label features
    """
    def __init__(self,
                 input_size,
                 label_features,
                 device="cuda:0",
                 sparse=False,
                 fixed_features=False,
                 use_external_weights=False,
                 transform=None):
        super(LabelFts, self).__init__()
        self.device = device  # Useful in case of multiple GPUs
        self.input_size = input_size
        self.label_features = label_features
        self.use_external_weights = use_external_weights
        self.fixed_features = fixed_features
        if not self.use_external_weights:
            self.weight = Parameter(
                torch.Tensor(self.label_features, self.input_size))
            self.sparse = True
        else:
            self.sparse = sparse
        self.Rpp = transform
        self.reset_parameters()

    def _get_clf(self, labels, features_shortlist=None, weights=None):
        if features_shortlist is not None:
            weights = F.embedding(features_shortlist,
                                  weights,
                                  sparse=self.sparse,
                                  padding_idx=None).squeeze()
        lbl_clf = labels.to(weights.device).mm(weights)
        return lbl_clf

    def forward(self, labels, features_shortlist=None, weight=None):
        if self.fixed_features:
            return self.Rpp(labels)
        else:
            if not self.use_external_weights:
                weight = self.weight
            return self.Rpp(self._get_clf(labels, features_shortlist, weight))

    def to(self, device=None):
        if device is None:
            super().to(self.device)
        else:
            super().to(device)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.input_size)
        if not self.use_external_weights:
            self.weight.data.uniform_(-stdv, stdv)

    def __repr__(self):
        s = "{name}({input_size}, {label_features}, {sparse}, {device}, {fixed_features}"
        if not self.use_external_weights:
            s += ", weight={}".format(self.weight.detach().cpu().numpy().shape)
        s += "\n%s" % (self.Rpp.__repr__())
        s += ")"
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def _init_(self, state_dict):
        """
            Initilizes model parameters
        """
        if not self.use_external_weights:
            keys = list(state_dict.keys())
            key = [x for x in keys if x.split(".")[-1] in ['weight']][0]
            weight = state_dict[key]
            self.weight.data.copy_(weight)
Beispiel #8
0
class Linear(torch.nn.Module):
    r"""Applies a linear tranformation to the incoming data

    .. math::
        \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}

    similar to :class:`torch.nn.Linear`.
    It supports lazy initialization and customizable weight and bias
    initialization.

    Args:
        in_channels (int): Size of each input sample. Will be initialized
            lazily in case it is given as :obj:`-1`.
        out_channels (int): Size of each output sample.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        weight_initializer (str, optional): The initializer for the weight
            matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`
            or :obj:`None`).
            If set to :obj:`None`, will match default weight initialization of
            :class:`torch.nn.Linear`. (default: :obj:`None`)
        bias_initializer (str, optional): The initializer for the bias vector
            (:obj:`"zeros"` or :obj:`None`).
            If set to :obj:`None`, will match default bias initialization of
            :class:`torch.nn.Linear`. (default: :obj:`None`)

    Shapes:
        - **input:** features :math:`(*, F_{in})`
        - **output:** features :math:`(*, F_{out})`
    """
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 bias: bool = True,
                 weight_initializer: Optional[str] = None,
                 bias_initializer: Optional[str] = None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer

        if in_channels > 0:
            self.weight = Parameter(torch.Tensor(out_channels, in_channels))
        else:
            self.weight = nn.parameter.UninitializedParameter()
            self._hook = self.register_forward_pre_hook(
                self.initialize_parameters)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._load_hook = self._register_load_state_dict_pre_hook(
            self._lazy_load_hook)

        self.reset_parameters()

    def __deepcopy__(self, memo):
        out = Linear(self.in_channels, self.out_channels, self.bias
                     is not None, self.weight_initializer,
                     self.bias_initializer)
        if self.in_channels > 0:
            out.weight = copy.deepcopy(self.weight, memo)
        if self.bias is not None:
            out.bias = copy.deepcopy(self.bias, memo)
        return out

    def reset_parameters(self):
        if self.in_channels <= 0:
            pass
        elif self.weight_initializer == 'glorot':
            inits.glorot(self.weight)
        elif self.weight_initializer == 'uniform':
            bound = 1.0 / math.sqrt(self.weight.size(-1))
            torch.nn.init.uniform_(self.weight.data, -bound, bound)
        elif self.weight_initializer == 'kaiming_uniform':
            inits.kaiming_uniform(self.weight,
                                  fan=self.in_channels,
                                  a=math.sqrt(5))
        elif self.weight_initializer is None:
            inits.kaiming_uniform(self.weight,
                                  fan=self.in_channels,
                                  a=math.sqrt(5))
        else:
            raise RuntimeError(f"Linear layer weight initializer "
                               f"'{self.weight_initializer}' is not supported")

        if self.bias is None or self.in_channels <= 0:
            pass
        elif self.bias_initializer == 'zeros':
            inits.zeros(self.bias)
        elif self.bias_initializer is None:
            inits.uniform(self.in_channels, self.bias)
        else:
            raise RuntimeError(f"Linear layer bias initializer "
                               f"'{self.bias_initializer}' is not supported")

    def forward(self, x: Tensor) -> Tensor:
        r"""
        Args:
            x (Tensor): The features.
        """
        return F.linear(x, self.weight, self.bias)

    @torch.no_grad()
    def initialize_parameters(self, module, input):
        if is_uninitialized_parameter(self.weight):
            self.in_channels = input[0].size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            self.reset_parameters()
        self._hook.remove()
        delattr(self, '_hook')

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        if is_uninitialized_parameter(self.weight):
            destination[prefix + 'weight'] = self.weight
        else:
            destination[prefix + 'weight'] = self.weight.detach()
        if self.bias is not None:
            destination[prefix + 'bias'] = self.bias.detach()

    def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
                        missing_keys, unexpected_keys, error_msgs):

        weight = state_dict[prefix + 'weight']
        if is_uninitialized_parameter(weight):
            self.in_channels = -1
            self.weight = nn.parameter.UninitializedParameter()
            if not hasattr(self, '_hook'):
                self._hook = self.register_forward_pre_hook(
                    self.initialize_parameters)

        elif is_uninitialized_parameter(self.weight):
            self.in_channels = weight.size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            if hasattr(self, '_hook'):
                self._hook.remove()
                delattr(self, '_hook')

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, bias={self.bias is not None})')
Beispiel #9
0
class Model:
    def __init__(self, config, args):
        self.saver = save.Saver(config.NPOP,
                                config.MODELDIR,
                                'models',
                                'bests',
                                resetTol=256)
        self.config, self.args = config, args

        self.init()
        if self.config.LOAD or self.config.BEST:
            self.load(self.config.BEST)

    def init(self):
        print('Initializing new model...')
        if self.config.SHAREINIT:
            self.shared(self.config.NPOP)
        else:
            self.unshared(self.config.NPOP)

        self.params = Parameter(torch.Tensor(np.array(self.models)))
        self.opt = None
        if not self.config.TEST:
            self.opt = ManualAdam([self.params],
                                  lr=0.001,
                                  weight_decay=0.00001)

    # Initialize a new network
    def initModel(self):
        return getParameters(trinity.ANN(self.config))

    def shared(self, n):
        model = self.initModel()
        self.models = [model for _ in range(n)]

    def unshared(self, n):
        self.models = [self.initModel() for _ in range(n)]

    # Grads and clip
    def stepOpt(self, gradDicts):
        grads = defaultdict(list)
        keysets = [grads.keys() for grads in gradDicts]
        for gradDict in gradDicts:
            for worker, grad in gradDict.items():
                grads[worker].append(grad)
        for worker, gradList in grads.items():
            grad = np.array(gradList)
            grad = np.mean(grad, 0)
            grad = np.clip(grad, -5, 5)
            grads[worker] = grad
        gradAry = torch.zeros_like(self.params)
        for worker, grad in grads.items():
            gradAry[worker] = torch.Tensor(grad)
        self.opt.step(gradAry)

    def checkpoint(self, reward):
        if self.config.TEST:
            return
        self.saver.checkpoint(self.params, self.opt, reward)

    def load(self, best=False):
        print('Loading model...')
        epoch = self.saver.load(self.opt, self.params, best)

    @property
    def nParams(self):
        nParams = sum([len(e) for e in self.model])
        print('#Params: ', str(nParams / 1000), 'K')

    @property
    def model(self):
        return self.params.detach().numpy()
Beispiel #10
0
class MimoLinearDynamicalOperator(torch.nn.Module):
    r"""Applies a multi-input-multi-output linear dynamical filtering operation.

    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        n_b (int): Number of learnable coefficients of the transfer function numerator
        n_a (int): Number of learnable coefficients of the transfer function denominator
        n_k (int, optional): Number of input delays in the numerator. Default: 0

    Shape:
        - Input: (batch_size, seq_len, in_channels)
        - Output: (batch_size, seq_len, out_channels)

    Attributes:
        b_coeff (Tensor): The learnable coefficients of the transfer function numerator
        a_coeff (Tensor): The learnable coefficients of the transfer function denominator

    Examples::

        >>> in_channels, out_channels = 2, 4
        >>> n_b, n_a, n_k = 2, 2, 1
        >>> G = MimoLinearDynamicalOperator(in_channels, out_channels, n_b, n_a, n_k)
        >>> batch_size, seq_len = 32, 100
        >>> u_in = torch.ones((batch_size, seq_len, in_channels))
        >>> y_out = G(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels)
    """
    def __init__(self, in_channels, out_channels, n_b, n_a, n_k=0):
        super(MimoLinearDynamicalOperator, self).__init__()
        self.b_coeff = Parameter(torch.zeros(out_channels, in_channels, n_b))
        self.a_coeff = Parameter(torch.zeros(out_channels, in_channels, n_a))
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.n_a = n_a
        self.n_b = n_b
        self.n_k = n_k

        with torch.no_grad():
            init_range = 0.01
            self.a_coeff[:] = (torch.rand(self.a_coeff.shape) -
                               0.5) * 2 * init_range
            self.b_coeff[:] = (torch.rand(self.b_coeff.shape) -
                               0.5) * 2 * init_range

    def forward(self, u_in, y_0=None, u_0=None):
        if self.n_k != 0:
            #u_d = u_in.roll(self.n_k, dims=-2)  # roll on the time axis
            #u_d[..., 0:self.n_k, :] = 0.0  # input sequence with delay
            u_d = torch.empty_like(u_in)
            u_d[..., self.n_k:, :] = u_in[:, :-self.n_k, :]
            u_d[..., 0:self.n_k, :] = 0.0
        else:
            u_d = u_in
        return MimoLinearDynamicalOperatorFun.apply(self.b_coeff, self.a_coeff,
                                                    u_d, y_0, u_0)

    def get_filtdata(self):
        r"""Returns the numerator and denominator coefficients of the transfer function :math:`q^{-1}`-polynomials.

        The polynomials are function of the variable :math:`q^{-1}`.
        The polynomial coefficients b and a have length m and n, respectively and are sorted in descending power order.

        For a certain input channel :math:`i` and output channel :math:`o`, the  corresponding transfer
        function :math:`G_{i\rightarrow o}(z)` is:

        .. math::
            G_{i\rightarrow o}(z) = q^{-n_k}\frac{b[o, i, 0] + b[o, i, 1]q^{-1} + \dots + b[o, i, n]q^{-m+1}}
            {a[o, i, 0] + a[o, i, 1]q^{-1} + \dots + a[o, i, n]q^{-n+1}}

        Returns:
            np.array(in_channels, out_channels, m), np.array(in_channels, out_channels, n):
                numerator :math:`\beta` and denominator :math:`\alpha` polynomial coefficients of the transfer function.


        Examples::

            >>> num, den = G.get_tfdata()
            >>> G_tf = control.TransferFunction(G2_num, G2_den, ts=1.0)
        """
        return self.__get_filtdata__()

    def get_tfdata(self):
        r"""Returns the numerator and denominator coefficients of the transfer function :math:`z`-polynomials.

        The polynomials are function of the variable Z-transform variable :math:`z`.
        The polynomial coefficients :math::`\beta` and :math:`\alpha` have equal length p and are sorted in descending power order.

        For a certain input channel :math:`i` and output channel :math:`o`, the  corresponding transfer
        function :math:`G_{i\rightarrow o}(z)` is:

        .. math::
            G_{i\rightarrow o}(z) = \frac{\beta[o, i, 0]z^{n-1} + \beta[o, i, 1]z^{n-1} + \dots + \beta[o, i, p]}{\alpha[o, i, 0]z^{n-1} + \alpha[o, i, 1]z^{n-2} + \dots + \alpha[o, i, p]}

        Returns:
            np.array(in_channels, out_channels, p), np.array(in_channels, out_channels, p):
                numerator :math:`\beta` and denominator :math:`\alpha` polynomial coefficients of the transfer function.


        Examples::

            >>> num, den = G.get_tfdata()
            >>> G_tf = control.TransferFunction(G2_num, G2_den, ts=1.0)
        """
        return self.__get_tfdata__()

    def __get_filtdata__(self):
        # returns the coefficients of the polynomials b and a as function of q^{-1}
        b_coeff_np, a_coeff_np = self.__get_ba_coeff__()
        b_seq = np.zeros_like(b_coeff_np,
                              shape=(self.out_channels, self.in_channels,
                                     self.n_b + self.n_k))  #b_coeff_np
        b_seq[:, :, self.n_k:] = b_coeff_np[:, :, :]
        a_seq = np.empty_like(a_coeff_np,
                              shape=(self.out_channels, self.in_channels,
                                     self.n_a + 1))
        a_seq[:, :, 0] = 1
        a_seq[:, :, 1:] = a_coeff_np[:, :, :]
        return b_seq, a_seq

    def __get_tfdata__(self):
        b_seq, a_seq = self.__get_filtdata__()
        M = self.n_b + self.n_k  # number of numerator coefficients of the q^{-1} polynomial
        N = self.n_a + 1  # number of denominator coefficients of the q^{-1} polynomial
        if M > N:
            num = b_seq
            den = np.c_[a_seq,
                        np.zeros((self.out_channels, self.in_channels, M - N))]
        elif N > M:
            num = np.c_[b_seq,
                        np.zeros((self.out_channels, self.in_channels, N - M))]
            den = a_seq
        else:  # N == M
            num = b_seq
            den = a_seq

        return num, den

    def __get_ba_coeff__(self):
        return self.b_coeff.detach().numpy(), self.a_coeff.detach().numpy()
Beispiel #11
0
class NESConv2d(nn.Conv2d):
    def __init__(self,
                 in_planes,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 multiplier=1.0,
                 spatial_multiplier=1.,
                 rep_dim=1,
                 repeat_weight=True,
                 use_coeff=False):
        if rep_dim == 0:
            # this is repeat along the channel dim (dimension 0 of the weights tensor)
            super(NESConv2d,
                  self).__init__(int(in_planes),
                                 int(np.ceil(out_channels / multiplier)),
                                 int(kernel_size * spatial_multiplier),
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(
                np.ceil(1. * out_channels / self.weight.shape[0]))
        elif rep_dim == 1:
            # this is to repeat along the filter dim(dimension 1 of the weights tensor)
            super(NESConv2d,
                  self).__init__(int(np.ceil(in_planes / multiplier)),
                                 int(out_channels),
                                 int(kernel_size * spatial_multiplier),
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1]))

        self.in_planes = in_planes
        self.out_channels_ori = out_channels
        self.groups = groups
        self.multiplier = multiplier
        self.spatial_multiplier = spatial_multiplier
        # specify the range for the w and h direction
        self.kernel_size = kernel_size
        self.w_wange = kernel_size * (spatial_multiplier - 1)
        self.h_wange = kernel_size * (spatial_multiplier - 1)

        self.rep_dim = rep_dim
        self.repeat_weight = repeat_weight
        self.use_coeff = use_coeff
        # print(self.weight.shape)
        # import pdb; pdb.set_trace()
        if spatial_multiplier > 1:
            out_num = int(self.rep_time * 3)
            self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                               in_planes,
                                               kernel_size=3,
                                               stride=2,
                                               padding=0,
                                               bias=False)
            self.bn1 = nn.BatchNorm2d(in_planes)
            self.relu = nn.ReLU6(inplace=True)
            self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                               out_num,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0,
                                               bias=False)
            self.bn2 = nn.BatchNorm2d(out_num)
            self.coefficient = Parameter(torch.Tensor(out_num),
                                         requires_grad=False)
        else:
            out_num = int(self.rep_time)
            self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                               in_planes,
                                               kernel_size=3,
                                               stride=2,
                                               padding=0,
                                               bias=False)
            self.bn1 = nn.BatchNorm2d(in_planes)
            self.relu = nn.ReLU6(inplace=True)
            self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                               out_num,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0,
                                               bias=False)
            self.bn2 = nn.BatchNorm2d(out_num)
            self.coefficient = Parameter(torch.Tensor(out_num),
                                         requires_grad=False)
        self.reuse = False
        self.coeff_grad = None

    def generate_share_weight(self,
                              base_weight,
                              rep_num,
                              coeff,
                              nchannel,
                              dim=0):
        ''' sample weights from base weight'''
        # pdb.set_trace()
        if rep_num == 1:
            return base_weight
        new_weight = []
        for i in range(rep_num):
            w_idx = coeff(i * 3 + 1) * self.w_wange
            start_idx_w = int(w_idx)
            end_idx_w = start_idx_w + self.kernel_size
            w_frac = w_idx - start_idx_w

            h_idx = coeff(i * 3 + 2) * self.h_wange
            start_idx_h = int(h_idx)
            end_idx_h = start_idx_h + self.kernel_size
            h_frac = h_idx - start_idx_h

            new_weight_temp = torch.cat(
                [base_weight[:, :, :, :], base_weight[:, :, :, :]],
                dim=2)[:, :, start_idx_w:end_idx_w, :] * (
                    1 - w_frac) + base_weight * w_frac
            new_weight_temp = torch.cat(
                [base_weight[:, :, :, :], base_weight[:, :, :, :]],
                dim=3)[:, :, :, start_idx_h:end_idx_h] * (
                    1 - h_frac) + base_weight * h_frac
            if dim == 0:
                new_weight_temp = torch.cat(
                    [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]],
                    dim=0) * (1 - coeff[int(i * 3)])
            else:
                new_weight_temp = torch.cat(
                    [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]],
                    dim=1) * (1 - coeff[i])
            new_weight.append(base_weight * coeff[int(i * 3)] +
                              new_weight_temp)
        out = torch.cat(new_weight, dim=dim)

        if dim == 0:
            return out[:nchannel, :, :, :]
        else:
            return out[:, :nchannel, :, :]

    def forward(self, x):
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])

        if self.use_coeff:
            if self.training:
                # set reuse to True for coefficient sharing
                if not self.reuse:
                    lr_conv1 = self.relu(self.bn1(self.conv1_stride_lr_1(x)))
                    # pdb.set_trace()
                    lr_conv1 = self.bn2(self.conv1_stride_lr_2(lr_conv1))
                    lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0,
                                                                       0]

                    self.coefficient.set_(
                        F.normalize(torch.mean(lr_conv1, 0),
                                    dim=0).clone().detach())
                    # pdb.set_trace()
                    self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0),
                                                  dim=0)

                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       self.out_channels_ori,
                                                       dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0)))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       x.shape[1],
                                                       dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, x,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        out = F.conv2d(out_tmp, self.weight)

            else:
                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                self.out_channels_ori,
                                dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach()))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                x.shape[1],
                                dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)

                        out = F.conv2d(out_tmp, self.weight)
        else:
            # print("use_coeff == False")
            if self.rep_dim == 0:
                out = F.conv2d(
                    x,
                    self.weight.repeat([self.rep_time, 1, 1,
                                        1])[:self.out_channels_ori, :, :, :],
                    None, 1)
            else:
                out = F.conv2d(
                    x,
                    self.weight.repeat([1, self.rep_time, 1,
                                        1])[:, :x.shape[1], :, :], None, 1)
        return out
Beispiel #12
0
class Embedding(torch.nn.Module):
    """
    General way to handle embeddings

    * Support for sequential models
    * Memory efficient way to compute weighted EmbeddingBag

    Arguments:
    ----------
    num_embeddings: int
        vocalubary size
    embedding_dim: int
        dimension for embeddings
    padding_idx: 0 or None, optional (default=None)
        index for <PAD>; embedding is not updated
    max_norm: None or float, optional (default=None)
        maintain norm of embeddings
    norm_type: int, optional (default=2)
        norm for max_norm
    scale_grad_by_freq: boolean, optional (default=False)
        Scale gradients by token frequency
    sparse: boolean, optional (default=False)
        sparse or dense gradients
        * the optimizer will infer from this parameters
    reduction: str or None, optional (default=None)
        * None: don't reduce
        * sum: sum over tokens
        * mean: mean over tokens
    pretrained_weights: torch.Tensor or None, optional (default=None)
        Initialize with these weights
        * first token is treated as a padding index
        * dim=1 should be one less than the num_embeddings
    device: str, optional (default="cuda:0")
        Keep embeddings on this device
    """
    def __init__(self,
                 num_embeddings,
                 embedding_dim,
                 padding_idx=None,
                 max_norm=None,
                 norm_type=2,
                 scale_grad_by_freq=False,
                 sparse=False,
                 reduction=True,
                 pretrained_weights=None,
                 device="cuda:0"):
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
        self.sparse = sparse
        self.reduce = self._construct_reduce(reduction)
        self.reduction = reduction
        self.device = torch.device(device)
        self.reset_parameters()
        if pretrained_weights is not None:
            self.from_pretrained(pretrained_weights)

    def _construct_reduce(self, reduction):
        if reduction is None:
            return self._reduce
        elif reduction == 'sum':
            return self._reduce_sum
        elif reduction == 'mean':
            return self._reduce_mean
        else:
            return NotImplementedError(f"Unknown reduction: {reduction}")

    def reset_parameters(self):
        """
            Reset weights
        """
        torch.nn.init.xavier_uniform_(
            self.weight.data, gain=torch.nn.init.calculate_gain('relu'))
        if self.padding_idx is not None:
            self.weight.data[self.padding_idx].fill_(0)

    def to(self):
        super().to(self.device)

    def _reduce_sum(self, x, w):
        if w is None:
            return torch.sum(x, dim=1)
        else:
            return torch.sum(x * w.unsqueeze(2), dim=1)

    def _reduce_mean(self, x, w):
        if w is None:
            return torch.mean(x, dim=1)
        else:
            return torch.mean(x * w.unsqueeze(2), dim=1)

    def _reduce(self, x, *args):
        return x

    def forward(self, x, w=None):
        """
        Forward pass for embedding layer

        Arguments:
        ---------
        x: torch.LongTensor
            indices of tokens in a batch
            (batch_size, max_features_in_a_batch)
        w: torch.Tensor or None, optional (default=None)
            weights of tokens in a batch
            (batch_size, max_features_in_a_batch)

        Returns:
        --------
        out: torch.Tensor
            embedding for each sample
            Shape: (batch_size, seq_len, embedding_dims), if reduction is None
            Shape: (batch_size, embedding_dims), otherwise
        """
        x = F.embedding(x, self.weight, self.padding_idx, self.max_norm,
                        self.norm_type, self.scale_grad_by_freq, self.sparse)
        return self.reduce(x, w)

    def from_pretrained(self, embeddings):
        # first index is treated as padding index
        assert embeddings.shape[0] == self.num_embeddings-1, \
            "Shapes doesn't match for pre-trained embeddings"
        self.weight.data[1:, :] = torch.from_numpy(embeddings)

    def get_weights(self):
        return self.weight.detach().cpu().numpy()[1:, :]

    def __repr__(self):
        s = '{name}({num_embeddings}, {embedding_dim}, {device}'
        s += ', reduction={reduction}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #13
0
class NeuroRNN(nn.Module):
    def __init__(self, input_size, output_dim, alpha_r, alpha_s, nonlinearity,
                 hidden_size, bias, ratio):

        super(NeuroRNN, self).__init__()
        self.input_size = input_size
        self.nonlinearity = nonlinearity
        self.hidden_size = hidden_size
        self.alpha_r = alpha_r
        self.alpha_s = alpha_s
        self.bias = bias
        self.output_dim = output_dim
        self.Win = Parameter(torch.Tensor(hidden_size, input_size))
        self.Win.requires_grad = True
        self.Wrec = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.Wrec.requires_grad = True
        self.Wout = Parameter(torch.Tensor(output_dim, hidden_size))
        self.Wout.requires_grad = True
        self.ratio = ratio

        if bias:
            self.bin = Parameter(torch.Tensor(hidden_size, hidden_size))
            self.brec = Parameter(torch.Tensor(hidden_size))

        # init weights
        nn.init.orthogonal_(self.Win)
        nn.init.orthogonal_(self.Wrec)
        nn.init.uniform_(self.Wout)

    def nonlin(self, inp):

        if self.nonlinearity == 'relu':
            nn.ReLU(inp)
            return inp
        else:
            inp = torch.tanh(inp)
            return inp

    def forward(self, x):

        # x shape: (batch_size, seq_len, input_size)
        I = torch.zeros(self.hidden_size,
                        x.size(0)).to(device)  # (hidden_size, batch_size)
        r = torch.zeros(self.hidden_size, x.size(0)).to(device)
        out = torch.zeros(x.size(0), x.size(1), self.output_dim).to(
            device)  # (batch_size, seq_len, output_dim)

        # for pascanu regularization
        #r_total = torch.zeros(x.size(0), self.hidden_size, x.size(1))
        #r_total.requires_grad = True

        for t in range(x.size(1)):
            I = (1 - self.alpha_s) * I + self.alpha_s * (
                torch.mm(self.Wrec, r) + torch.mm(self.Win, x[:, t].T))
            r = (1 - self.alpha_r) * r + self.alpha_r * (self.nonlin(I))
            #r_total[:, :, t] = r.T
            out[:, t, :] = torch.mm(self.Wout, r).T

        del I
        del r

        return out

    def dale_weight_init(self):

        with torch.no_grad():

            num_exc = np.int(self.ratio[0] * self.hidden_size)
            num_inh = np.int(self.hidden_size - num_exc)

            D = torch.diag_embed(
                torch.cat((torch.ones(num_exc), -1 * torch.ones(num_inh))))
            self.Wrec = torch.nn.Parameter(
                torch.abs(self.Wrec.detach()).matmul(D))

    def enforce_dale(self):

        with torch.no_grad():

            num_exc = np.int(self.ratio[0] * self.hidden_size)
            num_inh = np.int(self.hidden_size - num_exc)

            self.Wrec[:num_exc, :].clamp(min=0)
            self.Wrec[num_exc:, :].clamp(max=0)
Beispiel #14
0
class Model:
    '''Model manager class

   Convenience class wrapping saving/loading,
   model initialization, optimization, and logging.

   Args:
      ann: Model to optimize. Used to initialize weights.
      config: A Config specification
      args: Hook for additional user arguments
   '''
    def __init__(self, ann, config):
        self.saver = save.Saver(config.MODELDIR,
                                'models',
                                'bests',
                                resetTol=256)
        self.config = config

        print('Initializing new model...')
        self.net = ann(config)
        self.parameters = Parameter(
            torch.Tensor(np.array(getParameters(self.net))))

        #Have been experimenting with population based
        #training. Nothing stable yet -- advise avoiding
        if config.POPOPT:
            self.opt = PopulationOptimizer(self, config)
        else:
            self.opt = GradientOptimizer(self, config)

        if config.LOAD or config.BEST:
            self.load(self.opt, config.BEST)

    def step(self, recvs, blobs, log, lifetime):
        if self.config.TEST:
            return

        self.opt.step(recvs, blobs)
        perf, _ = self.checkpoint(self.opt, lifetime)
        return perf

    def load(self, opt, best=False):
        '''Load a model from file

      Args:
         best (bool): Whether to load the best (True)
             or most recent (False) checkpoint
      '''
        print('Loading model...')
        epoch = self.saver.load(opt, self.parameters, best)
        self.syncParameters()
        return self

    def checkpoint(self, opt, lifetime):
        '''Save the model to checkpoint

      Args:
         reward: Mean reward of the model
      '''
        return self.saver.checkpoint(self.parameters, opt, lifetime)

    def printParams(self):
        '''Print the number of model parameters'''
        nParams = len(self.weights)
        print('#Params: ', str(nParams / 1000), 'K')

    def syncParameters(self):
        parameters = self.parameters.detach().numpy().tolist()
        setParameters(self.net, parameters)

    @property
    def weights(self):
        '''Get model parameters

      Returns:
         a numpy array of model parameters
      '''
        return getParameters(self.net)
Beispiel #15
0
class Embedding(th.nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()

    
    def get_embed_list(self, sent_pair: list) -> np.ndarray:
        raise "not defined"
    
    def get_similarity(self, X1 : np.ndarray, X2 : np.ndarray) -> np.ndarray:
        similarity = X1 @ self.weight.detach().numpy() @ (X2.T)
        return similarity
#        return th.Tensor((cosine_similarity(X1, X2) + 1.0) / 2.0)
        
    def normalize(self):
        self.weight = Parameter(f.normalize(self.weight))
    
    def apply_distortion(self, sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray:
        shape = sim_matrix.shape
        if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
            return sim_matrix

        pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])])
        pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])])

        distortion_mask = -abs((pos_x - np.transpose(pos_y))) * ratio

        return nsim_matrix + distortion_mask



    
    def eta_U(self, f, e, distortion=False):
        data = [f,e]
        src_embedding, trg_embedding = self.get_embed_list(data)

        data_save = {}

        scale = min(len(src_embedding), len(trg_embedding))
        # src_indexs = [int(w) for w in f]
        # trg_indexs = [int(w) for w in f]

        data_save['scale'] = scale
        data_save['src_embedding'] = src_embedding
        data_save['trg_embedding'] = trg_embedding

        similarity = self.get_similarity(src_embedding, trg_embedding)

        self.save_for_backward(data_save)

        similarity = similarity - self.maxvalue - np.log(self.sum_exp_sim)
        if distortion:
            similarity = self.apply_distortion(sim)
        
        return similarity

    def re_compute_sum_exp_norm(self):
        temp = self.src_words_vec @ self.weight.detach().numpy() @  (self.trg_words_vec.T)
        self.maxvalue = np.max(temp)
        exp_value = np.exp( temp - self.maxvalue )
        self.sum_exp_sim = np.sum(exp_value)
        softmax_value = exp_value / self.sum_exp_sim 
        self.softmax_gradient = self.src_words_vec.T @ softmax_value @ self.trg_words_vec
        print(self.sum_exp_sim)
        #self.backword_softmax = th.Tensor(self.src_words_vec).T @ (th.Tensor(self.softmax_value) @ th.Tensor(self.trg_words_vec))
        
        
    def forward(self, data : list):
        f, e = data
        similarity = self.eta_U(f, e)
        return th.Tensor(similarity)
        
    
    def backward(self,g):
        data_save = self.saved_tensors
        temp = data_save['scale']*self.softmax_gradient 
        second = data_save['src_embedding'].T @ g.numpy() @ data_save['trg_embedding']

        
        self.weight.backward( th.Tensor(temp - second) )

        
    def save_for_backward(self, t):
        self.saved_tensors = t
Beispiel #16
0
class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):

    _version = 2
    _FLOAT_MODULE = MOD

    def __init__(self,
                 # ConvNd args
                 in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups,
                 bias,
                 padding_mode,
                 # BatchNormNd args
                 # num_features: out_channels
                 eps=1e-05, momentum=0.1,
                 # affine: True
                 # track_running_stats: True
                 # Args for this module
                 freeze_bn=False,
                 qconfig=None,
                 dim=2):
        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
                                         stride, padding, dilation, transposed,
                                         output_padding, groups, False, padding_mode)
        assert qconfig, 'qconfig must be provided for QAT module'
        self.qconfig = qconfig
        self.freeze_bn = freeze_bn if self.training else True
        self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
        self.weight_fake_quant = self.qconfig.weight()
        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_bn_parameters()

        # this needs to be called after reset_bn_parameters,
        # as they modify the same state
        if self.training:
            if freeze_bn:
                self.freeze_bn_stats()
            else:
                self.update_bn_stats()
        else:
            self.freeze_bn_stats()

    def reset_running_stats(self):
        self.bn.reset_running_stats()

    def reset_bn_parameters(self):
        self.bn.reset_running_stats()
        init.uniform_(self.bn.weight)
        init.zeros_(self.bn.bias)
        # note: below is actully for conv, not BN
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_parameters(self):
        super(_ConvBnNd, self).reset_parameters()

    def update_bn_stats(self):
        self.freeze_bn = False
        self.bn.training = True
        return self

    def freeze_bn_stats(self):
        self.freeze_bn = True
        self.bn.training = False
        return self

    def _forward(self, input):
        assert self.bn.running_var is not None
        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
        scale_factor = self.bn.weight / running_std
        weight_shape = [1] * len(self.weight.shape)
        weight_shape[0] = -1
        bias_shape = [1] * len(self.weight.shape)
        bias_shape[1] = -1
        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
        # using zero bias here since the bias for original conv
        # will be added later
        if self.bias is not None:
            zero_bias = torch.zeros_like(self.bias)
        else:
            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device)
        conv = self._conv_forward(input, scaled_weight, zero_bias)
        conv_orig = conv / scale_factor.reshape(bias_shape)
        if self.bias is not None:
            conv_orig = conv_orig + self.bias.reshape(bias_shape)
        conv = self.bn(conv_orig)
        return conv

    def extra_repr(self):
        # TODO(jerryzh): extend
        return super(_ConvBnNd, self).extra_repr()

    def forward(self, input):
        return self._forward(input)

    def train(self, mode=True):
        """
        Batchnorm's training behavior is using the self.training flag. Prevent
        changing it if BN is frozen. This makes sure that calling `model.train()`
        on a model with a frozen BN will behave properly.
        """
        self.training = mode
        if not self.freeze_bn:
            for module in self.children():
                module.train(mode)
        return self

    # ===== Serialization version history =====
    #
    # Version 1/None
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #   |--- gamma : Tensor
    #   |--- beta : Tensor
    #   |--- running_mean : Tensor
    #   |--- running_var : Tensor
    #   |--- num_batches_tracked : Tensor
    #
    # Version 2
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #   |--- bn : Module
    #        |--- weight : Tensor (moved from v1.self.gamma)
    #        |--- bias : Tensor (moved from v1.self.beta)
    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
    #        |--- running_var : Tensor (moved from v1.self.running_var)
    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)
        if version is None or version == 1:
            # BN related parameters and buffers were moved into the BN module for v2
            v2_to_v1_names = {
                'bn.weight': 'gamma',
                'bn.bias': 'beta',
                'bn.running_mean': 'running_mean',
                'bn.running_var': 'running_var',
                'bn.num_batches_tracked': 'num_batches_tracked',
            }
            for v2_name, v1_name in v2_to_v1_names.items():
                if prefix + v1_name in state_dict:
                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
                    state_dict.pop(prefix + v1_name)
                elif prefix + v2_name in state_dict:
                    # there was a brief period where forward compatibility
                    # for this module was broken (between
                    # https://github.com/pytorch/pytorch/pull/38478
                    # and https://github.com/pytorch/pytorch/pull/38820)
                    # and modules emitted the v2 state_dict format while
                    # specifying that version == 1. This patches the forward
                    # compatibility issue by allowing the v2 style entries to
                    # be used.
                    pass
                elif strict:
                    missing_keys.append(prefix + v2_name)

        super(_ConvBnNd, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    @classmethod
    def from_float(cls, mod):
        r"""Create a qat module from a float module or qparams_dict

            Args: `mod` a float module, either produced by torch.ao.quantization utilities
            or directly from user
        """
        # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
        # has no __name__ (code is fine though)
        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
        assert mod.qconfig, 'Input float module must have a valid qconfig'
        qconfig = mod.qconfig
        conv, bn = mod[0], mod[1]
        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
                         conv.stride, conv.padding, conv.dilation,
                         conv.groups, conv.bias is not None,
                         conv.padding_mode,
                         bn.eps, bn.momentum,
                         False,
                         qconfig)
        qat_convbn.weight = conv.weight
        qat_convbn.bias = conv.bias
        qat_convbn.bn.weight = bn.weight
        qat_convbn.bn.bias = bn.bias
        qat_convbn.bn.running_mean = bn.running_mean
        qat_convbn.bn.running_var = bn.running_var
        # mypy error: Cannot determine type of 'num_batches_tracked'
        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked  # type: ignore[has-type]
        return qat_convbn

    def to_float(self):
        cls = type(self)
        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
            self.in_channels,
            self.out_channels,
            self.kernel_size,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
            self.bias is not None,
            self.padding_mode)
        conv.weight = torch.nn.Parameter(self.weight.detach())
        if self.bias is not None:
            conv.bias = torch.nn.Parameter(self.bias.detach())

        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
            # fuse bn into conv
            conv.weight, conv.bias = fuse_conv_bn_weights(
                conv.weight,
                conv.bias,
                self.bn.running_mean,
                self.bn.running_var,
                self.bn.eps,
                self.bn.weight,
                self.bn.bias
            )

        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
            modules = []
            modules.append(conv)
            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
            modules.append(relu)
            conv_relu = cls._FUSED_FLOAT_MODULE(*modules)  # type: ignore[attr-defined]
            conv_relu.train(self.training)
            return conv_relu
        else:
            conv.train(self.training)
            return conv
class Conv2d_filtermap_1x1_compression(_ConvNd_filtermap_1x1_compression):
    def __init__(self, in_channels, out_channels, channel_compression_1x1, kernel_size, binary_filtermap = False, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        
        super(Conv2d_filtermap_1x1_compression, self).__init__(
            in_channels, out_channels, channel_compression_1x1, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)
        
        self.binary_filtermap = binary_filtermap
        self.reset_parameters1()
        #pdb.set_trace()
        #self.conv_weight = Parameter(torch.Tensor(out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #self.conv_weight.requires_grad = False
        #self.conv_weight.cuda() 
    def reset_parameters1(self):
        fm_size = self.filtermap.size()
        fm_width = fm_size[2]
        fm_height = fm_size[1]
        fm_depth = fm_size[0]
        # not for 1x1 conv, do the padding on the spatial 
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1:
           self.fm_pad_width = fm_width + 1
           self.fm_pad_height = fm_height + 1
        #for 1x1 conv no padding on the spatial 
        else:
           self.fm_pad_width = fm_width
           self.fm_pad_height = fm_height

        self.fm_pad_depth = fm_depth*2
        #set the ids for extracting filters from filtermap
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c
        
 
        fm_depth = self.fm_pad_depth
        fm_height = self.fm_pad_height
        fm_width = self.fm_pad_width
        
        
        ids = (torch.Tensor(range(0,k_h*k_w)))
        tmp_count = 0
        for y in range(0,k_h):
            for x in range(0,k_w):
                ids[tmp_count] = y*fm_width+x
                tmp_count = tmp_count+1
 
        ids0 = ids
               
        #pdb.set_trace() 
        for c in range(1,in_channels):
            ids_c = ids0 + c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #ids0 = ids
        #for x in range(1, out_channels):
        #    ids = torch.cat((ids,ids0),0)
        #pdb.set_trace()
        ids0 = ids
        for y in range(0,sample_y):
            for x in range(0,sample_x):
                if y == 0 and x == 0:
                   continue
                ss = y*stride_y*fm_width + x*stride_x
                ids_ss = ids0+ss
                ids = torch.cat((ids,ids_ss),0)
        
        #pdb.set_trace() 
        ids0 = ids
        for c in range(1,sample_c):
            ids_c = ids0+c*stride_c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #pdb.set_trace()
        #ids = ids.long()
        #ids = ids.detach()

        #pdb.set_trace()
        ids = ids.long()
        self.ids = Parameter(ids)
        self.ids.requires_grad = False
        #self.register_parameter()
        #if torch.max(ids) >= fm_depth*fm_height*fm_width or torch.min(ids) < 0:
        #print(torch.max(ids))
        #ids = Variable(ids)
    def extract_filters(self):
        #pdb.set_trace()
         
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups 
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        #for compressing the channel by 2 times
        #if in_channels != 3:
        #   filtermap_pad_tmp = torch.cat((self.filtermap,self.filtermap),0)
        #else:
        #   filtermap_pad_tmp = self.filtermap
        #filtermap_pad = torch.cat((filtermap_pad_tmp,filtermap_pad_tmp),0)
        
        #for not compressing the channel
        filtermap_pad = torch.cat((self.filtermap,self.filtermap),0)
        # not for 1x1 conv, do the padding on the spatial 
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: 
           filtermap_pad_s1 = filtermap_pad[:,1,:]
           filtermap_pad_s1 = filtermap_pad_s1[:,None,:]
           filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s1),1)
           filtermap_pad_s2 = filtermap_pad[:,:,1]
           filtermap_pad_s2 = filtermap_pad_s2[:,:,None]
           filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s2),2)

        #pdb.set_trace()
        ids = self.ids.detach()
        conv_weight = filtermap_pad.view(-1,1).index_select(0,ids)
        conv_weight = conv_weight.view(out_channels,in_channels,k_h,k_w)
        if self.binary_filtermap:
           binary_conv_weight = conv_weight.clone()
           for nf in range(0,out_channels):
               float_filter = conv_weight[nf,:,:,:];
               L1_norm = torch.norm(float_filter.view(-1,1),1);
               sign_filter = torch.sign(float_filter);
               binary_filter = sign_filter*L1_norm;
               binary_conv_weight[nf,:,:,:] = binary_filter
           return binary_conv_weight
        else:
           return conv_weight
        #pdb.set_trace()
        #for c in range(0,sample_c):
        #   for y in range(0,sample_y):
        #      for x in range(0,sample_x):
        #          filter_count = c*sample_y*sample_x + y*sample_x + x
        #          conv_weight_clone[filter_count,:,:,:] = filtermap_pad[c*stride_c:c*stride_c+in_channels, \
        #                                      y*stride_y:y*stride_y+k_h, x*stride_x:x*stride_x+k_w]
        #return conv_weight
    def forward(self, input):
        #return F.conv2d(input, self.weight, self.bias, self.stride,
        #                self.padding, self.dilation, self.groups)
        
        #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat))
        
        #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat))
       
        #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \
        #        self.kernel_size[1], self.out_channels);
        #conv_weight = conv_weight.permute(3, 0, 1, 2)
        #conv_weight = conv_weight.contiguous()
        
        #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2)

        #pdb.set_trace()
        #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #conv_weight.cuda()
        conv_weight = self.extract_filters()
        #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \
        #                                      0:0+self.kernel_size[0], 0:0+self.kernel_size[1]]
        out = F.conv2d(input, conv_weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
   
        return out        
        #return F.conv2d(input, conv_weight, self.bias, self.stride,
        #                 self.padding, self.dilation, self.groups)
    def fit_filtermap1(self, conv):
        conv_weight = conv.weight.data
        conv_weight_size = conv_weight.size()
        out_channels = conv_weight_size[0]
        in_channels = conv_weight_size[1]
        k_h = conv_weight_size[2]
        k_w = conv_weight_size[3]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c

        fm_ext = torch.Tensor(sample_c*in_channels,sample_y*k_h,sample_x*k_w)
        for c in range(0,sample_c):
            for y in range(0,sample_y):
                for x in range(0,sample_x):
                    filter_count = c*sample_y*sample_x + y*sample_x + x
                    fm_ext[c*in_channels:(c+1)*in_channels,y*k_h:(y+1)*k_h,x*k_w:(x+1)*k_w] = conv_weight[filter_count,:,:,:]


        #pdb.set_trace()
        for oc in range(0,sample_c):
            for c in range(1,sample_c):
                idx = sample_c-c+oc
                if idx > sample_c-1:
                   idx = 0
                fm_ext[oc*stride_c:(oc+1)*stride_c,:,:] += \
                   fm_ext[idx*stride_c:(idx+1)*stride_c,:,:]

        fm_ext = fm_ext[0:in_channels,:,:]/sample_c

        fm_ext_h = fm_ext.size()[1]
        fm_ext_w = fm_ext.size()[2]

        for y in range(0,sample_y):
            fm_ext[:,y*k_h,:] += fm_ext[:,(y*k_h-1+fm_ext_h)%fm_ext_h,:]
            fm_ext[:,y*k_h,:] = fm_ext[:,y*k_h,:]/2

        for x in range(0,sample_x):
            fm_ext[:,:,x*k_w] += fm_ext[:,:,(x*k_w-1+fm_ext_w)%fm_ext_w] 
            fm_ext[:,:,x*k_w] = fm_ext[:,:,x*k_w]/2

        y_ids = torch.Tensor([0,1])
        y_ids0 = y_ids
        for y in range(1,sample_y):
            y_ids = torch.cat((y_ids,y_ids0+y*k_h),0)
    
        x_ids = torch.Tensor([0,1])
        x_ids0 = x_ids
        for x in range(1,sample_x):
            x_ids = torch.cat((x_ids,x_ids0+x*k_w),0)

        fm = torch.index_select(fm_ext,1,y_ids.long())
        fm = torch.index_select(fm,2,x_ids.long())
    
    
        self.filtermap.data = fm
Beispiel #18
0
class ZMUSWrapper(nn.Module):
    """Zero-mean Unit-STD States
        see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-variance-of-two-groups-given-known-group-variances-mean
    """
    def __init__(self, policy_net, eps=1e-6):
        super(ZMUSWrapper, self).__init__()

        self.eps = eps
        self.policy_net = policy_net
        self.policy_cfg = self.policy_net.policy_cfg

        # Parameters
        state_size = self.policy_net.agent.observation_space.shape
        self.state_mean = Parameter(torch.Tensor(state_size, ))
        self.state_variance = Parameter(torch.Tensor(state_size, ))
        self.state_mean.requires_grad = False
        self.state_variance.requires_grad = False

        # cash
        self.size = 0
        self.ep_states_data = []

        self.first_pass = True

    def _get_state_mean(self):
        return self.state_mean.detach()

    def _get_state_variance(self):
        return self.state_variance.detach()

    def forward(self, s):
        self.size += 1
        self.ep_states_data.append(s)
        if not self.first_pass:
            s = (s - self._get_state_mean()) / \
                (torch.sqrt(self._get_state_variance()) + self.eps)
        return self.policy_net(s)

    def episode_callback(self):
        ep_states_tensor = torch.stack(self.ep_states_data)
        new_data_mean = torch.mean(ep_states_tensor, dim=0)
        new_data_var = torch.var(ep_states_tensor, dim=0)
        if self.first_pass:
            self.state_mean.data = new_data_mean
            self.state_variance.data = new_data_var
            self.first_pass = False
        else:
            n = len(self.ep_states_data)
            mean = self._get_state_mean()
            var = self._get_state_variance()
            new_data_mean_sq = torch.mul(new_data_mean, new_data_mean)
            size = min(self.policy_cfg.FORGET_COUNT_OBS_SCALER, self.size)
            new_mean = ((mean * size) + (new_data_mean * n)) / (size + n)
            new_var = (((size * (var + torch.mul(mean, mean))) +
                        (n * (new_data_var + new_data_mean_sq))) / (size + n) -
                       torch.mul(new_mean, new_mean))
            self.state_mean.data = new_mean
            self.state_variance.data = torch.clamp(
                new_var, 0.)  # occasionally goes negative, clip
            self.size += n

    def batch_callback(self):
        pass
Beispiel #19
0
class PhaseShifter(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
    Attributes:
        theta: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from uniform(0,2*pi)
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    scale: float
    theta: Tensor

    def __init__(self, in_features: int, out_features: int, scale: float=1, theta = None) -> None:
        super(PhaseShifter, self).__init__()
        self.in_features = in_features
        self.in_dim = self.in_features//2
        self.out_features = out_features
        self.scale = scale
        # self.theta = Parameter(torch.Tensor(self.out_features, self.in_dim))
        self.theta = Parameter(torch.Tensor(self.in_dim, self.out_features)) 
        self.reset_parameters(theta)

    def reset_parameters(self, theta = None) -> None:
        if theta is None:
            init.uniform_(self.theta, a=0, b=2*np.pi)
        else:
            assert theta.shape == (self.in_dim,self.out_features)
            self.theta = Parameter(theta) 
        self.real_kernel = (1 / self.scale) * torch.cos(self.theta)  #
        self.imag_kernel = (1 / self.scale) * torch.sin(self.theta)  #
    
    def forward(self, inputs: Tensor) -> Tensor:
        self.real_kernel = (1 / self.scale) * torch.cos(self.theta)  #
        self.imag_kernel = (1 / self.scale) * torch.sin(self.theta)  #        
        cat_kernels_4_real = torch.cat(
            (self.real_kernel, -self.imag_kernel),
            dim=-1
        )
        cat_kernels_4_imag = torch.cat(
            (self.imag_kernel, self.real_kernel),
            dim=-1
        )
        cat_kernels_4_complex = torch.cat(
            (cat_kernels_4_real, cat_kernels_4_imag),
            dim=0
        )  # This block matrix represents the conjugate transpose of the original:
        # [ W_R, -W_I; W_I, W_R]

        # output = F.linear(inputs, cat_kernels_4_complex)
        output = torch.matmul(inputs, cat_kernels_4_complex)
        return output

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )
    
    def get_theta(self) -> torch.Tensor:
        return self.theta.detach().clone()
    
    def get_weights(self) -> torch.Tensor:
        with torch.no_grad():
            real_kernel = (1 / self.scale) * torch.cos(self.theta)  #
            imag_kernel = (1 / self.scale) * torch.sin(self.theta)  #        
            # cat_kernels_4_real = torch.cat(
            #     (real_kernel, -imag_kernel),
            #     dim=-1
            # )
            # cat_kernels_4_imag = torch.cat(
            #     (imag_kernel, real_kernel),
            #     dim=-1
            # )
            # cat_kernels_4_complex = torch.cat(
            #     (cat_kernels_4_real, cat_kernels_4_imag),
            #     dim=0
            # )  # This block matrix represents the conjugate transpose of the original:
            # # [ W_R, -W_I; W_I, W_R]
            beam_weights = real_kernel + 1j*imag_kernel
        return beam_weights
Beispiel #20
0
class Quantization(nn.Module):
    def __init__(self, network_controller: NetworkQuantizationController, is_signed: bool,
                 alpha: float = 0.9, weights_values=None, efficient=True):
        """
        HMQ Block
        :param network_controller: The network controller
        :param is_signed: is this tensor signed
        :param alpha: the thresholds I.I.R value
        :param weights_values: In the case of weights quantized this is the tensors values
        :param efficient: Boolean flag stating to use the memory efficient
        """
        super(Quantization, self).__init__()
        self.weights_values = weights_values
        if weights_values is None:
            self.tensor_type = TensorType.ACTIVATION
            self.tensor_size = None
        else:
            self.tensor_type = TensorType.COEFFICIENT
            self.tensor_size = np.prod(weights_values.shape)

        self.network_controller = network_controller
        self.alpha = alpha
        self.is_signed_tensor = torch.Tensor([float(is_signed)]).cuda()

        if efficient:
            self.base_q = EfficientBaseQuantization()
        else:
            self.base_q = BaseQuantization()
        self.gumbel_softmax = GumbelSoftmax(ste=network_controller.ste)

        self.bits_vector = None
        self.mv_shifts = None
        self.base_thresholds = None
        self.nb_shifts_points_div = None
        self.search_matrix = None

    def init_quantization_coefficients(self):
        """
        This function initlized the HMQ parameters
        :return: None
        """
        init_threshold = 0
        n_bits_list, thresholds_shifts = self.network_controller.quantization_config.get_thresholds_bitwidth_lists(self)
        if self.is_coefficient():
            init_threshold = torch.pow(2.0, self.weights_values.abs().max().log2().ceil() + 1).item()
        if self.is_activation():
            n_bits_list = [8]

        self._init_quantization_params(n_bits_list, thresholds_shifts, init_threshold)
        self._init_search_matrix(self.network_controller.p, n_bits_list, len(thresholds_shifts))

    def _init_quantization_params(self, bit_list, thresholds_shifts, init_thresholds):
        self.update_bits_list(bit_list)
        self.mv_shifts = Parameter(torch.Tensor(thresholds_shifts), requires_grad=False)
        self.thresholds_shifts_points_div = Parameter(torch.pow(2.0, self.mv_shifts), requires_grad=False)
        self.base_thresholds = Parameter(torch.Tensor(1), requires_grad=False)
        init.constant_(self.base_thresholds, init_thresholds)

    def _init_search_matrix(self, p, n_bits_list, n_thresholds_options):
        n_channels = 1
        sm = -np.random.rand(n_channels, len(n_bits_list), n_thresholds_options, 1)
        n = np.prod(sm.shape)
        sm[:, 0, 0, 0] = np.log(p * n / (1 - p))  # for single channels
        self.search_matrix = Parameter(torch.Tensor(sm))

    def _get_quantization_probability_matrix(self, batch_size=1, noise_disable=False):
        return self.gumbel_softmax(self.search_matrix, self.network_controller.temperature, batch_size=batch_size,
                                   noise_disable=noise_disable)

    def _get_bits_probability(self, batch_size=1, noise_disable=False):
        p = self._get_quantization_probability_matrix(batch_size=batch_size, noise_disable=noise_disable)
        return p.sum(dim=4).sum(dim=3).sum(dim=1)

    def _update_iir(self, x):  # update scale using statistics
        if self.is_activation():
            if self.tensor_size is None:
                self.tensor_size = np.prod(x.shape[1:])  # Remove batch axis
            max_value = x.abs().max()
            self.base_thresholds.data.add_(self.alpha * (max_value - self.base_thresholds))

    def _calculate_expected_delta(self, p, max_scale):
        max_scales = max_scale / (self.thresholds_shifts_points_div.reshape(1, -1))
        max_scales = max_scales.reshape(1, 1, 1, -1, 1)

        nb_shifts = self.nb_shifts_points_div.reshape(1, 1, -1, 1, 1) * torch.pow(2.0, -self.is_signed_tensor)
        delta = (max_scales / nb_shifts) * p
        return delta.sum(dim=-1).sum(dim=-1).sum(dim=-1).sum(dim=-1)

    def _calculate_expected_threshold(self, p, max_threshold):
        p_t = p.sum(dim=4).sum(dim=2).sum(dim=1)
        thresholds = max_threshold / (self.thresholds_shifts_points_div.reshape(1, -1))
        return (p_t * thresholds).sum(dim=-1)

    def _calculate_expected_q_point(self, p, max_threshold, expected_delta, param_shape):
        t = self._calculate_expected_threshold(p, max_threshold=max_threshold).reshape(*param_shape)
        return t / expected_delta

    def _built_param_shape(self, x):
        random_size = x.shape[0] if self.is_activation() else x.shape[1]  # select random
        if len(x.shape) == 4:
            param_shape = [random_size, -1, 1, 1] if self.is_activation() else [-1, random_size, 1, 1]
        elif len(x.shape) == 2:
            param_shape = [random_size, -1] if self.is_activation() else [-1, random_size]
        else:
            raise NotImplemented
        return random_size, param_shape

    def forward(self, x):
        """
        The forward function of the HMQ module

        :param x: Input tensor x
        :return: A tensor after quantization
        """
        if self.network_controller.statistics_update:
            self._update_iir(x)
        max_threshold = torch.pow(2.0,
                                  torch.ceil(torch.log2(self.base_thresholds.detach().abs()))).detach()  # read scale
        if self.training and self.network_controller.temperature > 0:
            random_size, param_shape = self._built_param_shape(x)
            # axis according to tensor type (activation randomization is done over the batch axis,
            # coeff the randomization is done over the input channel axis)
            p = self._get_quantization_probability_matrix(batch_size=random_size)
            delta = self._calculate_expected_delta(p, max_threshold).reshape(*param_shape)
            q_points = self._calculate_expected_q_point(p, max_threshold, delta,
                                                        param_shape).reshape(*param_shape)
            return self.base_q(x, delta, q_points, self.is_signed_tensor)
        else:  # negative temperature/ infernce
            p = self._get_quantization_probability_matrix(batch_size=1, noise_disable=True).squeeze(dim=0)
            bits_index = torch.argmax(self._get_bits_probability(batch_size=1, noise_disable=True).squeeze(dim=0))
            max_index = torch.argmax(p[:, bits_index, :, 0], dim=-1)
            q_points = self.nb_shifts_points_div[bits_index] * torch.pow(2.0,
                                                                         -self.is_signed_tensor)
            max_scales = (max_threshold / self.thresholds_shifts_points_div.reshape(1, -1)).detach()
            delta = torch.stack(
                [(max_scales[i, mv] / q_points) for i, mv in enumerate(max_index)]).flatten().detach()
            return self.base_q(x, delta, q_points, self.is_signed_tensor)

    def get_bit_width(self):
        """
        This function return the selected bit-width
        :return: the bit-width of the HMQ
        """
        return self.bits_vector[torch.argmax(self._get_bits_probability(noise_disable=True).flatten())].item()

    def get_expected_bits(self):
        """
        This function return the expected bit-width
        :return: the expected bit-width of the HMQ
        """
        return (self.bits_vector * self._get_bits_probability(noise_disable=True)).sum()

    def get_float_size(self):
        """
        This function return the size of floating point tensor in bits
        Note: we assume 32 bits for floating point values
        :return: the floating point tensor size
        """
        return 32 * self.tensor_size

    def get_fxp_size(self):
        """
        This function return the size of quantized tensor in bits
        :return: the quantized tensor size
        """
        return self.get_bit_width() * self.tensor_size

    def is_activation(self):
        """
        This function return the boolean stating if this module quantize activation
        :return: a boolean flag stating if this activation quantization
        """
        return self.tensor_type == TensorType.ACTIVATION

    def is_coefficient(self):
        """
        This function return the boolean stating if this module quantize coefficient
        :return: a boolean flag stating if this coefficient quantization
        """
        return self.tensor_type == TensorType.COEFFICIENT

    def get_expected_tensor_size(self):
        """
         This function return the expected size of quantized tensor in bits
         :return: the expected size of quantized tensor
         """
        return torch.Tensor([self.tensor_size]).cuda()

    def update_bits_list(self, bits_list):
        """
        This function update the HMQ bit-width list
        :param bits_list: A list of new bit-widths
        :return: None
        """
        if self.bits_vector is None:
            self.bits_vector = Parameter(torch.Tensor(bits_list), requires_grad=False)
            self.nb_shifts_points_div = Parameter(
                torch.pow(2.0, self.bits_vector),  # - int(q_node.is_signed)
                requires_grad=False)  # move to init
        else:
            self.bits_vector.add_(torch.Tensor(bits_list).cuda() - self.bits_vector)
            self.nb_shifts_points_div.add_(torch.pow(2.0, self.bits_vector) - self.nb_shifts_points_div)
class Linear(nn.Module):
    """Linear layer
    Parameters:
    -----------
    input_size: int
        input size of transformation
    output_size: int
        output size of transformation
    bias: boolean, default=True
        whether to use bias or not
    device: str, default="cuda:0"
        keep on this device
    """
    def __init__(self, input_size, output_size, bias=True, device="cuda:0"):
        super(Linear, self).__init__()
        self.device = device  # Useful in case of multiple GPUs
        self.input_size = input_size
        self.output_size = output_size
        self.weight = Parameter(torch.Tensor(self.output_size,
                                             self.input_size))
        if bias:
            self.bias = Parameter(torch.Tensor(self.output_size, 1))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def forward(self, input):
        if self.bias is not None:
            return F.linear(input.to(self.device), self.weight,
                            self.bias.view(-1))
        else:
            return F.linear(input.to(self.device), self.weight)

    def to(self):
        """Transfer to device
        """
        super().to(self.device)

    def reset_parameters(self):
        """Initialize vectors
        """
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(
                self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)
        # stdv = 1. / math.sqrt(self.weight.size(1))
        # self.weight.data.uniform_(-stdv, stdv)
        # if self.bias is not None:
        #     self.bias.data.uniform_(-stdv, stdv)

    def get_weights(self):
        """Get weights as numpy array
        Bias is appended in the end
        """
        _wts = self.weight.detach().cpu().numpy()
        if self.bias is not None:
            _bias = self.bias.detach().cpu().numpy()
            _wts = np.hstack([_wts, _bias])
        return _wts

    def __repr__(self):
        s = '{name}({input_size}, {output_size}, {device}'
        if self.bias is not None:
            s += ', bias=True'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    @property
    def sparse(self):
        return False
Beispiel #22
0
class Conv2dDPQ(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 
                 qmin=1e-3, qmax=100, dmin=1e-5, dmax=10, bias=True, sign=True, wbits=4, abits=4, mode=Qmodes.layer_wise):
    
        """
        :param d_init: the inital quantization stepsize (alpha)
        :param mode: Qmodes.layer_wise or Qmodes.kernel_wise
        :param xmax_init: the quantization range for whole weights 
        """

        super(Conv2dDPQ, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                        stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
        self.qmin = qmin
        self.qmax = qmax
        self.dmin = dmin 
        self.dmax = dmax
        self.q_mode = mode
        self.sign = sign
        self.nbits = wbits 
        self.act_dpq = ActDPQ(signed=True, nbits=abits)
        if self.q_mode == Qmodes.kernel_wise:
            self.alpha = Parameter(torch.Tensor(out_channels))
        else:
            self.alpha = Parameter(torch.Tensor(1))
        self.xmax = Parameter(torch.Tensor(1))
        self.weight.requires_grad_(True)
        if bias:
            self.bias.requires_grad_(True)
        self.register_buffer('init_state', torch.zeros(1))

    def get_nbits(self):
        abits = self.act_dpq.get_nbits()
        # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        # print('after clamp, the result is :')
        # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
        if self.sign:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2) + 1).ceil()
        else:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2)).ceil()

        # print('the nbits for weight is  : {}'.format(nbits))
        self.nbits = int(nbits.item())
        return abits, nbits

    def get_quan_filters(self, filters):

        if self.training and self.init_state == 0:
            # print('initial alphas in current time !')
            Qp = 2 ** (self.nbits - 1) - 1
            self.alpha.data.copy_(2 * filters.abs().mean() / math.sqrt(Qp))
            self.xmax.data.copy_(self.alpha * Qp)
            # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
            # self.xmax.data.copy_(self.weight.abs().max())
            # wmean = self.weight.abs().mean()
            # self.alpha.data.copy_(4 * wmean * wmean / self.xmax)
            self.init_state.fill_(1)

        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        Qp = (self.xmax.detach()/self.alpha.detach()).item()
        g = 1.0 / math.sqrt(filters.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)
        # alpha = quantize_pow2(self.alpha)
        # xmax = self.xmax
        # alpha = self.alpha
        # g = 1.0 / math.sqrt(self.weight.numel() * (xmax.detach()/d.detach()).item())
        # d = grad_scale(self.alpha, g)
        # xmax = torch.clamp(self.xmax, self.xmax_min, self.xmax_max)

        if self.sign:
            wq = round_pass((torch.clamp(filters/xmax, -1, 1) * xmax)/alpha) * alpha
        else:
            wq = round_pass((torch.clamp(filters/xmax, 0, 1) * xmax)/alpha) * alpha 

        return wq

    def forward(self,x):
        if self.training and self.init_state == 0:
            # print('initial alphas in current time !')
            Qp = 2 ** (self.nbits - 1) - 1
            self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.xmax.data.copy_(self.alpha * Qp)
            # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
            # self.xmax.data.copy_(self.weight.abs().max())
            # wmean = self.weight.abs().mean()
            # self.alpha.data.copy_(4 * wmean * wmean / self.xmax)
            self.init_state.fill_(1)

        # alpha = quantize_pow2(self.alpha)
        # alpha = self.alpha
        # g = 1.0 / math.sqrt(self.weight.numel() * Qp)
        # xmax = self.xmax
        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        Qp = (self.xmax.detach()/self.alpha.detach()).item()
        g = 1.0 / math.sqrt(self.weight.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)

        if self.sign:
            wq = round_pass((torch.clamp(self.weight/xmax, -1, 1) * xmax)/alpha) * alpha
        else:
            wq = round_pass((torch.clamp(self.weight/xmax, 0, 1) * xmax)/alpha) * alpha 
        
        if self.act_dpq is not None:
            x = self.act_dpq(x)
        
        # print('the abits is : {} and the nbits is : {}'.format(self.act_dpq.nbits, self.nbits))
        return F.conv2d(x, wq, self.bias, self.stride, self.padding, self.dilation, self.groups)
Beispiel #23
0
class ABAE(torch.nn.Module):
    """
        The model described in the paper ``An Unsupervised Neural Attention Model for Aspect Extraction''
        by He, Ruidan and  Lee, Wee Sun  and  Ng, Hwee Tou  and  Dahlmeier, Daniel, ACL2017
        https://aclweb.org/anthology/papers/P/P17/P17-1036/

    """
    def __init__(self,
                 wv_dim: int = 200,
                 asp_count: int = 30,
                 ortho_reg: float = 0.1,
                 maxlen: int = 201,
                 init_aspects_matrix=None):
        """
        Initializing the model

        :param wv_dim: word vector size
        :param asp_count: number of aspects
        :param ortho_reg: coefficient for tuning the ortho-regularizer's influence
        :param maxlen: sentence max length taken into account
        :param init_aspects_matrix: None or init. matrix for aspects
        """
        super(ABAE, self).__init__()
        self.wv_dim = wv_dim
        self.asp_count = asp_count
        self.ortho = ortho_reg
        self.maxlen = maxlen

        self.attention = SelfAttention(wv_dim, maxlen)
        self.linear_transform = torch.nn.Linear(self.wv_dim, self.asp_count)
        self.softmax_aspects = torch.nn.Softmax()
        self.aspects_embeddings = Parameter(
            torch.empty(size=(wv_dim, asp_count)))

        if init_aspects_matrix is None:
            torch.nn.init.xavier_uniform(self.aspects_embeddings)
        else:
            self.aspects_embeddings.data = torch.from_numpy(
                init_aspects_matrix.T)

    def get_aspects_importances(self, text_embeddings):
        """
            Takes embeddings of a sentence as input, returns attention weights
        """

        # compute attention scores, looking at text embeddings average
        attention_weights = self.attention(text_embeddings)

        # multiplying text embeddings by attention scores -- and summing
        # (matmul: we sum every word embedding's coordinate with attention weights)
        weighted_text_emb = torch.matmul(
            attention_weights.unsqueeze(1),  # (batch, 1, sentence)
            text_embeddings  # (batch, sentence, wv_dim)
        ).squeeze()

        # encoding with a simple feed-forward layer (wv_dim) -> (aspects_count)
        raw_importances = self.linear_transform(weighted_text_emb)

        # computing 'aspects distribution in a sentence'
        aspects_importances = self.softmax_aspects(raw_importances)

        return attention_weights, aspects_importances, weighted_text_emb

    def forward(self, text_embeddings, negative_samples_texts):

        # negative samples are averaged
        averaged_negative_samples = torch.mean(negative_samples_texts, dim=2)

        # encoding: words embeddings -> sentence embedding, aspects importances
        _, aspects_importances, weighted_text_emb = self.get_aspects_importances(
            text_embeddings)

        # decoding: aspects embeddings matrix, aspects_importances -> recovered sentence embedding
        recovered_emb = torch.matmul(
            self.aspects_embeddings,
            aspects_importances.unsqueeze(2)).squeeze()

        # loss
        reconstruction_triplet_loss = ABAE._reconstruction_loss(
            weighted_text_emb, recovered_emb, averaged_negative_samples)
        max_margin = torch.max(reconstruction_triplet_loss,
                               torch.zeros_like(reconstruction_triplet_loss))

        return self.ortho * self._ortho_regularizer() + max_margin

    @staticmethod
    def _reconstruction_loss(text_emb, recovered_emb, averaged_negative_emb):

        positive_dot_products = torch.matmul(
            text_emb.unsqueeze(1), recovered_emb.unsqueeze(2)).squeeze()
        negative_dot_products = torch.matmul(
            averaged_negative_emb, recovered_emb.unsqueeze(2)).squeeze()
        reconstruction_triplet_loss = torch.sum(
            1 - positive_dot_products.unsqueeze(1) + negative_dot_products,
            dim=1)

        return reconstruction_triplet_loss

    def _ortho_regularizer(self):
        return torch.norm(
            torch.matmul(self.aspects_embeddings.t(), self.aspects_embeddings) \
            - torch.eye(self.asp_count))

    def get_aspect_words(self, w2v_model, topn=15):
        words = []

        # getting aspects embeddings
        aspects = self.aspects_embeddings.detach().numpy()

        # getting scalar products of word embeddings and aspect embeddings;
        # to obtain the ``probabilities'', one should also apply softmax
        words_scores = w2v_model.wv.syn0.dot(aspects)

        for row in range(aspects.shape[1]):
            argmax_scalar_products = np.argsort(-words_scores[:, row])[:topn]
            # print([w2v_model.wv.index2word[i] for i in argmax_scalar_products])
            # print([w for w, dist in w2v_model.similar_by_vector(aspects.T[row])[:topn]])
            words.append(
                [w2v_model.wv.index2word[i] for i in argmax_scalar_products])

        return words
Beispiel #24
0
class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.

    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.

    * :attr:`channel_len` defines the length of the channel when initializing scale and zero point
      for the per channel case.

    * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
      normalized by the constant, which is proportional to the square root of the number of
      elements in the tensor. The related literature justifying the use of this particular constant
      can be found here: https://openreview.net/pdf?id=rkgO66VKDS.

    * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.

    * :attr:`static_enabled` defines the flag for using observer's static estimation for
      scale and zero point.

    * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
    """
    def __init__(self,
                 observer,
                 quant_min=0,
                 quant_max=255,
                 scale=1.,
                 zero_point=0.,
                 channel_len=-1,
                 use_grad_scaling=False,
                 **observer_kwargs):
        super(_LearnableFakeQuantize, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max
        # also pass quant_min and quant_max to observer
        observer_kwargs["quant_min"] = quant_min
        observer_kwargs["quant_max"] = quant_max
        self.use_grad_scaling = use_grad_scaling
        if channel_len == -1:
            self.scale = Parameter(torch.tensor([scale]))
            self.zero_point = Parameter(torch.tensor([zero_point]))
        else:
            assert isinstance(
                channel_len, int
            ) and channel_len > 0, "Channel size must be a positive integer."
            self.scale = Parameter(torch.tensor([scale] * channel_len))
            self.zero_point = Parameter(
                torch.tensor([zero_point] * channel_len))

        self.activation_post_process = observer(**observer_kwargs)
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
            'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
            'quant_max out of bound'
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('fake_quant_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('static_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('learning_enabled',
                             torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())
        self.register_buffer('eps',
                             torch.tensor([torch.finfo(torch.float32).eps]))

    @torch.jit.export
    def enable_param_learning(self):
        r"""Enables learning of quantization parameters and
        disables static observer estimates. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=True) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=False)
        return self

    @torch.jit.export
    def enable_static_estimate(self):
        r"""Enables static observer estimates and disbales learning of
        quantization parameters. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def enable_static_observation(self):
        r"""Enables static observer accumulating data from input but doesn't
        update the quantization parameters. Forward path returns the original X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=False) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def toggle_observer_update(self, enabled=True):
        self.static_enabled[0] = int(enabled)  # type: ignore[operator]
        return self

    @torch.jit.export
    def enable_observer(self, enabled=True):
        self.toggle_observer_update(enabled)

    @torch.jit.export
    def toggle_qparam_learning(self, enabled=True):
        self.learning_enabled[0] = int(enabled)  # type: ignore[operator]
        self.scale.requires_grad = enabled
        self.zero_point.requires_grad = enabled
        return self

    @torch.jit.export
    def toggle_fake_quant(self, enabled=True):
        self.fake_quant_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def observe_quant_params(self):
        print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach()))
        print('_LearnableFakeQuantize Zero Point: {}'.format(
            self.zero_point.detach()))

    @torch.jit.export
    def calculate_qparams(self):
        self.scale.data.clamp_(min=self.eps.item())  # type: ignore[operator]
        scale = self.scale.detach()
        zero_point = self.zero_point.detach().round().clamp(
            self.quant_min, self.quant_max).long()
        return scale, zero_point

    def forward(self, X):
        if self.static_enabled[0] == 1:  # type: ignore[index]
            self.activation_post_process(X.detach())
            _scale, _zero_point = self.activation_post_process.calculate_qparams(
            )
            _scale = _scale.to(self.scale.device)
            _zero_point = _zero_point.to(self.zero_point.device)
            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point)
        else:
            self.scale.data.clamp_(
                min=self.eps.item())  # type: ignore[operator]

        if self.fake_quant_enabled[0] == 1:
            if self.qscheme in (torch.per_channel_symmetric,
                                torch.per_tensor_symmetric):
                self.zero_point.data.zero_()

            if self.use_grad_scaling:
                grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5
            else:
                grad_factor = 1.0
            if self.qscheme in (torch.per_channel_symmetric,
                                torch.per_channel_affine):
                X = torch._fake_quantize_learnable_per_channel_affine(
                    X, self.scale, self.zero_point, self.ch_axis,
                    self.quant_min, self.quant_max, grad_factor)
            else:
                X = torch._fake_quantize_learnable_per_tensor_affine(
                    X, self.scale, self.zero_point, self.quant_min,
                    self.quant_max, grad_factor)

        return X
Beispiel #25
0
class CustomEmbedding(torch.nn.Module):
    """
        Memory efficient way to compute weighted EmbeddingBag
    """
    def __init__(self,
                 num_embeddings,
                 embedding_dim,
                 max_norm=None,
                 norm_type=2,
                 scale_grad_by_freq=False,
                 sparse=False,
                 device="cuda:0"):
        """
            Args:
                num_embeddings: int: vocalubary size
                embedding_dim: int: dimension for embeddings
                padding_idx: int: index for <PAD>; embedding is not updated
                max_norm: 
                norm_type: int: default: 2
                scale_grad_by_freq: boolean: True/False
                sparse: boolean: sparse or dense gradients
        """
        super(CustomEmbedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = num_embeddings
        if self.padding_idx is not None:
            self.num_embeddings = num_embeddings + 1
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.weight = Parameter(
            torch.Tensor(self.num_embeddings, embedding_dim))
        self.sparse = sparse
        self.device = device
        self.reset_parameters()

    def reset_parameters(self):
        """
            Reset weights
        """
        self.weight.data.normal_(0, 1)
        if self.padding_idx is not None:
            self.weight.data[self.padding_idx].fill_(0)

    def to(self):
        super().to(self.device)

    def forward(self, batch_data):
        """
            Forward pass for embedding layer
            Arguments
                ----------
                batch_data: dict
                    {'X': torch.LongTensor (BxN),
                     'X_w': torch.FloatTensor (BxN)}
                    'X': Feature indices
                    'X_w': Feature weights
            Returns:
                ----------
                torch.Tensor
                    embedding for each sample B x embedding_dims
        """
        features = batch_data['X'].to(self.device)
        weights = batch_data['X_w'].to(self.device)
        out = F.embedding(features, self.weight, self.padding_idx,
                          self.max_norm, self.norm_type,
                          self.scale_grad_by_freq, self.sparse)
        out = weights.unsqueeze(1).bmm(out).squeeze()
        return out

    def get_weights(self):
        return self.weight.detach().cpu().numpy()

    def __repr__(self):
        s = '{name}({num_embeddings}, {embedding_dim}, {device}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def _init_(self, state_dict):
        keys = list(state_dict.keys())
        key = [x for x in keys if x.split(".")[-1] in ['weight']][0]
        weight = state_dict[key]
        self.weight.data.copy_(weight)
Beispiel #26
0
class CombineUV(nn.Module):
    """
        Combines all classifier components for DECAF
        Arguments
            ----------
            input_size: int
                Classifier's embeddinging dimension
            output_size: int
                Number of classifiers to train
            bias: bool (default=True)
                Flag to use bias
            use_classifier_wts: bool (default=False)
                If True learns a refinement vector
            padding_idx: int (default=None)
                Label padding index
            device: string (default="cuda:0")
                GPU id to keep the parameters
            sparse: bool (default=False)
                Type of optimizer to use for training
    """
    def __init__(self,
                 input_size,
                 output_size,
                 bias=True,
                 use_classifier_wts=False,
                 padding_idx=None,
                 device="cuda:0",
                 sparse=False):
        super(CombineUV, self).__init__()
        self._device = device  # Useful in case of multiple GPUs
        self.input_size = input_size
        self.output_size = output_size
        self.use_classifier_wts = use_classifier_wts
        self.sparse = sparse
        if self.sparse:
            self.device = "cpu"
        self.padding_idx = padding_idx

        if self.use_classifier_wts:
            if bias:
                if self.sparse:
                    self.bias = Parameter(torch.Tensor(self.output_size, 1))
                else:
                    self.bias = Parameter(torch.Tensor(self.output_size))
            else:
                self.register_parameter('bias', None)
            self.weight = Parameter(
                torch.Tensor(self.output_size, self.input_size))
            self.alpha = Parameter(torch.Tensor(1, self.input_size))
            self.beta = Parameter(torch.Tensor(1, self.input_size))
        else:
            self.register_parameter('bias', None)
            self.register_parameter('weight', None)
        self.prebias = None
        self.prepredict = None
        self.reset_parameters()

    def _get_clf(self, weights, shortlist=None, sparse=True, padding_idx=None):
        if shortlist is not None and weights is not None:
            return F.embedding(shortlist,
                               weights,
                               sparse=sparse,
                               padding_idx=padding_idx,
                               max_norm=None,
                               norm_type=2.,
                               scale_grad_by_freq=False)
        return weights

    def _get_rebuild(self, lbl_clf, shortlist=None):
        if self.prepredict is None:
            bias = self._get_clf(self.bias, shortlist, self.sparse,
                                 self.padding_idx)
            lbl_clf = lbl_clf.to(self.device)
            _lbl_clf = self._get_clf(self.weight, shortlist, self.sparse,
                                     self.padding_idx)
            if self.use_classifier_wts:
                lbl_clf = self.alpha.sigmoid()*_lbl_clf.squeeze() + \
                    self.beta.sigmoid()*lbl_clf
            return lbl_clf, bias
        else:
            return self.prepredict, self.prebias

    def forward(self, input, labels, shortlist=None, shortlist_first=None):
        input = input.to(self.device)
        if labels is not None:
            labels = labels.to(self.device)
        if shortlist_first is not None:
            shortlist_first = shortlist_first.to(self.device)
        lbl_clf, bias = self._get_rebuild(labels, shortlist_first)
        if shortlist is not None:
            shortlist = shortlist.to(self.device)
            input = input.to(self.device)
            lbl_clf = self._get_clf(lbl_clf, shortlist, sparse=False)
            bias = self._get_clf(bias, shortlist, sparse=False)
            out = torch.matmul(input.unsqueeze(1), lbl_clf.permute(0, 2, 1))
            if bias is not None:
                out = out.squeeze() + bias.squeeze()
            return out.squeeze()
        else:
            return F.linear(input, lbl_clf, bias)

    def _setup(self, lbl_clf):
        self.device = "cpu"
        if self.use_classifier_wts:
            norm_u = np.mean(
                torch.norm(self.weight, dim=1).detach().cpu().numpy())
            print("Alpha=",
                  torch.mean(self.alpha.detach().sigmoid()).cpu().numpy(),
                  "Beta=",
                  torch.mean(self.beta.detach().sigmoid()).cpu().numpy())
            norm_v = np.mean(torch.norm(lbl_clf, dim=1).detach().cpu().numpy())
            lbl_clf, _bias = self._get_rebuild(lbl_clf)
            norm_w = np.mean(torch.norm(lbl_clf, dim=1).detach().cpu().numpy())
            print("||u||=%0.2f, ||v||=%0.2f, ||w||=%0.2f" %
                  (norm_u, norm_v, norm_w))
            self.prebias = _bias + 0
        self.prepredict = lbl_clf

    def _pred_to(self):
        self.device = self._device
        self.prepredict = self.prepredict.to(self.device)
        if self.use_classifier_wts:
            self.prebias = self.prebias.to(self.device)

    def _pred_cpu(self):
        self.device = "cpu"
        self.prepredict = self.prepredict.cpu()
        if self.use_classifier_wts:
            self.prebias = self.prebias.cpu()

    def _clean(self):
        if self.prepredict is None:
            print("No need to clean")
            return
        print("Cleaing for GPU")
        self.prepredict = self.prepredict.cpu()
        if self.use_classifier_wts:
            self.prebias = self.prebias.cpu()
        self.prebias, self.prepredict = None, None

    def to(self):
        """
            Transfer to device
        """
        self.device = self._device
        super().to(self._device)

    def reset_parameters(self):
        """
            Initialize vectors
        """
        stdv = 1. / math.sqrt(self.input_size)

        if self.weight is not None:
            self.weight.data.uniform_(-stdv, stdv)
            self.alpha.data.fill_(0)
            self.beta.data.fill_(0)

        if self.bias is not None:
            self.bias.data.fill_(0)

    def get_weights(self):
        """
            Get weights as dictionary
        """
        parameters = {}
        if self.bias is not None:
            parameters['bias'] = self.bias.detach().cpu().numpy()
        if self.use_classifier_wts:
            parameters['weight'] = self.weight.detach().cpu().numpy()
            parameters['alpha'] = self.alpha.detach().cpu().numpy()
            parameters['beta'] = self.beta.detach().cpu().numpy()
        return parameters

    def __repr__(self):
        s = "{name}({input_size}, {output_size}, {_device}, {use_classifier_wts}, sparse={sparse}"
        if self.use_classifier_wts:
            if self.bias is not None:
                s += ', bias=True'
            s += ", Alpha={}".format(self.alpha.detach().cpu().numpy().shape)
            s += ", Beta={}".format(self.beta.detach().cpu().numpy().shape)
            s += ", weight={}".format(self.weight.detach().cpu().numpy().shape)
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def _init_(self, lbl_fts_mat, state_dict):
        if self.use_classifier_wts:
            print("INFO::Initializing the classifier")
            keys = list(state_dict.keys())
            key = [x for x in keys if x.split(".")[-1] in ['weight']][0]
            weight = state_dict[key]
            self.weight.data.copy_(lbl_fts_mat.mm(weight))
Beispiel #27
0
class Model:
    '''Model manager class

   Convenience class wrapping saving/loading,
   model initialization, optimization, and logging.

   Args:
      ann: Model to optimize. Used to initialize weights.
      config: A Config specification
      args: Hook for additional user arguments
   '''
    def __init__(self, ann, config, args):
        self.saver = save.Saver(config.MODELDIR,
                                'models',
                                'bests',
                                resetTol=256)
        self.config, self.args = config, args

        self.init(ann)
        if self.config.LOAD or self.config.BEST:
            self.load(self.config.BEST)

    def init(self, ann):
        print('Initializing new model...')
        self.initModel(ann)

        self.opt = None
        if not self.config.TEST:
            self.opt = ManualAdam([self.params],
                                  lr=0.001,
                                  weight_decay=0.00001)

    #Initialize a new network
    def initModel(self, ann):
        self.models = ann(self.config).params()
        self.params = Parameter(torch.Tensor(np.array(self.models)))

    #Grads and clip
    def stepOpt(self, gradList):
        '''Clip the provided gradients and step the optimizer

      Args:
         gradList: a list of gradients
      '''
        grad = np.array(gradList)
        grad = np.mean(grad, 0)
        grad = np.clip(grad, -5, 5)

        gradAry = torch.Tensor(grad)
        self.opt.step(gradAry)

    def checkpoint(self, reward):
        '''Save the model to checkpoint

      Args:
         reward: Mean reward of the model
      '''
        if self.config.TEST:
            return
        self.saver.checkpoint(self.params, self.opt, reward)

    def load(self, best=False):
        '''Load a model from file

      Args:
         best (bool): Whether to load the best (True)
             or most recent (False) checkpoint
      '''
        print('Loading model...')
        epoch = self.saver.load(self.opt, self.params, best)

    @property
    def nParams(self):
        '''Print the number of model parameters'''
        nParams = len(self.model)
        print('#Params: ', str(nParams / 1000), 'K')

    @property
    def model(self):
        '''Get model parameters

      Returns:
         a numpy array of model parameters
      '''
        return self.params.detach().numpy()
Beispiel #28
0
class WSConv2d(nn.Conv2d):
    def __init__(self,
                 in_planes,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 multiplier=1.0,
                 rep_dim=1,
                 repeat_weight=True,
                 use_coeff=False):
        if rep_dim == 0:
            # this is repeat along the channel dim (dimension 0 of the weights tensor)
            super(WSConv2d,
                  self).__init__(int(in_planes),
                                 int(np.ceil(out_channels / multiplier)),
                                 kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(
                np.ceil(1. * out_channels / self.weight.shape[0]))
        elif rep_dim == 1:
            # this is to repeat along the filter dim(dimension 1 of the weights tensor)
            super(WSConv2d,
                  self).__init__(int(np.ceil(in_planes / multiplier)),
                                 int(out_channels),
                                 kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1]))

        self.in_planes = in_planes
        self.out_channels_ori = out_channels
        self.groups = groups
        self.multiplier = multiplier
        self.rep_dim = rep_dim
        self.repeat_weight = repeat_weight
        self.use_coeff = use_coeff
        # print(self.weight.shape)
        # import pdb; pdb.set_trace()

        self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                           in_planes,
                                           kernel_size=3,
                                           stride=2,
                                           padding=0,
                                           bias=False)
        self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                           self.rep_time,
                                           kernel_size=1,
                                           stride=1,
                                           padding=0,
                                           bias=False)
        self.coefficient = Parameter(torch.Tensor(self.rep_time),
                                     requires_grad=False)
        self.reuse = False
        self.coeff_grad = None

    def generate_share_weight(self,
                              base_weight,
                              rep_num,
                              coeff,
                              nchannel,
                              dim=0):
        ''' sample weights from base weight'''
        # pdb.set_trace()
        if rep_num == 1:
            return base_weight
        new_weight = []
        for i in range(rep_num):
            if dim == 0:
                new_weight_temp = torch.cat(
                    [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]],
                    dim=0) * (1 - coeff[i])
            else:
                new_weight_temp = torch.cat(
                    [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]],
                    dim=1) * (1 - coeff[i])
            new_weight.append(base_weight * coeff[i] + new_weight_temp)
        out = torch.cat(new_weight, dim=dim)

        if dim == 0:
            return out[:nchannel, :, :, :]
        else:
            return out[:, :nchannel, :, :]

    def forward(self, x):
        """
            same padding as efficientnet tf version
        """
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])

        if self.use_coeff:
            if self.training:
                # set reuse to True for coefficient sharing
                if not self.reuse:
                    lr_conv1 = self.conv1_stride_lr_1(x)
                    # pdb.set_trace()
                    lr_conv1 = self.conv1_stride_lr_2(lr_conv1)
                    lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0,
                                                                       0]

                    self.coefficient.set_(
                        F.normalize(torch.mean(lr_conv1, 0),
                                    dim=0).clone().detach())
                    # pdb.set_trace()
                    self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0),
                                                  dim=0)

                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       self.out_channels_ori,
                                                       dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0)))
                else:
                    if self.repeat_weight:

                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       x.shape[1],
                                                       dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, x,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        out = F.conv2d(out_tmp, self.weight)

            else:
                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                self.out_channels_ori,
                                dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach()))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                x.shape[1],
                                dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)

                        out = F.conv2d(out_tmp, self.weight)
        else:
            # print("use_coeff == False")
            if self.rep_dim == 0:
                out = F.conv2d(
                    x,
                    self.weight.repeat([self.rep_time, 1, 1,
                                        1])[:self.out_channels_ori, :, :, :],
                    None, 1)
            else:
                out = F.conv2d(
                    x,
                    self.weight.repeat([1, self.rep_time, 1,
                                        1])[:, :x.shape[1], :, :], None, 1)
        return out
Beispiel #29
0
class Conv2dDPQ(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 qmin=1e-3,
                 qmax=100,
                 dmin=1e-5,
                 dmax=10,
                 bias=True,
                 sign=True,
                 wbits=4,
                 abits=4,
                 mode=Qmodes.layer_wise):
        """
        :param d_init: the inital quantization stepsize (alpha)
        :param mode: Qmodes.layer_wise or Qmodes.kernel_wise
        :param xmax_init: the quantization range for whole weights 
        """

        super(Conv2dDPQ, self).__init__(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding,
                                        dilation=dilation,
                                        groups=groups,
                                        bias=bias)

        self.qmin = qmin
        self.qmax = qmax
        self.dmin = dmin
        self.dmax = dmax
        self.q_mode = mode
        self.sign = sign
        self.nbits = wbits
        self.act_dpq = ActDPQ(signed=False, nbits=abits)
        self.alpha = Parameter(torch.Tensor(1))
        self.xmax = Parameter(torch.Tensor(1))
        self.weight.requires_grad_(True)
        if bias:
            self.bias.requires_grad_(True)
        self.register_buffer('init_state', torch.zeros(1))

    def get_nbits(self):
        abits = self.act_dpq.get_nbits()
        xmax = self.xmax.abs().item()
        alpha = self.alpha.abs().item()
        if self.sign:
            nbits = math.ceil(math.log(xmax / alpha + 1) / math.log(2) + 1)
        else:
            nbits = math.cell(math.log(xmax / alpha + 1) / math.log(2))

        self.nbits = nbits
        return abits, nbits

    def get_quan_filters(self, filters):
        if self.training and self.init_state == 0:
            Qp = 2**(self.nbits - 1) - 1
            self.xmax.data.copy_(filters.abs().max())
            self.alpha.data.copy_(self.xmax / Qp)
            # self.alpha[self.index].data.copy_(2 * filters.abs().mean() / math.sqrt(Qp))
            # self.xmax[self.index].data.copy_(self.alpha[self.index] * Qp)
            self.init_state.fill_(1)

        Qp = (self.xmax.detach() / self.alpha.detach()).abs().item()
        g = 1.0 / math.sqrt(filters.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)

        w = F.hardtanh(filters / xmax.abs(), -1, 1) * xmax.abs()
        w = w / alpha.abs()
        wq = round_pass(w) * alpha.abs()

        return wq

    def forward(self, x):
        if self.act_dpq is not None:
            x = self.act_dpq(x)

        wq = self.get_quan_filters(self.weight)
        return F.conv2d(x, wq, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)
Beispiel #30
0
class _ConvBnNd(nn.modules.conv._ConvNd):

  _version = 2

  def __init__(
      self,
      # ConvNd args
      in_channels,
      out_channels,
      kernel_size,
      stride,
      padding,
      dilation,
      transposed,
      output_padding,
      groups,
      bias,
      padding_mode,
      # BatchNormNd args
      # num_features: out_channels
      eps=1e-05,
      momentum=0.1,
      # affine: True
      # track_running_stats: True
      # Args for this module
      freeze_bn_stats=False,
      rt_spec=None,
      dim=2):
    nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels,
                                     kernel_size, stride, padding, dilation,
                                     transposed, output_padding, groups, False,
                                     padding_mode)
    assert rt_spec, 'Runtime spec be provided for quantized module {}'.format(
        self.__class__.__name__)
    self.bn_frozen = freeze_bn_stats if self.training else True
    self.dim = dim
    self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)

    self.weight_quantizer = rt_spec.get_weight_quantizer('weight')
    self.bias_quantizer = rt_spec.get_weight_quantizer('bias')

    if bias:
      self.bias = Parameter(torch.Tensor(out_channels))
    else:
      self.register_parameter('bias', None)
    self.reset_bn_parameters()

    # this needs to be called after reset_bn_parameters,
    # as they modify the same state
    if self.training:
      if freeze_bn_stats:
        self.freeze_bn_stats()
      else:
        self.update_bn_stats()
    else:
      self.freeze_bn_stats()

    self.conv_bn_fused = False

  def reset_running_stats(self):
    self.bn.reset_running_stats()

  def reset_bn_parameters(self):
    self.bn.reset_running_stats()
    init.uniform_(self.bn.weight)
    init.zeros_(self.bn.bias)
    # note: below is actully for conv, not BN
    if self.bias is not None:
      fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
      bound = 1 / math.sqrt(fan_in)
      init.uniform_(self.bias, -bound, bound)

  def batch_stats(self, x, bias=None):
    """Get the batch mean and variance of x and updates the BatchNorm's running mean and average.

      Args:
        x (torch.Tensor): input batch.
        bias (torch.Tensor): the bias that is to be applied to the batch.

      Returns:
        (mean, variance)

      Note:
        In case of `nn.Linear`, x may be of shape (N, C, L) or (N, L)
        where N is batch size, C is number of channels, L is the features size.
        The batch norm computes the stats over C in the first case or L on the second case.
        The batch normalization layer is
        (`nn.BatchNorm1d`)[https://pytorch.org/docs/stable/nn.html#batchnorm1d]

        In case of `nn.Conv2d`, x is of shape (N, C, H, W)
        where H,W are the image dimensions, and the batch norm computes the stats over C.
        The batch normalization layer is
        (`nn.BatchNorm2d`)[https://pytorch.org/docs/stable/nn.html#batchnorm2d]

        In case of `nn.Conv3d`, x is of shape (N, C, D, H, W)
        where H,W are the image dimensions, D is additional channel dimension,
        and the batch norm computes the stats over C.
        The batch normalization layer is
        (`nn.BatchNorm3d`)[https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html#torch.nn.BatchNorm3d]
    """
    channel_size = self.bn.num_features
    self.bn.num_batches_tracked += 1

    # Calculate current batch stats
    batch_mean = x.transpose(0, 1).contiguous().view(channel_size, -1).mean(1)
    # BatchNorm currently uses biased variance (without Bessel's correction) as was discussed at
    # https://github.com/pytorch/pytorch/issues/1410
    #
    # also see the source code itself:
    # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L216
    batch_var = x.transpose(0, 1).contiguous().view(channel_size, -1).var(
        1, unbiased=False)

    # Update running stats
    with torch.no_grad():
      biased_batch_mean = batch_mean + (bias if bias is not None else 0)
      # However - running_var is updated using unbiased variance!
      # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L223
      n = x.numel() / channel_size
      corrected_var = batch_var * (n / float(n - 1))
      momentum = self.bn.momentum
      if momentum is None:
        # momentum is None - we compute a cumulative moving average
        # as noted in https://pytorch.org/docs/stable/nn.html#batchnorm2d
        momentum = 1. / float(self.bn.num_batches_tracked)
      self.bn.running_mean.mul_(1 - momentum).add_(momentum * biased_batch_mean)
      self.bn.running_var.mul_(1 - momentum).add_(momentum * corrected_var)

    return batch_mean, batch_var

  def reset_parameters(self):
    super(_ConvBnNd, self).reset_parameters()

  def merge_bn_to_conv(self):
    with torch.no_grad():
      # Use the same implementation in nndct_shared/optimzation/fuse_conv_bn.py
      # to make sure the test accruacy is same as the deployable model.
      gamma = self.bn.weight.detach().cpu().numpy()
      beta = self.bn.bias.detach().cpu().numpy()
      running_var = self.bn.running_var.detach().cpu().numpy()
      running_mean = self.bn.running_mean.detach().cpu().numpy()
      epsilon = self.bn.eps

      scale = gamma / np.sqrt(running_var + epsilon)
      offset = beta - running_mean * scale

      weight = self.weight.detach().cpu().numpy()
      # Conv2d
      if self.dim == 2 and not self.transposed:
        # OIHW -> IHWO -> OIHW
        weight = np.multiply(weight.transpose(1, 2, 3, 0),
                             scale).transpose(3, 0, 1, 2)
      # ConvTranspose2d
      elif self.dim == 2 and self.transposed:
        # IOHW -> IHWO -> IOHW
        weight = np.multiply(weight.transpose(0, 2, 3, 1),
                             scale).transpose(0, 3, 1, 2)
      # Conv3D
      elif self.dim == 3 and not self.transposed:
        weight = np.multiply(weight.transpose(1, 2, 3, 4, 0),
                             scale).transpose(4, 0, 1, 2, 3)
      # ConvTranspose3d
      elif self.dim == 3 and self.transposed:
        weight = np.multiply(weight.transpose(2, 3, 4, 0, 1),
                             scale).transpose(3, 4, 0, 1, 2)
      else:
        raise RuntimeError(
            'Unsupported combinations: (dim={}, transposed={})'.format(
                self.dim, self.transposed))
      self.weight.copy_(torch.from_numpy(weight))

      bias = self.bias.detach().cpu().numpy() if self.bias is not None else 0
      bias = torch.from_numpy(bias * scale + offset)
      if self.bias is not None:
        self.bias.copy_(bias)
      else:
        self.bias = nn.Parameter(bias)
    self.conv_bn_fused = True

  def update_bn_stats(self):
    self.bn_frozen = False

  def freeze_bn_stats(self):
    self.bn_frozen = True

  def clear_non_native_bias(self):
    if self.bias is None:
      print('[WARNING] No bias to unmerge')
      return

    with torch.no_grad():
      gamma = self.bn.weight.detach().cpu().numpy()
      beta = self.bn.bias.detach().cpu().numpy()
      running_var = self.bn.running_var.detach().cpu().numpy()
      running_mean = self.bn.running_mean.detach().cpu().numpy()
      epsilon = self.bn.eps

      scale = gamma / np.sqrt(running_var + epsilon)

      bias = self.bias.detach().cpu().numpy()
      beta = torch.from_numpy(bias * scale + beta)
      self.bn.bias.copy_(beta)
      self.bias = None

  def broadcast_correction(self, c):
    """Broadcasts a correction factor to the output for elementwise operations.

    Two tensors are “broadcastable” if the following rules hold:
      - Each tensor has at least one dimension.
      - When iterating over the dimension sizes, starting at the trailing
        dimension, the dimension sizes must either be equal,
        one of them is 1, or one of them does not exist.
    See https://pytorch.org/docs/stable/notes/broadcasting.html
    """
    expected_output_dim = self.dim + 2
    view_fillers_dim = expected_output_dim - c.dim() - 1
    view_filler = (1,) * view_fillers_dim
    expected_view_shape = c.shape + view_filler
    return c.view(*expected_view_shape)

  def broadcast_correction_weight(self, c):
    """Broadcasts a correction factor to the weight."""
    if c.dim() != 1:
      raise ValueError("Correction factor needs to have a single dimension")

    expected_weight_dim = self.dim + 2
    view_fillers_dim = expected_weight_dim - c.dim()
    view_filler = (1,) * view_fillers_dim
    expected_view_shape = c.shape + view_filler
    return c.view(*expected_view_shape)

  def extra_repr(self):
    return super(_ConvBnNd, self).extra_repr()

  def forward(self, x, output_size=None):
    """
    See https://arxiv.org/pdf/1806.08342.pdf section 3.2.2.
    bn(conv(x)) = (conv(x) - E(conv(x))) * gamma / std(conv(x)) + beta
                = (x*W + B - E(x*W + B)) * gamma / sqrt(E((x*W + B - E(x*W + B))^2)) + beta
                = (x*W - E(x*W)) * gamma / std(x*W) + beta
    """
    gamma, beta = self.bn.weight, self.bn.bias
    if self.conv_bn_fused:
      quantized_weight = self.weight_quantizer(self.weight)
      quantized_bias = self.bias_quantizer(self.bias)
      return self._conv_forward(x, quantized_weight, quantized_bias,
                                output_size)

    if self.training and not self.bn_frozen:
      batch_mean, batch_var = self.batch_stats(
          self._conv_forward(x, self.weight, output_size=output_size),
          self.bias)
      recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps)
      with torch.no_grad():
        running_sigma = torch.sqrt(self.bn.running_var + self.bn.eps)

      w_corrected = self.weight * self.broadcast_correction_weight(
          gamma / running_sigma)
      w_quantized = self.weight_quantizer(w_corrected)
      recip_c = self.broadcast_correction(running_sigma * recip_sigma_batch)
      bias_corrected = beta - gamma * batch_mean * recip_sigma_batch
      bias_quantized = self.broadcast_correction(
          self.bias_quantizer(bias_corrected))

      y = self._conv_forward(x, w_quantized, None, output_size)
      y.mul_(recip_c).add_(bias_quantized)
    else:
      with torch.no_grad():
        recip_running_sigma = torch.rsqrt(self.bn.running_var + self.bn.eps)
      w_corrected = self.weight * self.broadcast_correction_weight(
          gamma * recip_running_sigma)
      w_quantized = self.weight_quantizer(w_corrected)
      mean_corrected = self.bn.running_mean - (
          self.bias if self.bias is not None else 0)
      bias_corrected = beta - gamma * mean_corrected * recip_running_sigma
      bias_quantized = self.bias_quantizer(bias_corrected)
      y = self._conv_forward(x, w_quantized, bias_quantized, output_size)
    return y

  def train(self, mode=True):
    """Batchnorm's training behavior is using the self.training flag. Prevent
    changing it if BN is frozen. This makes sure that calling `model.train()`
    on a model with a frozen BN will behave properly.
    """
    self.training = mode
    if not self.bn_frozen:
      for module in self.children():
        module.train(mode)
    return self

  @property
  def is_quantized(self):
    return True

  @classmethod
  def from_float(cls, conv, bn, rt_spec):
    """Create a qat module from a float module."""
    assert rt_spec, 'Input float module must have a valid rt_spec'
    convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
                 conv.stride, conv.padding, conv.dilation, conv.groups,
                 conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum,
                 False, rt_spec)
    convbn.weight = conv.weight
    convbn.bias = conv.bias
    convbn.bn.weight = bn.weight
    convbn.bn.bias = bn.bias
    convbn.bn.running_mean = bn.running_mean
    convbn.bn.running_var = bn.running_var
    convbn.bn.num_batches_tracked = bn.num_batches_tracked
    convbn.bn.eps = bn.eps
    return convbn
Beispiel #31
0
class Model:
    '''Model manager class

   Convenience class wrapping saving/loading,
   model initialization, optimization, and logging.

   Args:
      ann: Model to optimize. Used to initialize weights.
      config: A Config specification
      args: Hook for additional user arguments
   '''
    def __init__(self, ann, config):
        self.saver = save.Saver(config.MODELDIR,
                                'models',
                                'bests',
                                resetTol=256)
        self.config = config

        print('Initializing new model...')
        self.net = ann(config)
        self.parameters = Parameter(
            torch.Tensor(np.array(getParameters(self.net))))

    def load(self, opt, best=False):
        '''Load a model from file

      Args:
         best (bool): Whether to load the best (True)
             or most recent (False) checkpoint
      '''
        print('Loading model...')
        epoch = self.saver.load(opt, self.parameters, best)
        self.syncParameters()

    def checkpoint(self, opt, reward):
        '''Save the model to checkpoint

      Args:
         reward: Mean reward of the model
      '''
        self.saver.checkpoint(self.parameters, opt, reward)
        self.saver.print()

    @property
    def nParams(self):
        '''Print the number of model parameters'''
        nParams = len(self.weights)
        print('#Params: ', str(nParams / 1000), 'K')

    def syncParameters(self):
        parameters = self.parameters.detach().numpy().tolist()
        setParameters(self.net, parameters)

    @property
    def weights(self):
        '''Get model parameters

      Returns:
         a numpy array of model parameters
      '''
        return getParameters(self.net)