예제 #1
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     meter_config = {
         "lr":
         AverageValueMeter(),
         "trloss":
         AverageValueMeter(),
         "trdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valloss":
         AverageValueMeter(),
         "valdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valbdice":
         BatchDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
     }
     self.METERINTERFACE = MeterInterface(meter_config)
     return [
         "trloss_mean",
         ["trdice_DSC1", "trdice_DSC2", "trdice_DSC3"],
         "valloss_mean",
         ["valdice_DSC1", "valdice_DSC2", "valdice_DSC3"],
         ["valbdice_DSC1", "valbdice_DSC2", "valbdice_DSC3"],
         "lr_mean",
     ]
예제 #2
0
 def setUp(self) -> None:
     super().setUp()
     self._meter_config = {
         "avg1": AverageValueMeter(),
         "dice1": SliceDiceMeter(C=2),
         "dice2": SliceDiceMeter(C=2),
     }
     self.meters = MeterInterface(self._meter_config)
예제 #3
0
 def setUp(self) -> None:
     super().setUp()
     C = 3
     meter_config = {
         "hd_meter": HaussdorffDistance(C=C),
         "s_dice": SliceDiceMeter(C=C, report_axises=[1, 2]),
         "b_dice": BatchDiceMeter(C=C, report_axises=[1, 2]),
     }
     self.meter = MeterInterface(meter_config)
    def test_resume(self):

        meterinterface = MeterInterface(
            {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()}
        )
        meterinterface.step()
        meterinterface.step()
        meterinterface.step()

        for epoch in range(10):
            if epoch == 2:
                meterinterface.register_meter("avg2", AverageValueMeter())
            for i in range(10):
                meterinterface["avg1"].add(1)
                meterinterface["dice1"].add(
                    torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
                )
                try:
                    meterinterface["avg2"].add(2)
                except:
                    pass
            meterinterface.step()
        print(meterinterface.summary())
        state_dict = meterinterface.state_dict()

        meterinterface2 = MeterInterface(
            {
                "avg1": AverageValueMeter(),
                "avg2": AverageValueMeter(),
                "dice1": SliceDiceMeter(),
                "avg3": AverageValueMeter(),
            }
        )
        meterinterface2.load_state_dict(state_dict)

        for epoch in range(10):
            for i in range(10):
                meterinterface2["avg3"].add(1)
                meterinterface2["dice1"].add(
                    torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
                )
            meterinterface2.step()
        print(meterinterface2.summary())
 def test_meter_interface(self):
     meterinterface = MeterInterface(
         {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()}
     )
     print(meterinterface.summary())
     for epoch in range(10):
         if epoch == 2:
             meterinterface.register_meter("avg2", AverageValueMeter())
         for i in range(10):
             meterinterface["avg1"].add(1)
             meterinterface["dice1"].add(
                 torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
             )
             try:
                 meterinterface["avg2"].add(2)
             except:
                 pass
         meterinterface.step()
     print(meterinterface.summary())
 def setUp(self) -> None:
     self.meter_config = {
         "loss": AverageValueMeter(),
         "tra_dice": SliceDiceMeter(C=5),
     }
     self.criterion = nn.CrossEntropyLoss()