def test_train(self): model = SimpleModel().train() train_config = TrainConfig([], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(model=model, train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) self.assertTrue(model.training) res = dp.predict({'data': torch.rand(1, 3)}, is_train=True) self.assertTrue(model.training) self.assertTrue(res.requires_grad) self.assertIsNone(res.grad) with self.assertRaises(NotImplementedError): dp.process_batch( { 'data': torch.rand(1, 3), 'target': torch.rand(1) }, is_train=True) loss = SimpleLoss() train_config = TrainConfig([], loss, torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(model=model, train_config=train_config) res = dp.process_batch( { 'data': torch.rand(1, 3), 'target': torch.rand(1) }, is_train=True) self.assertTrue(model.training) self.assertTrue(loss.module.requires_grad) self.assertIsNotNone(loss.module.grad) self.assertTrue(np.array_equal(res, loss.res.data.numpy()))
def test_savig_best_states(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) model = SimpleModel() metrics_processor = MetricsProcessor() stages = [TrainStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)} for _ in list(range(20))]]), metrics_processor)] trainer = Trainer(TrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=0.1)), fsm).set_epoch_num(3).enable_best_states_saving(lambda: np.mean(stages[0].get_losses())) checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last', 'last_checkpoint.zip') best_checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'best', 'best_checkpoint.zip') class Val: def __init__(self): self.v = None first_val = Val() def on_epoch_end(val): if val.v is not None and np.mean(stages[0].get_losses()) < val.v: self.assertTrue(os.path.exists(best_checkpoint_file)) os.remove(best_checkpoint_file) val.v = np.mean(stages[0].get_losses()) return val.v = np.mean(stages[0].get_losses()) self.assertTrue(os.path.exists(checkpoint_file)) self.assertFalse(os.path.exists(best_checkpoint_file)) os.remove(checkpoint_file) trainer.add_on_epoch_end_callback(lambda: on_epoch_end(first_val)) trainer.train()
def test_initialisation(self): model = SimpleModel() train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) try: TrainDataProcessor(train_config=train_config) except: self.fail('DataProcessor initialisation raises exception')
def train(): model = resnet18(classes_num=1, in_channels=3, pretrained=True) train_config = TrainConfig([train_stage, val_stage], torch.nn.BCEWithLogitsLoss(), torch.optim.Adam(model.parameters(), lr=1e-4)) file_struct_manager = FileStructManager(base_dir='data', is_continue=False) trainer = Trainer(model, train_config, file_struct_manager, torch.device('cuda:0')).set_epoch_num(2) tensorboard = TensorboardMonitor(file_struct_manager, is_continue=False, network_name='PortraitSegmentation') log = LogMonitor(file_struct_manager).write_final_metrics() trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log) trainer.enable_best_states_saving( lambda: np.mean(train_stage.get_losses())) trainer.enable_lr_decaying( coeff=0.5, patience=10, target_val_clbk=lambda: np.mean(train_stage.get_losses())) trainer.add_on_epoch_end_callback( lambda: tensorboard.update_scalar('params/lr', trainer.data_processor().get_lr())) trainer.train()
def test_lr_decaying(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) model = SimpleModel() metrics_processor = MetricsProcessor() stages = [ TrainStage( TestDataProducer([[{ 'data': torch.rand(1, 3), 'target': torch.rand(1) } for _ in list(range(20))]]), metrics_processor), ValidationStage( TestDataProducer([[{ 'data': torch.rand(1, 3), 'target': torch.rand(1) } for _ in list(range(20))]]), metrics_processor) ] trainer = Trainer( model, TrainConfig(stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=0.1)), fsm).set_epoch_num(10) def target_value_clbk() -> float: return 1 trainer.enable_lr_decaying(0.5, 3, target_value_clbk) trainer.train() self.assertAlmostEqual(trainer.data_processor().get_lr(), 0.1 * (0.5**3), delta=1e-6)
def test_savig_states(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) model = SimpleModel() metrics_processor = MetricsProcessor() stages = [ TrainStage( TestDataProducer([[{ 'data': torch.rand(1, 3), 'target': torch.rand(1) } for _ in list(range(20))]]), metrics_processor) ] trainer = Trainer( model, TrainConfig(stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=0.1)), fsm).set_epoch_num(3) checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last', 'last_checkpoint.zip') def on_epoch_end(): self.assertTrue(os.path.exists(checkpoint_file)) os.remove(checkpoint_file) trainer.add_on_epoch_end_callback(on_epoch_end) trainer.train()
def test_base_ops(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) model = SimpleModel() trainer = Trainer(TrainConfig(model, [], torch.nn.L1Loss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) with self.assertRaises(Trainer.TrainerException): trainer.train()
def test_train(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) model = SimpleModel() metrics_processor = MetricsProcessor() stages = [TrainStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)} for _ in list(range(20))]]), metrics_processor), ValidationStage(TestDataProducer([[{'data': torch.rand(1, 3), 'target': torch.rand(1)} for _ in list(range(20))]]), metrics_processor)] Trainer(TrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \ .set_epoch_num(1).train()
def test_predict(self): model = SimpleModel().train() train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) self.assertTrue(model.training) res = dp.predict({'data': torch.rand(1, 3)}) self.assertFalse(model.training) self.assertFalse(res.requires_grad) self.assertIsNone(res.grad)
def test_continue_from_checkpoint(self): def on_node(n1, n2): self.assertTrue(np.array_equal(n1.numpy(), n2.numpy())) model = SimpleModel().train() loss = SimpleLoss() for optim in [ torch.optim.SGD(model.parameters(), lr=0.1), torch.optim.Adam(model.parameters(), lr=0.1) ]: train_config = TrainConfig([], loss, optim) dp_before = TrainDataProcessor(model=model, train_config=train_config) before_state_dict = model.state_dict().copy() dp_before.update_lr(0.023) with self.assertRaises(Model.ModelException): dp_before.save_state() try: fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) dp_before.set_checkpoints_manager(CheckpointsManager(fsm)) dp_before.save_state() except: self.fail( "Exception on saving state when 'CheckpointsManager' specified" ) fsm = FileStructManager(base_dir=self.base_dir, is_continue=True) dp_after = TrainDataProcessor(model=model, train_config=train_config) with self.assertRaises(Model.ModelException): dp_after.load() try: cm = CheckpointsManager(fsm) dp_after.set_checkpoints_manager(cm) dp_after.load() except: self.fail('DataProcessor initialisation raises exception') after_state_dict = model.state_dict().copy() dict_pair_recursive_bypass(before_state_dict, after_state_dict, on_node) self.assertEqual(dp_before.get_lr(), dp_after.get_lr()) shutil.rmtree(self.base_dir)
def test_prediction_notrain_output(self): model = SimpleModel() train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict({'data': torch.rand(1, 3)}, is_train=False) self.assertIs(type(res), torch.Tensor) model = NonStandardIOModel() train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict({'data': {'data1': torch.rand(1, 3), 'data2': torch.rand(1, 3)}}, is_train=False) self.assertIs(type(res), dict) self.assertIn('res1', res) self.assertIs(type(res['res1']), torch.Tensor) self.assertIn('res2', res) self.assertIs(type(res['res2']), torch.Tensor) self.assertFalse(model.training) self.assertFalse(res['res1'].requires_grad) self.assertIsNone(res['res1'].grad) self.assertFalse(res['res2'].requires_grad) self.assertIsNone(res['res2'].grad)