def train():
    model = resnet18(classes_num=1, in_channels=3, pretrained=True)
    train_config = TrainConfig([train_stage, val_stage],
                               torch.nn.BCEWithLogitsLoss(),
                               torch.optim.Adam(model.parameters(), lr=1e-4))

    file_struct_manager = FileStructManager(base_dir='data', is_continue=False)

    trainer = Trainer(model, train_config, file_struct_manager,
                      torch.device('cuda:0')).set_epoch_num(2)

    tensorboard = TensorboardMonitor(file_struct_manager,
                                     is_continue=False,
                                     network_name='PortraitSegmentation')
    log = LogMonitor(file_struct_manager).write_final_metrics()
    trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log)
    trainer.enable_best_states_saving(
        lambda: np.mean(train_stage.get_losses()))

    trainer.enable_lr_decaying(
        coeff=0.5,
        patience=10,
        target_val_clbk=lambda: np.mean(train_stage.get_losses()))
    trainer.add_on_epoch_end_callback(
        lambda: tensorboard.update_scalar('params/lr',
                                          trainer.data_processor().get_lr()))
    trainer.train()
def train(config_type: type(BaseSegmentationTrainConfig)):
    fsm = FileStructManager(base_dir=config_type.experiment_dir, is_continue=False)

    config = config_type({'train': ['train_seg.npy'], 'val': 'val_seg.npy'})

    trainer = Trainer(config, fsm, device=torch.device('cuda'))
    tensorboard = TensorboardMonitor(fsm, is_continue=False)
    trainer.monitor_hub.add_monitor(tensorboard)

    trainer.set_epoch_num(300)
    trainer.enable_lr_decaying(coeff=0.5, patience=10, target_val_clbk=lambda: np.mean(config.val_stage.get_losses()))
    trainer.add_on_epoch_end_callback(lambda: tensorboard.update_scalar('params/lr', trainer.data_processor().get_lr()))
    trainer.enable_best_states_saving(lambda: np.mean(config.val_stage.get_losses()))
    trainer.add_stop_rule(lambda: trainer.data_processor().get_lr() < 1e-6)

    trainer.train()
Beispiel #3
0
    def test_lr_decaying(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stages = [
            TrainStage(
                TestDataProducer([[{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]]), metrics_processor),
            ValidationStage(
                TestDataProducer([[{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]]), metrics_processor)
        ]
        trainer = Trainer(
            model,
            TrainConfig(stages, SimpleLoss(),
                        torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(10)

        def target_value_clbk() -> float:
            return 1

        trainer.enable_lr_decaying(0.5, 3, target_value_clbk)
        trainer.train()

        self.assertAlmostEqual(trainer.data_processor().get_lr(),
                               0.1 * (0.5**3),
                               delta=1e-6)
Beispiel #4
0
def train():
    train_config = PoseNetTrainConfig()

    file_struct_manager = FileStructManager(
        base_dir=PoseNetTrainConfig.experiment_dir, is_continue=False)

    trainer = Trainer(train_config, file_struct_manager, torch.device('cuda'))
    trainer.set_epoch_num(EPOCH_NUM)

    tensorboard = TensorboardMonitor(file_struct_manager, is_continue=False)
    log = LogMonitor(file_struct_manager).write_final_metrics()
    trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log)
    trainer.enable_best_states_saving(
        lambda: np.mean(train_config.val_stage.get_losses()))

    trainer.enable_lr_decaying(
        coeff=0.5,
        patience=10,
        target_val_clbk=lambda: np.mean(train_config.val_stage.get_losses()))
    trainer.add_on_epoch_end_callback(
        lambda: tensorboard.update_scalar('params/lr',
                                          trainer.data_processor().get_lr()))
    trainer.train()