Ejemplo n.º 1
0
def send_model_task(project_pk):
    """Trains, Saves, Predicts, Fills Queue."""
    from core.models import Project, TrainingSet
    from core.utils.utils_model import predict_data, train_and_save_model
    from core.utils.utils_queue import fill_queue, find_queue_length

    project = Project.objects.get(pk=project_pk)
    queue = project.queue_set.get(type="normal")
    irr_queue = project.queue_set.get(type="irr")
    al_method = project.learning_method
    batch_size = project.batch_size

    model = train_and_save_model(project)
    if al_method != "random":
        predict_data(project, model)
    TrainingSet.objects.create(
        project=project,
        set_number=project.get_current_training_set().set_number + 1)

    # Determine if queue size has changed (num_coders changed) and re-fill queue
    num_coders = len(project.projectpermissions_set.all()) + 1
    q_length = find_queue_length(batch_size, num_coders)
    if q_length != queue.length:
        queue.length = q_length
        queue.save()

    fill_queue(
        queue,
        irr_queue=irr_queue,
        orderby=al_method,
        irr_percent=project.percentage_irr,
        batch_size=batch_size,
    )
Ejemplo n.º 2
0
def test_predict_data(test_project_with_trained_model, tmpdir):
    project = test_project_with_trained_model

    predictions = predict_data(project, project.model_set.get())

    # Number of unlabeled data * number of labels.  Each data gets a prediction for each label.
    expected_predction_count = project.data_set.filter(
        datalabel__isnull=True).count() * project.labels.count()
    assert len(predictions) == expected_predction_count

    for prediction in predictions:
        assert isinstance(prediction, DataPrediction)
        assert_obj_exists(
            DataPrediction, {
                'data': prediction.data,
                'model': prediction.model,
                'label': prediction.label,
                'predicted_probability': prediction.predicted_probability
            })
Ejemplo n.º 3
0
def test_project_predicted_data(test_project_with_trained_model, tmpdir):
    project = test_project_with_trained_model

    predict_data(project, project.model_set.get())

    return test_project_with_trained_model