Beispiel #1
0
    def test_get_hierarchy(self, mocker, container_export):
        hierarchy_mock = mocker.patch(
            "container_export.ContainerHierarchy.from_container")
        log_mock = mocker.patch("container_export.ExportLog")
        export, mocks = container_export("test", "test", flywheel.Session())

        hierarchy_mock.assert_called_once_with(mocks["context"].client,
                                               flywheel.Session())
Beispiel #2
0
class TestGetDestination:
    @pytest.mark.parametrize(
        "parent,raising",
        [
            (flywheel.Subject(label="test"), does_not_raise()),
            (flywheel.Session(label="test"), does_not_raise()),
            (flywheel.Group(label="test"), pytest.raises(ValueError)),
            (flywheel.Project(label="test"), pytest.raises(ValueError)),
            (flywheel.Acquisition(label="test"), pytest.raises(ValueError)),
        ],
    )
    def test_container(self, sdk_mock, parent, raising):
        container = flywheel.models.analysis_output.AnalysisOutput(
            parent=parent, id="test"
        )
        sdk_mock.get_analysis.return_value = container
        sdk_mock.get.return_value = parent

        with raising:
            dest = get_destination(sdk_mock, "test")

            sdk_mock.get_analysis.assert_called_once_with("test")
            # assert dest.__class__ == parent.__class__
            assert isinstance(dest, parent.__class__)

    def test_analysis_does_not_exist(self, sdk_mock):
        container = flywheel.models.analysis_output.AnalysisOutput(
            parent=flywheel.Project(), id="test"
        )
        sdk_mock.get.side_effect = flywheel.rest.ApiException(status=404)
        sdk_mock.get_analysis.return_value = container
        with pytest.raises(flywheel.rest.ApiException):
            dest = get_destination(sdk_mock, "test")
            assert isinstance(dest, flywheel.Project)
Beispiel #3
0
    def test_log(self, mocker, container_export):
        export, mocks = container_export("test",
                                         "test",
                                         flywheel.Session(),
                                         mock=True)
        log_mock = mocker.patch("container_export.logging.getLogger")

        log = export.log
        log_mock.assert_called_once_with("GRP-9 Session None Export")
def test_define_created():
    sess_id = bson.ObjectId()
    session = flywheel.Session(id=sess_id)

    created_container = define_created(session)

    assert created_container == {
        "container": "session",
        "id": sess_id,
        "new": True
    }
Beispiel #5
0
class TestNeedsExport:
    @pytest.mark.parametrize(
        "dest, tags, force, result",
        [
            (flywheel.Subject(id="test"), [], True, True),
            (flywheel.Subject(id="test"), [], False, True),
            (flywheel.Subject(id="test"), ["EXPORTED"], True, True),
            (flywheel.Subject(id="test"), ["EXPORTED"], False, False),
            (flywheel.Session(id="test"), [], True, True),
            (flywheel.Session(id="test"), [], False, True),
            (flywheel.Session(id="test"), ["EXPORTED"], True, True),
            (flywheel.Session(id="test"), ["EXPORTED"], False, False),
        ],
    )
    def test_container_exists(self, dest, tags, force, result):
        dest.tags = tags

        out = container_needs_export(dest, {"force_export": force})

        assert out == result
Beispiel #6
0
    def test_session_errors(self):
        fw = self.fw

        # Try to create session without project id
        try:
            session = flywheel.Session(label=self.rand_string())
            session_id = fw.add_session(session)
            self.fail('Expected ApiException creating invalid session!')
        except flywheel.ApiException as e:
            self.assertEqual(e.status, 400)

        # Try to get a session that doesn't exist
        try:
            fw.get_session('DOES_NOT_EXIST')
            self.fail('Expected ApiException retrieving invalid session!')
        except flywheel.ApiException as e:
            self.assertEqual(e.status, 404)
Beispiel #7
0
    def test_session_analysis(self):
        fw = self.fw

        session = flywheel.Session(project=self.project_id,
                                   label=self.rand_string())

        # Add
        session_id = fw.add_session(session)
        self.assertNotEmpty(session_id)

        poem = 'When a vast image out of Spiritus Mundi'
        fw.upload_file_to_session(session_id,
                                  flywheel.FileSpec('yeats.txt', poem))

        file_ref = flywheel.FileReference(id=session_id,
                                          type='session',
                                          name='yeats.txt')

        analysis = flywheel.AnalysisInput(label=self.rand_string(),
                                          description=self.rand_string(),
                                          inputs=[file_ref])

        # Add
        analysis_id = fw.add_session_analysis(session_id, analysis)
        self.assertNotEmpty(analysis_id)

        # Get the list of analyses in the session
        analyses = fw.get_session_analyses(session_id)
        self.assertEqual(len(analyses), 1)

        r_analysis = analyses[0]

        self.assertEqual(r_analysis.id, analysis_id)
        self.assertEmpty(r_analysis.job)

        self.assertTimestampBeforeNow(r_analysis.created)
        self.assertGreaterEqual(r_analysis.modified, r_analysis.created)

        self.assertEqual(len(r_analysis.inputs), 1)
        self.assertEqual(r_analysis.inputs[0].name, 'yeats.txt')
Beispiel #8
0
                call('tar zcvf %s.tar.gz %s' % (pjoin(
                    top_level, extra_file), pjoin(top_level, extra_file)),
                     shell=True)
                tarball = '%s.tar.gz' % pjoin(top_level, extra_file)
                fw.upload_file_to_acquisition(upload_id, tarball)
                call('/bin/rm %s' % tarball, shell=True)
            else:
                print('taring folder %s and uploading to %s' %
                      (pjoin(top_level, extra_file), folder))
                fw.upload_file_to_acquisition(upload_id,
                                              pjoin(top_level, extra_file))
    return True


#iterate through sublist scanids and identity and upload those that aren't present
for scanid in scans:
    print('uploading %s' % scanid)
    if scanid in fw_scanids or scanid.split('.')[0] in fw_scanids:
        print('already uploaded')
        continue
    print('adding session %s to project %s' % (scanid, study))
    session_id = fw.add_session(
        flywheel.Session(project=study_id, label=scanid))
    fw.modify_session(session_id, {'subject': {'code': scanid}})
    runs = upload_fmri(session_id, scanid)
    if runs != False:
        upload_behavioral(session_id, scanid, runs)
    success = upload_anatomical(session_id, scanid)
    success = upload_dwi(session_id, scanid)
    success = upload_extra(session_id, scanid)
Beispiel #9
0
    def test_sessions(self):
        fw = self.fw

        session_name = self.rand_string()
        session = flywheel.Session(label=session_name,
                                   project=self.project_id,
                                   info={'some-key': 37},
                                   subject=flywheel.Subject(
                                       code=self.rand_string_lower(),
                                       firstname=self.rand_string(),
                                       lastname=self.rand_string(),
                                       sex='other',
                                       age=util.years_to_seconds(56),
                                       info={'some-subject-key': 37}))

        # Add
        session_id = fw.add_session(session)
        self.assertNotEmpty(session_id)

        # Get
        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.id, session_id)
        self.assertEqual(r_session.label, session_name)
        self.assertIn('some-key', r_session.info)
        self.assertEqual(r_session.info['some-key'], 37)
        self.assertTimestampBeforeNow(r_session.created)
        self.assertGreaterEqual(r_session.modified, r_session.created)
        self.assertIsNotNone(r_session.subject)
        self.assertEqual(r_session.subject.firstname,
                         session.subject.firstname)
        self.assertEqual(r_session.age_years, 56)

        # Generic Get is equivalent
        self.assertEqual(fw.get(session_id).to_dict(), r_session.to_dict())

        # Get All
        sessions = fw.get_all_sessions()
        self.assertNotEmpty(sessions)

        self.sanitize_for_collection(r_session)
        self.assertIn(r_session, sessions)

        # Get from parent
        sessions = fw.get_project_sessions(self.project_id)
        self.assertIn(r_session, sessions)

        # Modify
        new_name = self.rand_string()
        session_mod = flywheel.Session(label=new_name)
        fw.modify_session(session_id, session_mod)

        changed_session = fw.get_session(session_id)
        self.assertEqual(changed_session.label, new_name)
        self.assertEqual(changed_session.created, r_session.created)
        self.assertGreater(changed_session.modified, r_session.modified)

        # Notes, Tags
        message = 'This is a note'
        fw.add_session_note(session_id, message)

        tag = 'example-tag'
        fw.add_session_tag(session_id, tag)

        # Replace Info
        fw.replace_session_info(session_id, {'foo': 3, 'bar': 'qaz'})

        # Set Info
        fw.set_session_info(session_id, {'foo': 42, 'hello': 'world'})

        # Check
        r_session = fw.get_session(session_id)

        self.assertEqual(len(r_session.notes), 1)
        self.assertEqual(r_session.notes[0].text, message)

        self.assertEqual(len(r_session.tags), 1)
        self.assertEqual(r_session.tags[0], tag)

        self.assertEqual(r_session.info['foo'], 42)
        self.assertEqual(r_session.info['bar'], 'qaz')
        self.assertEqual(r_session.info['hello'], 'world')

        # Delete info fields
        fw.delete_session_info_fields(session_id, ['foo', 'bar'])

        r_session = fw.get_session(session_id)
        self.assertNotIn('foo', r_session.info)
        self.assertNotIn('bar', r_session.info)
        self.assertEqual(r_session.info['hello'], 'world')

        # Delete
        fw.delete_session(session_id)

        sessions = fw.get_all_sessions()
        self.sanitize_for_collection(r_session)
        self.assertNotIn(r_session, sessions)
Beispiel #10
0
    def test_session_files(self):
        fw = self.fw

        session = flywheel.Session(label=self.rand_string(),
                                   project=self.project_id)
        session_id = fw.add_session(session)

        # Upload a file
        poem = 'The best lack all conviction, while the worst'
        fw.upload_file_to_session(session_id,
                                  flywheel.FileSpec('yeats.txt', poem))

        # Check that the file was added to the session
        r_session = fw.get_session(session_id)
        self.assertEqual(len(r_session.files), 1)
        self.assertEqual(r_session.files[0].name, 'yeats.txt')
        self.assertEqual(r_session.files[0].size, 45)
        self.assertEqual(r_session.files[0].mimetype, 'text/plain')

        # Download the file and check content
        self.assertDownloadFileTextEquals(
            fw.download_file_from_session_as_data, session_id, 'yeats.txt',
            poem)

        # Test unauthorized download with ticket for the file
        self.assertDownloadFileTextEqualsWithTicket(
            fw.get_session_download_url, session_id, 'yeats.txt', poem)

        # Test file attributes
        self.assertEqual(r_session.files[0].modality, None)
        self.assertEmpty(r_session.files[0].classification)
        self.assertEqual(r_session.files[0].type, 'text')

        resp = fw.modify_session_file(
            session_id, 'yeats.txt',
            flywheel.FileEntry(modality='modality', type='type'))

        # Check that no jobs were triggered, and attrs were modified
        self.assertEqual(resp.jobs_spawned, 0)

        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.files[0].modality, "modality")
        self.assertEmpty(r_session.files[0].classification)
        self.assertEqual(r_session.files[0].type, 'type')

        # Test classifications
        resp = fw.replace_session_file_classification(
            session_id, 'yeats.txt', {
                'Custom': ['measurement1', 'measurement2'],
            })
        self.assertEqual(resp.modified, 1)
        self.assertEqual(resp.jobs_spawned, 0)

        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.files[0].classification,
                         {'Custom': ['measurement1', 'measurement2']})

        resp = fw.modify_session_file_classification(
            session_id, 'yeats.txt', {
                'add': {
                    'Custom': ['HelloWorld'],
                },
                'delete': {
                    'Custom': ['measurement2']
                }
            })
        self.assertEqual(resp.modified, 1)
        self.assertEqual(resp.jobs_spawned, 0)

        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.files[0].classification, {
            'Custom': ['measurement1', 'HelloWorld'],
        })

        # Test file info
        self.assertEmpty(r_session.files[0].info)
        fw.replace_session_file_info(session_id, 'yeats.txt', {
            'a': 1,
            'b': 2,
            'c': 3,
            'd': 4
        })

        fw.set_session_file_info(session_id, 'yeats.txt', {'c': 5})

        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.files[0].info['a'], 1)
        self.assertEqual(r_session.files[0].info['b'], 2)
        self.assertEqual(r_session.files[0].info['c'], 5)
        self.assertEqual(r_session.files[0].info['d'], 4)

        fw.delete_session_file_info_fields(session_id, 'yeats.txt', ['c', 'd'])
        r_session = fw.get_session(session_id)
        self.assertEqual(r_session.files[0].info['a'], 1)
        self.assertEqual(r_session.files[0].info['b'], 2)
        self.assertNotIn('c', r_session.files[0].info)
        self.assertNotIn('d', r_session.files[0].info)

        fw.replace_session_file_info(session_id, 'yeats.txt', {})
        r_session = fw.get_session(session_id)
        self.assertEmpty(r_session.files[0].info)

        # Delete file
        fw.delete_session_file(session_id, 'yeats.txt')
        r_session = fw.get_session(session_id)
        self.assertEmpty(r_session.files)

        # Delete session
        fw.delete_session(session_id)
Beispiel #11
0
class TestValidateContext:
    @pytest.mark.parametrize(
        "config, call_num",
        [
            (
                {
                    "export_project": "test1",
                    "force_export": True,
                    "check_gear_rules": True,
                },
                1,
            ),
            (
                {
                    "export_project": "test1",
                    "force_export": False,
                    "check_gear_rules": True,
                },
                1,
            ),
            (
                {
                    "export_project": "test1",
                    "archive_project": "test2",
                    "force_export": True,
                },
                2,
            ),
            (
                {
                    "export_project": "test1",
                    "archive_project": "test2",
                    "force_export": False,
                },
                2,
            ),
        ],
    )
    def test_validate_calls(self, mocker, gear_context, config, call_num, caplog):
        caplog.set_level(logging.INFO)
        mock_proj = (
            flywheel.Project(
                label="test",
                parents=flywheel.models.container_parents.ContainerParents(
                    group="test"
                ),
            ),
        )
        gear_context.config = config
        get_proj_mock = mocker.patch("validate.get_project")
        get_proj_mock.return_value = mock_proj

        get_dest_mock = mocker.patch("validate.get_destination")
        get_dest_mock.return_value = flywheel.Subject(label="test")

        check_exported_mock = mocker.patch("validate.container_needs_export")
        check_exported_mock.return_value = True

        check_gear_rules_mock = mocker.patch("validate.validate_gear_rules")
        check_gear_rules_mock.return_value = True

        export, archive, dest = validate_context(gear_context)

        assert get_proj_mock.call_count == call_num
        get_dest_mock.assert_called_once_with(gear_context.client, "test")
        check_exported_mock.assert_called_once_with(
            flywheel.Subject(label="test"), config
        )
        msgs = [rec.message for rec in caplog.records]
        if "check_gear_rules" in config:
            check_gear_rules_mock.assert_called_once_with(
                gear_context.client, mock_proj
            )
            assert "No enabled rules were found. Moving on..." in msgs
        else:
            check_gear_rules_mock.assert_not_called()
            assert "No enabled rules were found. Moving on..." not in msgs

    @pytest.mark.parametrize(
        "proj",
        [
            {"export_project": flywheel.Project(label="export")},
            {"archive_project": flywheel.Project(label="archvie")},
        ],
    )
    def test_get_proj_errors(self, mocker, sdk_mock, gear_context, caplog, proj):
        gear_context.config = {"export_project": "test", "archive_project": "test"}
        gear_context.config.update(proj)

        def get_proj_side_effect(fw, project):
            if not hasattr(project, "label"):
                raise flywheel.rest.ApiException(status=600)

        get_proj_mock = mocker.patch("validate.get_project")
        get_proj_mock.side_effect = get_proj_side_effect

        with pytest.raises(SystemExit):
            export, archive, dest = validate_context(gear_context)
            assert all(
                [
                    rec.levelno == logging.ERROR
                    for rec in caplog.get_records(when="teardown")
                ]
            )

    @pytest.mark.parametrize(
        "to_mock,val,to_err,raises,log",
        [
            (
                ["get_project"],
                [None],
                dict(),
                True,
                "Export project needs to be specified",
            ),
            (
                ["get_project", "validate_gear_rules"],
                [flywheel.Project(label="test"), False],
                dict(),
                True,
                "Aborting Session Export: test has ENABLED GEAR RULES and 'check_gear_rules' == True. If you would like to force the export regardless of enabled gear rules re-run.py the gear with 'check_gear_rules' == False. Warning: Doing so may result in undesired behavior.",
            ),
            (
                ["get_project", "validate_gear_rules", "get_destination"],
                ["test", True, None],
                {"get_destination": ValueError("test")},
                True,
                "Could not find destination with id test",
            ),
            (
                ["get_project", "validate_gear_rules", "get_destination"],
                ["test", True, None],
                {"get_destination": ApiException(status=20)},
                True,
                "Could not find destination with id test",
            ),
            (
                [
                    "get_project",
                    "validate_gear_rules",
                    "get_destination",
                    "container_needs_export",
                ],
                ["test", True, flywheel.Session(label="test"), False],
                dict(),
                True,
                "session test has already been exported and <force_export> = False. Nothing to do!",
            ),
        ],
    )
    def test_errors(
        self, mocker, to_mock, val, to_err, raises, log, caplog, gear_context
    ):
        gear_context.config = {
            "destination": {"id": "test", "container_type": "session", "label": "test"},
            "export_project": "test",
            "archive_project": "test",
            "check_gear_rules": True,
        }
        mocks = {}
        for mock, val in zip(to_mock, val):
            mocks[mock] = mocker.patch(f"validate.{mock}")
            mocks[mock].return_value = val
        for mock, err in to_err.items():
            if mock in mocks:
                mocks[mock].side_effect = err
        if raises:
            my_raise = pytest.raises(SystemExit)
        else:
            my_raise = does_not_raise()
        with my_raise:
            validate_context(gear_context)
Beispiel #12
0
class TestContainerExporter:
    @pytest.mark.parametrize(
        "origin,raises",
        [
            (flywheel.Session(label="origin"), does_not_raise()),
            ("origin", pytest.raises(AttributeError)),
        ],
    )
    def test_init(self, mocker, origin, raises):
        gear_context_mock = MagicMock(
            spec=dir(flywheel_gear_toolkit.GearToolkitContext))
        hierarchy_patch = mocker.patch(
            "container_export.ContainerExporter.get_hierarchy")
        log_patch = mocker.patch("container_export.ExportLog")

        exporter = None
        with raises:
            exporter = ContainerExporter("export", "archive", origin,
                                         gear_context_mock)

        hierarchy_patch.assert_called_once_with(origin)

        # Validate attributes if exporter is called
        if hasattr(origin, "container_type"):
            log_patch.assert_called_once_with("export", "archive")

            for attr in [
                    "status",
                    "_log",
            ]:
                assert getattr(exporter, attr) is None

            assert exporter.gear_context == gear_context_mock
            assert exporter.origin_container == origin
            assert exporter.container_type == origin.container_type

    @pytest.mark.parametrize(
        "origin",
        [flywheel.Subject(code="origin"),
         flywheel.Session(label="origin")])
    def test_from_gear_context(self, mocker, origin):
        gear_context_mock = MagicMock(
            spec=dir(flywheel_gear_toolkit.GearToolkitContext))
        log_patch = mocker.patch("container_export.ExportLog")
        hierarchy_patch = mocker.patch(
            "container_export.ContainerExporter.get_hierarchy")

        validate_patch = mocker.patch("container_export.validate_context")
        export_proj = flywheel.Project(label="export")
        archive_proj = (flywheel.Project(label="archive"), )
        validate_patch.return_value = [
            export_proj,
            archive_proj,
            origin,
        ]

        exporter = ContainerExporter.from_gear_context(gear_context_mock)

        assert exporter.origin_container == origin
        log_patch.assert_called_once_with(export_proj, archive_proj)
        hierarchy_patch.assert_called_once_with(origin)

    def test_log(self, mocker, container_export):
        export, mocks = container_export("test",
                                         "test",
                                         flywheel.Session(),
                                         mock=True)
        log_mock = mocker.patch("container_export.logging.getLogger")

        log = export.log
        log_mock.assert_called_once_with("GRP-9 Session None Export")

    @pytest.mark.parametrize(
        "container,exp",
        [
            (flywheel.Session(), "test-None_export_log.csv"),
            (flywheel.Subject(), "test_export_log.csv"),
        ],
    )
    def test_csv_path(self, mocker, container_export, container, exp):
        export, mocks = container_export("test", "test", container, mock=True)
        mocks["hierarchy"].return_value.subject.label = "test"
        mocks["context"].output_dir = "/tmp/gear"

        path = export.csv_path

        assert path == f"/tmp/gear/{exp}"

    def test_get_hierarchy(self, mocker, container_export):
        hierarchy_mock = mocker.patch(
            "container_export.ContainerHierarchy.from_container")
        log_mock = mocker.patch("container_export.ExportLog")
        export, mocks = container_export("test", "test", flywheel.Session())

        hierarchy_mock.assert_called_once_with(mocks["context"].client,
                                               flywheel.Session())

    @pytest.mark.parametrize("info", [{"test": "test"}, {"test": None}, {}])
    @pytest.mark.parametrize("ctype,other", [("Session", {
        "age": "10"
    }), ("Subject", {
        "sex": "F"
    })])
    def test_get_create_container_kwargs(self, mocker, c, info, ctype, other):
        container = c(ctype, id="test", info=info, **other)
        out = ContainerExporter.get_create_container_kwargs(container)

        assert all([key in out for key in other.keys()])
        assert out.get("info").get("export").get("origin_id") == hash_value(
            "test")
        info.update({"export": {"origin_id": hash_value("test")}})
        assert info == out.get("info")

    @pytest.mark.parametrize(
        "container,label",
        [
            (
                flywheel.Session(id="test"),
                (f"info.export.origin_id={hash_value('test')}", ),
            ),
            (
                flywheel.Subject(label="test", code="test"),
                ("label=test", "code=test"),
            ),
            (
                flywheel.Subject(label="5", code="5"),
                ('label="5"', 'code="5"'),
            ),
        ],
    )
    def test_get_container_find_queries(self, container, label):

        queries = ContainerExporter.get_container_find_queries(container)

        assert queries == label

    @pytest.mark.parametrize("same", [True, False])
    @pytest.mark.parametrize(
        "origin,export,parent,par_type",
        [
            (
                flywheel.Session(id="test"),
                flywheel.Session(
                    id="test2",
                    info={"export": {
                        "origin_id": hash_value("test")
                    }}),
                MagicMock(spec=dir(flywheel.Subject).extend("sessions")),
                "subject",
            ),
            (
                flywheel.Subject(label="test"),
                flywheel.Subject(label="test"),
                MagicMock(spec=dir(flywheel.Project).extend("subjects")),
                "project",
            ),
        ],
    )
    def test_find_container_copy(self, origin, export, parent, mocker,
                                 par_type, same):
        parent.container_type = par_type
        parent.id = "test_parent"
        export.parents = flywheel.ContainerParents(**{par_type: parent.id})
        origin.parents = flywheel.ContainerParents(
            **{par_type: parent.id if same else "test_parent2"})

        finder_mock = getattr(parent, f"{export.container_type}s").find_first
        finder_mock.return_value = export

        out = ContainerExporter.find_container_copy(origin, parent)
        if not same:
            if par_type == "project":
                finder_mock.assert_called_once_with("label=test")
            else:
                finder_mock.assert_called_once_with(
                    f"info.export.origin_id={hash_value('test')}")
        else:
            assert out == origin

    @pytest.mark.parametrize(
        "origin,parent",
        [
            (flywheel.Session(id="test"),
             MagicMock(spec=dir(flywheel.Subject))),
            (
                flywheel.Session(id="test", tags=["test", "one"]),
                MagicMock(spec=dir(flywheel.Subject)),
            ),
        ],
    )
    def test_create_container_copy(self, origin, parent):
        add_mock = getattr(parent, f"add_{origin.container_type}")
        add_mock.return_value = origin
        out = ContainerExporter.create_container_copy(origin, parent)
        assert out.label == origin.label
        assert out.tags == origin.tags

    @pytest.mark.parametrize("found", [None, flywheel.Subject(label="test")])
    def test_find_or_create_container_copy(self, mocker, found):
        find_mock = mocker.patch(
            "container_export.ContainerExporter.find_container_copy")
        find_mock.return_value = found
        create_mock = mocker.patch(
            "container_export.ContainerExporter.create_container_copy")
        create_mock.return_value = flywheel.Subject(label="test2")
        container, created = ContainerExporter.find_or_create_container_copy(
            flywheel.Subject(label="test"), "test")

        assert created == (found is None)
        assert (container == flywheel.Subject(label="test")
                if not found is None else flywheel.Subject(label="test2"))

    @pytest.mark.parametrize("base", [flywheel.Subject, flywheel.Session])
    def test_export_container_files(self, sdk_mock, mocker, base):

        exporter_mock = mocker.patch(
            "container_export.FileExporter.from_client")
        origin = base(files=[])
        side_effect = []
        for i in range(10):
            origin.files.append(flywheel.FileEntry(name=str(i)))
            side_effect.append((str(i) if i % 2 == 0 else None,
                                True if i % 3 == 0 else False))
        exporter_mock.return_value.find_or_create_file_copy.side_effect = side_effect

        found, created, failed = ContainerExporter.export_container_files(
            sdk_mock, origin, "other", None)
        assert failed == ["1", "3", "5", "7", "9"]
        assert created == ["0", "6"]
        assert found == ["2", "4", "8"]

    @pytest.mark.parametrize(
        "container",
        [
            flywheel.Subject(label="test"),
            flywheel.Session(label="test",
                             subject=flywheel.Subject(label="test")),
        ],
    )
    def test_get_subject_export_params(self, mocker, container_export,
                                       container):
        if container.container_type == "subject":
            cont_mock = mocker.patch.object(container, "reload")
            cont_mock.return_value = "mocked"
        else:
            cont_mock = mocker.patch.object(container.subject, "reload")
            cont_mock.return_value = "mocked"

        export, mocks = container_export("test", "test", container, mock=True)

        orig, proj, att, hier = export.get_subject_export_params()

        assert orig == "mocked"
        assert proj == "test"
        if container.container_type == "subject":
            assert att == None
        else:
            assert att == False

        assert mocks["hierarchy"].call_count == 1

    @pytest.mark.parametrize(
        "origin,ctype",
        [
            (MagicMock(spec=dir(flywheel.Session)), "session"),
            (MagicMock(spec=(dir(flywheel.Subject) + ["sessions"])),
             "subject"),
        ],
    )
    def test_get_origin_sessions(self, container_export, origin, ctype):
        origin.container_type = ctype
        if ctype == "subject":
            origin.sessions.iter.return_value = ["1", "2", "3"]
        else:
            origin.reload.return_value = origin
        container_ex, mocks = container_export("test",
                                               "test",
                                               origin,
                                               mock=True)

        sess = container_ex.get_origin_sessions()

        if ctype == "subject":
            assert sess == ["1", "2", "3"]
        else:
            assert sess == [origin]
Beispiel #13
0
def test_container_hierarchy():
    hierarchy_dict = {
        "group":
        flywheel.Group(id="test_group", label="Test Group"),
        "project":
        flywheel.Project(label="test_project"),
        "subject":
        flywheel.Subject(label="test_subject", sex="other"),
        "session":
        flywheel.Session(
            age=31000000,
            label="test_session",
            weight=50,
        ),
    }
    # test from_dict
    test_hierarchy = ContainerHierarchy.from_dict(hierarchy_dict)
    # test deepcopy
    assert deepcopy(test_hierarchy) != test_hierarchy
    # test path
    assert test_hierarchy.path == "test_group/test_project/test_subject/test_session"
    # test parent
    assert test_hierarchy.parent.label == "test_subject"
    # test from_container
    mock_client = MagicMock(spec=dir(flywheel.Client))
    parent_dict = dict()
    for item in ("group", "project", "subject"):
        value = hierarchy_dict.copy().get(item)
        parent_dict[item] = item
        setattr(mock_client, f"get_{item}", lambda x: value)
    session = flywheel.Session(age=31000000, label="test_session", weight=50)
    session.parents = parent_dict
    assert (ContainerHierarchy.from_container(
        mock_client, session).container_type == "session")
    # test _get_container
    assert test_hierarchy._get_container(None, None, None) is None
    with pytest.raises(ValueError) as exc:
        test_hierarchy._get_container(None, "garbage", "garbage_id")
        assert str(exc) == "Cannot get a container of type garbage"
    mock_client = MagicMock(spec=dir(flywheel.Client))
    mock_client.get_session = lambda x: x
    assert (test_hierarchy._get_container(mock_client, "session",
                                          "session_id") == "session_id")
    # test container_type
    assert test_hierarchy.container_type == "session"
    # test dicom_map
    exp_map = {
        "PatientWeight": 50,
        "PatientAge": "011M",
        "ClinicalTrialTimePointDescription": "test_session",
        "PatientSex": "O",
        "PatientID": "test_subject",
    }
    assert exp_map == test_hierarchy.dicom_map
    # test get
    assert test_hierarchy.get("container_type") == "session"
    # test get_patient_sex_from_subject
    assert test_hierarchy.get_patientsex_from_subject(flywheel.Subject()) == ""
    # test get_patientage_from_session
    assert test_hierarchy.get_patientage_from_session(
        flywheel.Session()) is None

    # test get_child_hierarchy
    test_acquisition = flywheel.Acquisition(label="test_acquisition")
    acq_hierarchy = test_hierarchy.get_child_hierarchy(test_acquisition)
    assert acq_hierarchy.dicom_map[
        "SeriesDescription"] == test_acquisition.label
    # test get_parent_hierarchy
    parent_hierarchy = test_hierarchy.get_parent_hierarchy()
    assert parent_hierarchy.container_type == "subject"