def do_lipschitz_projection(self): """ Perform the Lipschitz projection step by solving the QP """ with torch.no_grad(): if self.QP == "qpth": # qpth library proj_coefficients = QPFunction(verbose=False)(nn.Parameter(self.Q), -2.0*self.coefficients, nn.Parameter(self.G), nn.Parameter(self.h), nn.Parameter(self.e), nn.Parameter(self.e)) self.coefficients_vect_.data = proj_coefficients.view(-1) elif self.QP == "cvxpy": # cvxpylayers library """ # row_wise verification proj_coefficients = torch.empty(self.coefficients.data.shape) for i in range(self.coefficients.data.shape[0]): proj_coefficient, = self.qp(-2.0*self.coefficients.data[i, :]) proj_coefficients[i, :] = proj_coefficient self.coefficients_vect_.data = proj_coefficients.view(-1) """ proj_coefficients, = self.qp(-2.0 * self.coefficients.data) self.coefficients_vect_.data = proj_coefficients.view(-1)
def forward(self, x, Q, p, G, h, m): print("Cuda current device", torch.cuda.current_device()) nBatch = x.size(0) if (m > 1): p = p.float().t() else: p = p.float() G = G.float() #.cuda() Q = Q.float() if (m >= 2): Q = Q.unsqueeze(0) h = h.float() #print(Q.size(),p.size(),G.size(),h.size()) e = Variable(torch.Tensor(), requires_grad=True) x = QPFunction(verbose=True)(Q, p, G, h, e, e) #.cuda() x = x.view(10, -1) ##this was not needed earlier return F.log_softmax(x, dim=1)
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, sentence_end: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, offsets: torch.LongTensor = None): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containing the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ if isinstance(self.bert_model, PretrainedTransformerMismatchedEmbedder): encoder_inputs = tokens["tokens"] if self.bert_config.type_vocab_size > 1: encoder_inputs["type_ids"] = verb_indicator encoded_text = self.bert_model(**encoder_inputs) batch_size = encoded_text.shape[0] if self.bert_config.type_vocab_size == 1: verb_embeddings = encoded_text[ torch.arange(batch_size).to(encoded_text.device), verb_indicator.argmax(1), :] verb_embeddings = torch.where( (verb_indicator.sum(1, keepdim=True) > 0).repeat( 1, verb_embeddings.shape[-1]), verb_embeddings, torch.zeros_like(verb_embeddings)) encoded_text = torch.cat( (encoded_text, verb_embeddings.unsqueeze(1).repeat( 1, encoded_text.shape[1], 1)), dim=2) mask = tokens["tokens"]["mask"] index = mask.sum(1).argmax().item() # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item()) # print(tokens["tokens"]["token_ids"][index,:]) else: mask = get_text_field_mask(tokens) bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) batch_size, _ = mask.size() embedded_text_input = self.embedding_dropout(bert_embeddings) # Restrict to sentence part sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat( batch_size, 1).to(mask.device) < sentence_end.unsqueeze(1).repeat( 1, mask.shape[1])).long() cutoff = sentence_end.max().item() if self._encoder is None: encoded_text = embedded_text_input mask = sentence_mask[:, :cutoff].contiguous() encoded_text = encoded_text[:, :cutoff, :] tags = tags[:, :cutoff].contiguous() else: predicate_embeddings = self.predicate_embedding(verb_indicator) encoder_inputs = torch.cat( (embedded_text_input, predicate_embeddings), dim=-1) encoded_text = self._encoder(encoder_inputs, mask=sentence_mask.bool()) # print(verb_indicator) predicate_index = (verb_indicator * torch.arange( start=verb_indicator.shape[-1] - 1, end=-1, step=-1).to(mask.device).unsqueeze(0).repeat( batch_size, 1)).argmax(1) # print(predicate_index) predicate_hidden = encoded_text[ torch.arange(batch_size).to(mask.device), predicate_index] predicate_exists, _ = verb_indicator.max(1) encoded_text = encoded_text[:, :cutoff, :] tags = tags[:, :cutoff].contiguous() mask = sentence_mask[:, :cutoff].contiguous() predicate_exists = predicate_exists.unsqueeze(1).repeat( 1, encoded_text.shape[-1]) predicate_hidden = torch.where( predicate_exists > 0, predicate_hidden, torch.zeros_like(predicate_hidden)) encoded_text = torch.cat( (encoded_text, predicate_hidden.unsqueeze(1).repeat( 1, encoded_text.shape[1], 1)), dim=-1) sequence_length = encoded_text.shape[1] logits = self.tag_projection_layer(encoded_text) # print(mask, logits) if self._lp and sequence_length <= 100: eps = 1e-4 Q = eps * torch.eye( sequence_length * self.num_classes, sequence_length * self.num_classes).unsqueeze(0).repeat( batch_size, 1, 1).to(logits.device).float() p = logits.view(batch_size, -1) G = -1 * torch.eye( sequence_length * self.num_classes).unsqueeze(0).repeat( batch_size, 1, 1).to(logits.device).float() h = torch.zeros_like(p) A = torch.arange(sequence_length * self.num_classes).unsqueeze(0).repeat( sequence_length, 1) A2 = torch.arange(sequence_length).unsqueeze(1).repeat( 1, sequence_length * self.num_classes) * self.num_classes A = torch.where((A >= A2) & (A < A2 + self.num_classes), torch.ones_like(A), torch.zeros_like(A)) A = A.unsqueeze(0).repeat(batch_size, 1, 1).to(logits.device).float() b = torch.ones_like(A[:, :, 0]) probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()), torch.autograd.Variable(torch.Tensor()), A, b) probs = probs.view(batch_size, sequence_length, self.num_classes) """logits_shape = logits.shape logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000) max_sequence_length = min([l for l in self.lengths if l >= sequence_length]) if max_sequence_length > logits_shape[1]: logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1) lp_layer = self._layer_list[self.length_map[max_sequence_length]] probs, = lp_layer(logits) print(torch.isnan(probs).any()) if max_sequence_length > logits_shape[1]: probs = probs[:,:logits_shape[1],:]""" logits = (torch.nn.functional.relu(probs) + 1e-4).log() if self._lpsmap: if self._lpsmap_core_only: all_logits = logits else: all_logits = torch.cat((logits, 0.5 * torch.ones( (batch_size, 1, logits.shape[-1])).to(logits.device)), dim=1) probs = [] for i in range(batch_size): if self.constrain_crf_decoding: unaries = logits[i, :, :].view(-1).cpu() additionals = self.crf.transitions.view(-1).repeat( sequence_length) + 10000 * ( self.crf._constraint_mask[:-2, :-2] - 1).view(-1).repeat(sequence_length) start_transitions = self.crf.start_transitions + 10000 * ( self.crf._constraint_mask[-2, :-2] - 1) end_transitions = self.crf.start_transitions + 10000 * ( self.crf._constraint_mask[-1, :-2] - 1) additionals = torch.cat( (additionals, start_transitions, end_transitions), dim=0).cpu() fg = TorchFactorGraph() x = fg.variable_from(unaries) f = PFactorSequence() f.initialize( [self.num_classes for _ in range(sequence_length)]) factor = TorchOtherFactor(f, x, additionals) fg.add(factor) # add budget constraint for each state for state in self._core_roles: vars_state = x[state::self.num_classes] fg.add(AtMostOne(vars_state)) # solve SparseMAP fg.solve(max_iter=200) probs.append( unaries.to(logits.device).view(sequence_length, self.num_classes)) else: fg = TorchFactorGraph() x = fg.variable_from(all_logits[i, :, :].cpu()) for j in range(sequence_length): fg.add(Xor(x[j, :])) for j in self._core_roles: fg.add(AtMostOne(x[:sequence_length, j])) if not self._lpsmap_core_only: full_sequence = list(range(sequence_length)) base_roles = set([ second for (_, second) in self._r_roles + self._c_roles ]) """for (r_role, base_role) in self._r_roles+self._c_roles: for j in range(sequence_length): fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))""" for base_role in base_roles: fg.add(OrOut(x[:, base_role])) for (r_role, base_role) in self._r_roles + self._c_roles: fg.add(OrOut(x[:, r_role])) fg.add( Or(x[[sequence_length, sequence_length], [r_role, base_role]], negated=[True, False])) max_iter = 100 if not self._lpsmap_core_only: max_iter = min(max_iter, 400) elif (not self.training) and not self._val_inference: max_iter = min(max_iter, 200) fg.solve(max_iter=max_iter) probs.append(x.value[:sequence_length, :].contiguous().to( logits.device)) class_probabilities = torch.stack(probs) # class_probabilities = self.lpsmap(logits) max_seq_length = 200 # if self.lpsmap is None: """with torch.no_grad(): # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))]) max_iter = 1000 constraint_types = ["xor", "budget"] constraint_dims = ["row", "col"] constraint_sets = [list(range(sequence_length)), self._core_roles] class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter) # if max_seq_length > sequence_length: # logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1) # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)""" # logits = (class_probabilities+1e-4).log() else: reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) output_dict = { "logits": logits, "class_probabilities": class_probabilities } # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min()) if self._lpsmap: loss = LpsmapLoss.apply(logits, class_probabilities, tags, mask) # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities)) # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum() elif self.constrain_crf_decoding: loss = -self.crf(logits, tags, mask) else: loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.make_output_human_readable( output_dict).pop("tags") from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) if self.constrain_crf_decoding and not self._lpsmap: batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format([ self.vocab.get_token_from_index( tag, namespace=self._label_namespace) for tag in seq ]) for (seq, _) in self.crf.viterbi_tags(logits, mask) ] else: batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] # print(batch_bio_gold_tags) batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss output_dict["gold_tags"] = [x["gold_tags"] for x in metadata] return output_dict
def forward(self, category, inv_count, price, cancel, collection_thresholds): #print("collection_thresholds",collection_thresholds) self.lp_infeasible = 0 self.cancel_coef_neg_est = self.cancel_coef_est.clamp(max=0) self.cancel_coef_neg_opt = self.cancel_coef_opt.clamp(max=0) self.nBatch = category.size(0) #x = x.view(nBatch, -1) #We want to compute everything we can without thresholds first. This will allow us to use our learned parameters to feed the LP self.inventory_distribution_raw_est = PoissonFunction( self.nKnapsackCategories, self.nThresholds, verbose=-1)( self.inventory_lam_est, self.thresholds) + self.eps #self.inventory_distribution_norm_est = normalize_JK(self.inventory_distribution_raw_est,dim=1) self.inventory_distribution_batch_by_threshold_est = torch.mm( category, self.inventory_distribution_raw_est) + self.eps self.inventory_distribution_raw_opt = PoissonFunction( self.nKnapsackCategories, self.nThresholds, verbose=-1)( self.inventory_lam_opt, self.thresholds) + self.eps #self.inventory_distribution_norm_opt = normalize_JK(self.inventory_distribution_raw_opt,dim=1) self.inventory_distribution_batch_by_threshold_opt = torch.mm( category, self.inventory_distribution_raw_opt) + self.eps ##Here we'll calculate cancel probability by inventory self.belief_cancel_rate_cXt_est = cancel_rate_belief_cXt( self.cancel_coef_neg_est, self.cancel_intercept_est, self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories, self.nThresholds)) belief_fill_rate_cXt_est = 1 - self.belief_cancel_rate_cXt_est price_cXt_est = self.prices_est.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds) ##Here we'll calculate cancel probability by inventory self.belief_cancel_rate_cXt_opt = cancel_rate_belief_cXt( self.cancel_coef_neg_opt, self.cancel_intercept_opt, self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories, self.nThresholds)) belief_fill_rate_cXt_opt = 1 - self.belief_cancel_rate_cXt_opt price_cXt_opt = self.prices_opt.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds) self.belief_total_demand_cXt_est = self.inventory_distribution_raw_est * ( self.demand_distribution_est.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds)) belief_total_demand_c_vector_est = torch.sum( self.belief_total_demand_cXt_est, dim=1) self.belief_total_demand_cXt_opt = self.inventory_distribution_raw_opt * ( self.demand_distribution_opt.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds)) belief_total_demand_c_vector_opt = torch.sum( self.belief_total_demand_cXt_opt, dim=1) if self.parametric_knapsack: self.belief_total_demand_opt = torch.sum( self.belief_total_demand_cXt_opt) self.belief_total_cancels_cXt_opt = self.belief_cancel_rate_cXt_opt * self.belief_total_demand_cXt_opt self.belief_total_fills_cXt_opt = belief_fill_rate_cXt_opt * self.belief_total_demand_cXt_opt self.knapsack_cancels_matrix = torch.div( torch.sum(self.belief_total_cancels_cXt_opt, dim=1).expand_as( self.belief_total_cancels_cXt_opt) - torch.cumsum(self.belief_total_cancels_cXt_opt, dim=1) + self.belief_total_cancels_cXt_opt, self.belief_total_demand_opt.expand(self.nKnapsackCategories, self.nThresholds)) self.knapsack_fills_matrix = torch.div( torch.sum(self.belief_total_fills_cXt_opt, dim=1).expand_as( self.belief_total_fills_cXt_opt) - torch.cumsum(self.belief_total_fills_cXt_opt, dim=1) + self.belief_total_fills_cXt_opt, self.belief_total_demand_opt.expand(self.nKnapsackCategories, self.nThresholds)) self.knapsack_revenues_matrix = self.knapsack_fills_matrix * price_cXt_opt self.knapsack_cancels = self.knapsack_cancels_matrix.view(1, -1) self.knapsack_fills = self.knapsack_fills_matrix.view(1, -1) self.knapsack_revenues = self.knapsack_revenues_matrix.view(-1) Q = self.Q_zeros + self.eps * Variable( torch.eye(self.nKnapsackCategories * self.nThresholds)) self.inequalityMatrix = torch.cat( (self.knapsack_cancels, -1 * self.knapsack_fills, self.PosValMatrix)) self.knapsack_cancels_RHS = torch.sum( self.knapsack_cancels_matrix * self.benchmark_thresholds) self.knapsack_fills_RHS = torch.sum(self.knapsack_fills_matrix * self.benchmark_thresholds) #self.inequalityVector = torch.cat((self.cancel_rate_param*self.h,-1*self.accept_rate_param*self.h,self.PosValVector)) self.inequalityVector = torch.cat( (self.knapsack_cancels_RHS * self.h, -1 * self.knapsack_fills_RHS * self.h, self.PosValVector)) try: thresholds_raw = QPFunctionJK(verbose=1)( Q, -1 * self.knapsack_revenues, self.inequalityMatrix, self.inequalityVector, self.A, self.b) self.thresholds_raw_matrix = thresholds_raw.view( self.nKnapsackCategories, -1) #self.accept_rate=1.0*self.accept_rate_original #self.cancel_rate=1.0*self.cancel_rate_original except AssertionError: print("Error solving LP, likely infeasible") self.lp_infeasible = 1 #print("New Accept and Cancel Rates:",self.accept_rate,self.cancel_rate) self.thresholds_raw_matrix = F.relu( self.thresholds_raw_matrix) + self.eps self.thresholds_raw_matrix_norm = normalize_JK( self.thresholds_raw_matrix, dim=1) #This cXt matrix shows the probability of accepting an order under the learned thresholds, obtained either through direct optimization or through solving an LP accept_probability_cXt = torch.cumsum( self.thresholds_raw_matrix_norm, dim=1 ) #this gives the accept probability by cXt under parameterized thresholds #category is BxC matrix, so summing across dim 0 gets the number of accepted orders per category accept_probability_collection_bXt = torch.cumsum(collection_thresholds, dim=1) reject_probability_collection_bXt = 1 - accept_probability_collection_bXt accept_percent_collection_bXt = accept_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est accept_percent_collection_b_vector = torch.sum( accept_percent_collection_bXt, dim=1 ).squeeze( ) #This is the believed acceptance rate of general orders of the categories corresponding with the batch under the collection thresholds reject_percent_collection_b_vector = 1 - accept_percent_collection_b_vector self.batch_total_demand_b_vector = ( 1 / accept_percent_collection_b_vector) #.clamp(min=0,max=100) #new to v37 reject_percent_collection_expanded_bXt = reject_percent_collection_b_vector.unsqueeze( 1).expand(self.nBatch, self.nThresholds) self.truncated_orders_distribution_bXt = torch.div( reject_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est, reject_percent_collection_expanded_bXt + self.eps) truncated_demand_b_vector = self.batch_total_demand_b_vector - 1 #self.belief_total_demand_cXt truncated_demand_bXt = truncated_demand_b_vector.unsqueeze(1).expand( self.nBatch, self.nThresholds) * self.truncated_orders_distribution_bXt batch_total_demand_bXt = truncated_demand_bXt + inv_count self.batch_total_demand_cXt = torch.mm(category.t(), batch_total_demand_bXt) batch_total_demand_c_vector = torch.sum(self.batch_total_demand_cXt, dim=1) batch_zero_demand_c_vector = 1 - batch_total_demand_c_vector.ge(0) #batch_supplement_demand = torch.masked_select(belief_total_demand_c_vector_est,batch_zero_demand_c_vector) self.estimated_batch_total_demand = torch.sum( self.batch_total_demand_b_vector ) #+torch.sum(batch_supplement_demand) #Now we want to see how accurate our inventory distributions are for the batch accept_probability_batch_by_threshold = CumSumNoGrad( verbose=-1)(collection_thresholds) + self.eps self.inventory_distribution_batch_by_thresholds = torch.mm( category, self.inventory_distribution_raw_est) arrival_probability_batch_by_threshold_unnormed = self.inventory_distribution_batch_by_thresholds * accept_probability_batch_by_threshold arrival_probability_batch_by_threshold = torch.div( arrival_probability_batch_by_threshold_unnormed, torch.sum(arrival_probability_batch_by_threshold_unnormed, dim=1).unsqueeze(1).expand_as( arrival_probability_batch_by_threshold_unnormed)) log_arrival_prob = torch.log(arrival_probability_batch_by_threshold + self.eps) #Like we do for inventory, we want to measure the accuracy of our cancel params for the batch self.belief_cancel_rate_bXt = torch.mm(category, self.belief_cancel_rate_cXt_est) belief_fill_rate_bXt = 1 - self.belief_cancel_rate_bXt self.belief_cancel_rate_b_vector = torch.sum( self.belief_cancel_rate_bXt * inv_count, dim=1).squeeze() belief_fill_rate_b_vector = 1 - self.belief_cancel_rate_b_vector log_cancel_prob = torch.log( torch.cat((belief_fill_rate_b_vector.unsqueeze(1), self.belief_cancel_rate_b_vector.unsqueeze(1)), 1) + self.eps) self.belief_category_dist_bXc = self.demand_distribution_est.unsqueeze( 0).expand(self.nBatch, self.nKnapsackCategories) log_category_prob = torch.log(self.belief_category_dist_bXc + self.eps) ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation accept_probability_using_threshold_params_bXt = torch.mm( category, accept_probability_cXt) truncated_accept_estimate = truncated_demand_bXt * accept_probability_using_threshold_params_bXt #This is the number of truncated orders we expect to accept (using param thresholds) at each inventory level corresponding to each order in the batch truncated_cancel_estimate = truncated_accept_estimate * self.belief_cancel_rate_bXt truncated_fill_estimate = truncated_accept_estimate * belief_fill_rate_bXt truncated_revenue_estimate = truncated_fill_estimate * ( price.unsqueeze(1).expand(self.nBatch, self.nThresholds)) truncated_revenue_estimate_sum = torch.sum(truncated_revenue_estimate) self.truncated_cancel_estimate_sum = torch.sum( truncated_cancel_estimate) truncated_fill_estimate_sum = torch.sum(truncated_fill_estimate) self.truncated_accept_estimate_sum = torch.sum( truncated_accept_estimate) ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation fill = 1 - cancel batch_cancel_bXt = cancel.unsqueeze(1).expand( self.nBatch, self.nThresholds ) * inv_count * accept_probability_using_threshold_params_bXt batch_fill_bXt = fill.unsqueeze(1).expand( self.nBatch, self.nThresholds ) * inv_count * accept_probability_using_threshold_params_bXt batch_cancel_b_vector = torch.sum(batch_cancel_bXt, dim=1).squeeze() batch_fill_b_vector = torch.sum(batch_fill_bXt, dim=1).squeeze() batch_accept_b_vector = torch.sum( inv_count * accept_probability_using_threshold_params_bXt, dim=1).squeeze() #print("sanity check",batch_accept_b_vector, batch_fill_b_vector+batch_cancel_b_vector) #print("sanity check 2", torch.sum(batch_accept_b_vector), torch.sum(batch_fill_b_vector+batch_cancel_b_vector)) batch_revenue_b_vector = price * batch_fill_b_vector self.batch_fill_sum = torch.sum(batch_fill_b_vector, dim=0) self.batch_revenue_sum = torch.sum(batch_revenue_b_vector, dim=0) self.batch_cancel_sum = torch.sum(batch_cancel_b_vector, dim=0) self.batch_accept_sum = torch.sum(batch_accept_b_vector, dim=0) new_objective_loss = -(1.0 / 50000) * (truncated_revenue_estimate_sum + self.batch_revenue_sum) new_cancel_constraint_loss = self.truncated_cancel_estimate_sum + self.batch_cancel_sum - ( self.truncated_accept_estimate_sum + self.batch_accept_sum) * self.cancel_rate_evaluation new_accept_constraint_loss = (1.0 / 7.0) * ( (self.truncated_accept_estimate_sum + self.batch_accept_sum) * self.accept_rate_evaluation - truncated_fill_estimate_sum - self.batch_fill_sum) #new_cancel_constraint_loss = truncated_cancel_estimate_sum+self.batch_cancel_sum-self.estimated_batch_total_demand*self.cancel_rate_param #new_accept_constraint_loss = (1.0/7.0)*(self.estimated_batch_total_demand*self.accept_rate_param-truncated_fill_estimate_sum-self.batch_fill_sum) observed_cancel_constraint_loss = self.batch_cancel_sum - ( self.batch_accept_sum) * self.cancel_rate_evaluation observed_accept_constraint_loss = ( 1.0 / 7.0) * (self.batch_accept_sum * self.accept_rate_evaluation - self.batch_fill_sum) return new_objective_loss, new_cancel_constraint_loss, new_accept_constraint_loss, arrival_probability_batch_by_threshold, log_arrival_prob, log_cancel_prob, log_category_prob, self.estimated_batch_total_demand, observed_cancel_constraint_loss, observed_accept_constraint_loss, self.lp_infeasible
def emd_inference_qpth(distance_matrix, weight1, weight2, form='QP', l2_strength=0.0001): """ to use the QP solver QPTH to derive EMD (LP problem), one can transform the LP problem to QP, or omit the QP term by multiplying it with a small value,i.e. l2_strngth. :param distance_matrix: nbatch * element_number * element_number :param weight1: nbatch * weight_number :param weight2: nbatch * weight_number :return: emd distance: nbatch*1 flow : nbatch * weight_number *weight_number """ weight1 = (weight1 * weight1.shape[-1]) / weight1.sum(1).unsqueeze(1) weight2 = (weight2 * weight2.shape[-1]) / weight2.sum(1).unsqueeze(1) nbatch = distance_matrix.shape[0] nelement_distmatrix = distance_matrix.shape[1] * distance_matrix.shape[2] nelement_weight1 = weight1.shape[1] nelement_weight2 = weight2.shape[1] Q_1 = distance_matrix.view(-1, 1, nelement_distmatrix).double() if form == 'QP': # version: QTQ Q = torch.bmm(Q_1.transpose(2, 1), Q_1).double().cuda( ) + 1e-4 * torch.eye(nelement_distmatrix).double().cuda().unsqueeze( 0).repeat(nbatch, 1, 1) # 0.00001 * p = torch.zeros(nbatch, nelement_distmatrix).double().cuda() elif form == 'L2': # version: regularizer Q = (l2_strength * torch.eye(nelement_distmatrix).double() ).cuda().unsqueeze(0).repeat(nbatch, 1, 1) p = distance_matrix.view(nbatch, nelement_distmatrix).double() else: raise ValueError('Unkown form') h_1 = torch.zeros(nbatch, nelement_distmatrix).double().cuda() h_2 = torch.cat([weight1, weight2], 1).double() h = torch.cat((h_1, h_2), 1) G_1 = -torch.eye(nelement_distmatrix).double().cuda().unsqueeze(0).repeat( nbatch, 1, 1) G_2 = torch.zeros( [nbatch, nelement_weight1 + nelement_weight2, nelement_distmatrix]).double().cuda() # sum_j(xij) = si for i in range(nelement_weight1): G_2[:, i, nelement_weight2 * i:nelement_weight2 * (i + 1)] = 1 # sum_i(xij) = dj for j in range(nelement_weight2): G_2[:, nelement_weight1 + j, j::nelement_weight2] = 1 #xij>=0, sum_j(xij) <= si,sum_i(xij) <= dj, sum_ij(x_ij) = min(sum(si), sum(dj)) G = torch.cat((G_1, G_2), 1) A = torch.ones(nbatch, 1, nelement_distmatrix).double().cuda() b = torch.min(torch.sum(weight1, 1), torch.sum(weight2, 1)).unsqueeze(1).double() flow = QPFunction(verbose=-1)(Q, p, G, h, A, b) emd_score = torch.sum((1 - Q_1).squeeze() * flow, 1) return emd_score, flow.view(-1, nelement_weight1, nelement_weight2)