コード例 #1
0
ファイル: tmp_mix.py プロジェクト: leliyliu/codes-for-papers
class ActDPQ(nn.Module):
    def __init__(self, signed=True, nbits=4, qmin=1e-3, qmax=100, dmin=1e-5, dmax=10):
        """
        :param nbits: the initial quantization bit width of activation
        :param signed: whether the activation data is signed
        """
        super(ActDPQ, self).__init__()
        self.qmin = qmin
        self.qmax = qmax
        self.dmin = dmin 
        self.dmax = dmax
        self.signed = signed
        self.nbits = nbits
        self.alpha = Parameter(torch.Tensor(1))
        self.xmax = Parameter(torch.Tensor(1))
        self.register_buffer('init_state', torch.zeros(1))

    def get_nbits(self):
        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        if self.signed:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2) + 1).ceil()
        else:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2)).ceil()
        self.nbits = int(nbits.item())
        return nbits

    def forward(self, x):
        if self.alpha is None:
            return x
        
        if self.init_state == 0:
            Qp = 2 ** (self.nbits - 1) - 1
            self.alpha.data.copy_(2 * x.abs().mean() / math.sqrt(Qp))
            self.xmax.data.copy_(self.alpha * Qp)
            self.init_state.fill_(1)

        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        Qp = (self.xmax/self.alpha).item()
        # alpha = quantize_pow2(self.alpha)
        # alpha = self.alpha
        g = 1.0 / math.sqrt(x.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)
        # xmax = self.xmax

        if self.signed: 
            x = round_pass((torch.clamp(x/xmax, -1, 1)*xmax)/alpha) * alpha
        else:
            x = round_pass((torch.clamp(x/xmax, 0, 1)*xmax)/alpha) * alpha
        
        return x
コード例 #2
0
class TPReLU(Module):
    def __init__(self, num_parameters=1, init=0.25):
        self.num_parameters = num_parameters
        super(TPReLU, self).__init__()
        self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
        self.bias = Parameter(torch.zeros(num_parameters))

    def forward(self, input):
        bias_resize = self.bias.view(1, self.num_parameters,
                                     *((1, ) *
                                       (input.dim() - 2))).expand_as(input)
        return F.prelu(input - bias_resize, self.weight.clamp(0,
                                                              1)) + bias_resize

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
         + str(self.num_parameters) + ')'
コード例 #3
0
ファイル: tmp_mix.py プロジェクト: leliyliu/codes-for-papers
class Conv2dDPQ(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 
                 qmin=1e-3, qmax=100, dmin=1e-5, dmax=10, bias=True, sign=True, wbits=4, abits=4, mode=Qmodes.layer_wise):
    
        """
        :param d_init: the inital quantization stepsize (alpha)
        :param mode: Qmodes.layer_wise or Qmodes.kernel_wise
        :param xmax_init: the quantization range for whole weights 
        """

        super(Conv2dDPQ, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                        stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
        self.qmin = qmin
        self.qmax = qmax
        self.dmin = dmin 
        self.dmax = dmax
        self.q_mode = mode
        self.sign = sign
        self.nbits = wbits 
        self.act_dpq = ActDPQ(signed=True, nbits=abits)
        if self.q_mode == Qmodes.kernel_wise:
            self.alpha = Parameter(torch.Tensor(out_channels))
        else:
            self.alpha = Parameter(torch.Tensor(1))
        self.xmax = Parameter(torch.Tensor(1))
        self.weight.requires_grad_(True)
        if bias:
            self.bias.requires_grad_(True)
        self.register_buffer('init_state', torch.zeros(1))

    def get_nbits(self):
        abits = self.act_dpq.get_nbits()
        # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        # print('after clamp, the result is :')
        # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
        if self.sign:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2) + 1).ceil()
        else:
            nbits = (torch.log(self.xmax/self.alpha + 1) / math.log(2)).ceil()

        # print('the nbits for weight is  : {}'.format(nbits))
        self.nbits = int(nbits.item())
        return abits, nbits

    def get_quan_filters(self, filters):

        if self.training and self.init_state == 0:
            # print('initial alphas in current time !')
            Qp = 2 ** (self.nbits - 1) - 1
            self.alpha.data.copy_(2 * filters.abs().mean() / math.sqrt(Qp))
            self.xmax.data.copy_(self.alpha * Qp)
            # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
            # self.xmax.data.copy_(self.weight.abs().max())
            # wmean = self.weight.abs().mean()
            # self.alpha.data.copy_(4 * wmean * wmean / self.xmax)
            self.init_state.fill_(1)

        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        Qp = (self.xmax.detach()/self.alpha.detach()).item()
        g = 1.0 / math.sqrt(filters.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)
        # alpha = quantize_pow2(self.alpha)
        # xmax = self.xmax
        # alpha = self.alpha
        # g = 1.0 / math.sqrt(self.weight.numel() * (xmax.detach()/d.detach()).item())
        # d = grad_scale(self.alpha, g)
        # xmax = torch.clamp(self.xmax, self.xmax_min, self.xmax_max)

        if self.sign:
            wq = round_pass((torch.clamp(filters/xmax, -1, 1) * xmax)/alpha) * alpha
        else:
            wq = round_pass((torch.clamp(filters/xmax, 0, 1) * xmax)/alpha) * alpha 

        return wq

    def forward(self,x):
        if self.training and self.init_state == 0:
            # print('initial alphas in current time !')
            Qp = 2 ** (self.nbits - 1) - 1
            self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.xmax.data.copy_(self.alpha * Qp)
            # print('the xmax is : {} and the alpha is : {} '.format(self.xmax.data, self.alpha.data))
            # self.xmax.data.copy_(self.weight.abs().max())
            # wmean = self.weight.abs().mean()
            # self.alpha.data.copy_(4 * wmean * wmean / self.xmax)
            self.init_state.fill_(1)

        # alpha = quantize_pow2(self.alpha)
        # alpha = self.alpha
        # g = 1.0 / math.sqrt(self.weight.numel() * Qp)
        # xmax = self.xmax
        self.xmax.data.copy_(self.xmax.clamp(self.qmin, self.qmax))
        self.alpha.data.copy_(self.alpha.clamp(self.dmin, self.dmax))
        Qp = (self.xmax.detach()/self.alpha.detach()).item()
        g = 1.0 / math.sqrt(self.weight.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)

        if self.sign:
            wq = round_pass((torch.clamp(self.weight/xmax, -1, 1) * xmax)/alpha) * alpha
        else:
            wq = round_pass((torch.clamp(self.weight/xmax, 0, 1) * xmax)/alpha) * alpha 
        
        if self.act_dpq is not None:
            x = self.act_dpq(x)
        
        # print('the abits is : {} and the nbits is : {}'.format(self.act_dpq.nbits, self.nbits))
        return F.conv2d(x, wq, self.bias, self.stride, self.padding, self.dilation, self.groups)
コード例 #4
0
class _LearnableFakeQuantize(nn.Module):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.

    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.

    * :attr: `channel_len` defines the length of the channel when initializing scale and zero point
             for the per channel case.

    * :attr: `grad_factor` defines a factor that will be multiplied to the gradients for scale
             and zero point during the backward path for the learnable fake quantization operators.
             By default, it is 1.

    * :attr: `fake_quant_enabled` defines the flag for enabling fake quantization on the output.

    * :attr: `static_enabled` defines the flag for using observer's static estimation for
             scale and zero point.

    * attr: `learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
    """
    def __init__(self,
                 observer,
                 quant_min=0,
                 quant_max=255,
                 scale=1.,
                 zero_point=0.,
                 channel_len=-1,
                 grad_factor=1.):
        super(_LearnableFakeQuantize, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max
        self.grad_factor = grad_factor

        if channel_len == -1:
            self.scale = Parameter(torch.tensor([scale]))
            self.zero_point = Parameter(torch.tensor([zero_point]))
        else:
            assert isinstance(
                channel_len, int
            ) and channel_len > 0, "Channel size must be a positive integer."
            self.scale = Parameter(torch.tensor([scale] * channel_len))
            self.zero_point = Parameter(
                torch.tensor([zero_point] * channel_len))

        self.activation_post_process = observer
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
               'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
               'quant_max out of bound'
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('fake_quant_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('static_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('learning_enabled',
                             torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())

    @torch.jit.export
    def enable_param_learning(self):
        r"""Enables learning of quantization parameters and
        disables static observer estimates. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=True) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=False)
        return self

    @torch.jit.export
    def enable_static_estimate(self):
        r"""Enables static observer estimates and disbales learning of
        quantization parameters. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def enable_static_observation(self):
        r"""Enables static observer accumulating data from input but doesn't
        update the quantization parameters. Forward path returns the original X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=False) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def toggle_observer_update(self, enabled=True):
        self.static_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def toggle_qparam_learning(self, enabled=True):
        self.learning_enabled[0] = int(enabled)
        self.scale.requires_grad = enabled
        self.zero_point.requires_grad = enabled
        return self

    @torch.jit.export
    def toggle_fake_quant(self, enabled=True):
        self.fake_quant_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def observe_quant_params(self):
        print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach()))
        print('_LearnableFakeQuantize Zero Point: {}'.format(
            self.zero_point.detach()))

    @torch.jit.export
    def calculate_qparams(self):
        return self.activation_post_process.calculate_qparams()

    def forward(self, X):
        self.activation_post_process(X.detach())
        _scale, _zero_point = self.calculate_qparams()
        _scale = _scale.to(self.scale.device)
        _zero_point = _zero_point.to(self.zero_point.device)

        if self.static_enabled[0] == 1:
            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point)

        if self.fake_quant_enabled[0] == 1:
            if self.learning_enabled[0] == 1:
                self.zero_point.clamp(self.quant_min, self.quant_max)
                if self.qscheme in (torch.per_channel_symmetric,
                                    torch.per_channel_affine):
                    X = _LearnableFakeQuantizePerChannelOp.apply(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, self.grad_factor)
                else:
                    X = _LearnableFakeQuantizePerTensorOp.apply(
                        X, self.scale, self.zero_point, self.quant_min,
                        self.quant_max, self.grad_factor)
            else:
                if self.qscheme == torch.per_channel_symmetric or \
                        self.qscheme == torch.per_channel_affine:
                    X = torch.fake_quantize_per_channel_affine(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max)
                else:
                    X = torch.fake_quantize_per_tensor_affine(
                        X, float(self.scale.item()),
                        int(self.zero_point.item()), self.quant_min,
                        self.quant_max)

        return X

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        # We will be saving the static state of scale (instead of as a dynamic param).
        super(_LearnableFakeQuantize,
              self)._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + 'scale'] = self.scale.data
        destination[prefix + 'zero_point'] = self.zero_point

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        local_state = ['scale', 'zero_point']
        for name in local_state:
            key = prefix + name
            if key in state_dict:
                val = state_dict[key]
                if name == 'scale':
                    self.scale.data.copy_(val)
                else:
                    setattr(self, name, val)
            elif strict:
                missing_keys.append(key)
        super(_LearnableFakeQuantize,
              self)._load_from_state_dict(state_dict, prefix, local_metadata,
                                          strict, missing_keys,
                                          unexpected_keys, error_msgs)
コード例 #5
0
ファイル: readouts.py プロジェクト: jcbyts/V1FreeViewingCode
class Point2DGaussian(Readout):
    """
    A readout using a spatial transformer layer whose positions are sampled from one Gaussian per neuron. Mean
    and covariance of that Gaussian are learned.

    Args:
        in_shape (list, tuple): shape of the input feature map [channels, width, height]
        outdims (int): number of output units
        bias (bool): adds a bias term
        init_mu_range (float): initialises the the mean with Uniform([-init_range, init_range])
                            [expected: positive value <=1]. Default: 0.1
        init_sigma (float): The standard deviation of the Gaussian with `init_sigma` when `gauss_type` is
            'isotropic' or 'uncorrelated'. When `gauss_type='full'` initialize the square root of the
            covariance matrix with with Uniform([-init_sigma, init_sigma]). Default: 1
        batch_sample (bool): if True, samples a position for each image in the batch separately
                            [default: True as it decreases convergence time and performs just as well]
        align_corners (bool): Keyword agrument to gridsample for bilinear interpolation.
                It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the
                behavior to pre PyTorch 1.3 functionality for comparability.
        gauss_type (str): Which Gaussian to use. Options are 'isotropic', 'uncorrelated', or 'full' (default).
        shifter (dict): Parameters for a predictor of shfiting grid locations. Has to have a form like
                        {
                        'hidden_layers':1,
                        'hidden_features':20,
                        'final_tanh': False,
                        }
    """

    def __init__(self, in_shape=[10,10,10],
                outdims=10,
                bias=True,
                init_mu_range=0.1,
                init_sigma=1,
                gamma_l1=0.001,
                gamma_l2=0.1,
                batch_sample=True,
                align_corners=True,
                gauss_type='uncorrelated',
                shifter=None,
                constrain_positive=False,
                **kwargs):

        super().__init__()

        # pytorch lightning helper to save all hyperparamters
        self.save_hyperparameters()

        # determines whether the Gaussian is isotropic or not
        self.gauss_type = gauss_type

        if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma <= 0.0:
            raise ValueError("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive")

        # store statistics about the images and neurons
        self.in_shape = in_shape
        self.outdims = outdims

        # sample a different location per example
        self.batch_sample = batch_sample

        # constrain feature vector to be positive
        self.constrain_positive = constrain_positive

        # position grid shape
        self.grid_shape = (1, outdims, 1, 2)

        # initialize means
        self._mu = Parameter(torch.Tensor(*self.grid_shape))  # mean location of gaussian for each neuron

        if gauss_type == 'full':
            self.sigma_shape = (1, outdims, 2, 2)
        elif gauss_type == 'uncorrelated':
            self.sigma_shape = (1, outdims, 1, 2)
        elif gauss_type == 'isotropic':
            self.sigma_shape = (1, outdims, 1, 1)
        else:
            raise ValueError(f'gauss_type "{gauss_type}" not known')

        self.init_sigma = init_sigma
        self.sigma = Parameter(torch.Tensor(*self.sigma_shape))  # standard deviation for gaussian for each neuron

        self.initialize_features()

        if shifter:
            self.shifter = nn.Sequential()
            layer = OrderedDict()
            if shifter["hidden_layers"]==0:
                layer["linear"] = nn.Linear(2, 2, bias=True)
                if shifter["final_tanh"]:
                    layer["activation"] = nn.Tanh()
            else:
                layer["linear"] = nn.Linear(2, shifter["hidden_features"], bias=False)
                if "activation" in shifter.keys():
                    if shifter["activation"]=="relu":
                        layer["activation"] = nn.ReLU()
                    elif shifter["activation"]=="softplus":
                        if "lengthscale" in shifter.keys():
                            layer["activation"] = nn.Softplus(beta=shifter["lengthscale"])
                        else:
                            layer["activation"] = nn.Softplus()
                else:
                        layer["activation"] = nn.ReLU()
            
            self.shifter.add_module("layer0", nn.Sequential(layer))

            for l in range(1,shifter['hidden_layers']+1):
                layer = OrderedDict()
                if l == shifter['hidden_layers']: # is final layer
                    layer["linear"] = nn.Linear(shifter["hidden_features"],2,bias=True)
                    if shifter["final_tanh"]:
                        layer["activation"] = nn.Tanh()
                else:
                    layer["linear"] = nn.Linear(shifter["hidden_features"],shifter["hidden_features"],bias=True)
                    if "activation" in shifter.keys():
                        if shifter["activation"]=="relu":
                            layer["activation"] = nn.ReLU()
                        elif shifter["activation"]=="softplus":
                            if "lengthscale" in shifter.keys():
                                layer["activation"] = nn.Softplus(beta=shifter["lengthscale"])
                            else:
                                layer["activation"] = nn.Softplus()
                    else:
                        layer["activation"] = nn.ReLU()

                self.shifter.add_module("layer{}".format(l), nn.Sequential(layer))
        else:
            self.shifter = None

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter("bias", bias)
        else:
            self.register_parameter("bias", None)

        self.register_buffer("regvalplaceholder", torch.zeros((1,2)))
        self.init_mu_range = init_mu_range
        self.align_corners = align_corners
        self.initialize()

    @property
    def features(self):
        if self.constrain_positive:
            feat = F.relu(self._features)
        else:
            feat = self._features

        if self._shared_features:
            feat = self.scales * feat[..., self.feature_sharing_index]
        
        return feat

    @property
    def grid(self):
        return self.sample_grid(batch_size=1, sample=False)

    def feature_l1(self, average=True):
        """
        Returns the l1 regularization term either the mean or the sum of all weights
        Args:
            average(bool): if True, use mean of weights for regularization
        """
        if self._original_features:
            if average:
                return self._features.abs().mean()
            else:
                return self._features.abs().sum()
        else:
            return 0

    @property
    def mu(self):
        return self._mu

    def sample_grid(self, batch_size, sample=None):
        """
        Returns the grid locations from the core by sampling from a Gaussian distribution
        Args:
            batch_size (int): size of the batch
            sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron
                            or use the mean, mu, of the Gaussian distribution without sampling.
                           if sample is None (default), samples from the N(mu,sigma) during training phase and
                             fixes to the mean, mu, during evaluation phase.
                           if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed
        """
        with torch.no_grad():
            self.mu.clamp(min=-1, max=1)  # at eval time, only self.mu is used so it must belong to [-1,1]
            if self.gauss_type != 'full':
                self.sigma.clamp(min=0)  # sigma/variance i    s always a positive quantity

        grid_shape = (batch_size,) + self.grid_shape[1:]

        sample = self.training if sample is None else sample
        if sample:
            norm = self.mu.new(*grid_shape).normal_()
        else:
            norm = self.mu.new(*grid_shape).zero_()  # for consistency and CUDA capability

        if self.gauss_type != 'full':
            return (norm * self.sigma + self.mu).clamp(-1,1) # grid locations in feature space sampled randomly around the mean self.mu
        else:
            return (torch.einsum('ancd,bnid->bnic', self.sigma, norm) + self.mu).clamp_(-1,1) # grid locations in feature space sampled randomly around the mean self.mu


    def initialize(self):
        """
        Initializes the mean, and sigma of the Gaussian readout along with the features weights
        """

        self._mu.data.uniform_(-self.init_mu_range, self.init_mu_range)

        if self.gauss_type != 'full':
            self.sigma.data.fill_(self.init_sigma)
        else:
            self.sigma.data.uniform_(-self.init_sigma, self.init_sigma)

        self._features.data.fill_(1 / self.in_shape[0])

        if self.bias is not None:
            self.bias.data.fill_(0)

    def initialize_features(self, match_ids=None):
        import numpy as np
        """
        The internal attribute `_original_features` in this function denotes whether this instance of the FullGuassian2d
        learns the original features (True) or if it uses a copy of the features from another instance of FullGaussian2d
        via the `shared_features` (False). If it uses a copy, the feature_l1 regularizer for this copy will return 0
        """
        c, w, h = self.in_shape
        self._original_features = True
        if match_ids is not None:
            raise ValueError(f'match_ids to combine across session "{match_ids}" is not implemented yet')
        else:
            self._features = Parameter(torch.Tensor(1, c, 1, self.outdims))  # feature weights for each channel of the core
            self._shared_features = False
    

    def forward(self, x, sample=None, shift=None, out_idx=None):
        """
        Propagates the input forwards through the readout
        Args:
            x: input data
            sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron
                            or use the mean, mu, of the Gaussian distribution without sampling.
                           if sample is None (default), samples from the N(mu,sigma) during training phase and
                             fixes to the mean, mu, during evaluation phase.
                           if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed
            shift (bool): shifts the location of the grid (from eye-tracking data)
            out_idx (bool): index of neurons to be predicted

        Returns:
            y: neuronal activity
        """
        N, c, w, h = x.size()
        c_in, w_in, h_in = self.in_shape
        if (c_in, w_in, h_in) != (c, w, h):
            raise ValueError("the specified feature map dimension is not the readout's expected input dimension")
        feat = self.features
        feat = feat.reshape(1, c, self.outdims)
        bias = self.bias
        outdims = self.outdims

        if self.batch_sample:
            # sample the grid_locations separately per sample per batch
            grid = self.sample_grid(batch_size=N, sample=sample)  # sample determines sampling from Gaussian
        else:
            # use one sampled grid_locations for all sample in the batch
            grid = self.sample_grid(batch_size=1, sample=sample).expand(N, outdims, 1, 2)

        if out_idx is not None:
            if isinstance(out_idx, np.ndarray):
                if out_idx.dtype == bool:
                    out_idx = np.where(out_idx)[0]
            feat = feat[:, :, out_idx]
            grid = grid[:, out_idx]
            if bias is not None:
                bias = bias[out_idx]
            outdims = len(out_idx)

        if shift is not None:
            # shifter is run outside the readout forward
            grid = grid + shift[:, None, None, :]

        y = F.grid_sample(x, grid, align_corners=self.align_corners)
        y = (y.squeeze(-1) * feat).sum(1).view(N, outdims)

        if self.bias is not None:
            y = y + bias
        return y

    def regularizer(self):
        if self.shifter is None:
            out = 0
        else:
            out = self.shifter(self.regvalplaceholder).abs().sum()*10
        # enforce the shifter to have 0 shift at 0,0 in
        feat = self.features
        out = out + self.hparams.gamma_l2 * feat.pow(2).mean().sqrt() + self.hparams.gamma_l1 * feat.abs().mean()
        return out

    def __repr__(self):
        """
        returns a string with setup of this model
        """
        c, w, h = self.in_shape
        r = self.gauss_type + ' '
        r += self.__class__.__name__ + " (" + "{} x {} x {}".format(c, w, h) + " -> " + str(self.outdims) + ")"
        if self.bias is not None:
            r += " with bias"
        if self.shifter is not None:
            r += " with shifter"

        for ch in self.children():
            r += "  -> " + ch.__repr__() + "\n"
        return r
コード例 #6
0
class SafetyNet(nn.Module):
    def __init__(self,
                 nKnapsackCategories,
                 nThresholds,
                 starting_thresholds,
                 nineq=1,
                 neq=0,
                 eps=1e-8,
                 cancel_rate_target=.05,
                 cancel_rate_evaluation=.05,
                 accept_rate_target=.75,
                 accept_rate_evaluation=.75,
                 cancel_initializer=.02,
                 inventory_initializer=3,
                 cancel_coef_initializer=-.2,
                 cancel_intercept_initializer=.3,
                 price_initializer=1,
                 parametric_knapsack=False,
                 knapsack_type=None):
        super().__init__()
        self.nKnapsackCategories = nKnapsackCategories
        self.nThresholds = nThresholds
        #self.nBatch = nBatch
        self.nineq = nineq
        self.neq = neq
        self.eps = eps
        self.cancel_rate_evaluation = cancel_rate_evaluation
        self.accept_rate_evaluation = accept_rate_evaluation
        self.benchmark_thresholds = Variable(starting_thresholds)
        #self.accept_rate_original=Parameter(accept_rate*torch.ones(1))
        #self.cancel_rate_original=Parameter(cancel_rate*torch.ones(1))
        #self.cancel_rate= self.cancel_rate_original*1.0
        #self.accept_rate=self.accept_rate_original*1.0
        self.accept_rate_param = Parameter(accept_rate_target * torch.ones(1))
        self.cancel_rate_param = Parameter(cancel_rate_target * torch.ones(1))
        self.inventory_initializer = inventory_initializer
        self.parametric_knapsack = parametric_knapsack
        self.h = Variable(torch.ones(self.nineq))
        ##Add matrix to make all variables >=0
        self.PosValMatrix = -1 * Variable(
            torch.eye(self.nKnapsackCategories * self.nThresholds))
        self.PosValVector = Variable(
            torch.zeros(self.nKnapsackCategories * self.nThresholds))

        #Equality constraints. These will be the constraints to choose one variable per category
        ##These will be Variables as they are not something that is estimated by the model
        A = torch.zeros(self.nKnapsackCategories,
                        self.nKnapsackCategories * self.nThresholds)
        for row in range(self.nKnapsackCategories):
            A[row][self.nThresholds * row:self.nThresholds * (row + 1)] = 1
        self.A = Variable(A)
        self.b = Variable(torch.ones(self.nKnapsackCategories))

        self.Q_zeros = Variable(
            torch.zeros(nKnapsackCategories * nThresholds,
                        nKnapsackCategories * nThresholds))
        #Initialize thresholds
        self.thresholds = Variable(torch.arange(0, self.nThresholds))
        #Initialize cancel and revenue parameters
        if self.parametric_knapsack:
            self.thresholds_raw_matrix = Variable(starting_thresholds)
            #self.cancel_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5)*cancel_initializer)
            #self.cancel_lam = Parameter(torch.ones(self.nKnapsackCategories)*cancel_initializer)
            #self.cancel_spread = Variable(torch.ones(self.nKnapsackCategories))
            #self.revenue_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5))
            #self.revenue_lam = Parameter(torch.ones(self.nKnapsackCategories))
            #self.revenue_spread = Variable(torch.ones(self.nKnapsackCategories))
        else:
            #self.thresholds_raw_matrix = Parameter(torch.ones(self.nKnapsackCategories,self.nThresholds)*(1.0/self.nThresholds))
            self.thresholds_raw_matrix = Parameter(starting_thresholds)
        self.thresholds_raw_matrix_norm = torch.div(
            self.thresholds_raw_matrix,
            torch.sum(self.thresholds_raw_matrix,
                      dim=1).unsqueeze(1).expand_as(
                          self.thresholds_raw_matrix))

        #Inventory distribution parameters
        self.inventory_lam_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * inventory_initializer)
        #Cancel distribution parameters
        self.cancel_coef_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * cancel_coef_initializer)
        self.cancel_intercept_opt = Parameter(
            torch.ones(self.nKnapsackCategories) *
            cancel_intercept_initializer)
        self.prices_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * price_initializer)
        self.demand_distribution_opt = Parameter(
            torch.ones(self.nKnapsackCategories) *
            (1.0 / self.nKnapsackCategories))

        self.inventory_lam_est = Parameter(
            torch.ones(self.nKnapsackCategories) * inventory_initializer)
        #Cancel distribution parameters
        self.cancel_coef_est = Parameter(
            torch.ones(self.nKnapsackCategories) * cancel_coef_initializer)
        self.cancel_intercept_est = Parameter(
            torch.ones(self.nKnapsackCategories) *
            cancel_intercept_initializer)
        self.prices_est = Parameter(
            torch.ones(self.nKnapsackCategories) * price_initializer)
        self.demand_distribution_est = Parameter(
            torch.ones(self.nKnapsackCategories) *
            (1.0 / self.nKnapsackCategories))

    def normalize_thresholds(self):
        self.thresholds_raw_matrix.data.clamp_(min=self.eps,
                                               max=1.0 - self.eps)
        param_sums = torch.ger(
            self.thresholds_raw_matrix.sum(dim=1).squeeze(),
            Variable(torch.ones(self.nThresholds)))
        self.thresholds_raw_matrix.data.div_(param_sums.data)

    def normalize_demand_params(self):
        self.demand_distribution_est.data.clamp_(min=self.eps)
        self.demand_distribution_est.data.div_(
            self.demand_distribution_est.data.sum())
        self.demand_distribution_opt.data.clamp_(min=self.eps)
        self.demand_distribution_opt.data.div_(
            self.demand_distribution_opt.data.sum())

    def forward(self, category, inv_count, price, cancel,
                collection_thresholds):
        #print("collection_thresholds",collection_thresholds)
        self.lp_infeasible = 0
        self.cancel_coef_neg_est = self.cancel_coef_est.clamp(max=0)
        self.cancel_coef_neg_opt = self.cancel_coef_opt.clamp(max=0)
        self.nBatch = category.size(0)
        #x = x.view(nBatch, -1)

        #We want to compute everything we can without thresholds first. This will allow us to use our learned parameters to feed the LP
        self.inventory_distribution_raw_est = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_est, self.thresholds) + self.eps
        #self.inventory_distribution_norm_est = normalize_JK(self.inventory_distribution_raw_est,dim=1)
        self.inventory_distribution_batch_by_threshold_est = torch.mm(
            category, self.inventory_distribution_raw_est) + self.eps

        self.inventory_distribution_raw_opt = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_opt, self.thresholds) + self.eps
        #self.inventory_distribution_norm_opt = normalize_JK(self.inventory_distribution_raw_opt,dim=1)
        self.inventory_distribution_batch_by_threshold_opt = torch.mm(
            category, self.inventory_distribution_raw_opt) + self.eps

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_est = cancel_rate_belief_cXt(
            self.cancel_coef_neg_est, self.cancel_intercept_est,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_est = 1 - self.belief_cancel_rate_cXt_est
        price_cXt_est = self.prices_est.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_opt = cancel_rate_belief_cXt(
            self.cancel_coef_neg_opt, self.cancel_intercept_opt,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_opt = 1 - self.belief_cancel_rate_cXt_opt
        price_cXt_opt = self.prices_opt.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        self.belief_total_demand_cXt_est = self.inventory_distribution_raw_est * (
            self.demand_distribution_est.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_est = torch.sum(
            self.belief_total_demand_cXt_est, dim=1)

        self.belief_total_demand_cXt_opt = self.inventory_distribution_raw_opt * (
            self.demand_distribution_opt.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_opt = torch.sum(
            self.belief_total_demand_cXt_opt, dim=1)

        if self.parametric_knapsack:

            self.belief_total_demand_opt = torch.sum(
                self.belief_total_demand_cXt_opt)
            self.belief_total_cancels_cXt_opt = self.belief_cancel_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.belief_total_fills_cXt_opt = belief_fill_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.knapsack_cancels_matrix = torch.div(
                torch.sum(self.belief_total_cancels_cXt_opt, dim=1).expand_as(
                    self.belief_total_cancels_cXt_opt) -
                torch.cumsum(self.belief_total_cancels_cXt_opt, dim=1) +
                self.belief_total_cancels_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_fills_matrix = torch.div(
                torch.sum(self.belief_total_fills_cXt_opt, dim=1).expand_as(
                    self.belief_total_fills_cXt_opt) -
                torch.cumsum(self.belief_total_fills_cXt_opt, dim=1) +
                self.belief_total_fills_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_revenues_matrix = self.knapsack_fills_matrix * price_cXt_opt
            self.knapsack_cancels = self.knapsack_cancels_matrix.view(1, -1)
            self.knapsack_fills = self.knapsack_fills_matrix.view(1, -1)
            self.knapsack_revenues = self.knapsack_revenues_matrix.view(-1)
            Q = self.Q_zeros + self.eps * Variable(
                torch.eye(self.nKnapsackCategories * self.nThresholds))
            self.inequalityMatrix = torch.cat(
                (self.knapsack_cancels, -1 * self.knapsack_fills,
                 self.PosValMatrix))
            self.knapsack_cancels_RHS = torch.sum(
                self.knapsack_cancels_matrix * self.benchmark_thresholds)
            self.knapsack_fills_RHS = torch.sum(self.knapsack_fills_matrix *
                                                self.benchmark_thresholds)
            #self.inequalityVector = torch.cat((self.cancel_rate_param*self.h,-1*self.accept_rate_param*self.h,self.PosValVector))
            self.inequalityVector = torch.cat(
                (self.knapsack_cancels_RHS * self.h,
                 -1 * self.knapsack_fills_RHS * self.h, self.PosValVector))
            try:
                thresholds_raw = QPFunctionJK(verbose=1)(
                    Q, -1 * self.knapsack_revenues, self.inequalityMatrix,
                    self.inequalityVector, self.A, self.b)
                self.thresholds_raw_matrix = thresholds_raw.view(
                    self.nKnapsackCategories, -1)
                #self.accept_rate=1.0*self.accept_rate_original
                #self.cancel_rate=1.0*self.cancel_rate_original
            except AssertionError:
                print("Error solving LP, likely infeasible")
                self.lp_infeasible = 1
                #print("New Accept and Cancel Rates:",self.accept_rate,self.cancel_rate)
            self.thresholds_raw_matrix = F.relu(
                self.thresholds_raw_matrix) + self.eps
        self.thresholds_raw_matrix_norm = normalize_JK(
            self.thresholds_raw_matrix, dim=1)
        #This cXt matrix shows the probability of accepting an order under the learned thresholds, obtained either through direct optimization or through solving an LP
        accept_probability_cXt = torch.cumsum(
            self.thresholds_raw_matrix_norm, dim=1
        )  #this gives the accept probability by cXt under parameterized thresholds

        #category is BxC matrix, so summing across dim 0 gets the number of accepted orders per category
        accept_probability_collection_bXt = torch.cumsum(collection_thresholds,
                                                         dim=1)
        reject_probability_collection_bXt = 1 - accept_probability_collection_bXt
        accept_percent_collection_bXt = accept_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est
        accept_percent_collection_b_vector = torch.sum(
            accept_percent_collection_bXt, dim=1
        ).squeeze(
        )  #This is the believed acceptance rate of general orders of the categories corresponding with the batch under the collection thresholds
        reject_percent_collection_b_vector = 1 - accept_percent_collection_b_vector
        self.batch_total_demand_b_vector = (
            1 / accept_percent_collection_b_vector)  #.clamp(min=0,max=100)

        #new to v37
        reject_percent_collection_expanded_bXt = reject_percent_collection_b_vector.unsqueeze(
            1).expand(self.nBatch, self.nThresholds)
        self.truncated_orders_distribution_bXt = torch.div(
            reject_probability_collection_bXt *
            self.inventory_distribution_batch_by_threshold_est,
            reject_percent_collection_expanded_bXt + self.eps)
        truncated_demand_b_vector = self.batch_total_demand_b_vector - 1  #self.belief_total_demand_cXt
        truncated_demand_bXt = truncated_demand_b_vector.unsqueeze(1).expand(
            self.nBatch,
            self.nThresholds) * self.truncated_orders_distribution_bXt
        batch_total_demand_bXt = truncated_demand_bXt + inv_count
        self.batch_total_demand_cXt = torch.mm(category.t(),
                                               batch_total_demand_bXt)
        batch_total_demand_c_vector = torch.sum(self.batch_total_demand_cXt,
                                                dim=1)
        batch_zero_demand_c_vector = 1 - batch_total_demand_c_vector.ge(0)
        #batch_supplement_demand = torch.masked_select(belief_total_demand_c_vector_est,batch_zero_demand_c_vector)
        self.estimated_batch_total_demand = torch.sum(
            self.batch_total_demand_b_vector
        )  #+torch.sum(batch_supplement_demand)

        #Now we want to see how accurate our inventory distributions are for the batch
        accept_probability_batch_by_threshold = CumSumNoGrad(
            verbose=-1)(collection_thresholds) + self.eps
        self.inventory_distribution_batch_by_thresholds = torch.mm(
            category, self.inventory_distribution_raw_est)
        arrival_probability_batch_by_threshold_unnormed = self.inventory_distribution_batch_by_thresholds * accept_probability_batch_by_threshold
        arrival_probability_batch_by_threshold = torch.div(
            arrival_probability_batch_by_threshold_unnormed,
            torch.sum(arrival_probability_batch_by_threshold_unnormed,
                      dim=1).unsqueeze(1).expand_as(
                          arrival_probability_batch_by_threshold_unnormed))
        log_arrival_prob = torch.log(arrival_probability_batch_by_threshold +
                                     self.eps)

        #Like we do for inventory, we want to measure the accuracy of our cancel params for the batch
        self.belief_cancel_rate_bXt = torch.mm(category,
                                               self.belief_cancel_rate_cXt_est)
        belief_fill_rate_bXt = 1 - self.belief_cancel_rate_bXt
        self.belief_cancel_rate_b_vector = torch.sum(
            self.belief_cancel_rate_bXt * inv_count, dim=1).squeeze()
        belief_fill_rate_b_vector = 1 - self.belief_cancel_rate_b_vector
        log_cancel_prob = torch.log(
            torch.cat((belief_fill_rate_b_vector.unsqueeze(1),
                       self.belief_cancel_rate_b_vector.unsqueeze(1)), 1) +
            self.eps)

        self.belief_category_dist_bXc = self.demand_distribution_est.unsqueeze(
            0).expand(self.nBatch, self.nKnapsackCategories)
        log_category_prob = torch.log(self.belief_category_dist_bXc + self.eps)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation

        accept_probability_using_threshold_params_bXt = torch.mm(
            category, accept_probability_cXt)
        truncated_accept_estimate = truncated_demand_bXt * accept_probability_using_threshold_params_bXt  #This is the number of truncated orders we expect to accept (using param thresholds) at each inventory level corresponding to each order in the batch
        truncated_cancel_estimate = truncated_accept_estimate * self.belief_cancel_rate_bXt
        truncated_fill_estimate = truncated_accept_estimate * belief_fill_rate_bXt
        truncated_revenue_estimate = truncated_fill_estimate * (
            price.unsqueeze(1).expand(self.nBatch, self.nThresholds))
        truncated_revenue_estimate_sum = torch.sum(truncated_revenue_estimate)
        self.truncated_cancel_estimate_sum = torch.sum(
            truncated_cancel_estimate)
        truncated_fill_estimate_sum = torch.sum(truncated_fill_estimate)
        self.truncated_accept_estimate_sum = torch.sum(
            truncated_accept_estimate)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation
        fill = 1 - cancel
        batch_cancel_bXt = cancel.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_fill_bXt = fill.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_cancel_b_vector = torch.sum(batch_cancel_bXt, dim=1).squeeze()
        batch_fill_b_vector = torch.sum(batch_fill_bXt, dim=1).squeeze()
        batch_accept_b_vector = torch.sum(
            inv_count * accept_probability_using_threshold_params_bXt,
            dim=1).squeeze()
        #print("sanity check",batch_accept_b_vector, batch_fill_b_vector+batch_cancel_b_vector)
        #print("sanity check 2", torch.sum(batch_accept_b_vector), torch.sum(batch_fill_b_vector+batch_cancel_b_vector))

        batch_revenue_b_vector = price * batch_fill_b_vector
        self.batch_fill_sum = torch.sum(batch_fill_b_vector, dim=0)
        self.batch_revenue_sum = torch.sum(batch_revenue_b_vector, dim=0)
        self.batch_cancel_sum = torch.sum(batch_cancel_b_vector, dim=0)
        self.batch_accept_sum = torch.sum(batch_accept_b_vector, dim=0)

        new_objective_loss = -(1.0 / 50000) * (truncated_revenue_estimate_sum +
                                               self.batch_revenue_sum)
        new_cancel_constraint_loss = self.truncated_cancel_estimate_sum + self.batch_cancel_sum - (
            self.truncated_accept_estimate_sum +
            self.batch_accept_sum) * self.cancel_rate_evaluation
        new_accept_constraint_loss = (1.0 / 7.0) * (
            (self.truncated_accept_estimate_sum + self.batch_accept_sum) *
            self.accept_rate_evaluation - truncated_fill_estimate_sum -
            self.batch_fill_sum)
        #new_cancel_constraint_loss = truncated_cancel_estimate_sum+self.batch_cancel_sum-self.estimated_batch_total_demand*self.cancel_rate_param
        #new_accept_constraint_loss = (1.0/7.0)*(self.estimated_batch_total_demand*self.accept_rate_param-truncated_fill_estimate_sum-self.batch_fill_sum)

        observed_cancel_constraint_loss = self.batch_cancel_sum - (
            self.batch_accept_sum) * self.cancel_rate_evaluation
        observed_accept_constraint_loss = (
            1.0 / 7.0) * (self.batch_accept_sum * self.accept_rate_evaluation -
                          self.batch_fill_sum)

        return new_objective_loss, new_cancel_constraint_loss, new_accept_constraint_loss, arrival_probability_batch_by_threshold, log_arrival_prob, log_cancel_prob, log_category_prob, self.estimated_batch_total_demand, observed_cancel_constraint_loss, observed_accept_constraint_loss, self.lp_infeasible