Ejemplo n.º 1
0
    def test_DenseNet161_freezing(self):
        hyper_params = Hyperparameter_Set.from_file(
            "./Test/Params/DenseNet161_23_overfit.json")
        hyper_params["solver_params"]["num_epochs"] = 1
        hyper_params["model_params"]["freeze_point"] = 20

        model = DenseNet161(
            freeze_point=hyper_params["model_params"]["freeze_point"],
            num_classes=hyper_params["model_params"]["num_classes"])
        train_params_before = list(
            model.trainyou_wewill.parameters())[0].data.numpy()
        freezed_params_before = list(
            model.features.parameters())[0].data.numpy()

        solver = Solver(hyper_params["solver_params"], model)
        solver.train(
            self.train_loader,
            self.train_loader,
            self.train_loader,
            hyper_params["solver_params"]["num_epochs"],
            log_frequency=hyper_params["solver_params"]["log_frequency"])

        testing.assert_raises(
            AssertionError, testing.assert_array_equal, train_params_before,
            list(model.cpu().trainyou_wewill.parameters())[0].data.numpy())
        testing.assert_array_equal(
            freezed_params_before,
            list(model.cpu().features.parameters())[0].data.numpy())
Ejemplo n.º 2
0
    def test_get_log_iterations(self):
        exp10 = list(range(9, -1, -1))
        exp3 = [9, 6, 3]
        exp2 = [9, 4]
        exp1 = [9]

        solver = Solver(Constants.converted_valid_solver_params, self.dummy_net)

        self.assertEqual(solver._get_log_iterations(10, 10), exp10)
        self.assertEqual(solver._get_log_iterations(10, 3), exp3)
        self.assertEqual(solver._get_log_iterations(10, 2), exp2)
        self.assertEqual(solver._get_log_iterations(10, 1), exp1)
Ejemplo n.º 3
0
    def test_DenseNet161_requires_grads_disabled(self):
        hyper_params = Hyperparameter_Set.from_file(
            "./Test/Params/DenseNet161_23_overfit.json")
        model = DenseNet161(
            freeze_point=hyper_params["model_params"]["freeze_point"],
            num_classes=hyper_params["model_params"]["num_classes"])
        solver = Solver(hyper_params["solver_params"], model)

        for param in solver.model.features.parameters():
            self.assertFalse(param.requires_grad)

        for param in solver.model.trainyou_wewill.parameters():
            self.assertTrue(param.requires_grad)

        for param in solver.model.classifier.parameters():
            self.assertTrue(param.requires_grad)
Ejemplo n.º 4
0
data_params, solver_params, model_params = hyper_params["data_params"], hyper_params["solver_params"], hyper_params["model_params"]


data_root_dir = data_params["data_root_dir"]
split_config = data_params["split_config"]
augmentation_params = data_params["augmentation_params"]
split_config["test_subsets"] = ["E2"]
top_k = 5

_, val_set, test_set = DatasetContainer(data_root_dir, split_config, augmentation_params).get_datasets()
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=data_params["batch_size"],
                                          shuffle=False,
                                          num_workers=data_params["num_workers"],
                                          )

print("Loading the model...")
model = Model(freeze_point=model_params["freeze_point"], num_classes = model_params["num_classes"])
model.load_state_dict(torch.load(os.path.join(results_dir, "model.pt")))

print("Initializing the solver...")
solver = Solver(hyper_params["solver_params"], model)

print("Let's test the model...")
test_acc, confusion_matrix = solver.test(test_loader,create_confusion_matrix=True)
print("Test accuracy:", test_acc)

print("Let's test Top %i accuracy ..." % top_k)
top_k_test_acc = solver.test_top_k(test_loader, top_k)
print("Top%i accuracy:" % (top_k), top_k_test_acc)
Ejemplo n.º 5
0
 def test_overfit_test(self):
     dummy_net = Fully_Connected_Net(in_size=3 * 224 * 224, out_size=24)
     solver = Solver(Constants.overfit_params_finger_images, dummy_net)
     log, weights = solver.train(self.test_loader, self.test_loader, self.test_loader, epoch_count=15, log_frequency=1)
     self.assertGreater(log["train_acc"][-1], 0.9)
     self.assertIsNotNone(weights)
Ejemplo n.º 6
0
 def test_overfit_train(self):     # and therefore test, if construction of the three datasets worked
     dummy_net = Fully_Connected_Net(in_size=3 * 224 * 224, out_size=24)
     solver = Solver(Constants.overfit_params_finger_images, dummy_net)
     log, weights = solver.train(self.val_loader, self.val_loader, self.val_loader, epoch_count=15, log_frequency=1)
     self.assertGreater(log["train_acc"][-1], 0.9)
     self.assertIsNotNone(weights)
Ejemplo n.º 7
0
def train(hyper_params,
          architecture,
          result_dir,
          train_set,
          val_set,
          model_path=None):
    data_params, solver_params, model_params = hyper_params[
        "data_params"], hyper_params["solver_params"], hyper_params[
            "model_params"]

    print("Initializing the data loaders...")
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=data_params["batch_size"],
        shuffle=False,
        num_workers=data_params["num_workers"],
        sampler=RandomSubsetSampler(train_set,
                                    data_params["samples_per_epoch"]),
        drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=data_params["batch_size"],
        shuffle=False,
        num_workers=data_params["num_workers"],
        sampler=RandomSubsetSampler(val_set,
                                    data_params["samples_per_epoch"] * 2))

    small_val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=data_params["batch_size"],
        shuffle=False,
        num_workers=data_params["num_workers"],
        sampler=RandomSubsetSampler(val_set,
                                    data_params["small_val_set_size"]))

    print("Loading the model...")
    model = architecture(freeze_point=model_params["freeze_point"],
                         num_classes=model_params["num_classes"])

    if not model_path is None:
        model.load_state_dict(torch.load(model_path))

    print("Initializing the solver...")
    solver = Solver(hyper_params["solver_params"], model)

    print("Finished initialization, ready to train!")
    print("#" * 30)

    log, best_weights = solver.train(
        train_loader,
        val_loader,
        small_val_loader,
        solver_params["num_epochs"],
        log_frequency=solver_params["log_frequency"],
        log_loss=solver_params["log_loss"],
        verbose=solver_params["verbose"])

    print("#" * 30)

    print("Creating confusion matrix for validation and testset...")
    solver.set_model_params(best_weights)
    _, confusion_matrix = solver.test(val_loader, create_confusion_matrix=True)

    print("Creating directory for results...")
    result_dir = os.path.join(result_dir, get_timestamp())
    os.mkdir(result_dir)

    print("Saving the confusion matrix...")
    np.save(os.path.join(result_dir, "confusion_matrix_val.npy"),
            confusion_matrix)

    print("Saving the log...")
    cast_log_to_float(log)
    with open(os.path.join(result_dir, "training.log"), mode="w") as f:
        json.dump(log, f)

    print("Saving the model...")
    torch.save(best_weights, os.path.join(result_dir, "model.pt"))

    print("Saving the params...")
    hyper_params.save(os.path.join(result_dir, "params.json"))

    print("Clean up...")
    del solver
    del model
    del log
    torch.cuda.empty_cache()
    gc.collect()
Ejemplo n.º 8
0
    def test_overfit(self):
        solver = Solver(Constants.overfit_params, self.dummy_net)
        log, weights = solver.train(self.data_loader, self.data_loader, self.data_loader, epoch_count=10, log_frequency=1)

        self.assertGreater(log["train_acc"][-1], 0.9)
        self.assertIsNotNone(weights)
Ejemplo n.º 9
0
 def test_invalid_initialization(self):
     try:
         Solver(Constants.invalid_solver_params, self.dummy_net)
         raise AssertionError("Expected an exception...")
     except:
         pass
Ejemplo n.º 10
0
 def test_valid_initilization(self):
     solver = Solver(Constants.converted_valid_solver_params, self.dummy_net)
     self.assertEqual(solver.params, Constants.converted_valid_solver_params)