Beispiel #1
0
def download_csv(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    data = request.get_json()
    ids = [str(x) for x in data['selectedIds']]
    if not ids:
        return jsonify(error='No ids selected')

    elements = db.session.query(
        Dataset.id,
        Dataset.free_text,
        ProblemLabel.label,
    ).join(LabelEvent.data).join(ProblemLabel).filter(
        Dataset.problem_id == problem_id, ).group_by(Dataset.id, LabelEvent.id,
                                                     ProblemLabel.id).distinct(
                                                         Dataset.id).all()

    stream = io.StringIO()
    writer = csv.writer(stream)
    writer.writerow(['text', 'label'])
    for element in elements:
        writer.writerow(element[1:])

    blob = stream.getvalue()
    response = make_response(blob)
    cd = 'attachment; filename=data.csv'
    response.headers['Content-Disposition'] = cd
    response.mimetype = 'text/csv'

    return response
def multi_class_delete_label_event(problem_id, data_id):
    data = Dataset.query.get(data_id)
    assert_rights_to_problem(data.problem)
    problem_labels = data.problem.labels
    label_id = request.form.get('label_id')
    selected_label = ProblemLabel.query.get(label_id) if label_id else None

    for x in data.label_events:
        db.session.delete(x)

    for label in problem_labels:
        db.session.add(
            LabelEvent(
                data=data,
                label=label,
                created_by=current_user.username,
                label_matches=True if label == selected_label else None))
    db.session.commit()

    if selected_label:
        flash('Label event replaced with %s' % selected_label.label)
    else:
        flash('Label events removed')

    return redirect(url_for('train', problem_id=problem_id))
Beispiel #3
0
def batch_label(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    data = request.get_json()
    ids = [str(x) for x in data['selectedIds']]
    if not ids:
        return jsonify(error='No ids selected')

    labels = (db.session.query(LabelEvent.id).outerjoin(
        LabelEvent.data).filter(Dataset.id.in_(ids),
                                Dataset.problem_id == problem.id,
                                LabelEvent.label_id == data['label']).all())
    if labels:
        LabelEvent.query.filter(LabelEvent.id.in_(
            [x[0] for x in labels])).delete(synchronize_session='fetch')

    if data['value'] != 'undo':
        for dataset_id in ids:
            db.session.add(
                LabelEvent(label_id=data['label'],
                           label_matches=data['value'],
                           data_id=dataset_id))

    db.session.commit()
    return jsonify(status='ok', labels_removed=len(labels))
Beispiel #4
0
def dataset(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    data, problem_labels = _dataset_selector(problem)

    return render_template('dataset.html',
                           data=data,
                           problem=problem,
                           problem_labels=problem_labels)
Beispiel #5
0
def dataset(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    probabilities = db.select(
        [
            db.func.json_object_agg(DatasetLabelProbability.label_id,
                                    DatasetLabelProbability.probability)
        ],
        from_obj=DatasetLabelProbability).where(
            DatasetLabelProbability.data_id == Dataset.id).correlate(
                Dataset.__table__).label('dataset_probabilities')

    label_created_at = (db.select(
        [
            db.func.to_char(LabelEvent.created_at,
                            db.text("'YYYY-MM-DD HH24:MI:SS'"))
        ],
        from_obj=LabelEvent).where(LabelEvent.data_id == Dataset.id).order_by(
            LabelEvent.created_at.desc()).limit(1).correlate(
                Dataset.__table__).label('label_created_at'))

    label_matches = db.select(
        [
            db.func.json_agg(
                db.func.json_build_array(LabelEvent.label_id,
                                         LabelEvent.label_matches))
        ],
        from_obj=LabelEvent).where(LabelEvent.data_id == Dataset.id).correlate(
            Dataset.__table__).label('label_matches')

    data = (db.session.query(
        Dataset.id, Dataset.free_text, Dataset.entity_id, Dataset.table_name,
        Dataset.meta, probabilities, Dataset.sort_value, label_matches,
        label_created_at).filter(Dataset.problem_id == problem.id).order_by(
            Dataset.id.asc()).all())
    problem_labels = db.session.query(
        ProblemLabel.id,
        ProblemLabel.label).filter(ProblemLabel.problem == problem).order_by(
            ProblemLabel.order_index).all()

    return render_template('dataset.html',
                           data=data,
                           problem=problem,
                           problem_labels=problem_labels)
Beispiel #6
0
def training_job(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    data = (db.session.query(
        TrainingJob.id, TrainingJob.accuracy, TrainingJob.created_at).filter(
            TrainingJob.problem_id == problem.id).order_by(
                TrainingJob.created_at.desc()).all())
    plot_data = (db.session.query(
        db.func.to_char(TrainingJob.created_at,
                        db.text("'YYYY-MM-DD HH24:MI:SS'")),
        TrainingJob.accuracy,
    ).filter(TrainingJob.problem_id == problem.id).order_by(
        TrainingJob.created_at.asc()).all())
    return render_template('training_job.html',
                           data=data,
                           plot_data=plot_data,
                           problem=problem)
Beispiel #7
0
def train(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    if not Dataset.query.filter(Dataset.problem_id == problem.id).count():
        return render_template('train_no_data.html', problem=problem)

    if request.method == 'POST':
        for key, value in request.form.items():
            if key.startswith('label_'):
                label_id = key.split('label_')[1]
                label = ProblemLabel.query.get(label_id)

                if LabelEvent.query.filter_by(
                        label=label,
                        data=Dataset.query.get(
                            request.form['data_id'])).count():
                    flash('This item has already been labeled...skipping?')
                else:
                    label_event = LabelEvent(label=label,
                                             label_matches={
                                                 'yes': True,
                                                 'no': False,
                                                 'skip': None
                                             }[value],
                                             data=Dataset.query.get(
                                                 request.form['data_id']),
                                             created_by=current_user.username)
                    db.session.add(label_event)
                    db.session.commit()

    sample = Dataset.query.filter(~Dataset.label_events.any(),
                                  Dataset.problem_id == problem.id).order_by(
                                      Dataset.sort_value,
                                      db.func.RANDOM()).first()

    return render_template(
        'train.html',
        sample=sample,
        problem=problem,
        problem_labels_arr=[
            dict(id=x.id, name=x.label, order_index=x.order_index)
            for x in sorted(problem.labels, key=lambda x: x.order_index)
        ])
Beispiel #8
0
def train_log(problem_id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    labeled_data_count = Dataset.query.filter(
        Dataset.label_events.any(),
        Dataset.problem_id == problem.id).group_by(Dataset.id).count()
    progress = (labeled_data_count /
                Dataset.query.filter(Dataset.problem_id == problem.id).count())

    return render_template(
        'train_log.html',
        event_log=get_event_log(problem),
        progress=progress,
        labeled_data_count=labeled_data_count,
        problem=problem,
        problem_labels_arr=[
            dict(id=x.id, name=x.label, order_index=x.order_index)
            for x in sorted(problem.labels, key=lambda x: x.order_index)
        ])
Beispiel #9
0
def delete_label_event(problem_id, id):
    problem = Problem.query.get(problem_id)
    assert_rights_to_problem(problem)

    label_event = LabelEvent.query.get(id)
    value = request.form.get('value')
    if label_event:
        if value:
            db.session.add(
                LabelEvent(data=label_event.data,
                           label=label_event.label,
                           label_matches={
                               'true': True,
                               'false': False,
                               'skip': None
                           }[value]))
        db.session.delete(label_event)
        db.session.commit()
        if value:
            flash('Label event replaced with %s' % value)
        else:
            flash('Label event removed.')
    return redirect(url_for('train', problem_id=problem_id))