def main(args): address_parser = AddressParser(model_type=args.model_type, device=0) train_container = PickleDatasetContainer(args.train_dataset_path) lr_scheduler = StepLR(step_size=20) address_parser.retrain(train_container, 0.8, epochs=args.epochs, batch_size=args.batch_size, num_workers=6, learning_rate=args.learning_rate, callbacks=[lr_scheduler], logging_path=f"./chekpoints/{args.model_type}") test_container = PickleDatasetContainer(args.test_dataset_path) checkpoint = "best" address_parser.test(test_container, batch_size=args.batch_size, num_workers=4, logging_path=f"./chekpoints/{args.model_type}", checkpoint=checkpoint)
def main(args): address_parser = AddressParser(model_type=args.model_type, device=0) if args.mode in ("train", "both"): train_container = PickleDatasetContainer(args.train_dataset_path) lr_scheduler = StepLR(step_size=20) address_parser.retrain(train_container, 0.8, epochs=100, batch_size=1024, num_workers=6, learning_rate=0.001, callbacks=[lr_scheduler], logging_path=f"./chekpoints/{args.model_type}") if args.mode in ("test", "both"): test_container = PickleDatasetContainer(args.test_dataset_path) if args.mode == "test": checkpoint = handle_pre_trained_checkpoint(args.model_type) else: checkpoint = "best" address_parser.test(test_container, batch_size=2048, num_workers=4, logging_path=f"./chekpoints/{args.model_type}", checkpoint=checkpoint)
def test_exception_is_thrown_on_optimizer_argument(self): with self.assertRaises(ValueError): StepLR(self.optimizer, step_size=3)
def test_step_lr_integration(self): step_lr = StepLR(step_size=3) self._fit_with_callback_integration(step_lr)