예제 #1
0
    def get_inputs(cls, num_prediction_steps, num_results=None):
        #Arrange
        model = load_test_model()
        args = get_args(model_name, Path())
        input_params = InputParameters(args)
        image_generation_params = ImageGenerationParameters(args)
        input_data = get_input_data()

        #Update input data parameters
        num_classes = len(set(input_data[image_generation_params.label_col]))
        image_generation_params_update = dict(num_classes=num_classes,
                                              image_cols=['Image'])
        update_params(image_generation_params,
                      **image_generation_params_update)

        #Num results
        num_results = num_results or image_generation_params.batch_size * num_prediction_steps

        #Mocks
        prediction_results = np.zeros((num_results, num_classes))
        for row_id in range(num_results):
            prediction_results[row_id, row_id % num_classes] = 1

        model.predict_generator = MagicMock()
        model.predict_generator.return_value = prediction_results

        return model, input_data, input_params, image_generation_params, prediction_results
예제 #2
0
    def test_configure_base_model(self):
        #Arrange
        op = Operation(num_unfrozen_layers, configure_base=True, base_level=0)
        model = load_test_model()

        #Act
        model = op.configure(model)

        #Assert
        self.verify_unfrozen_layers(model, num_unfrozen_layers)
예제 #3
0
    def test_set_model(self):
        #Arrange
        checkpoint, _, _, _ = get_checkpoint()
        model = load_test_model()

        #Act
        checkpoint.set_model(model)

        #Assert
        self.assertIsNotNone(checkpoint._model)
예제 #4
0
    def test_save(self):
        #Arrange
        model = load_test_model()
        model_input = ModelInput(model_name)

        #Mocks
        model.save = MagicMock()

        #Act
        model_input.save(model, batch_id, epoch_id)

        #Assert
        model.save.assert_called_with(str(model_input.file_name(batch_id, epoch_id)))
예제 #5
0
    def test_on_batch_end_save_called(self):
        #Arrange
        batch_id = 1
        checkpoint, batch_input_files, _, _ = get_checkpoint()
        input_data = get_input_data()
        model = load_test_model()
        checkpoint.set_model(model)
        checkpoint.set_input_data(input_data)
        checkpoint.on_epoch_begin(epoch_id)
        checkpoint.on_batch_begin(batch_id)

        #Act & Assert
        self.on_batch_end(checkpoint, batch_id, batch_input_files)
예제 #6
0
def get_train_args():
    input_data = get_input_data()
    model = load_test_model()
    input_params, training_params, image_generation_params, transformation_params = get_params()
    image_generation_params.num_classes = 64

    trainer = ImageTraining(
                    input_params,
                    training_params,
                    image_generation_params,
                    transformation_params,
                    MagicMock(),
                    summary = False)

    return model, input_data, trainer
예제 #7
0
    def test_on_epoch_end(self):
        #Arrange
        checkpoint, _, _, epoch_end_input_files = get_checkpoint()
        model = load_test_model()
        input_data = get_input_data()
        result_file = epoch_end_input_files[0]
        result_file.save = MagicMock()
        checkpoint.set_model(model)
        checkpoint.set_input_data(input_data)
        checkpoint.on_epoch_begin(epoch_id)
        checkpoint.on_batch_begin(batch_id)

        #Act
        checkpoint.on_epoch_end(epoch_id)

        #Assert
        self.on_result_file_save(checkpoint, result_file)