Пример #1
0
    def get_html_report(self, lod: LOD) -> str:
        """
        Generates report data and return back
        :param lod: Level of details. See class LOD
        :return: str
        """
        report = self.report
        keys = BoxList()
        statuses = BoxList()
        xl_records = BoxList()
        db_records = BoxList()
        mismatches = BoxList()

        for cnt, k in enumerate(self.report.keys):
            if lod in (LOD.ALL_FULL, LOD.ALL_MID) or (lod == LOD.MISMATCH and report.statuses[cnt] != Status.NO_CHANGE):
                keys.append(k)
                statuses.append(report.statuses[cnt])
                mismatches.append(report.mismatches[cnt])
                xl_records.append(report.xl_records[cnt] if lod == LOD.ALL_FULL or (
                        lod == LOD.MISMATCH and report.statuses[cnt] != Status.NO_CHANGE) else '-')
                db_records.append(report.db_records[cnt] if lod == LOD.ALL_FULL or (
                        lod == LOD.MISMATCH and report.statuses[cnt] != Status.NO_CHANGE) else '-')

        html_text = f'<p>Total xl_records: {self.total_xl_records}</p><p>Total DB records: {self.total_db_records}</p>'\
                    f'<p>Number of Issues: {report.issue_cnt}</p>'
        if lod != LOD.SUMMARY:
            data = dict({f'key{self.index_keys}': keys, 'status': statuses, 'db_update': None,
                         'mismatch': mismatches})
            if report.xl_records:
                data['xl_record'] = xl_records
            if report.db_records:
                data['db_record'] = db_records
            html_text += pd.DataFrame(data).to_html()
        html_text += '</p>'
        return html_text
Пример #2
0
    def assign_jobs(self) -> Tuple[BoxList, List]:
        new_jobs = BoxList()
        exceptions = []
        for job in self.jobs_db.where('status', '==', JOB_STATUS_CREATED):
            try:
                if self.should_start_job(job):
                    log.info(f'Assigning job ' f'{box2json(job)}...')
                    self.assign_job(job)
                    new_jobs.append(job)
            except Exception as e:
                # Could have been a network failure, so just try again.
                # More granular exceptions should be handled before this
                # which can set the job to not run again
                # if that's what's called for.

                log.exception(f'Exception triggering eval for job {job}, '
                              f'will try again shortly.')
                exceptions.append(e)

        # TODO: Check for failed / crashed instance once per minute
        # TODO: Stop instances if they have been idle for longer than timeout
        # TODO: Cap total instances
        # TODO: Cap instances per bot owner, using first part of docker tag
        # TODO: Delete instances over threshold of stopped+started

        return new_jobs, exceptions
Пример #3
0
    def check_gce_ops_in_progress(self):
        ops_still_in_progress = BoxList()
        for op in self.gce_ops_in_progress:

            try:
                op_result = Box(
                    self.gce.zoneOperations().get(project=self.project,
                                                  zone=self.zone,
                                                  operation=op.name).execute())
            except:
                log.exception('Could not get op_result')
                break
            if op_result.status == 'DONE':
                if 'error' in op_result:
                    log.error(f'GCE operation resulted in an error: '
                              f'{op_result.error}\nOperation was:'
                              f'\n{box2json(op)}')
                    if op.operationType == 'insert':
                        # Retry the creation?
                        pass
                    # elif op.operationType == 'start':
                    #
            else:
                ops_still_in_progress.append(op)
        self.gce_ops_in_progress = ops_still_in_progress
Пример #4
0
def save_problem_ci_results(ci_error, db, error, eval_data, gist, problem_ci,
                            results, should_merge):
    if not should_merge:
        # If problem_ci fails, don't save to aggregate bot scores collection
        if ci_error:
            log.error('Problem CI failed, not saving to bots '
                      'official scores as this is likely an issue '
                      'with the new version of the problem.')
            problem_ci.status = PROBLEM_CI_STATUS_FAILED
            problem_ci.error = ci_error
            update_pr_status_problem_ci(ci_error, problem_ci, eval_data)
        else:
            log.info('Problem CI not yet finished')

    else:
        # Aggregate data from bot evals now that they're done
        gists = BoxList()
        for bot_eval_key in problem_ci.bot_eval_keys:
            bot_eval = db.get(get_eval_db_key(bot_eval_key))
            save_to_bot_scores(
                bot_eval, bot_eval.eval_key,
                Box(score=bot_eval.results.score, eval_key=bot_eval.eval_key))
            gists.append(bot_eval.gist)
        problem_ci.gists = gists
        update_pr_status_problem_ci(error, problem_ci, eval_data)
        problem_ci.status = PROBLEM_CI_STATUS_PASSED
    db.set(problem_ci.id, problem_ci)
Пример #5
0
 def test_box_list(self):
     new_list = BoxList({"item": x} for x in range(0, 10))
     new_list.extend([{"item": 22}])
     assert new_list[-1].item == 22
     new_list.append([{"bad_item": 33}])
     assert new_list[-1][0].bad_item == 33
     assert repr(new_list).startswith("<BoxList:")
     for x in new_list.to_list():
         assert not isinstance(x, (BoxList, Box))
     new_list.insert(0, {"test": 5})
     new_list.insert(1, ["a", "b"])
     new_list.append("x")
     assert new_list[0].test == 5
     assert isinstance(str(new_list), str)
     assert isinstance(new_list[1], BoxList)
     assert not isinstance(new_list.to_list(), BoxList)
Пример #6
0
 def test_boxlist(self):
     new_list = BoxList({'item': x} for x in range(0, 10))
     new_list.extend([{'item': 22}])
     assert new_list[-1].item == 22
     new_list.append([{'bad_item': 33}])
     assert new_list[-1][0].bad_item == 33
     assert repr(new_list).startswith("<BoxList:")
     for x in new_list.to_list():
         assert not isinstance(x, (BoxList, Box, LightBox))
     new_list.insert(0, {'test': 5})
     new_list.insert(1, ['a', 'b'])
     new_list.append('x')
     assert new_list[0].test == 5
     assert isinstance(str(new_list), str)
     assert isinstance(new_list[1], BoxList)
     assert not isinstance(new_list.to_list(), BoxList)
    def getPatientDict(self):
        parseDir = self.sourceDir

        if os.path.isfile(os.path.join(parseDir, 'Patient')):
            self.patientBaseDict = pinn2Json().read(os.path.join(parseDir, 'Patient'))
        else:
            logging.error('may be a empty plan! skip')
            return None

        patientImageSetList = BoxList()
        if 'ImageSetList' in self.patientBaseDict:
            # image_set_list = baseDict.get('ImageSetList')
            for imageSet in self.patientBaseDict.ImageSetList:
                logging.info('ImageSet_%s', imageSet.ImageSetID)
                if 'phantom' in imageSet.SeriesDescription:
                    logging.warning('this is Phantom for QA, skip!')
                    continue
                # read CT image set of this plan
                (imageSet['CTHeader'], imageSet['CTInfo'], imageSet['CTData']
                 ) = self.readCT(imageSet.ImageName)
                patientImageSetList.append(imageSet)
        self.patientBaseDict['imageSetListRawData'] = patientImageSetList

        patientPlanList = BoxList()
        if 'PlanList' in self.patientBaseDict:
            # plan_list = baseDict.get('PlanList')
            for plan in self.patientBaseDict.PlanList:
                logging.info('plan_%s,base on ImageSet_%s',
                             plan.PlanID, plan.PrimaryCTImageSetID)
                if 'QA' in plan.PlanName or 'copy' in plan.PlanName:
                    logging.warning('this is Copy or QA plan, skip!')
                else:
                    planDirName = 'Plan_' + (str)(plan.PlanID)
                    logging.info('Reading plan:%s ......', planDirName)
                    plan['planData'] = self.readPlan(
                        planDirName, plan.PrimaryCTImageSetID)
                patientPlanList.append(plan)
        self.patientBaseDict['planListRawData'] = patientPlanList
Пример #8
0
 def test_frozen_list(self):
     bl = BoxList([5, 4, 3], frozen_box=True)
     with pytest.raises(BoxError):
         bl.pop(1)
     with pytest.raises(BoxError):
         bl.remove(4)
     with pytest.raises(BoxError):
         bl.sort()
     with pytest.raises(BoxError):
         bl.reverse()
     with pytest.raises(BoxError):
         bl.append("test")
     with pytest.raises(BoxError):
         bl.extend([4])
     with pytest.raises(BoxError):
         del bl[0]
     with pytest.raises(BoxError):
         bl[0] = 5
     bl2 = BoxList([5, 4, 3])
     del bl2[0]
     assert bl2[0] == 4
     bl2[1] = 4
     assert bl2[1] == 4
Пример #9
0
def get_defaults(defaults):
    f = Box(default_box=True, read_only=False)
    config_f = None
    if 'formatting' in defaults:
        config_f = defaults.formatting

    # Sheet level
    f.read_only = getdictvalue(config_f, 'read_only', False)
    f.tab_color = "D9D9D9"
    f.position = config_f.get('position', -1)

    # Table level
    f.table_style.name = getdictvalue(config_f.table_style, 'name',
                                      'TableStyleMedium2')
    # f.table_style.show_first_column = defval_dict(config_f.table_style, 'show_first_column', False)
    f.table_style.show_last_column = getdictvalue(config_f.table_style,
                                                  'show_last_column', False)
    f.table_style.show_row_stripes = getdictvalue(config_f.table_style,
                                                  'show_row_stripes', True)
    # f.table_style.show_column_stripes = defval_dict(config_f.table_style, 'show_column_stripes', True)

    # Column level
    # col_formatting = Box(default_box=True)
    data = BoxList()
    if not getdictvalue(config_f, 'data', None):
        # for column_config in defval_dict(config_f, 'data', []):
        #     for c in column_config.attributes:
        #         col_formatting[c] = _get_setting(column_config)
        col_formatting = _get_col_setting(None)
        col_formatting['attributes'] = "['*']"
        data.append(col_formatting)
    else:
        data = config_f.data
    f.data = data
    defaults._box_config['default_box'] = True
    defaults.formatting = f
    return defaults
Пример #10
0
def run_botleague_ci(branch,
                     version,
                     set_version_fn,
                     pr_message,
                     supported_problems,
                     sim_url=None,
                     container_postfix=None) -> bool:
    # Send pull request to Botleague
    log.info('Sending pull requests to botleague for supported problems')
    github_token = os.environ['BOTLEAGUE_GITHUB_TOKEN']
    github_client = Github(github_token)
    # Get our fork owner
    botleague_fork_owner = 'deepdrive'

    # NOTE: Fork on github was manually created
    botleague_fork = github_client.get_repo(
        f'{botleague_fork_owner}/botleague')
    problem_cis = BoxList()
    for problem in supported_problems:
        hash_to_branch_from = get_head_commit('botleague/botleague',
                                              github_token)
        botleague_branch_name = f'deepdrive_{version}_' \
            f'id-{generate_rand_alphanumeric(3)}'
        fork_ref = botleague_fork.create_git_ref(
            ref=f'refs/heads/{botleague_branch_name}', sha=hash_to_branch_from)
        problem_json_path = f'problems/deepdrive/{problem}/problem.json'
        problem_def, problem_sha = get_file_from_github(
            botleague_fork, problem_json_path, ref=botleague_branch_name)

        set_version_fn(problem_def, version)

        # Add a newline before comments
        content = problem_def.to_json(indent=2) \
            .replace('\n  "$comment-', '\n\n  "$comment-')

        update_resp = botleague_fork.update_file(problem_json_path,
                                                 pr_message,
                                                 content=content,
                                                 sha=problem_sha,
                                                 branch=botleague_branch_name)
        pull = Box(
            title=f'CI trigger for {botleague_branch_name}',
            body=f'',
            head=f'{botleague_fork_owner}:{botleague_branch_name}',
            base='master',
        )

        set_pull_body(pull, sim_url, container_postfix)

        # if branch not in ['master']:  # Change to this after v3 is merged.
        if branch not in blconfig.release_branches:
            pull.draft = True

        pull_resp = create_pull_request(pull,
                                        repo_full_name='botleague/botleague',
                                        token=github_token)

        head_sha = Box(update_resp).commit.sha
        problem_cis.append(
            Box(pr_number=pull_resp.json()['number'], commit=head_sha))
    problem_cis = wait_for_problem_cis(problem_cis)
    if all(p.status == 'passed' for p in problem_cis):
        log.success(f'Problem ci\'s passed! Problem cis were: '
                    f'{box2json(problem_cis)}')
        return True
    else:
        url_prefix = 'https://github.com/botleague/botleague/pull/'
        ci_urls = '\n'.join(
            [f'{url_prefix}{p.pr_number}' for p in problem_cis])
        raise RuntimeError(
            f'Problem ci\'s failed! Problem cis were: '
            f'{box2json(problem_cis)}. Check PRs for errors {ci_urls}')
Пример #11
0
    def update_db(self, force_update=False):
        # Lets filter out all FKEYs and M2Ms
        m2m_fields = [f.name for f in self.model._meta.many_to_many]
        fkey_fields = [f.name for f in self.model._meta.fields if f.many_to_one]
        concrete_fields = [f.name for f in self.model._meta.concrete_fields if not f.many_to_one]

        err_records = BoxList()

        for i, r in self.records.items():  # TODO: HG: Db update record counter should be returned and updated in the logs
            if r.status == Status.NO_CHANGE:
                continue
            if r.status in (Status.XL, Status.MISMATCH):
                # We cannot create record if FKEY doesn't exist in referenced DB table or record has MISMATCH
                datadict = Box(default_box=False)
                filter = Box()
                invalid_ref = False
                for f, refobj in r.refobjs.items():
                    if not refobj or None in refobj:
                        v = ','.join([re.sub('^\* ', '', i) for i in r.xl_record[f].rsplit('\n')])
                        logging.error(
                            f" {nm(self.model)} - Wont update record [{i}] since "
                            f"[{f}={v}] has missing "
                            f"reference object. Ensure reference record exists either in DB or XLS")
                        invalid_ref = True
                        break

                if invalid_ref:
                    continue

                for f in self.index_keys:
                    if not r.refobjs or f not in r.refobjs:
                        filter[f] = getattr(r.xl_record, f)
                    else:
                        filter[f + '_id'] = r.refobjs[f][0].db_record.pk

                for f, v in r.xl_record.items():
                    if f in concrete_fields:
                        datadict[f] = v
                    elif r.refobjs:
                        if not force_update and f in r.refobjs and True in [ref.status != Status.NO_CHANGE for ref in r.refobjs[f]]:
                            v = ','.join([re.sub('^\* ', '', i) for i in r.xl_record[f].rsplit('\n')])
                            logging.error(
                                f" {nm(self.model)} - Wont update record [{i}] since "
                                f"[{f}={v}] has a change. Use --force_update to update reference and this record.")
                            logging.info(f'Mismatch Reference Object - {r.refobjs[f]}')
                            invalid_ref = True
                            break
                        elif f in [*fkey_fields, *r.refobjs] and (force_update or r.status == Status.XL):
                            datadict[f + '_id'] = r.refobjs[f][0].db_record.pk
                    else:
                        v = ','.join([re.sub('^\* ', '', i) for i in r.xl_record[f].rsplit('\n')])
                        logging.error(f" {nm(self.model)} - Wont update record [{i}] since [{f}={v}] "
                                      f"is invalid (its neither concrete neither has reference)")
                        invalid_ref = True
                        break

                if invalid_ref:
                    continue

                if force_update or r.status == Status.XL:  # don't update DB if record MISMATCH & force_update is False
                                                            # filter should always return 1 object if exists else 0
                    # TODO: HG: Below code is not compatible with django reference columns defined using `db_column`
                    # see https://docs.djangoproject.com/en/3.1/ref/models/fields/#database-representation

                    (dbobj, created) = self.model.objects.update_or_create(**filter, defaults=datadict)
                    r.db_record = dbobj
                    if r.refobjs:
                        for f in r.refobjs.keys() & m2m_fields:
                            getattr(dbobj, f).set([ro.db_record for ro in r.refobjs[f]]) # set will add all the records

                    dbobj.save()
                    r.status = Status.NO_CHANGE
                    logging.debug(f' {nm(self.model)}: {"created" if created else "updated"} object: {dbobj}')
                else:
                    err_records.append(r)

            elif r.status in (Status.MISMATCH, Status.XL) and force_update:
                # We need to update DB with XL values.
                #  Insert missing FKEYs in referenced table.
                pass
Пример #12
0
    def compare(self, xl_record, db_record):
        """ Compares xl record vs db record.
         Logic:
         1. know xl record field, using parser resolve FKEY inplace of field but also stores the xl special field _originals
         2. compare FKEY values (read from _originals dictionary) with respective model values (Possible FKEY records are DIFFERENT records in xl and DB, then compare against both)
         3. compare all concrete fields AS IS

         Returns tuple (mismatches, references)
         """

        mismatches = BoxList()
        refobjs = Box()
        if not xl_record:
            mismatches.append(Mismatch(field="", type=nm(None), status=Status.DB, message="", extra_info=None))
        else:  # compare each field value
            m2m_fields = [f.name for f in self.model._meta.many_to_many]
            for f, value in xl_record.items():
                references = self.config_data[f].references  # TODO: HG: This can throw key error
                if references:
                    ref_model = ""
                    if f in m2m_fields:
                        values = [re.sub('^\* ', '', i) for i in value.rsplit('\n')]
                    else:
                        values = [value]

                    for v in values:
                        refs = Box()
                        if len(references) != len(v.split(' - ')):
                            # we have invalid configuration for this reference and exported data doesn't honor this config.
                            # We cannot proceed with further checking of this record.
                            mismatches.append(
                                Mismatch(field=f, type=nm(type(value)), status=Status.CANNOT_COMPARE,
                                         message=f'Configuration mismatch for field [{f}] and xls data. Cannot get importer.',
                                         extra_info=None))
                            continue
                        for idx, ref in enumerate(references):
                            # We can have multiple references
                            # e.g. compdependency.component_version => [(componentversion,component.name),(componentversion,version)]
                            ref_model = ref[
                                0]  # We are good to use the ref[0] and keep overwriting it becoz now we only support $model starting refs. TODO: below logic needs to be re-written to support model names in reference in config.yaml
                            ref_field = ref[1]
                            refs[ref_field] = v.split(' - ')[
                                idx]  # ref values can be combination of multiple fields separated by ' - '

                        # IMP: Ideally we should pass values and search for record but as of now
                        # we have reference fields are same as reference model's index
                        # e.g. ref field ['name - version'] will be defined as ref model's index as well
                        #   hence it is easy to find the ref record from ref model's importable_sheet's
                        try:
                            # TODO: HG: Refactor need - pull out ref object checker in separate function
                            ref_importer = Registry.importer.get_sheet(ref_model)
                            if not ref_importer:
                                if self.config_data[f].formatting.read_only:
                                    logging.info(f'Skipping comparison for field [{f}] with value [{value}], '
                                                 f'as its marked as read_only in config.'
                                                 f' Cannot get importer for its model [{ref_model}]')
                                else:
                                    mismatches.append(
                                        Mismatch(field=f, type=nm(type(value)), status=Status.CANNOT_COMPARE,
                                                 message=f'Cannot get importer for model [{ref_model}]',
                                                 extra_info=None))
                                    # To solve this issue - add ref_model in config.yml and mark it read_only
                            else:
                                ref_obj = ref_importer.get_record_from_dict(refs)
                                refobjs.setdefault(f, []).append(ref_obj)
                                if not ref_obj or ref_obj.status == Status.MISMATCH:
                                    mismatches.append(Mismatch(field=f, type=nm(type(value)), status=Status.MISMATCH,
                                                               message=f'no referenced record found; record wont be '
                                                                       f'create/updated in DB' if not ref_obj else
                                                               'referenced record has MISMATCH',
                                                               extra_info=Box(reference_record=ref_obj)))
                        except KeyError as e:
                            msg = f"{nm(self.model)}.{f}'s value {value} - record not available in reference model" \
                                  f" {ref_model}. Exception: {e}"
                            logging.error(msg)
                            mismatches.append(Mismatch(field=f, type=nm(type(value)), status=Status.MISMATCH,
                                                       message=msg, extra_info=None))
                elif db_record and not hasattr(db_record, f):
                    mismatches.append(Mismatch(field=f, type=nm(type(value)), status=Status.XL,
                                               message=f'field: "{f}" doesnt exist in DB',
                                               extra_info=None))
                elif db_record and getattr(db_record, f) != type(getattr(db_record, f))(value):
                    mismatches.append(Mismatch(field=f, type=nm(type(value)), status=Status.MISMATCH,
                                               message=f'values differ, dbvalue: "{getattr(db_record, f)}" and xlsvalue: "{value}"',
                                               extra_info=None))

            # We will fill in the refobjs which aren't part of xls but in DB
        return (refobjs, mismatches)
Пример #13
0
class JobManager:
    """
    The evaluation endpoint implementation for the Deepdrive Problem Endpoint.

    - `problem` is the string identifier for the problem.
    - `eval_id` is the unique identifier for this evaluation run.
    - `eval_key` is the evaluation key to pass back to the Botleague liaison.
    - `seed` is the seed to use for random number generation.
    - `docker_tag` is the tag for the bot container image.
    - `pull_request` is the relevant pull request details, or None.
    """
    def __init__(self, jobs_db=None, instances_db=None):
        self.gce_ops_in_progress = BoxList()
        self.instances_db: DB = instances_db or get_worker_instances_db()
        self.jobs_db: DB = jobs_db or get_jobs_db()
        self.gce = googleapiclient.discovery.build('compute', 'v1')
        self.project: str = GCP_PROJECT
        self.zone: str = GCP_ZONE

    def run(self):
        self.assign_jobs()
        self.check_gce_ops_in_progress()
        self.check_jobs_in_progress()
        self.check_for_idle_instances()
        # TODO: restart instances that have been evaluating for more than
        #  problem timeout
        # TODO: self.delete_idle_instances_over_threshold()

    def assign_jobs(self) -> Tuple[BoxList, List]:
        new_jobs = BoxList()
        exceptions = []
        for job in self.jobs_db.where('status', '==', JOB_STATUS_CREATED):
            try:
                if self.should_start_job(job):
                    log.info(f'Assigning job ' f'{box2json(job)}...')
                    self.assign_job(job)
                    new_jobs.append(job)
            except Exception as e:
                # Could have been a network failure, so just try again.
                # More granular exceptions should be handled before this
                # which can set the job to not run again
                # if that's what's called for.

                log.exception(f'Exception triggering eval for job {job}, '
                              f'will try again shortly.')
                exceptions.append(e)

        # TODO: Check for failed / crashed instance once per minute
        # TODO: Stop instances if they have been idle for longer than timeout
        # TODO: Cap total instances
        # TODO: Cap instances per bot owner, using first part of docker tag
        # TODO: Delete instances over threshold of stopped+started

        return new_jobs, exceptions

    def check_jobs_in_progress(self):
        for job in self.jobs_db.where('status', '==', JOB_STATUS_RUNNING):
            if SHOULD_TIMEOUT_JOBS:
                # TODO: We need to stop the job if it's still running before
                #  returning the worker back to the instance pool
                self.handle_timed_out_jobs(job)

    def handle_timed_out_jobs(self, job):
        max_seconds = Box(job, default_box=True).eval_spec.max_seconds
        used_default_max_seconds = False
        if not max_seconds:
            used_default_max_seconds = True
            if job.job_type == JOB_TYPE_EVAL:
                max_seconds = 60 * 5
            elif job.job_type in [
                    JOB_TYPE_SIM_BUILD, JOB_TYPE_DEEPDRIVE_BUILD
            ]:
                max_seconds = 60 * 60
            else:
                log.error(f'Unexpected job type {job.job_type} for job: '
                          f'{box2json(job)} setting timeout to 5 minutes')
                max_seconds = 60 * 5
        if time.time() - job.started_at.timestamp() > max_seconds:
            if used_default_max_seconds:
                log.debug('No max_seconds in problem definition, used default')
            log.error(f'Job took longer than {max_seconds} seconds, '
                      f'consider stopping instance: {job.instance_id} '
                      f'in case the instance is bad. Job:\n{box2json(job)}')
            job.status = JOB_STATUS_TIMED_OUT
            self.jobs_db.set(job.id, job)
            self.make_instance_available(job.instance_id)
            # TODO: Stop the instance in case there's an issue with the
            #  instance itself
            # TODO: Set job error timeout
            pass

    def make_instance_available(self, instance_id):
        # TODO: Move this into problem-constants and rename
        #  problem-helpers as it's shared with problem-worker
        instance = self.instances_db.get(instance_id)
        if not instance:
            log.warning('Instance does not exist, perhaps it was terminated.')
        elif instance.status != constants.INSTANCE_STATUS_AVAILABLE:
            instance.status = constants.INSTANCE_STATUS_AVAILABLE
            instance.time_last_available = SERVER_TIMESTAMP
            self.instances_db.set(instance_id, instance)
            log.info(f'Made instance {instance_id} available')
        else:
            log.warning(f'Instance {instance_id} already available')

    def check_for_finished_jobs(self):
        # TODO: Make this more efficient by querying instances or just
        #   disable or don't do this at all in the loop
        #   since callback will do it for us.

        try:
            for job in self.jobs_db.where('status', '==', JOB_STATUS_FINISHED):
                if 'instance_id' in job:
                    inst_id = job.instance_id
                    if inst_id != LOCAL_INSTANCE_ID:
                        instance = self.instances_db.get(inst_id)
                        if not instance:
                            log.debug(
                                f'Instance "{inst_id}" not found for job:\n'
                                f'{box2json(job)}')
                        elif instance.status == INSTANCE_STATUS_USED:
                            self.make_instance_available(inst_id)
        except:
            log.exception('Unable to check for finished jobs')

    def check_gce_ops_in_progress(self):
        ops_still_in_progress = BoxList()
        for op in self.gce_ops_in_progress:

            try:
                op_result = Box(
                    self.gce.zoneOperations().get(project=self.project,
                                                  zone=self.zone,
                                                  operation=op.name).execute())
            except:
                log.exception('Could not get op_result')
                break
            if op_result.status == 'DONE':
                if 'error' in op_result:
                    log.error(f'GCE operation resulted in an error: '
                              f'{op_result.error}\nOperation was:'
                              f'\n{box2json(op)}')
                    if op.operationType == 'insert':
                        # Retry the creation?
                        pass
                    # elif op.operationType == 'start':
                    #
            else:
                ops_still_in_progress.append(op)
        self.gce_ops_in_progress = ops_still_in_progress

    def assign_job(self, job) -> Optional[Box]:
        if dbox(job).run_local_debug:
            log.warning(f'Run local debug is true, setting instance id to '
                        f'{constants.LOCAL_INSTANCE_ID}')
            self.assign_job_to_instance(constants.LOCAL_INSTANCE_ID, job)
            return job
        worker_instances = self.get_worker_instances()
        self.prune_terminated_instances(worker_instances)
        available_running_instances, available_stopped_instances = \
            self.get_available_instances(worker_instances)
        if available_running_instances:
            self.assign_to_running_instance(available_running_instances, job)
        elif available_stopped_instances:
            self.start_instance_and_assign(available_stopped_instances, job)
        else:
            if len(worker_instances) < MAX_WORKER_INSTANCES:
                self.create_instance_and_assign(job, worker_instances)
            else:
                log.warning(
                    f'Over instance limit, waiting for instances to become '
                    f'available to run job {job.id}')
                return job

        # TODO(Challenge): For network separation: Set DEEPDRIVE_SIM_HOST
        # TODO(Challenge): For network separation: Set network tags between
        #  bot and problem container for port 5557
        return job

    def create_instance_and_assign(self, job, worker_instances):
        create_op = self.create_instance(current_instances=worker_instances)
        instance_id = create_op.targetId
        instance_name = create_op.targetLink.split('/')[-1]
        self.save_worker_instance(
            Box(id=instance_id,
                name=instance_name,
                status=INSTANCE_STATUS_USED,
                assigned_at=SERVER_TIMESTAMP,
                started_at=SERVER_TIMESTAMP,
                created_at=SERVER_TIMESTAMP))
        self.assign_job_to_instance(instance_id, job)
        self.gce_ops_in_progress.append(create_op)
        log.success(f'Created instance {instance_id} for ' f'job {job.id}')

    def start_instance_and_assign(self, available_stopped_instances, job):
        self.assign_job_to_stopped_instance(available_stopped_instances, job)

    def assign_job_to_stopped_instance(self, available_stopped_instances, job):
        # We can't assume stopped instances are available
        # as starting an instance is not immediately reflected in GCE's API
        # :(
        inst = available_stopped_instances[0]
        self.save_worker_instance(
            Box(
                id=inst.id,
                name=inst.name,
                inst=inst,
                status=INSTANCE_STATUS_USED,
                assigned_at=SERVER_TIMESTAMP,
                started_at=SERVER_TIMESTAMP,
            ))
        self.assign_job_to_instance(inst.id, job)
        self.start_instance(inst)
        log.success(f'Started instance {inst.id} for job {job.id}')

    def assign_to_running_instance(self, available_running_instances, job):
        inst = available_running_instances[0]
        # Set the instance to used before starting the job in case
        # it calls back to /results very quickly before setting status.
        self.save_worker_instance(
            Box(id=inst.id,
                name=inst.name,
                inst=inst,
                status=INSTANCE_STATUS_USED,
                assigned_at=SERVER_TIMESTAMP))
        self.assign_job_to_instance(inst.id, job)
        log.success(f'Marked job {job.id} to start on '
                    f'running instance {inst.id}')

    def get_available_instances(self, worker_instances):
        def get_available(instances):
            ret = []
            for inst in instances:
                inst_meta = dbox(self.instances_db.get(inst.id))
                if not inst_meta:
                    log.error(f'Could not find instance {inst.id} in DB')
                elif inst_meta.status == INSTANCE_STATUS_AVAILABLE:
                    ret.append(inst)
            return ret

        # https://cloud.google.com/compute/docs/instances/instance-life-cycle
        instances_by_status = group_instances_by_status(worker_instances)
        provisioning_instances = instances_by_status.get('provisioning')
        staging_instances = instances_by_status['staging']
        running_instances = instances_by_status['running']
        # TODO: Handle these
        stopping_instances = instances_by_status['stopping']
        # TODO: Handle these
        repairing_instances = instances_by_status['repairing']
        stopped_instances = instances_by_status['terminated']
        available_running_instances = get_available(running_instances)
        available_stopped_instances = get_available(stopped_instances)
        return available_running_instances, available_stopped_instances

    def get_worker_instances(self):
        return self.list_instances(WORKER_INSTANCE_LABEL)

    def confirm_evaluation(self, job) -> bool:
        if in_test():
            status = JOB_STATUS_CREATED
            ret = True
        elif dbox(job).confirmed:
            log.info(f'Job already confirmed ' f'{box2json(job)}')
            status = JOB_STATUS_CREATED
            ret = True
        else:
            url = f'{job.botleague_liaison_host}/confirm'
            json = {'eval_key': job.eval_spec.eval_key}
            log.info(f'Confirming eval {json} at {url}...')
            confirmation = requests.post(url, json=json)
            if 400 <= confirmation.status_code < 500:
                status = JOB_STATUS_DENIED_CONFIRMATION
                log.error('Botleague denied confirmation of job, skipping')
                ret = False
            elif not confirmation.ok:
                status = JOB_STATUS_CREATED
                log.error('Unable to confirm job with botleague liaison, '
                          'will try again shortly')
                ret = False
            else:
                status = JOB_STATUS_CREATED
                log.success(f'Confirmed eval job ' f'{box2json(job)} at {url}')
                ret = True
        job.status = status
        job.confirmed = ret
        self.save_job(job)
        return ret

    def start_instance(self, inst):
        if in_test():
            log.warning('Not starting instance in test')
        else:
            op = self.gce.instances().start(project=self.project,
                                            zone=self.zone,
                                            instance=inst.name).execute()
            self.gce_ops_in_progress.append(op)

    def assign_job_to_instance(self, instance_id, job):
        # TODO: Compare and swap
        job.status = JOB_STATUS_ASSIGNED
        job.instance_id = instance_id
        job.started_at = SERVER_TIMESTAMP
        self.save_job(job)

    def save_worker_instance(self, worker_instance):
        self.instances_db.set(worker_instance.id, worker_instance)

    def save_job(self, job):
        self.jobs_db.set(job.id, job)
        return job.id

    def set_eval_data(self, inst, eval_spec):
        inst.eval_spec = eval_spec
        self.instances_db.set(inst.id, inst)

    def list_instances(self, label) -> BoxList:
        if label:
            query_filter = f'labels.{label}:*'
        else:
            query_filter = None
        ret = self.gce.instances().list(project=self.project,
                                        zone=self.zone,
                                        filter=query_filter).execute()
        ret = BoxList(ret.get('items', []))
        return ret

    def create_instance(self, current_instances):
        if in_test():
            log.warning('Not creating instance in test')
            return None
        instance_name = self.get_next_instance_name(current_instances)
        config_path = os.path.join(ROOT, INSTANCE_CONFIG_PATH)
        config = Box.from_json(filename=config_path)
        # TODO: If job is CI, no GPU needed, but maybe more CPU
        config.name = instance_name
        config.disks[0].deviceName = instance_name
        create_op = Box(
            self.gce.instances().insert(project=self.project,
                                        zone=self.zone,
                                        body=config.to_dict()).execute())
        return create_op

    @staticmethod
    def get_next_instance_name(current_instances):
        current_instance_names = [i.name for i in current_instances]
        current_instance_indexes = []
        for name in current_instance_names:
            index = name[name.rindex('-') + 1:]
            if index.isdigit():
                current_instance_indexes.append(int(index))
            else:
                log.warning('Instance with non-numeric index in name found ' +
                            name)
        if not current_instance_indexes:
            next_index = 0
        else:
            next_index = max(current_instance_indexes) + 1
        instance_name = INSTANCE_NAME_PREFIX + str(next_index)
        return instance_name

    def should_start_job(self, job) -> bool:
        if job.job_type == JOB_TYPE_EVAL:
            if not self.confirm_evaluation(job):
                ret = False
            else:
                problem = job.eval_spec.problem

                # Verify that the specified problem is supported
                if problem not in SUPPORTED_PROBLEMS:
                    log.error(f'Unsupported problem "{problem}"')
                    job.status = JOB_STATUS_DENIED_CONFIRMATION
                    self.save_job(job)
                    ret = False
                else:
                    ret = True
        elif job.job_type in JOB_TYPES:
            ret = True
        else:
            log.error(f'Unsupported job type {job.job_type}, skipping job '
                      f'{box2json(job)}')
            ret = False
        return ret

    def check_for_idle_instances(self):
        available_instances = self.instances_db.where(
            'status', '==', INSTANCE_STATUS_AVAILABLE)
        gce_workers = self.get_worker_instances()
        for instance in available_instances:
            last_dt = get_datetime_from_datetime_nanos(instance)
            idle_time = datetime.utcnow() - last_dt
            gce_worker = [w for w in gce_workers if w.id == dbox(instance).id]
            if gce_worker:
                gce_status = gce_worker[0].status
            else:
                gce_status = 'NOT FOUND'

            if idle_time > timedelta(minutes=5) and gce_status == 'RUNNING':
                log.info(f'Stopping idle instance {box2json(instance)}')
                stop_op = self.gce.instances().stop(
                    project=self.project,
                    zone=self.zone,
                    instance=instance.name).execute()
                return stop_op

    def prune_terminated_instances(self, worker_instances):
        worker_ids = [w.id for w in worker_instances]
        db_instances = self.instances_db.where('id', '>', '')
        term_db = get_db('deepdrive_worker_instances_terminated')
        for dbinst in db_instances:
            if dbinst.id not in worker_ids:
                term_db.set(dbinst.id, dbinst)
                self.instances_db.delete(dbinst.id)
Пример #14
0
class parseWholePatient(object):
    def __init__(self, sourceDir):
        if not os.path.isdir(sourceDir):
            logging.info('target dir %s not exist!', sourceDir)
            raise 'IOError'

        self.sourceDir = sourceDir
        self.patientBaseDict = None
        self.patientImageSetList = BoxList()
        self.patientPlanList = BoxList()

    # def readPatient(self,parseDir):
    #     """
    #     one patient may contain multi-plans, parse one by one
    #     :return: self.patient.planList
    #     """
    #     infoDict = None
    #     if os.path.isfile(os.path.join(parseDir, 'Patient')):
    #         infoDict = pinn2Json().read(os.path.join(parseDir, 'Patient'))
    #     else:
    #         logging.error('not a vilidation plan!')
    #     return infoDict

    def getPatientDict(self):
        parseDir = self.sourceDir

        if os.path.isfile(os.path.join(parseDir, 'Patient')):
            baseDict = pinn2Json().read(os.path.join(parseDir, 'Patient'))
        else:
            logging.error('may be a empty plan! skip')
            return None

        if 'ImageSetList' in baseDict:
            image_set_list = baseDict.get('ImageSetList')
            for imageSet in image_set_list:
                logging.info('ImageSet_%s', imageSet.ImageSetID)
                if 'phantom' in imageSet.SeriesDescription:
                    logging.warning('this is Phantom for QA, skip!')
                    continue
                # read CT image set of this plan
                (imageSet['CTHeader'],
                 imageSet['CTData']) = self.readCT(imageSet.ImageName)
                self.patientImageSetList.append(imageSet)

        if 'PlanList' in baseDict:
            plan_list = baseDict.get('PlanList')
            for plan in plan_list:
                logging.info('plan_%s,base on ImageSet_%s', plan.PlanID,
                             plan.PrimaryCTImageSetID)
                if 'QA' in plan.PlanName or 'copy' in plan.PlanName:
                    logging.warning('this is Copy or QA plan, skip!')
                else:
                    planDirName = 'Plan_' + (str)(plan.PlanID)
                    logging.info('Reading plan:%s ......', planDirName)
                    plan['planData'] = self.readPlan(planDirName)
                self.patientPlanList.append(plan)

    def readPlan(self, planDirRefPath):
        """
        read one plan:
            data List:
            plan.Points,
            plan.roi,
            plan.Trial,
        :param planDir: plan relative path ./Plan_N
        :return: dict plan
        """
        planDict = None
        planDirAbsPath = os.path.join(self.sourceDir, planDirRefPath)
        if not os.path.isdir(planDirAbsPath):
            self.logging.info("directory %s not exsit!", planDirAbsPath)
            raise IOError

        if os.path.isfile(os.path.join(planDirAbsPath, 'plan.PlanInfo')):
            planDict = pinn2Json().read(
                os.path.join(planDirAbsPath, 'plan.PlanInfo'))

        if os.path.isfile(os.path.join(planDirAbsPath, 'plan.Points')):
            planDict['Points'] = pinn2Json().read(
                os.path.join(planDirAbsPath, 'plan.Points'))

        if os.path.isfile(os.path.join(planDirAbsPath, 'plan.VolumeInfo')):
            planDict['VolumeInfo'] = pinn2Json().read(
                os.path.join(planDirAbsPath, 'plan.VolumeInfo'))

        logging.info('Reading ROIs, will taking long time, waiting..... ')
        if os.path.isfile(os.path.join(planDirAbsPath, 'plan.roi')):
            planDict['rois'] = pinn2Json().read(
                os.path.join(planDirAbsPath, 'plan.roi'))
            self.getContours(planDict['rois'])

        # if os.path.isfile(os.path.join(planDirAbsPath, 'plan.Pinnacle.Machines')):
        #     planDict['machines'] = pinn2Json().read(
        #         os.path.join(planDirAbsPath, 'plan.Pinnacle.Machines'))

        if os.path.isfile(os.path.join(planDirAbsPath, 'plan.Trial')):
            planTrialData = pinn2Json().read(
                os.path.join(planDirAbsPath, 'plan.Trial'))
            if 'TrialList' in planTrialData:
                currentTrailList = planTrialData['TrialList']
                logging.info("PlanHave %d Trials", len(currentTrailList))
                for currentTrail in currentTrailList:
                    logging.info('======================')
                    logging.info("Trial:%s", currentTrail.Name)
                    data = self.readTrialMaxtrixData(planDirAbsPath,
                                                     currentTrail, planDict)
            else:
                logging.info('======================')
                logging.info("Trial:%s", planTrialData.Trial.Name)
                data = self.readTrialMaxtrixData(planDirAbsPath,
                                                 planTrialData['Trial'],
                                                 planDict)

            planDict['Trial'] = data
        return planDict

    def printPatientBaseInfo(self, patientDataDict):
        """
        Parse the "Patient_XX/Patient" file, get the plan Frame.
        """
        if patientDataDict:
            logging.info("PatientName:%s%s", patientDataDict.LastName,
                         patientDataDict.Firstname)
            logging.info("MRN:%s", patientDataDict.MedicalRecordNumber)
            logging.info("\nimageList:")
            if 'ImageSetList' in patientDataDict:
                for imageSet in patientDataDict.ImageSetList:
                    logging.info(imageSet.ImageName, imageSet.ImageSetID)

            logging.info("\nplanList:")
            if 'PlanList' in patientDataDict:
                for plan in patientDataDict.PlanList:
                    logging.info(plan.PlanName, plan.PlanID,
                                 plan.PrimaryCTImageSetID)

    def readCT(self, CTName):
        """
        Read a CT cube for a plan
        """
        imHdr = pinn2Json().read(
            os.path.join(self.sourceDir, (CTName + '.header')))

        # Read the data from the file
        imData = np.fromfile(os.path.join(self.sourceDir, (CTName + '.img')),
                             dtype='int16')

        # Reshape to a 3D array
        imData = imData.reshape((imHdr.z_dim, imHdr.y_dim, imHdr.x_dim))

        # Solaris uses big endian schema.
        if sys.byteorder == 'little':
            if imHdr.byte_order == 1:
                imData = imData.byteswap(True)
        else:
            if imHdr.byte_order == 0:
                imData = imData.byteswap(True)

        ctVoxSize = [imHdr.z_pixdim, imHdr.y_pixdim, imHdr.x_pixdim]

        # f1 = imView.slicesView(imData, voxSize=ctVoxSize)

        return imHdr, imData

    def getContours(self, planControurData):

        if 'roiList' in planControurData:
            roiList = planControurData['roiList']
            for curROI in roiList:
                logging.info(curROI.name)
                logging.info(curROI.num_curve)

    ####################################################################################################################################################
    # Function: getstructshift()
    # Purpose: reads in values from ImageSet_0.header to get x and y shift
    ####################################################################################################################################################
    def getstructshift(imageHeadFile):
        xshift = 0
        yshift = 0
        zshift = 0

        imgHdr = pinn2Json().read(imageHeadFile)
        x_dim = float(imgHdr.x_dim)
        y_dim = float(imgHdr.y_dim)
        z_dim = float(imgHdr.z_dim)
        xpixdim = float(imgHdr.x_pixdim)
        ypixdim = float(imgHdr.y_pixdim)
        zpixdim = float(imgHdr.z_pixdim)

        # pinnacle version differences
        # xstart = float(imgHdr.x_start_dicom)
        # ystart = float(imgHdr.y_start_dicom)

        xstart = float(imgHdr.x_start)
        ystart = float(imgHdr.y_start)
        zstart = float(imgHdr.z_start)
        patient_position = imgHdr.patient_position
        # with open("%s%s/ImageSet_%s.header" % (Inputf, patientfolder, imagesetnumber), "rt", encoding=u'utf-8',
        #           errors='ignore') as f2:
        #     for line in f2:
        #         if "x_dim =" in line:
        #             x_dim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "y_dim =" in line:
        #             y_dim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "x_pixdim =" in line:
        #             xpixdim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "y_pixdim =" in line:
        #             ypixdim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "x_start =" in line and "index" not in line:
        #             xstart = float((line.split(" ")[-1]).replace(';', ''))
        #             print("xstart = ", xstart)
        #         if "y_start =" in line:
        #             ystart = float((line.split(" ")[-1]).replace(';', ''))
        #         if "z_dim =" in line:
        #             z_dim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "z_pixdim =" in line:
        #             zpixdim = float((line.split(" ")[-1]).replace(';', ''))
        #         if "z_start =" in line and "index" not in line:
        #             zstart = float((line.split(" ")[-1]).replace(';', ''))
        if patient_position == 'HFS':
            xshift = ((x_dim * xpixdim / 2) + xstart) * 10
            yshift = -((y_dim * ypixdim / 2) + ystart) * 10
            zshift = -((z_dim * zpixdim / 2) + zstart) * 10
        elif patient_position == 'HFP':
            xshift = -((x_dim * xpixdim / 2) + xstart) * 10
            yshift = ((y_dim * ypixdim / 2) + ystart) * 10
            zshift = -((z_dim * zpixdim / 2) + zstart) * 10
        elif patient_position == 'FFP':
            xshift = ((x_dim * xpixdim / 2) + xstart) * 10
            yshift = ((y_dim * ypixdim / 2) + ystart) * 10
            zshift = ((z_dim * zpixdim / 2) + zstart) * 10
        elif patient_position == 'FFS':
            xshift = -((x_dim * xpixdim / 2) + xstart) * 10
            yshift = -((y_dim * ypixdim / 2) + ystart) * 10
            zshift = ((z_dim * zpixdim / 2) + zstart) * 10

        logging.info("X shift = %s", xshift)
        logging.info("Y shift = %s", yshift)
        logging.info("Z shift = %s", zshift)
        return (xshift, yshift, zshift)

    def readTrialMaxtrixData(self, trialBasePath, curTrial, planDict):

        planPoints = planDict['Points']
        doseHdr = curTrial.DoseGrid
        dose = np.zeros(
            (doseHdr.Dimension.Z, doseHdr.Dimension.Y, doseHdr.Dimension.X))
        for pInd, ps in enumerate(curTrial.PrescriptionList):
            logging.info('%s:%d:%d', ps.Name, ps.PrescriptionDose,
                         ps.NumberOfFractions)

            for bInd, bm in enumerate(curTrial.BeamList):
                try:
                    # Get the name of the file where the beam dose is saved -
                    # PREVIOUSLY USED DoseVarVolume ?
                    doseFile = os.path.join(
                        trialBasePath, "plan.Trial.binary.%03d" %
                        int(bm.DoseVolume.split('-')[1]))

                    # Read the dose from the file
                    bmDose = np.fromfile(doseFile, dtype='float32')

                    if bmDose.nbytes == 0:
                        raise DoseInvalidException('')

                except IOError or SystemError:
                    raise DoseInvalidException('')
                # Reshape to a 3D array
                bmDose = bmDose.reshape(
                    (doseHdr.Dimension.Z, doseHdr.Dimension.Y,
                     doseHdr.Dimension.X))

                # Solaris uses big endian schema.
                # Almost everything else is little endian
                if sys.byteorder == 'little':
                    bmDose = bmDose.byteswap(True)

                bmFactor = bm.MonitorUnitInfo.NormalizedDose * \
                    bm.MonitorUnitInfo.CollimatorOutputFactor * \
                    bm.MonitorUnitInfo.TotalTransmissionFraction
                dosePerMU = 0.665
                # getting dose/Mu from the plan.Pinnacle.Machines file
                # dosePerMU = self.getDosePerMU()
                MUs = bm.MonitorUnitInfo.PrescriptionDose / \
                    (bmFactor * dosePerMU)
                logging.info('%s:%d', bm.Name, MUs)

                # Weight the dose cube by the beam weight
                dose += (bmDose * bm.Weight)

                # rescale dose to prescriptionDose
                totalPrescriptionDose = ps.PrescriptionDose * ps.NumberOfFractions
                doseAtPoint = totalPrescriptionDose * 1
                if ps.Name == bm.PrescriptionName:
                    if ps.WeightsProportionalTo == 'Point Dose':
                        for pt in planPoints['PoiList']:
                            if pt.Name == ps.PrescriptionPoint:
                                doseAtPoint = self.doseAtCoord(
                                    dose, doseHdr, pt.XCoord, pt.YCoord,
                                    pt.ZCoord)

                                logging.info(doseAtPoint)

                dose = dose * (doseAtPoint / totalPrescriptionDose)
        return dose, doseHdr

    def readDoses(self, planTrialData, planBasePath):
        """
        input:  Read a dose cube for a trial in a given plan and
        return: a numpy array

        Currently tested for:
                (1) Dose is prescribed to a norm point;
                        beam weights are proportional to point dose
                        and control point dose is not stored.
                (2) Dose is prescribed to mean dose of target;
        """
        # trialFile = os.path.join(self.sourceDir, 'plan.Trial')
        # if not os.path.isfile(trialFile):
        #     self.logging.info("not such file %s", trialFile)
        #     return None
        # trialData = pinn2Json().read(trialFile)
        # pts = pinn2Json().read(os.path.join(self.sourceDir, 'plan.Points'))
        if not planTrialData:
            raise IOError

        pts = pinn2Json().read(os.path.join(planBasePath, 'plan.Points'))
        # pts = pointsList['']

        trialList = []
        doseDataDict = {}
        dose = None
        if 'TrialList' in planTrialData:
            logging.info(('plan has %d Trials', len(planTrialData.TrialList)))
            for curTrial in planTrialData.TrialList:
                trialList.append(curTrial)
        else:
            logging.info(('plan has %d Trials', len(planTrialData.Trial)))
            trialList.append(planTrialData.Trial)

        for curTrial in trialList:
            doseHdr = curTrial.DoseGrid
            doseData = np.zeros((doseHdr.Dimension.Z, doseHdr.Dimension.Y,
                                 doseHdr.Dimension.X))

            for bInd, bm in enumerate(curTrial.BeamList):
                try:
                    # Get the name of the file where the beam dose is saved -
                    # PREVIOUSLY USED DoseVarVolume ?
                    doseFile = os.path.join(
                        planBasePath, "plan.Trial.binary.%03d" %
                        int(bm.DoseVolume.split('-')[1]))

                    # Read the dose from the file
                    bmDose = np.fromfile(doseFile, dtype='float32')

                    if bmDose.nbytes == 0:
                        raise DoseInvalidException('')

                except IOError or SystemError:
                    raise DoseInvalidException('')
                # Reshape to a 3D array
                bmDose = bmDose.reshape(
                    (doseHdr.Dimension.Z, doseHdr.Dimension.Y,
                     doseHdr.Dimension.X))

                # Solaris uses big endian schema.
                # Almost everything else is little endian
                if sys.byteorder == 'little':
                    bmDose = bmDose.byteswap(True)

                doseFactor = 1.0

                # Weight the dose cube by the beam weight
                # Assume dose is prescribed to a norm point and beam weights are proportional to point dose
                doseAtPoint = 0.0

                prescriptionPoint = []
                prescriptionDose = []
                prescriptionPointDose = []
                prescriptionPointDoseFactor = []

                for pp in curTrial.PrescriptionList:
                    if pp.Name == bm.PrescriptionName:
                        prescriptionDose.append(pp.PrescriptionDose *
                                                pp.NumberOfFractions)
                        if pp.WeightsProportionalTo == 'Point Dose':
                            for pt in pts.PoiList:
                                if pt.Name == pp.PrescriptionPoint:
                                    doseAtPoint = self.doseAtCoord(
                                        bmDose, doseHdr, pt.XCoord, pt.YCoord,
                                        pt.ZCoord)
                                    doseFactor = pp.PrescriptionDose * \
                                        pp.NumberOfFractions * \
                                        (bm.Weight * 0.01 / doseAtPoint)

                                    prescriptionPoint.append(
                                        [pt.XCoord, pt.YCoord, pt.ZCoord])
                                    prescriptionPointDose.append(doseAtPoint)
                                    prescriptionPointDoseFactor.append(
                                        doseFactor)
                        elif pp.WeightsProportionalTo == 'ROI Mean':
                            logging.info('get ROI mean dose')

                dose += (bmDose * doseFactor)
        for bm, pD, pp in zip(range(len(prescriptionPointDose)),
                              prescriptionPointDose, prescriptionPoint):
            indPP = coordToIndex(doseHdr, pp[0], pp[1], pp[2])

        return dose, doseHdr
        #         doseData += bmDose
        #     doseDataDict[(curTrial.Name + 'DoseArray')] = doseData
        # return doseDataDict

    def coordToIndex(self, imHdr, xCoord, yCoord, zCoord):
        """
        Convert corrdinate positions to coordinate indices
        """

        # coord in cm from primary image centre
        xCoord -= imHdr.Origin.X
        yCoord = imHdr.Origin.Y + imHdr.Dimension.Y * imHdr.VoxelSize.Y - yCoord
        zCoord -= imHdr.Origin.Z

        # coord now in cm from start of dose cube
        xCoord /= imHdr.VoxelSize.X
        yCoord /= imHdr.VoxelSize.Y
        zCoord /= imHdr.VoxelSize.Z

        # coord now in pixels from start of dose cube
        return xCoord, yCoord, zCoord

    # ----------------------------------------- #

    def doseAtCoord(self, doseData, doseHdr, xCoord, yCoord, zCoord):
        """
        Linearly interpolate the dose at a set of coordinates
        """
        xCoord, yCoord, zCoord = self.coordToIndex(doseHdr, xCoord, yCoord,
                                                   zCoord)

        xP = np.floor(xCoord)
        yP = np.floor(yCoord)
        zP = np.floor(zCoord)

        xF = xCoord - xP
        yF = yCoord - yP
        zF = zCoord - zP

        dose = self.doseAtIndex(doseData, zP, yP, xP) * (1.0 - zF) * (1.0 - yF) * (1.0 - xF) + \
            self.doseAtIndex(doseData, zP, yP, xP + 1) * (1.0 - zF) * (1.0 - yF) * xF + \
            self.doseAtIndex(doseData, zP, yP + 1, xP) * (1.0 - zF) * yF * (1.0 - xF) + \
            self.doseAtIndex(doseData, zP, yP + 1, xP + 1) * (1.0 - zF) * yF * xF + \
            self.doseAtIndex(doseData, zP + 1, yP, xP) * zF * (1.0 - yF) * (1.0 - xF) + \
            self.doseAtIndex(doseData, zP + 1, yP, xP + 1) * zF * (1.0 - yF) * xF + \
            self.doseAtIndex(doseData, zP + 1, yP + 1, xP) * zF * yF * (1.0 - xF) + \
            self.doseAtIndex(doseData, zP + 1, yP + 1, xP + 1) * zF * yF * xF

        return dose

    # ----------------------------------------- #

    def doseAtIndex(self, dose, indZ, indY, indX):
        """
        Return dose at indices.
        Beyond end of dose array return zero
        """
        try:
            dd = dose[indZ, indY, indX]
            if indZ > 0 and indY > 0 and indX > 0:
                return dd
            else:
                return 0.0
        except IndexError:
            return 0.0