예제 #1
0
    def store_other_items(self, hard_an, hard_ap, prefix: str):
        """
        :param hard_an:
        :param hard_ap:
        :param prefix:
        :return:

        We may store other items:
            'prec':precision
            'sm':the proportion of triplets that satisfy margin
            'h_ap':average (anchor, positive) distance
            'h_an':average (anchor, negative) distance
        """
        # precision
        key = prefix + 'prec'
        if key not in self.meter_dict:
            self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
        self.meter_dict[key].update((hard_an > hard_ap).float().mean().item())
        # the proportion of triplets that satisfy margin
        key = prefix + 'sm'
        if key not in self.meter_dict:
            self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
        self.meter_dict[key].update(
            (hard_an > hard_ap + self.cfg.margin).float().mean().item())
        # average (anchor, positive) distance
        key = prefix + 'h_ap'
        if key not in self.meter_dict:
            self.meter_dict[key] = Meter(name=key)
        self.meter_dict[key].update(hard_ap.mean().item())
        # average (anchor, negative) distance
        key = prefix + 'h_an'
        if key not in self.meter_dict:
            self.meter_dict[key] = Meter(name=key)
        self.meter_dict[key].update(hard_an.mean().item())
예제 #2
0
    def store_calculate_loss(self, loss):
        """
        :param loss: torch.stack(loss_list).sum()
        :return:

        Meter: stores and computes the average of recent values.
        """
        if self.cfg.name not in self.meter_dict:
            # Here use RecentAverageMeter as Meter
            self.meter_dict[self.cfg.name] = Meter(name=self.cfg.name)
        # Update the meter, store the current  whole loss.
        self.meter_dict[self.cfg.name].update(loss.item())
예제 #3
0
    def store_score_accuracy(self, acc, accuracy_name):
        """
        :param acc:
        :param accuracy_name:
        :return:

        Meter: stores and computes the average of recent values.
        """
        if accuracy_name not in self.meter_dict:
            # Here use RecentAverageMeter as Meter
            self.meter_dict[accuracy_name] = Meter(name=accuracy_name)
        # Update the meter, store the current  whole loss.
        self.meter_dict[accuracy_name].update(acc.item())
예제 #4
0
 def store_other_items(self, avg_prob, acc, avg_hit_prob, avg_unhit_prob,
                       prefix: str):
     """
     :param avg_prob:
     :param acc:
     :param avg_hit_prob:
     :param avg_unhit_prob:
     :param prefix:
     :return:
     """
     # average prob
     key = prefix + 'avg_prob'
     if key not in self.meter_dict:
         self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
     self.meter_dict[key].update(avg_prob.item())
     # acc
     key = prefix + 'acc'
     if key not in self.meter_dict:
         self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
     self.meter_dict[key].update(acc.item())
     # average hit prob
     key = prefix + 'avg_hit_prob'
     if key not in self.meter_dict:
         self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
     if type(avg_hit_prob) == float:
         self.meter_dict[key].update(avg_hit_prob)
     else:
         self.meter_dict[key].update(avg_hit_prob.item())
     # average unhit prob
     key = prefix + 'avg_unhit_prob'
     if key not in self.meter_dict:
         self.meter_dict[key] = Meter(name=key, fmt='{:.2%}')
     if type(avg_unhit_prob) == float:
         self.meter_dict[key].update(avg_unhit_prob)
     else:
         self.meter_dict[key].update(avg_unhit_prob.item())
예제 #5
0
    def may_calculate_part_loss(self, loss_list):
        """
        :param loss_list: each part loss
        :return:

        Meter: stores and computes the average of recent values.
        For each part loss, calculate the loss separately.
        """
        if len(loss_list) > 1:

            # stores and computes each part average of recent values
            for i in range(len(loss_list)):
                # if there is not the meter of the part, create a new one.
                if self.part_fmt.format(i + 1) not in self.meter_dict:
                    self.meter_dict[self.part_fmt.format(i + 1)] = Meter(name=self.part_fmt.format(i + 1))
                # Update the meter, store the current part loss
                self.meter_dict[self.part_fmt.format(i + 1)].update(loss_list[i].item())