Ejemplo n.º 1
0
    def test_savig_best_states(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stages = [TrainStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)}
                                                for _ in list(range(20))]]), metrics_processor)]
        trainer = Trainer(TrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=0.1)),
                          fsm).set_epoch_num(3).enable_best_states_saving(lambda: np.mean(stages[0].get_losses()))

        checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last', 'last_checkpoint.zip')
        best_checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'best', 'best_checkpoint.zip')

        class Val:
            def __init__(self):
                self.v = None

        first_val = Val()

        def on_epoch_end(val):
            if val.v is not None and np.mean(stages[0].get_losses()) < val.v:
                self.assertTrue(os.path.exists(best_checkpoint_file))
                os.remove(best_checkpoint_file)
                val.v = np.mean(stages[0].get_losses())
                return

            val.v = np.mean(stages[0].get_losses())

            self.assertTrue(os.path.exists(checkpoint_file))
            self.assertFalse(os.path.exists(best_checkpoint_file))
            os.remove(checkpoint_file)

        trainer.add_on_epoch_end_callback(lambda: on_epoch_end(first_val))
        trainer.train()
Ejemplo n.º 2
0
    def test_lr_decaying(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stages = [
            TrainStage(
                TestDataProducer([[{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]]), metrics_processor),
            ValidationStage(
                TestDataProducer([[{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]]), metrics_processor)
        ]
        trainer = Trainer(
            model,
            TrainConfig(stages, SimpleLoss(),
                        torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(10)

        def target_value_clbk() -> float:
            return 1

        trainer.enable_lr_decaying(0.5, 3, target_value_clbk)
        trainer.train()

        self.assertAlmostEqual(trainer.data_processor().get_lr(),
                               0.1 * (0.5**3),
                               delta=1e-6)
Ejemplo n.º 3
0
    def test_savig_states(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stages = [
            TrainStage(
                TestDataProducer([[{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]]), metrics_processor)
        ]
        trainer = Trainer(
            model,
            TrainConfig(stages, SimpleLoss(),
                        torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(3)

        checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last',
                                       'last_checkpoint.zip')

        def on_epoch_end():
            self.assertTrue(os.path.exists(checkpoint_file))
            os.remove(checkpoint_file)

        trainer.add_on_epoch_end_callback(on_epoch_end)
        trainer.train()
Ejemplo n.º 4
0
 def test_train(self):
     fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
     model = SimpleModel()
     metrics_processor = MetricsProcessor()
     stages = [TrainStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)}
                                             for _ in list(range(20))]]), metrics_processor),
               ValidationStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)}
                                                  for _ in list(range(20))]]), metrics_processor)]
     Trainer(TrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
         .set_epoch_num(1).train()
Ejemplo n.º 5
0

class SegmentationMetricsProcessor(MetricsProcessor):
    def __init__(self, stage_name: str):
        super().__init__()
        self.add_metrics_group(MetricsGroup(stage_name).add(JaccardMetric()).add(DiceMetric()))


###################################
# define train config and train model
###################################

train_data_producer = DataProducer([train_dataset], batch_size=2, num_workers=3)
val_data_producer = DataProducer([val_dataset], batch_size=2, num_workers=3)

train_stage = TrainStage(train_data_producer, SegmentationMetricsProcessor('train')).enable_hard_negative_mining(0.1)
val_metrics_processor = SegmentationMetricsProcessor('validation')
val_stage = ValidationStage(val_data_producer, val_metrics_processor)


def train():
    model = resnet18(classes_num=1, in_channels=3, pretrained=True)
    train_config = TrainConfig(model, [train_stage, val_stage], torch.nn.BCEWithLogitsLoss(),
                               torch.optim.Adam(model.parameters(), lr=1e-4))

    file_struct_manager = FileStructManager(base_dir='data', is_continue=False)

    trainer = Trainer(train_config, file_struct_manager, torch.device('cuda:0')).set_epoch_num(2)

    tensorboard = TensorboardMonitor(file_struct_manager, is_continue=False, network_name='PortraitSegmentation')
    log = LogMonitor(file_struct_manager).write_final_metrics()