コード例 #1
0
    def do_lipschitz_projection(self):
        """
        Perform the Lipschitz projection step by solving the QP
        """

        with torch.no_grad():
            if self.QP == "qpth":
                # qpth library
                proj_coefficients = QPFunction(verbose=False)(nn.Parameter(self.Q), -2.0*self.coefficients, nn.Parameter(self.G), nn.Parameter(self.h), nn.Parameter(self.e), nn.Parameter(self.e))
                self.coefficients_vect_.data = proj_coefficients.view(-1)

            elif self.QP == "cvxpy":
                # cvxpylayers library
                """
                # row_wise verification 
                proj_coefficients = torch.empty(self.coefficients.data.shape)
                for i in range(self.coefficients.data.shape[0]):
                    proj_coefficient, = self.qp(-2.0*self.coefficients.data[i, :])
                    proj_coefficients[i, :] = proj_coefficient
                self.coefficients_vect_.data = proj_coefficients.view(-1)
                """
                proj_coefficients, = self.qp(-2.0 * self.coefficients.data)
                self.coefficients_vect_.data = proj_coefficients.view(-1)
コード例 #2
0
ファイル: test1_1mod.py プロジェクト: lopa23/flim_optcrf
    def forward(self, x, Q, p, G, h, m):
        print("Cuda current device", torch.cuda.current_device())
        nBatch = x.size(0)

        if (m > 1):
            p = p.float().t()
        else:
            p = p.float()

        G = G.float()  #.cuda()
        Q = Q.float()
        if (m >= 2):
            Q = Q.unsqueeze(0)
        h = h.float()
        #print(Q.size(),p.size(),G.size(),h.size())

        e = Variable(torch.Tensor(), requires_grad=True)

        x = QPFunction(verbose=True)(Q, p, G, h, e, e)  #.cuda()

        x = x.view(10, -1)  ##this was not needed earlier

        return F.log_softmax(x, dim=1)
コード例 #3
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            verb_indicator: torch.Tensor,
            sentence_end: torch.LongTensor,
            metadata: List[Any],
            tags: torch.LongTensor = None,
            offsets: torch.LongTensor = None):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containing the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """

        if isinstance(self.bert_model,
                      PretrainedTransformerMismatchedEmbedder):
            encoder_inputs = tokens["tokens"]
            if self.bert_config.type_vocab_size > 1:
                encoder_inputs["type_ids"] = verb_indicator
            encoded_text = self.bert_model(**encoder_inputs)
            batch_size = encoded_text.shape[0]
            if self.bert_config.type_vocab_size == 1:
                verb_embeddings = encoded_text[
                    torch.arange(batch_size).to(encoded_text.device),
                    verb_indicator.argmax(1), :]
                verb_embeddings = torch.where(
                    (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                        1, verb_embeddings.shape[-1]), verb_embeddings,
                    torch.zeros_like(verb_embeddings))
                encoded_text = torch.cat(
                    (encoded_text, verb_embeddings.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=2)
            mask = tokens["tokens"]["mask"]
            index = mask.sum(1).argmax().item()
            # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item())
            # print(tokens["tokens"]["token_ids"][index,:])
        else:
            mask = get_text_field_mask(tokens)
            bert_embeddings, _ = self.bert_model(
                input_ids=util.get_token_ids_from_text_field_tensors(tokens),
                # token_type_ids=verb_indicator,
                attention_mask=mask,
            )

            batch_size, _ = mask.size()
            embedded_text_input = self.embedding_dropout(bert_embeddings)
            # Restrict to sentence part
            sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat(
                batch_size, 1).to(mask.device) <
                             sentence_end.unsqueeze(1).repeat(
                                 1, mask.shape[1])).long()
            cutoff = sentence_end.max().item()
            if self._encoder is None:
                encoded_text = embedded_text_input
                mask = sentence_mask[:, :cutoff].contiguous()
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
            else:
                predicate_embeddings = self.predicate_embedding(verb_indicator)
                encoder_inputs = torch.cat(
                    (embedded_text_input, predicate_embeddings), dim=-1)
                encoded_text = self._encoder(encoder_inputs,
                                             mask=sentence_mask.bool())
                # print(verb_indicator)
                predicate_index = (verb_indicator * torch.arange(
                    start=verb_indicator.shape[-1] - 1, end=-1,
                    step=-1).to(mask.device).unsqueeze(0).repeat(
                        batch_size, 1)).argmax(1)
                # print(predicate_index)
                predicate_hidden = encoded_text[
                    torch.arange(batch_size).to(mask.device), predicate_index]
                predicate_exists, _ = verb_indicator.max(1)
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
                mask = sentence_mask[:, :cutoff].contiguous()
                predicate_exists = predicate_exists.unsqueeze(1).repeat(
                    1, encoded_text.shape[-1])
                predicate_hidden = torch.where(
                    predicate_exists > 0, predicate_hidden,
                    torch.zeros_like(predicate_hidden))
                encoded_text = torch.cat(
                    (encoded_text, predicate_hidden.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=-1)

        sequence_length = encoded_text.shape[1]
        logits = self.tag_projection_layer(encoded_text)
        # print(mask, logits)
        if self._lp and sequence_length <= 100:
            eps = 1e-4
            Q = eps * torch.eye(
                sequence_length * self.num_classes,
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            p = logits.view(batch_size, -1)
            G = -1 * torch.eye(
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            h = torch.zeros_like(p)
            A = torch.arange(sequence_length *
                             self.num_classes).unsqueeze(0).repeat(
                                 sequence_length, 1)
            A2 = torch.arange(sequence_length).unsqueeze(1).repeat(
                1, sequence_length * self.num_classes) * self.num_classes
            A = torch.where((A >= A2) & (A < A2 + self.num_classes),
                            torch.ones_like(A), torch.zeros_like(A))
            A = A.unsqueeze(0).repeat(batch_size, 1,
                                      1).to(logits.device).float()
            b = torch.ones_like(A[:, :, 0])
            probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()),
                                 torch.autograd.Variable(torch.Tensor()), A, b)
            probs = probs.view(batch_size, sequence_length, self.num_classes)
            """logits_shape = logits.shape
            logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000)
            max_sequence_length = min([l for l in self.lengths if l >= sequence_length])
            if max_sequence_length > logits_shape[1]:
                logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1)
            lp_layer = self._layer_list[self.length_map[max_sequence_length]]
            probs, = lp_layer(logits)
            print(torch.isnan(probs).any())
            if max_sequence_length > logits_shape[1]:
                probs = probs[:,:logits_shape[1],:]"""
            logits = (torch.nn.functional.relu(probs) + 1e-4).log()
        if self._lpsmap:
            if self._lpsmap_core_only:
                all_logits = logits
            else:
                all_logits = torch.cat((logits, 0.5 * torch.ones(
                    (batch_size, 1, logits.shape[-1])).to(logits.device)),
                                       dim=1)
            probs = []
            for i in range(batch_size):
                if self.constrain_crf_decoding:
                    unaries = logits[i, :, :].view(-1).cpu()
                    additionals = self.crf.transitions.view(-1).repeat(
                        sequence_length) + 10000 * (
                            self.crf._constraint_mask[:-2, :-2] -
                            1).view(-1).repeat(sequence_length)
                    start_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-2, :-2] - 1)
                    end_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-1, :-2] - 1)
                    additionals = torch.cat(
                        (additionals, start_transitions, end_transitions),
                        dim=0).cpu()
                    fg = TorchFactorGraph()
                    x = fg.variable_from(unaries)
                    f = PFactorSequence()

                    f.initialize(
                        [self.num_classes for _ in range(sequence_length)])
                    factor = TorchOtherFactor(f, x, additionals)
                    fg.add(factor)
                    # add budget constraint for each state
                    for state in self._core_roles:
                        vars_state = x[state::self.num_classes]
                        fg.add(AtMostOne(vars_state))
                    # solve SparseMAP
                    fg.solve(max_iter=200)
                    probs.append(
                        unaries.to(logits.device).view(sequence_length,
                                                       self.num_classes))
                else:
                    fg = TorchFactorGraph()
                    x = fg.variable_from(all_logits[i, :, :].cpu())
                    for j in range(sequence_length):
                        fg.add(Xor(x[j, :]))
                    for j in self._core_roles:
                        fg.add(AtMostOne(x[:sequence_length, j]))
                    if not self._lpsmap_core_only:
                        full_sequence = list(range(sequence_length))
                        base_roles = set([
                            second
                            for (_, second) in self._r_roles + self._c_roles
                        ])
                        """for (r_role, base_role) in self._r_roles+self._c_roles:
                            for j in range(sequence_length):
                                fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))"""
                        for base_role in base_roles:
                            fg.add(OrOut(x[:, base_role]))
                        for (r_role,
                             base_role) in self._r_roles + self._c_roles:
                            fg.add(OrOut(x[:, r_role]))
                            fg.add(
                                Or(x[[sequence_length, sequence_length],
                                     [r_role, base_role]],
                                   negated=[True, False]))
                    max_iter = 100
                    if not self._lpsmap_core_only:
                        max_iter = min(max_iter, 400)
                    elif (not self.training) and not self._val_inference:
                        max_iter = min(max_iter, 200)
                    fg.solve(max_iter=max_iter)
                    probs.append(x.value[:sequence_length, :].contiguous().to(
                        logits.device))
            class_probabilities = torch.stack(probs)
            # class_probabilities = self.lpsmap(logits)
            max_seq_length = 200
            # if self.lpsmap is None:
            """with torch.no_grad():
                # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))])
                max_iter = 1000
                constraint_types = ["xor", "budget"]
                constraint_dims = ["row", "col"]
                constraint_sets = [list(range(sequence_length)), self._core_roles]
                class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter)
                # if max_seq_length > sequence_length:
                #     logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1)
                # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)"""
            # logits = (class_probabilities+1e-4).log()
        else:
            reshaped_log_probs = logits.view(-1, self.num_classes)
            class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
                [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min())
            if self._lpsmap:
                loss = LpsmapLoss.apply(logits, class_probabilities, tags,
                                        mask)
                # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities))
                # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum()
            elif self.constrain_crf_decoding:
                loss = -self.crf(logits, tags, mask)
            else:
                loss = sequence_cross_entropy_with_logits(
                    logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                if self.constrain_crf_decoding and not self._lpsmap:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format([
                            self.vocab.get_token_from_index(
                                tag, namespace=self._label_namespace)
                            for tag in seq
                        ]) for (seq, _) in self.crf.viterbi_tags(logits, mask)
                    ]
                else:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format(tags)
                        for tags in batch_bio_predicted_tags
                    ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print(batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
            output_dict["gold_tags"] = [x["gold_tags"] for x in metadata]
        return output_dict
コード例 #4
0
    def forward(self, category, inv_count, price, cancel,
                collection_thresholds):
        #print("collection_thresholds",collection_thresholds)
        self.lp_infeasible = 0
        self.cancel_coef_neg_est = self.cancel_coef_est.clamp(max=0)
        self.cancel_coef_neg_opt = self.cancel_coef_opt.clamp(max=0)
        self.nBatch = category.size(0)
        #x = x.view(nBatch, -1)

        #We want to compute everything we can without thresholds first. This will allow us to use our learned parameters to feed the LP
        self.inventory_distribution_raw_est = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_est, self.thresholds) + self.eps
        #self.inventory_distribution_norm_est = normalize_JK(self.inventory_distribution_raw_est,dim=1)
        self.inventory_distribution_batch_by_threshold_est = torch.mm(
            category, self.inventory_distribution_raw_est) + self.eps

        self.inventory_distribution_raw_opt = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_opt, self.thresholds) + self.eps
        #self.inventory_distribution_norm_opt = normalize_JK(self.inventory_distribution_raw_opt,dim=1)
        self.inventory_distribution_batch_by_threshold_opt = torch.mm(
            category, self.inventory_distribution_raw_opt) + self.eps

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_est = cancel_rate_belief_cXt(
            self.cancel_coef_neg_est, self.cancel_intercept_est,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_est = 1 - self.belief_cancel_rate_cXt_est
        price_cXt_est = self.prices_est.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_opt = cancel_rate_belief_cXt(
            self.cancel_coef_neg_opt, self.cancel_intercept_opt,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_opt = 1 - self.belief_cancel_rate_cXt_opt
        price_cXt_opt = self.prices_opt.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        self.belief_total_demand_cXt_est = self.inventory_distribution_raw_est * (
            self.demand_distribution_est.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_est = torch.sum(
            self.belief_total_demand_cXt_est, dim=1)

        self.belief_total_demand_cXt_opt = self.inventory_distribution_raw_opt * (
            self.demand_distribution_opt.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_opt = torch.sum(
            self.belief_total_demand_cXt_opt, dim=1)

        if self.parametric_knapsack:

            self.belief_total_demand_opt = torch.sum(
                self.belief_total_demand_cXt_opt)
            self.belief_total_cancels_cXt_opt = self.belief_cancel_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.belief_total_fills_cXt_opt = belief_fill_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.knapsack_cancels_matrix = torch.div(
                torch.sum(self.belief_total_cancels_cXt_opt, dim=1).expand_as(
                    self.belief_total_cancels_cXt_opt) -
                torch.cumsum(self.belief_total_cancels_cXt_opt, dim=1) +
                self.belief_total_cancels_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_fills_matrix = torch.div(
                torch.sum(self.belief_total_fills_cXt_opt, dim=1).expand_as(
                    self.belief_total_fills_cXt_opt) -
                torch.cumsum(self.belief_total_fills_cXt_opt, dim=1) +
                self.belief_total_fills_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_revenues_matrix = self.knapsack_fills_matrix * price_cXt_opt
            self.knapsack_cancels = self.knapsack_cancels_matrix.view(1, -1)
            self.knapsack_fills = self.knapsack_fills_matrix.view(1, -1)
            self.knapsack_revenues = self.knapsack_revenues_matrix.view(-1)
            Q = self.Q_zeros + self.eps * Variable(
                torch.eye(self.nKnapsackCategories * self.nThresholds))
            self.inequalityMatrix = torch.cat(
                (self.knapsack_cancels, -1 * self.knapsack_fills,
                 self.PosValMatrix))
            self.knapsack_cancels_RHS = torch.sum(
                self.knapsack_cancels_matrix * self.benchmark_thresholds)
            self.knapsack_fills_RHS = torch.sum(self.knapsack_fills_matrix *
                                                self.benchmark_thresholds)
            #self.inequalityVector = torch.cat((self.cancel_rate_param*self.h,-1*self.accept_rate_param*self.h,self.PosValVector))
            self.inequalityVector = torch.cat(
                (self.knapsack_cancels_RHS * self.h,
                 -1 * self.knapsack_fills_RHS * self.h, self.PosValVector))
            try:
                thresholds_raw = QPFunctionJK(verbose=1)(
                    Q, -1 * self.knapsack_revenues, self.inequalityMatrix,
                    self.inequalityVector, self.A, self.b)
                self.thresholds_raw_matrix = thresholds_raw.view(
                    self.nKnapsackCategories, -1)
                #self.accept_rate=1.0*self.accept_rate_original
                #self.cancel_rate=1.0*self.cancel_rate_original
            except AssertionError:
                print("Error solving LP, likely infeasible")
                self.lp_infeasible = 1
                #print("New Accept and Cancel Rates:",self.accept_rate,self.cancel_rate)
            self.thresholds_raw_matrix = F.relu(
                self.thresholds_raw_matrix) + self.eps
        self.thresholds_raw_matrix_norm = normalize_JK(
            self.thresholds_raw_matrix, dim=1)
        #This cXt matrix shows the probability of accepting an order under the learned thresholds, obtained either through direct optimization or through solving an LP
        accept_probability_cXt = torch.cumsum(
            self.thresholds_raw_matrix_norm, dim=1
        )  #this gives the accept probability by cXt under parameterized thresholds

        #category is BxC matrix, so summing across dim 0 gets the number of accepted orders per category
        accept_probability_collection_bXt = torch.cumsum(collection_thresholds,
                                                         dim=1)
        reject_probability_collection_bXt = 1 - accept_probability_collection_bXt
        accept_percent_collection_bXt = accept_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est
        accept_percent_collection_b_vector = torch.sum(
            accept_percent_collection_bXt, dim=1
        ).squeeze(
        )  #This is the believed acceptance rate of general orders of the categories corresponding with the batch under the collection thresholds
        reject_percent_collection_b_vector = 1 - accept_percent_collection_b_vector
        self.batch_total_demand_b_vector = (
            1 / accept_percent_collection_b_vector)  #.clamp(min=0,max=100)

        #new to v37
        reject_percent_collection_expanded_bXt = reject_percent_collection_b_vector.unsqueeze(
            1).expand(self.nBatch, self.nThresholds)
        self.truncated_orders_distribution_bXt = torch.div(
            reject_probability_collection_bXt *
            self.inventory_distribution_batch_by_threshold_est,
            reject_percent_collection_expanded_bXt + self.eps)
        truncated_demand_b_vector = self.batch_total_demand_b_vector - 1  #self.belief_total_demand_cXt
        truncated_demand_bXt = truncated_demand_b_vector.unsqueeze(1).expand(
            self.nBatch,
            self.nThresholds) * self.truncated_orders_distribution_bXt
        batch_total_demand_bXt = truncated_demand_bXt + inv_count
        self.batch_total_demand_cXt = torch.mm(category.t(),
                                               batch_total_demand_bXt)
        batch_total_demand_c_vector = torch.sum(self.batch_total_demand_cXt,
                                                dim=1)
        batch_zero_demand_c_vector = 1 - batch_total_demand_c_vector.ge(0)
        #batch_supplement_demand = torch.masked_select(belief_total_demand_c_vector_est,batch_zero_demand_c_vector)
        self.estimated_batch_total_demand = torch.sum(
            self.batch_total_demand_b_vector
        )  #+torch.sum(batch_supplement_demand)

        #Now we want to see how accurate our inventory distributions are for the batch
        accept_probability_batch_by_threshold = CumSumNoGrad(
            verbose=-1)(collection_thresholds) + self.eps
        self.inventory_distribution_batch_by_thresholds = torch.mm(
            category, self.inventory_distribution_raw_est)
        arrival_probability_batch_by_threshold_unnormed = self.inventory_distribution_batch_by_thresholds * accept_probability_batch_by_threshold
        arrival_probability_batch_by_threshold = torch.div(
            arrival_probability_batch_by_threshold_unnormed,
            torch.sum(arrival_probability_batch_by_threshold_unnormed,
                      dim=1).unsqueeze(1).expand_as(
                          arrival_probability_batch_by_threshold_unnormed))
        log_arrival_prob = torch.log(arrival_probability_batch_by_threshold +
                                     self.eps)

        #Like we do for inventory, we want to measure the accuracy of our cancel params for the batch
        self.belief_cancel_rate_bXt = torch.mm(category,
                                               self.belief_cancel_rate_cXt_est)
        belief_fill_rate_bXt = 1 - self.belief_cancel_rate_bXt
        self.belief_cancel_rate_b_vector = torch.sum(
            self.belief_cancel_rate_bXt * inv_count, dim=1).squeeze()
        belief_fill_rate_b_vector = 1 - self.belief_cancel_rate_b_vector
        log_cancel_prob = torch.log(
            torch.cat((belief_fill_rate_b_vector.unsqueeze(1),
                       self.belief_cancel_rate_b_vector.unsqueeze(1)), 1) +
            self.eps)

        self.belief_category_dist_bXc = self.demand_distribution_est.unsqueeze(
            0).expand(self.nBatch, self.nKnapsackCategories)
        log_category_prob = torch.log(self.belief_category_dist_bXc + self.eps)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation

        accept_probability_using_threshold_params_bXt = torch.mm(
            category, accept_probability_cXt)
        truncated_accept_estimate = truncated_demand_bXt * accept_probability_using_threshold_params_bXt  #This is the number of truncated orders we expect to accept (using param thresholds) at each inventory level corresponding to each order in the batch
        truncated_cancel_estimate = truncated_accept_estimate * self.belief_cancel_rate_bXt
        truncated_fill_estimate = truncated_accept_estimate * belief_fill_rate_bXt
        truncated_revenue_estimate = truncated_fill_estimate * (
            price.unsqueeze(1).expand(self.nBatch, self.nThresholds))
        truncated_revenue_estimate_sum = torch.sum(truncated_revenue_estimate)
        self.truncated_cancel_estimate_sum = torch.sum(
            truncated_cancel_estimate)
        truncated_fill_estimate_sum = torch.sum(truncated_fill_estimate)
        self.truncated_accept_estimate_sum = torch.sum(
            truncated_accept_estimate)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation
        fill = 1 - cancel
        batch_cancel_bXt = cancel.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_fill_bXt = fill.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_cancel_b_vector = torch.sum(batch_cancel_bXt, dim=1).squeeze()
        batch_fill_b_vector = torch.sum(batch_fill_bXt, dim=1).squeeze()
        batch_accept_b_vector = torch.sum(
            inv_count * accept_probability_using_threshold_params_bXt,
            dim=1).squeeze()
        #print("sanity check",batch_accept_b_vector, batch_fill_b_vector+batch_cancel_b_vector)
        #print("sanity check 2", torch.sum(batch_accept_b_vector), torch.sum(batch_fill_b_vector+batch_cancel_b_vector))

        batch_revenue_b_vector = price * batch_fill_b_vector
        self.batch_fill_sum = torch.sum(batch_fill_b_vector, dim=0)
        self.batch_revenue_sum = torch.sum(batch_revenue_b_vector, dim=0)
        self.batch_cancel_sum = torch.sum(batch_cancel_b_vector, dim=0)
        self.batch_accept_sum = torch.sum(batch_accept_b_vector, dim=0)

        new_objective_loss = -(1.0 / 50000) * (truncated_revenue_estimate_sum +
                                               self.batch_revenue_sum)
        new_cancel_constraint_loss = self.truncated_cancel_estimate_sum + self.batch_cancel_sum - (
            self.truncated_accept_estimate_sum +
            self.batch_accept_sum) * self.cancel_rate_evaluation
        new_accept_constraint_loss = (1.0 / 7.0) * (
            (self.truncated_accept_estimate_sum + self.batch_accept_sum) *
            self.accept_rate_evaluation - truncated_fill_estimate_sum -
            self.batch_fill_sum)
        #new_cancel_constraint_loss = truncated_cancel_estimate_sum+self.batch_cancel_sum-self.estimated_batch_total_demand*self.cancel_rate_param
        #new_accept_constraint_loss = (1.0/7.0)*(self.estimated_batch_total_demand*self.accept_rate_param-truncated_fill_estimate_sum-self.batch_fill_sum)

        observed_cancel_constraint_loss = self.batch_cancel_sum - (
            self.batch_accept_sum) * self.cancel_rate_evaluation
        observed_accept_constraint_loss = (
            1.0 / 7.0) * (self.batch_accept_sum * self.accept_rate_evaluation -
                          self.batch_fill_sum)

        return new_objective_loss, new_cancel_constraint_loss, new_accept_constraint_loss, arrival_probability_batch_by_threshold, log_arrival_prob, log_cancel_prob, log_category_prob, self.estimated_batch_total_demand, observed_cancel_constraint_loss, observed_accept_constraint_loss, self.lp_infeasible
コード例 #5
0
ファイル: emd_utils.py プロジェクト: zhushaoquan/DeepEMD
def emd_inference_qpth(distance_matrix,
                       weight1,
                       weight2,
                       form='QP',
                       l2_strength=0.0001):
    """
    to use the QP solver QPTH to derive EMD (LP problem),
    one can transform the LP problem to QP,
    or omit the QP term by multiplying it with a small value,i.e. l2_strngth.
    :param distance_matrix: nbatch * element_number * element_number
    :param weight1: nbatch  * weight_number
    :param weight2: nbatch  * weight_number
    :return:
    emd distance: nbatch*1
    flow : nbatch * weight_number *weight_number

    """

    weight1 = (weight1 * weight1.shape[-1]) / weight1.sum(1).unsqueeze(1)
    weight2 = (weight2 * weight2.shape[-1]) / weight2.sum(1).unsqueeze(1)

    nbatch = distance_matrix.shape[0]
    nelement_distmatrix = distance_matrix.shape[1] * distance_matrix.shape[2]
    nelement_weight1 = weight1.shape[1]
    nelement_weight2 = weight2.shape[1]

    Q_1 = distance_matrix.view(-1, 1, nelement_distmatrix).double()

    if form == 'QP':
        # version: QTQ
        Q = torch.bmm(Q_1.transpose(2, 1), Q_1).double().cuda(
        ) + 1e-4 * torch.eye(nelement_distmatrix).double().cuda().unsqueeze(
            0).repeat(nbatch, 1, 1)  # 0.00001 *
        p = torch.zeros(nbatch, nelement_distmatrix).double().cuda()
    elif form == 'L2':
        # version: regularizer
        Q = (l2_strength * torch.eye(nelement_distmatrix).double()
             ).cuda().unsqueeze(0).repeat(nbatch, 1, 1)
        p = distance_matrix.view(nbatch, nelement_distmatrix).double()
    else:
        raise ValueError('Unkown form')

    h_1 = torch.zeros(nbatch, nelement_distmatrix).double().cuda()
    h_2 = torch.cat([weight1, weight2], 1).double()
    h = torch.cat((h_1, h_2), 1)

    G_1 = -torch.eye(nelement_distmatrix).double().cuda().unsqueeze(0).repeat(
        nbatch, 1, 1)
    G_2 = torch.zeros(
        [nbatch, nelement_weight1 + nelement_weight2,
         nelement_distmatrix]).double().cuda()
    # sum_j(xij) = si
    for i in range(nelement_weight1):
        G_2[:, i, nelement_weight2 * i:nelement_weight2 * (i + 1)] = 1
    # sum_i(xij) = dj
    for j in range(nelement_weight2):
        G_2[:, nelement_weight1 + j, j::nelement_weight2] = 1
    #xij>=0, sum_j(xij) <= si,sum_i(xij) <= dj, sum_ij(x_ij) = min(sum(si), sum(dj))
    G = torch.cat((G_1, G_2), 1)
    A = torch.ones(nbatch, 1, nelement_distmatrix).double().cuda()
    b = torch.min(torch.sum(weight1, 1), torch.sum(weight2,
                                                   1)).unsqueeze(1).double()
    flow = QPFunction(verbose=-1)(Q, p, G, h, A, b)

    emd_score = torch.sum((1 - Q_1).squeeze() * flow, 1)
    return emd_score, flow.view(-1, nelement_weight1, nelement_weight2)