def store_other_items(self, hard_an, hard_ap, prefix: str): """ :param hard_an: :param hard_ap: :param prefix: :return: We may store other items: 'prec':precision 'sm':the proportion of triplets that satisfy margin 'h_ap':average (anchor, positive) distance 'h_an':average (anchor, negative) distance """ # precision key = prefix + 'prec' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') self.meter_dict[key].update((hard_an > hard_ap).float().mean().item()) # the proportion of triplets that satisfy margin key = prefix + 'sm' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') self.meter_dict[key].update( (hard_an > hard_ap + self.cfg.margin).float().mean().item()) # average (anchor, positive) distance key = prefix + 'h_ap' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key) self.meter_dict[key].update(hard_ap.mean().item()) # average (anchor, negative) distance key = prefix + 'h_an' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key) self.meter_dict[key].update(hard_an.mean().item())
def store_calculate_loss(self, loss): """ :param loss: torch.stack(loss_list).sum() :return: Meter: stores and computes the average of recent values. """ if self.cfg.name not in self.meter_dict: # Here use RecentAverageMeter as Meter self.meter_dict[self.cfg.name] = Meter(name=self.cfg.name) # Update the meter, store the current whole loss. self.meter_dict[self.cfg.name].update(loss.item())
def store_score_accuracy(self, acc, accuracy_name): """ :param acc: :param accuracy_name: :return: Meter: stores and computes the average of recent values. """ if accuracy_name not in self.meter_dict: # Here use RecentAverageMeter as Meter self.meter_dict[accuracy_name] = Meter(name=accuracy_name) # Update the meter, store the current whole loss. self.meter_dict[accuracy_name].update(acc.item())
def store_other_items(self, avg_prob, acc, avg_hit_prob, avg_unhit_prob, prefix: str): """ :param avg_prob: :param acc: :param avg_hit_prob: :param avg_unhit_prob: :param prefix: :return: """ # average prob key = prefix + 'avg_prob' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') self.meter_dict[key].update(avg_prob.item()) # acc key = prefix + 'acc' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') self.meter_dict[key].update(acc.item()) # average hit prob key = prefix + 'avg_hit_prob' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') if type(avg_hit_prob) == float: self.meter_dict[key].update(avg_hit_prob) else: self.meter_dict[key].update(avg_hit_prob.item()) # average unhit prob key = prefix + 'avg_unhit_prob' if key not in self.meter_dict: self.meter_dict[key] = Meter(name=key, fmt='{:.2%}') if type(avg_unhit_prob) == float: self.meter_dict[key].update(avg_unhit_prob) else: self.meter_dict[key].update(avg_unhit_prob.item())
def may_calculate_part_loss(self, loss_list): """ :param loss_list: each part loss :return: Meter: stores and computes the average of recent values. For each part loss, calculate the loss separately. """ if len(loss_list) > 1: # stores and computes each part average of recent values for i in range(len(loss_list)): # if there is not the meter of the part, create a new one. if self.part_fmt.format(i + 1) not in self.meter_dict: self.meter_dict[self.part_fmt.format(i + 1)] = Meter(name=self.part_fmt.format(i + 1)) # Update the meter, store the current part loss self.meter_dict[self.part_fmt.format(i + 1)].update(loss_list[i].item())