Esempio n. 1
0
def main(args):
    address_parser = AddressParser(model_type=args.model_type, device=0)

    train_container = PickleDatasetContainer(args.train_dataset_path)

    lr_scheduler = StepLR(step_size=20)

    address_parser.retrain(train_container,
                           0.8,
                           epochs=args.epochs,
                           batch_size=args.batch_size,
                           num_workers=6,
                           learning_rate=args.learning_rate,
                           callbacks=[lr_scheduler],
                           logging_path=f"./chekpoints/{args.model_type}")

    test_container = PickleDatasetContainer(args.test_dataset_path)

    checkpoint = "best"

    address_parser.test(test_container,
                        batch_size=args.batch_size,
                        num_workers=4,
                        logging_path=f"./chekpoints/{args.model_type}",
                        checkpoint=checkpoint)
Esempio n. 2
0
def main(args):
    address_parser = AddressParser(model_type=args.model_type, device=0)

    if args.mode in ("train", "both"):
        train_container = PickleDatasetContainer(args.train_dataset_path)

        lr_scheduler = StepLR(step_size=20)

        address_parser.retrain(train_container,
                               0.8,
                               epochs=100,
                               batch_size=1024,
                               num_workers=6,
                               learning_rate=0.001,
                               callbacks=[lr_scheduler],
                               logging_path=f"./chekpoints/{args.model_type}")

    if args.mode in ("test", "both"):
        test_container = PickleDatasetContainer(args.test_dataset_path)

        if args.mode == "test":
            checkpoint = handle_pre_trained_checkpoint(args.model_type)
        else:
            checkpoint = "best"

        address_parser.test(test_container,
                            batch_size=2048,
                            num_workers=4,
                            logging_path=f"./chekpoints/{args.model_type}",
                            checkpoint=checkpoint)
Esempio n. 3
0
 def test_exception_is_thrown_on_optimizer_argument(self):
     with self.assertRaises(ValueError):
         StepLR(self.optimizer, step_size=3)
Esempio n. 4
0
 def test_step_lr_integration(self):
     step_lr = StepLR(step_size=3)
     self._fit_with_callback_integration(step_lr)