예제 #1
0
def session_overview(session_id: int):
    label_session = sessions.get_session_by_id(db.session, session_id)
    dataset = backend.get_dataset(label_session.dataset)

    has_thumbs = thumbnails.has_thumbnails(label_session)

    resume_point = None
    for session_element in label_session.elements:
        if len(session_element.labels) == 0:
            resume_point = session_element.element_index
            break

    if label_session.session_type == LabelSessionType.CATEGORICAL_IMAGE.name:
        images = backend.get_images(dataset)
        return render_template('session_overview_categorical.html',
                               label_session=label_session,
                               dataset=dataset,
                               images=images,
                               resume_point=resume_point,
                               has_thumbs=has_thumbs)

    elif label_session.session_type == LabelSessionType.CATEGORICAL_SLICE.name:
        return render_template('session_overview_categorical_slice.html',
                               label_session=label_session,
                               dataset=dataset,
                               resume_point=resume_point,
                               has_thumbs=has_thumbs)
    else:  # COMPARISON_SLICE
        return render_template('session_overview_comparison.html',
                               label_session=label_session,
                               dataset=dataset,
                               resume_point=resume_point,
                               has_thumbs=has_thumbs)
예제 #2
0
    def test_get_images_names(self):
        dataset = backend.get_dataset('dataset1')
        images = backend.get_images(dataset)

        self.assertEqual(images[0].name, 'img1.nii.gz')
        self.assertEqual(images[1].name, 'img2.nii')
        self.assertEqual(images[2].name, 'img3')
예제 #3
0
    def test_get_images_datasets(self):
        dataset = backend.get_dataset('dataset1')
        images = backend.get_images(dataset)

        self.assertEqual(images[0].dataset, dataset)
        self.assertEqual(images[1].dataset, dataset)
        self.assertEqual(images[2].dataset, dataset)
예제 #4
0
    def test_get_images_paths(self):
        dataset = backend.get_dataset('dataset1')
        images = backend.get_images(dataset)

        self.assertEqual(
            images[0].path,
            os.path.join(backend.DATASETS_PATH, 'dataset1', 'img1.nii.gz'))
        self.assertEqual(
            images[1].path,
            os.path.join(backend.DATASETS_PATH, 'dataset1', 'img2.nii'))
        self.assertEqual(
            images[2].path,
            os.path.join(backend.DATASETS_PATH, 'dataset1', 'img3'))
예제 #5
0
def create_comparison_session(dataset_name: str):
    dataset = backend.get_dataset(dataset_name)
    if dataset is None:
        abort(400)

    current_sessions = sessions.get_sessions(db.session, dataset)
    label_session_count = len(current_sessions)

    images = backend.get_images(dataset)
    total_image_count = len(images)

    form = CreateComparisonSessionForm(meta={'csrf': False})

    comparison_sessions = sessions.get_sessions(db.session, dataset, LabelSessionType.COMPARISON_SLICE)
    for sess in comparison_sessions:
        form.comparisons.choices.append((str(sess.id), sess.session_name))

    form.image_count.validators = [
        ComparisonNumberRange(min=1, max=total_image_count,
                              message='Must be between %(min)s and %(max)s (the dataset size).')
    ]

    if form.validate_on_submit():
        if form.session_name.data in [se.session_name for se in current_sessions]:
            form.session_name.errors.append('Session name already in use.')
        elif form.comparisons.data == 'create' and form.min_slice_percent.data >= form.max_slice_percent.data:
            form.max_slice_percent.errors.append('Max must be greater than min.')
        else:
            slice_type = backend.SliceType[form.slice_type.data]
            if form.comparisons.data == 'create':
                slices = sampling.sample_slices(dataset, slice_type, form.image_count.data, form.slice_count.data,
                                                form.min_slice_percent.data, form.max_slice_percent.data)
                if form.comparison_count.data is None:
                    comparisons = sampling.all_comparisons(slices)
                else:
                    comparisons = sampling.sample_comparisons(slices, form.comparison_count.data,
                                                              form.max_comparisons_per_slice.data)
            else:
                from_session = sessions.get_session_by_id(db.session, int(form.comparisons.data))
                comparisons = sampling.get_comparisons_from_session(from_session)
            label_values = [v.strip() for v in form.label_values.data.split(',')]
            sessions.create_comparison_slice_session(db.session, form.session_name.data, form.prompt.data,
                                                     dataset, label_values, comparisons)
            return redirect(url_for('dataset_overview', dataset_name=dataset.name))

    return render_template('create_comparison_session.html',
                           dataset=dataset,
                           label_session_count=label_session_count,
                           total_image_count=total_image_count,
                           form=form)
예제 #6
0
def dataset_overview(dataset_name: str):
    dataset = backend.get_dataset(dataset_name)
    if dataset is None:
        abort(404)

    images = backend.get_images(dataset)
    label_sessions = sessions.get_sessions(db.session, dataset)

    sessions_by_type: Dict[LabelSessionType, List[LabelSession]] = {st: [] for st in LabelSessionType}
    for sess in label_sessions:
        sessions_by_type[LabelSessionType[sess.session_type]].append(sess)

    return render_template('dataset_overview.html',
                           dataset=dataset,
                           images=images,
                           label_sessions=sessions_by_type)
예제 #7
0
def create_sort_session(dataset_name: str):
    dataset = backend.get_dataset(dataset_name)
    if dataset is None:
        abort(400)

    current_sessions = sessions.get_sessions(db.session, dataset)
    label_session_count = len(current_sessions)

    images = backend.get_images(dataset)
    total_image_count = len(images)

    form = CreateSortSessionForm(meta={'csrf': False})

    form.image_count.validators = [
        NumberRange(min=1, max=total_image_count, message='Must be between %(min)s and %(max)s (the dataset size).')
    ]
    for sess in sessions.get_sessions(db.session, dataset):
        t = sess.session_type
        if t in SLICE_SESSION_NAMES:
            form.slices_from.choices.append((str(sess.id), sess.session_name))

    if form.validate_on_submit():
        if form.session_name.data in [se.session_name for se in current_sessions]:
            form.session_name.errors.append('Session name already in use.')
        elif form.min_slice_percent.data >= form.max_slice_percent.data:
            form.max_slice_percent.errors.append('Max must be greater than min.')
        else:
            if form.slices_from.data == 'create':
                slice_type = backend.SliceType[form.slice_type.data]
                slices = sampling.sample_slices(dataset, slice_type, form.image_count.data, form.slice_count.data,
                                                form.min_slice_percent.data, form.max_slice_percent.data)
            else:
                from_session = sessions.get_session_by_id(db.session, int(form.slices_from.data))
                slices = sampling.get_slices_from_session(from_session)

            sessions.create_sort_slice_session(db.session, form.session_name.data, form.prompt.data, dataset, slices)
            return redirect(url_for('dataset_overview', dataset_name=dataset.name))
    return render_template('create_sort_session.html',
                           dataset=dataset,
                           label_session_count=label_session_count,
                           total_image_count=total_image_count,
                           form=form)
예제 #8
0
def sample_slices(dataset: Dataset, slice_type: SliceType, image_count: int, slice_count: int,
                  min_slice_percent: int, max_slice_percent: int) -> List[ImageSlice]:
    images = random.sample(backend.get_images(dataset), image_count)
    slices: List[ImageSlice] = []

    for i in range(slice_count):
        im: DataImage = random.choice(images)
        im_slice_max = get_volume_width(im.path, slice_type)

        slice_min = int(im_slice_max * (min_slice_percent / 100))
        slice_max = int(im_slice_max * (max_slice_percent / 100))

        if slice_min == slice_max:
            sl = slice_max
        else:
            sl = random.randrange(slice_min, slice_max)
        slices.append(ImageSlice(im.name, sl, slice_type))

    slices = list(set(slices))  # Remove duplicates
    return slices
예제 #9
0
def session_overview(session_id: int):
    label_session = sessions.get_session_by_id(db.session, session_id)
    dataset = backend.get_dataset(label_session.dataset)

    resume_point = None
    for session_element in label_session.elements:
        if len(session_element.labels) == 0:
            resume_point = session_element.element_index
            break

    if label_session.session_type == LabelSessionType.CATEGORICAL_IMAGE.name:
        images = backend.get_images(dataset)
        return render_template('session_overview_categorical.html',
                               label_session=label_session,
                               dataset=dataset,
                               images=images,
                               resume_point=resume_point)

    elif label_session.session_type == LabelSessionType.CATEGORICAL_SLICE.name:
        return render_template('session_overview_categorical_slice.html',
                               label_session=label_session,
                               dataset=dataset,
                               resume_point=resume_point)
    elif label_session.session_type == LabelSessionType.COMPARISON_SLICE.name:
        return render_template('session_overview_comparison.html',
                               label_session=label_session,
                               dataset=dataset,
                               resume_point=resume_point)
    elif label_session.session_type == LabelSessionType.SORT_SLICE.name:
        labels_complete = comparesort.add_next_comparison(db.session, label_session)[0]
        return render_template('session_overview_sort.html',
                               label_session=label_session,
                               dataset=dataset,
                               resume_point=resume_point,
                               labels_complete=labels_complete,
                               slice_elements=[el for el in label_session.elements if not el.is_comparison()],
                               comparison_elements=[el for el in label_session.elements if el.is_comparison()])
    else:
        abort(500)
예제 #10
0
def create_categorical_image_session(session: Session, name: str, prompt: str,
                                     dataset: Dataset,
                                     label_values: List[str]):
    images = backend.get_images(dataset)

    label_session = LabelSession(
        dataset=dataset.name,
        session_name=name,
        session_type=LabelSessionType.CATEGORICAL_IMAGE.name,
        prompt=prompt,
        date_created=datetime.now(),
        label_values_str=','.join(label_values),
        element_count=len(images))

    session.add(label_session)

    for i, im in enumerate(images):
        el = SessionElement(element_index=i,
                            image_1_name=im.name,
                            session=label_session)
        session.add(el)

    session.commit()
예제 #11
0
def dataset_list():
    datasets = [(d, backend.get_images(d), sessions.get_sessions(db.session, d))
                for d in backend.get_datasets()]
    return render_template('dataset_list.html',
                           datasets=datasets)
예제 #12
0
    def test_get_images_length(self):
        dataset = backend.get_dataset('dataset1')
        images = backend.get_images(dataset)
        num_images = len(images)

        self.assertEqual(num_images, 3)