Ejemplo n.º 1
0
def get_neox_args(context):
    args = AttrMap(context.get_hparams())
    exp_config = context.get_experiment_config()

    # Gather overrides.
    overwrite_values = args.pop("overwrite_values")
    # We are going to overwrite certain neox_args with determined config values
    # from the experiment config to ensure consistency.
    assert ("batches" in exp_config["searcher"]["max_length"]
            ), "Please specify max_length in batches."
    assert ("batches" in exp_config["min_validation_period"]
            ), "Please specify min_validation_period in batches."
    overwrite_values.update({
        "train_iters":
        exp_config["searcher"]["max_length"]["batches"],
        "save_interval":
        exp_config["min_validation_period"]["batches"],
        "eval_interval":
        exp_config["min_validation_period"]["batches"],
        "global_num_gpus":
        context.distributed.get_size(),
        "seed":
        context.env.trial_seed,
    })
    for k, v in overwrite_values.items():
        logging.info(f"Setting neox_args.{k} to {v}")

    # Build neox args.
    neox_args = NeoXArgs.process_parsed_deepy_args(
        args, overwrite_values=overwrite_values)
    return neox_args
Ejemplo n.º 2
0
def configJobInternal(config):
    
    # Get cluster config values
    # Cluster section is ignored, add it again
    d = AttrMap( config )
    
    # remove "config" value and convert boolean flags to strings again, 
    # these config is used only with strings in the generator
    d.pop("config")    
    config = AttrMap( { "Cluster" : d } )

    # add the modules path of this file, which can be used in the (.ini files)
    # and also the execDir of the submit.py
    config["General"] = {
                         "configuratorModuleDir" : configuratorModuleDir,
                         "configuratorModulePath" : configuratorModulePath,
                         "currentWorkDir" : os.getcwd(),
                         "jobDir" : os.path.dirname(os.path.abspath(config.Cluster.jobGeneratorConfig)),
                         "configureScriptArgs" : " ".join(sys.argv)
                         }
                         
    config.Cluster.jobGeneratorOutputDir = os.path.abspath(config.Cluster.jobGeneratorOutputDir )
    
    
    
    # load generator module dyanamically and start the job generator
    # load the given Cluster.jobGenerator class
    module, generatorClass = iH.importClassFromModuleString(config.Cluster.jobGenerator, verbose=True)
    
    gen = generatorClass(config)
    gen.generate()
    
    return 0
Ejemplo n.º 3
0
 def __init__(self, name=None, **fields):
     global NEXT_RESOURCE_NUM
     if name is not None:
         fields["name"] = name
     if "name" not in fields:
         fields["name"] = "resource-%d" % NEXT_RESOURCE_NUM
         NEXT_RESOURCE_NUM += 1
     fields['tags'] = Tags(fields.get('tags', []))
     AttrMap.__init__(self, fields)
Ejemplo n.º 4
0
def get_stats(root: Any) -> AttrMap:
    """
    Выдаем статистику
    :return: {stats: [dict(name, score) for user], me: position}
    """

    return AttrMap(request('stats', login=root.LOGIN))
Ejemplo n.º 5
0
def sign_in(root: Any, _locals: dict, login: str, password: str,
            password2: str):
    """
    Регистрация пользователя
    """

    if login == '':
        _locals['entry_login'].focus()
        root.Alert.show('Введите логин')
    elif len(login) > 20:
        _locals['entry_login'].focus()
        root.Alert.show('Логин не должен быть длиннее 20 символов')
    elif password == '':
        _locals['entry_password'].focus()
        root.Alert.show('Введите пароль')
    elif password2 == '':
        _locals['entry_password2'].focus()
        root.Alert.show('Повторите пароль')
    elif password != password2:
        root.Alert.show('Пароли не совпадают')
    elif (response := request('sing_in', login=login,
                              password=password))['response']:
        if os.path.exists('.auth'):
            os.remove('.auth')
        with open('.auth', 'w') as data:
            data.write(f'{response["login"]}::{password}')
        subprocess.check_call(['attrib', '+H', '.auth'])
        root.LOGIN = response['login']
        root.USER_NAME, root.USER_ID = response['login'].split('#')
        root.profile = AttrMap(request(f'profile/{root.USER_ID}'))
        root.history_view(_locals=_locals)
Ejemplo n.º 6
0
    def _get_obs(self):
        cam_color, cam_depth = self.camera.get_image()

        # update objects positions registered with digits
        self.digits.update()
        colors, depths = self.digits.render()

        obj_pose = self.obj.get_base_pose()

        return AttrMap({
            "camera": {
                "color": cam_color,
                "depth": cam_depth
            },
            "digits": [{
                "color": color,
                "depth": depth
            } for color, depth in zip(colors, depths)],
            "robot":
            self.robot.get_states(),
            "object": {
                "position": np.array(obj_pose[0]),
                "orientation": np.array(obj_pose[1]),
            },
        })
Ejemplo n.º 7
0
    def upward_wrapper(self):
        childrens_attrs = {}
        for k, v in self.children().items():
            attr = getattr(v, func.__name__)
            if callable(attr):
                childrens_attrs[k] = attr()
            else:
                childrens_attrs[k] = attr
        self_attrs = func(self)

        intersection = set(childrens_attrs.keys()).intersection(self_attrs.keys())
        if intersection:
            raise AttributeError(
                f"Found keys {intersection} in both childrens_attrs and self_attrs"
            )

        attrs = {**childrens_attrs, **self_attrs}
        attrs = _remove_empty_dict_leaf(attrs)

        # TODO(poweic): It's confusing to have both SpaceDict & AttrMap at the same
        # time. Also, return AttrMap in new() method of SpaceDict is weird.
        # Consider creating that's both SpaceDict and AttrMap at the same time.
        if func.__name__ == "get_states":
            attrs = AttrMap(attrs)
        else:
            attrs = SpaceDict(attrs)

        return attrs
Ejemplo n.º 8
0
    def load(self, path):
        with open(path, "rb") as f:
            custom = yaml.safe_load(f)

        custom = AttrMap(custom)

        for attr in self._config:
            if attr in custom:
                self._config[attr].update(custom[attr])
Ejemplo n.º 9
0
def complete_quest(root: Any, quest_data: "interface.QuestProcess"):
    """
    Сохраняем результаты тестирования
    """

    root.profile = AttrMap(
        request(f'profile/{root.USER_ID}/quest/{quest_data.quest.name}',
                completed_count=f'{quest_data.completed_count}',
                score=f'{quest_data.score}',
                answers=f'[{quest_data.answers}]'))
Ejemplo n.º 10
0
def _main(root: Any, _locals: dict = None):
    if LOGIN and PASSWORD and USER_NAME and USER_ID:
        if request('auth', login=LOGIN, password=PASSWORD)['response']:
            root.profile = AttrMap(request(f'profile/{USER_ID}'))
            root.home_view(_locals=_locals)
        else:
            os.remove('.auth')
            root.log_in_view(_locals=_locals)
    else:
        request('test')
        root.log_in_view(_locals=_locals)
Ejemplo n.º 11
0
    def __init__(self, name=None, **fields):
        """
        Construct a resource with the specified attributes.

        Parameters
        ----------
        name : string
            Name of the resource. It must be unique to the collection. If it is
            not specified, a unique name is generated.

        **fields : string -> strings, dicts, and lists
            Other fields in the resource.
        """
        global NEXT_RESOURCE_NUM
        if name is not None:
            fields["name"] = name
        if "name" not in fields:
            fields["name"] = "resource-%d" % NEXT_RESOURCE_NUM
            NEXT_RESOURCE_NUM += 1
        fields['tags'] = Tags(fields.get('tags', []))
        AttrMap.__init__(self, fields)
def _load_untrusted_json(json_str):
    try:
        obj = json.loads(json_str, parse_int=int, parse_constant=bool)
        if (obj['klass'] == 'Command' and isinstance(obj['type'], str)
                and (obj['params'] is None or isinstance(obj['params'], dict))
                and isinstance(obj['nth_turn'], int)):
            return AttrMap(obj)
    except Exception as e:
        logger.debug('[S] Failed to decode json.')
        logger.debug(str(e))

    return None
Ejemplo n.º 13
0
    def __init__(self, context: DeepSpeedTrialContext) -> None:
        self.context = context
        self.exp_config = self.context.get_experiment_config()
        self.args = AttrMap(self.context.get_hparams())

        # Initalize and get arguments, timers, and Tensorboard writer.
        try:
            self.neox_args = get_neox_args(self.context)
        except:
            traceback.print_exc()
            raise InvalidHP("Could not parse neox_args.")
        self.wrapped_writer = TorchWriter()
        self.neox_args.tensorboard_writer = self.wrapped_writer.writer
        self.neox_args.configure_distributed_args()
        # The tokenizer needs to be built before model initialization in order to set the
        # required padded_vocab_size argument.
        self.neox_args.build_tokenizer()
        megatron_train.initialize_megatron(neox_args=self.neox_args)
        self.timers = megatron_utils.Timers(
            use_wandb=False,
            tensorboard_writer=self.neox_args.tensorboard_writer)

        # Model, optimizer, and learning rate.
        self.timers("model and optimizer").start()
        (
            model,
            self.optimizer,
            self.lr_scheduler,
        ) = megatron_train.setup_model_and_optimizer(neox_args=self.neox_args)
        self.model = self.context.wrap_model_engine(model)
        self.timers("model and optimizer").stop()

        # Print setup timing.
        megatron_utils.print_rank_0("done with setups ...")
        self.timers.log(["model and optimizer"])
        megatron_utils.print_rank_0("training ...")

        # For tracking.
        if not self.args.search_world_size:
            self.reducer = self.context.wrap_reducer(LMReducers(
                self.neox_args),
                                                     for_training=False,
                                                     for_validation=True)
        self.report_memory_flag = True
        self.total_train_loss_dict = {}
        self.total_val_loss_dict = {}
        self.tflops = 0
        self.reported_flops = False
        self.overflow_monitor = megatron_utils.OverflowMonitor(self.optimizer)
        self.noise_scale_logger = megatron_utils.get_noise_scale_logger(
            self.neox_args)
        self.timers("interval time").start()
Ejemplo n.º 14
0
def test_mutable():
    """

    注意: AttrDict中的sequence type会自动转化为tuple, 即非mutable的形式。也就是
    对list中的对象无法修改。如果想要允许对list中的对象进行修改, 需要使用AttrMap,
    并指定 ``sequence_type = list``。
    """
    user_data = {
        "id": "EN-0001",
        "phone_numbers": [
            {"label": "home", "number": "111-222-3333"},
            {"label": "work", "number": "444-555-6666"},
            {"label": "mobile", "number": "777-888-9999"},
        ],
        "profile": {
            "SSN": "123-45-6789",
            "drivers_license": {
                "state": "DC",
                "license_number": "DC-1234-5678",
            }
        }
    }
    user = AttrMap(user_data, sequence_type=list)

    assert user.id == "EN-0001"
    assert user["id"] == "EN-0001"

    user.id = "EN-0002"
    assert user.id == "EN-0002"
    assert user["id"] == "EN-0002"

    # nested dict is also attrdict
    assert user.phone_numbers[0].number == "111-222-3333"
    assert user.phone_numbers[0]["number"] == "111-222-3333"

    user.phone_numbers[0].number = "111-111-1111"
    assert user.phone_numbers[0].number == "111-111-1111"
    assert user.phone_numbers[0]["number"] == "111-111-1111"
Ejemplo n.º 15
0
def test_invalid_name():
    user_data = {
        "_id": 1,
        "first name": "John",
        "last name": "David",
        "email": "*****@*****.**",
    }
    user = AttrMap(user_data, sequence_type=list)

    # 无法用 dict.attr 的风格访问
    with pytest.raises(Exception):
        user._id
    # 但仍可以用 dict[attr] 的风格访问
    assert user["_id"] == 1
Ejemplo n.º 16
0
def test_coco():
    # ignore warnings in this test
    warnings.simplefilter('ignore')

    # args
    args = AttrMap({
        'epochs': 1,
        'steps': 1,
        'imagenet_weights': False,
        'snapshots': False,
        'dataset_type': 'coco',
        'coco_path': 'tests/test-data/coco',
    })
    # run training / evaluation
    keras_retinanet.bin.train.main_(args)
Ejemplo n.º 17
0
def test_csv():
    # ignore warnings in this test
    warnings.simplefilter('ignore')

    args = AttrMap({
        'epochs': 1,
        'steps': 1,
        'imagenet_weights': False,
        'snapshots': False,
        'dataset_type': 'csv',
        'annotations': 'tests/test-data/csv/annotations.csv',
        'classes': 'tests/test-data/csv/classes.csv'
    })

    # run training / evaluation
    keras_retinanet.bin.train.main_(args)
Ejemplo n.º 18
0
def test_vgg():
    # ignore warnings in this test
    warnings.simplefilter('ignore')

    args = AttrMap({
        'backbone': 'vgg16',
        'epochs': 1,
        'steps': 1,
        'imagenet_weights': False,
        'snapshots': False,
        'freeze_backbone': True,
        'dataset_type': 'coco',
        'coco_path': 'tests/test-data/coco'
    })

    # run training / evaluation
    keras_retinanet.bin.train.main_(args)
Ejemplo n.º 19
0
    def __init__(
        self,
        message: Optional[int] = None,
        *,
        peer_id: Optional[int] = None,
        random_id: Optional[int] = None,
        user_id: Optional[int] = None,
        domain: Optional[str] = None,
        chat_id: Optional[int] = None,
        user_ids: Optional[List[int]] = None,
        peer_ids: Optional[List[int]] = None,
        lat: Optional[float] = None,
        long: Optional[float] = None,
        attachment: Optional[List[str]] = None,
        reply_to: Optional[int] = None,
        forward_messages: Optional[List[int]] = None,
        sticker_id: Optional[int] = None,
        group_id: Optional[int] = None,
        keyboard: Optional[str] = None,
        payload: Optional[str] = None,
        dont_parse_links: Optional[bool] = None,
        disable_mentions: Optional[bool] = None,
        intent: Optional[str] = None,
        expire_ttl: Optional[int] = None,
        silent: Optional[bool] = None,
        **kwargs,
    ):
        if random_id is None:
            random_id = randint(-2 * 31, 2 * 31)

        preload_data = locals().copy()
        del preload_data["self"]
        kwargs_vals = preload_data.pop("kwargs")
        preload_data.update(kwargs_vals)

        self._params = AttrMap(
            dict(filter(lambda x: x[1] is not None, preload_data.items())))
        self._set_path()
        self._join_attach()
Ejemplo n.º 20
0
def log_in(root: Any, _locals: dict, login: str, password: str):
    """
    Авторизация пользователя
    """

    if login == '':
        _locals['entry_login'].focus()
        root.Alert.show('Введите логин')
    elif password == '':
        _locals['entry_password'].focus()
        root.Alert.show('Введите пароль')
    elif request('auth', login=login, password=password)['response']:
        if os.path.exists('.auth'):
            os.remove('.auth')
        with open('.auth', 'w') as data:
            data.write(f'{login}::{password}')
        subprocess.check_call(['attrib', '+H', '.auth'])
        root.LOGIN = login
        root.USER_NAME, root.USER_ID = login.split('#')
        root.profile = AttrMap(request(f'profile/{root.USER_ID}'))
        root.history_view(_locals=_locals)
    else:
        root.Alert.show('Неправильный логин или пароль')
Ejemplo n.º 21
0
G = AttrMap(
    {
        'APPEND': [],
        'ACTIONS': [],
        'BUILD': [],
        'BUILDER': ["./build-boot-image.py"],
        'BUNDLE_FILE': '',
        'BUNDLED': 'bundled',
        'DEBUG': 0,
        'DEFAULT': 'default',
        'DRY_RUN': 0,
        'EDIT': '',
        'BOOT': False,
        'EXCLUDES': [],
        'FLAGS': [],
        'IMAGE': 'boot-image.raw',
        'ISO': 'centos.iso',
        'LABEL': '',
        'MODE': '',
        'MODES': ['default'],
        'ONLY': [],
        'SUPPORTED_VERSION': '1',
        'TOPDIR': 'bundled',
        'TAR_CLEANUP': [],
        'VARS': [],
        'VERBOSE': 0,
        'STRICT': 1,
        'FORMAT': '',
    },
    sequence_type=list)  #,recursive=False)
Ejemplo n.º 22
0
    def __init__(self):
        with open(self.DEFAULT, "rb") as f:
            default = yaml.safe_load(f)

        self._config = AttrMap(default)
Ejemplo n.º 23
0
                log['gen/loss'] = loss_g.item()
                log['dis/loss'] = loss_d.item()

                logreport(log)

        with torch.no_grad():
            log_test = test(config, test_data_loader, gen, criterionMSE, epoch)
            testreport(log_test)

        if epoch % config.snapshot_interval == 0:
            checkpoint(config, epoch, gen, dis)

        logreport.save_lossgraph()
        testreport.save_lossgraph()


if __name__ == '__main__':
    with open('config.yml', 'r') as f:
        config = yaml.load(f)
    config = AttrMap(config)

    utils.make_manager()
    n_job = utils.job_increment()
    config.out_dir = os.path.join(config.out_dir, '{:06}'.format(n_job))
    os.makedirs(config.out_dir)
    print('Job number: {:04d}'.format(n_job))

    shutil.copyfile('config.yml', os.path.join(config.out_dir, 'config.yml'))

    train(config)
def main():
    
    """ {old validatation file infos}  is compared 
        to { new validation file infos}  ==> outputs new file validation info
    """
    
    parser = MyOptParser()
    
    parser.add_argument("-s", "--searchDirNew", dest="searchDirNew",
            help="""This is the search directory where it is looked for output files (.tiff,.exr,.rib.gz). """, 
            metavar="<path>", default=None, required=False)
    
    
    parser.add_argument("--valFileInfoGlobNew", dest="valFileInfoGlobNew",
            help="""
            The globbing expression for all input xmls with file status which 
            are consolidated into a new file info under --output. The found and validated files in --searchDir (if specified) are
            added to the set of new files.
            """, default=None, metavar="<glob>", required=False)
    
    parser.add_argument("--valFileInfoGlobOld", dest="valFileInfoGlobOld",
            help="""
            The globbing expression for all old input xmls with file status which 
            are consolidated with the new files into a combined file info under --output.
            """, default=None, metavar="<glob>", required=False)
    
    parser.add_argument("--pipelineSpecs", dest="pipelineSpecs", default="",
            help="""Json file with info about the pipeline, fileValidation, fileValidationTools.                 
                 """, metavar="<string>", required=True)
    
    parser.add_argument("--statusFolder", dest="statusFolder", default=None,
            help="""The output status folder which contains links to files which are finished, or can be recovered.                
                 """, metavar="<string>", required=False)
    
                                                       
    parser.add_argument("--validateOnlyLastModified", dest="validateOnlyLastModified", type=cF.toBool, default=True,
            help="""The file with the moset recent modified time is only validated, all others are set to finished!.""", required=False)
                         

    parser.add_argument("-o", "--output", dest="output",
            help="""The output xml which is written, which proivides validation info for each file found""", metavar="<path>", required=True)
    
    
    try:
        
        print("====================== FileValidation ===========================")
        
        opts= AttrMap(vars(parser.parse_args()))
        if not opts.searchDirNew and not opts.valFileInfoGlobNew:
            raise ValueError("You need to define either searchDirNew or valFileInfoGlobNew!")
        
        if opts.valFileInfoGlobOld == "":
            opts.valFileInfoGlobOld = None
        
        print("searchDir: %s" % opts.searchDirNew)
        print("valFileInfoGlobNew: %s" % opts.valFileInfoGlobNew)
        print("valFileInfoGlobOld: %s" % opts.valFileInfoGlobOld)
        print("output: %s" % opts.output)
        
        
        d = cF.jsonLoad(opts.pipelineSpecs)
        pipelineTools = d["pipelineTools"]
        fileValidationSpecs = d["fileValidationSpecs"]
        fileValidationTools = d["fileValidationTools"]
        
        valDataAllNew = dict()
        deleteFiles = []
        
        # load new validataion datas
        if opts.valFileInfoGlobNew is not None:
            print("Load new validation files")
            valDataAllNew , valFilesNew  = loadValidationFiles(opts.valFileInfoGlobNew)
            
            preferGlobalPaths(valDataAllNew)
            
        
        # add searchDir files to new set
        # search files ============================================================================
        if opts.searchDirNew is not None:
            print("Validate all files in: %s with pipeLineSpecs: %s" % (opts.searchDirNew , opts.pipelineSpecs) )
            allFiles = searchFiles(opts.searchDirNew, opts, fileValidationSpecs,fileValidationTools,pipelineTools)
            for ha, f in allFiles.items():
              if ha in valDataAllNew:
                  print("""WARNING: File %s already found in validation data set 
                           from globbing expr. %s """ % (f["absPath"], opts.valFileInfoGlobNew))
              else:
                valDataAllNew[ha] = f
        # ===============================================================================================
        
        
        
        # load old validation datas
        if opts.valFileInfoGlobOld is not None:
            print("Load old validation files")
            valDataAllOld , valFilesOld  = loadValidationFiles(opts.valFileInfoGlobOld)
            preferGlobalPaths(valDataAllOld)
            
            # add old to new validatation infos 
            for ha, valInfo in valDataAllOld.items():
              
                if ha not in valDataAllNew:
                    # this old file hash is not in our current list, so add it!
                    
                    # check absPath if it exists otherwise try to extent the relPath with dir of this validation file.
                    if not os.path.exists(valInfo["absPath"]):
                      absPath = os.path.join( os.path.dirname(valInfo["validatationInfoPath"]) , valInfo["relPath"] )
                      if not os.path.exists(absPath):
                         print(valInfo["validatationInfoPath"])
                         raise NameError("""File path in valid. info file: %s 
                                            does not exist, extended rel. path to: %s does also not exist!""" % (valInfo["absPath"],absPath))
                      else:
                         print("Replacing inexisting path %s with %s", valInfo["absPath"], absPath)
                         valInfo["absPath"] = absPath
                      
                    # copy element to new file info
                    valDataAllNew[ha] = valInfo
                else:
                    # we have the same hash in the new info
                    # take our new one which is better!
                    # delete old file if it is not linked to by new file

                    if  os.path.realpath(valDataAllNew[ha]["absPath"]) !=  os.path.realpath(valInfo["absPath"]):
                        deleteFiles.append(valInfo["absPath"])
     

        # make final list
        finalFiles = [ f for f in valDataAllNew.values() ]
        
        printSummary(finalFiles,pipelineTools,False)
        
        print("Make output validation file")
        f = open(opts.output,"w+")
        cF.jsonDump(finalFiles,f, sort_keys=True)
        f.close();
        
        # Renew status folder, move over new xml info
        if opts.statusFolder is not None:
          
            print("Renew status folder:")
            finished = os.path.join(opts.statusFolder,"finished")
            recover = os.path.join(opts.statusFolder,"recover")
            
            cF.makeDirectory(finished,interact=False, defaultMakeEmpty=True)
            cF.makeDirectory(recover ,interact=False, defaultMakeEmpty=True)
            # make symlinks for all files in the appropriate folder:
            paths = {"recover": recover, "finished": finished}           
            
            for f in finalFiles:
                h = f["hash"]
                p = os.path.relpath(f["absPath"],start=paths[f["status"]])
                filename = os.path.basename(p)
                head,ext = os.path.splitext(filename)
                
                os.symlink(p, os.path.join( paths[f["status"]] , head+"-uuid-"+h+ext ) );


        print("=================================================================")
        
    except Exception as e:
        print("====================================================================")
        print("Exception occured: " + str(e))
        print("====================================================================")
        traceback.print_exc(file=sys.stdout)
        parser.print_help()
        return 1
Ejemplo n.º 25
0
            out_rgb = np.clip(out_[:3], -1, 1)
            out_cloud = np.clip(out_[3], -1, 1)
            allim[0, 0, :] = np.repeat(in_nir[None, :, :], repeats=3,
                                       axis=0) * 127.5 + 127.5
            allim[0, 1, :] = in_rgb * 127.5 + 127.5
            allim[0, 2, :] = out_rgb * 127.5 + 127.5
            allim[0, 3, :] = np.repeat(
                out_cloud[None, :, :], repeats=3, axis=0) * 127.5 + 127.5
            allim = allim.transpose(0, 3, 1, 4, 2)
            allim = allim.reshape((h * p, w * p, c))

            save_image(args.out_dir, allim, i, 1, filename=filename)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--test_dir', type=str, required=True)
    parser.add_argument('--out_dir', type=str, required=True)
    parser.add_argument('--pretrained', type=str, required=True)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--gpu_ids', type=int, default=[0])
    parser.add_argument('--manualSeed', type=int, default=0)
    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = yaml.load(f)
    config = AttrMap(config)

    predict(config, args)
Ejemplo n.º 26
0
def process(xmlfile, *xmlfiles, **options):
    """
    Main program code.

        1. Load XML file.
        2. Export structured annotation data from XML to relational database.
        3. Save database tables as CSV files.
        4. Create polygons from the annotations.
        5. Create polygonal selection by adding and subtracting polygons.
        6. Save polygonal selection as a binary object.
        7. Create binary mask at the specified resolution with the specified
           field-of-view (FOV) from the polygonal selection.
        8. Export binary mask (values: 0-255) to an 8-bit TIFF file.
        +1 (optional): display polygonal selection and binary mask.

    """
    p = AttrMap(options)

    # Start logging
    global logger
    temp_logfile = tempfile.mkstemp()
    logging.basicConfig(format='[%(asctime)s] %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        filename=temp_logfile[1],
                        filemode="w")
    logger = logging.getLogger("roi")
    logger.setLevel(max(50 - p.verbose, 1))
    logger.critical("The program started with the command: {}".format(" ".join(
        sys.argv)))

    xmlfiles = (xmlfile, ) + xmlfiles
    err = 0
    for f in xmlfiles:
        try:
            # Parse data in XML file and create data tables
            logger.info("{}".format(f))
            if not str(f).lower().endswith(".xml"):
                logger.warning("Input does not appear to be an XML file.")
            points, regions, annotations = parse_xml(f)

            # Create output base name
            p.basename = os.path.splitext(os.path.split(f)[-1])[0]
            if p.outdir:
                if not os.path.isdir(p.outdir):
                    os.makedirs(p.outdir)
                p.outbase = os.path.join(p.outdir, p.basename)
            else:
                p.outbase = os.path.splitext(f)[0]

            # Save the tables of the relational database
            if p.csv:
                points.to_csv(p.outbase + "_points.csv")
                regions.to_csv(p.outbase + "_regions.csv")
                annotations.to_csv(p.outbase + "_annotations.csv")

            # Retrieve polygons
            polygons = create_polygons(points)

            # Treat annotation layers separately
            for layer in annotations.index.unique():
                layerdir = p.outdir or os.path.split(f)[0]
                layerdir = os.path.join(
                    layerdir, "AnnotationLayer_{0:02d}".format(layer))
                if not os.path.isdir(layerdir):
                    os.makedirs(layerdir)
                layerbase = os.path.join(layerdir, p.basename)

                # Set algebra
                selection = create_selection(polygons, regions, layer=layer)
                if selection is None:
                    logger.warning("Annotation layer {} does not have any "
                                   "polygonal selections.".format(layer))
                    continue

                # Display selection
                if p.display:
                    visualise_polygon(selection, show=True, save=False)

                # Export the polygonal selection object to a binary file
                if p.bin:
                    with open(layerbase + "_selection.obj", "wb") as fp:
                        dill.dump(selection, fp)

                # Generate binary mask
                if p.mask:
                    if len(p.scale) == 1:
                        scale_x, scale_y = p.scale * 2  # p.scale is a tuple!
                    elif len(p.scale) == 2:
                        scale_x, scale_y = p.scale  # p.scale is a tuple!
                    else:
                        raise ValueError(
                            "The number of scaling factors must be 2.")
                    mask = create_mask(selection,
                                       original_shape=p.original_shape,
                                       target_shape=p.target_shape,
                                       scale_x=scale_x,
                                       dimscale_x=p.dimscale_x,
                                       scale_y=scale_y,
                                       dimscale_y=p.dimscale_y,
                                       tile=p.tile,
                                       fill_value=p.fill_value)
                    # Display binary mask
                    if p.display:
                        plt.imshow(mask, cmap="gray", aspect="equal")
                        plt.show()
                    # Save binary mask
                    Image.fromarray(mask).save(
                        os.path.join(layerbase + "_mask.tif"))
                    # Save the corresponding histology image
                    if p.histo:
                        histo = create_histo(img=p.image,
                                             dimlevel=p.dimlevel,
                                             original_shape=p.original_shape,
                                             target_shape=p.target_shape,
                                             scale_x=scale_x,
                                             dimscale_x=p.dimscale_x,
                                             scale_y=scale_y,
                                             dimscale_y=p.dimscale_y,
                                             tile=p.tile)
                        Image.fromarray(histo).save(
                            os.path.join(layerbase + "_histo.tif"))
        except Exception as exc:
            logger.critical("FAILURE: {}. Exception: {}".format(
                f, exc.args[0]))
            err += 1
            continue

    # Conclude run
    if err == 0:
        logger.critical("All tasks were successfully completed.")
    else:
        logger.critical("Tasks were completed with {} error(s).".format(err))

    # Save logs
    try:
        shutil.copy(temp_logfile[1], p.outbase + ".log")
    except PermissionError:
        pass
Ejemplo n.º 27
0
def test_add_fragment_urls(mock_log_change):
    # Ensure shelfmark not existing is properly handled.
    command = add_fragment_urls.Command()
    row = AttrMap({"shelfmark": "mm", "url": "example.com"})
    command.add_fragment_urls(row)  # Test would fail if error were raised
    assert command.stats["not_found"] == 1
    assert not mock_log_change.call_count

    # Ensure that the iiif url is not overwritten unless overwrite arg is provided
    command = add_fragment_urls.Command()
    command.overwrite = None
    command.dryrun = None
    orig_frag = Fragment.objects.create(
        shelfmark="T-S NS 305.66",
        iiif_url="https://cudl.lib.cam.ac.uk/iiif/MS-TS-NS-J-00490",
    )
    row = AttrMap({
        "shelfmark": orig_frag.shelfmark,
        "url": "https://cudl.lib.cam.ac.uk/view/MS-TS-NS-J-00600",
    })
    command.add_fragment_urls(row)
    fragment = Fragment.objects.get(shelfmark=orig_frag.shelfmark)
    assert fragment.url == row["url"]
    assert fragment.iiif_url == orig_frag.iiif_url
    assert command.stats["url_added"] == 1
    assert not command.stats["iiif_added"]
    assert not command.stats["iiif_updated"]
    assert not command.stats["url_updated"]
    mock_log_change.assert_called_with(fragment, "added URL")

    command = add_fragment_urls.Command()
    command.overwrite = True
    command.dryrun = None
    orig_frag = Fragment.objects.create(
        shelfmark="T-S NS 305.75",
        iiif_url="https://cudl.lib.cam.ac.uk/iiif/MS-TS-NS-J-00490",
    )
    row = AttrMap({
        "shelfmark": orig_frag.shelfmark,
        "url": "https://cudl.lib.cam.ac.uk/view/MS-TS-NS-J-00600",
    })
    command.add_fragment_urls(row)
    fragment = Fragment.objects.get(shelfmark=orig_frag.shelfmark)
    assert fragment.iiif_url != orig_frag.iiif_url
    assert fragment.iiif_url == "https://cudl.lib.cam.ac.uk/iiif/MS-TS-NS-J-00600"
    assert command.stats["iiif_updated"] == 1
    mock_log_change.assert_called_with(fragment,
                                       "added URL and updated IIIF URL")

    # test updating url — url matches, should skip
    fragment.url = row.url
    fragment.save()
    command.stats = defaultdict(int)
    command.add_fragment_urls(row)
    assert not command.stats["url_updated"]
    assert not command.stats["url_added"]
    assert command.stats["skipped"] == 1

    # fragment url is set but does not match, no overwrite
    fragment.url = "http://example.com/fragment/view"
    fragment.save()
    command.overwrite = False
    command.stats = defaultdict(int)
    command.add_fragment_urls(row)
    assert not command.stats["url_updated"]
    assert not command.stats["url_added"]
    assert command.stats["skipped"] == 1

    # url mismatch, overwrite specified
    command.overwrite = True
    command.stats = defaultdict(int)
    command.add_fragment_urls(row)
    assert command.stats["url_updated"] == 1
    assert not command.stats["url_added"]
    assert not command.stats["skipped"]

    # Ensure that changes aren't saved if dryrun argument is provided
    mock_log_change.reset_mock()
    command = add_fragment_urls.Command()
    command.overwrite = None
    command.dryrun = True
    orig_frag = Fragment.objects.create(
        shelfmark="T-S NS 305.80",
        iiif_url="https://cudl.lib.cam.ac.uk/iiif/MS-TS-NS-J-00490",
    )
    row = AttrMap({
        "shelfmark": orig_frag.shelfmark,
        "url": "https://cudl.lib.cam.ac.uk/view/MS-TS-NS-J-00600",
    })
    command.add_fragment_urls(row)
    fragment = Fragment.objects.get(shelfmark=orig_frag.shelfmark)
    assert fragment.iiif_url == orig_frag.iiif_url
    assert not mock_log_change.call_count
Ejemplo n.º 28
0
import os
import sys
import yaml
from attrdict import AttrMap

current_dir = os.path.dirname(os.path.abspath(__file__))
with open(current_dir + '/settings.yml', 'r') as stream:
    config = yaml.load(stream, Loader=yaml.FullLoader)
    sys.modules[__name__] = AttrMap(config)
Ejemplo n.º 29
0
 def new(self):
     # TODO(poweic): instead of None, use torch.Tensor? (placeholder + strict schema)
     return AttrMap({
         k: v.new() if isinstance(v, collections.abc.Mapping) else None
         for k, v in self.spaces.items()
     })
Ejemplo n.º 30
0
 def get_children_states(self):
     return AttrMap(
         _remove_empty_dict_leaf(
             {k: v.get_states() for k, v in self.children().items()}
         )
     )