def __init_meters__(self) -> List[Union[str, List[str]]]: meter_config = { "lr": AverageValueMeter(), "trloss": AverageValueMeter(), "trdice": SliceDiceMeter(C=self.model.arch_dict["num_classes"], report_axises=self.axis), "valloss": AverageValueMeter(), "valdice": SliceDiceMeter(C=self.model.arch_dict["num_classes"], report_axises=self.axis), "valbdice": BatchDiceMeter(C=self.model.arch_dict["num_classes"], report_axises=self.axis), } self.METERINTERFACE = MeterInterface(meter_config) return [ "trloss_mean", ["trdice_DSC1", "trdice_DSC2", "trdice_DSC3"], "valloss_mean", ["valdice_DSC1", "valdice_DSC2", "valdice_DSC3"], ["valbdice_DSC1", "valbdice_DSC2", "valbdice_DSC3"], "lr_mean", ]
def setUp(self) -> None: super().setUp() self._meter_config = { "avg1": AverageValueMeter(), "dice1": SliceDiceMeter(C=2), "dice2": SliceDiceMeter(C=2), } self.meters = MeterInterface(self._meter_config)
def setUp(self) -> None: super().setUp() C = 3 meter_config = { "hd_meter": HaussdorffDistance(C=C), "s_dice": SliceDiceMeter(C=C, report_axises=[1, 2]), "b_dice": BatchDiceMeter(C=C, report_axises=[1, 2]), } self.meter = MeterInterface(meter_config)
def test_resume(self): meterinterface = MeterInterface( {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()} ) meterinterface.step() meterinterface.step() meterinterface.step() for epoch in range(10): if epoch == 2: meterinterface.register_meter("avg2", AverageValueMeter()) for i in range(10): meterinterface["avg1"].add(1) meterinterface["dice1"].add( torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224)) ) try: meterinterface["avg2"].add(2) except: pass meterinterface.step() print(meterinterface.summary()) state_dict = meterinterface.state_dict() meterinterface2 = MeterInterface( { "avg1": AverageValueMeter(), "avg2": AverageValueMeter(), "dice1": SliceDiceMeter(), "avg3": AverageValueMeter(), } ) meterinterface2.load_state_dict(state_dict) for epoch in range(10): for i in range(10): meterinterface2["avg3"].add(1) meterinterface2["dice1"].add( torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224)) ) meterinterface2.step() print(meterinterface2.summary())
def test_meter_interface(self): meterinterface = MeterInterface( {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()} ) print(meterinterface.summary()) for epoch in range(10): if epoch == 2: meterinterface.register_meter("avg2", AverageValueMeter()) for i in range(10): meterinterface["avg1"].add(1) meterinterface["dice1"].add( torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224)) ) try: meterinterface["avg2"].add(2) except: pass meterinterface.step() print(meterinterface.summary())
def setUp(self) -> None: self.meter_config = { "loss": AverageValueMeter(), "tra_dice": SliceDiceMeter(C=5), } self.criterion = nn.CrossEntropyLoss()