def __init__(self, yaml_file, tokenizer=None, add_od_labels=True, max_img_seq_length=50, max_seq_length=70, max_seq_a_length=40, is_train=True, mask_prob=0.15, max_masked_tokens=3, add_conf=False, **kwargs): """Constructor. Args: yaml file with all required data (image feature, caption, labels, etc) tokenizer: tokenizer for text processing. add_od_labels: whether to add labels from yaml file to BERT. max_img_seq_length: max image sequence length. max_seq_length: max text sequence length. max_seq_a_length: max caption sequence length. is_train: train or test mode. mask_prob: probability to mask a input token. max_masked_tokens: maximum number of tokens to be masked in one sentence. kwargs: other arguments. """ self.yaml_file = yaml_file self.cfg = load_from_yaml_file(yaml_file) self.root = op.dirname(yaml_file) self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root) self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root) self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'), self.root) assert op.isfile(self.feat_file) if add_od_labels: assert op.isfile(self.label_file) if is_train: assert op.isfile(self.caption_file) and tokenizer is not None self.label_tsv = None if not self.label_file else TSVFile( self.label_file) self.feat_tsv = TSVFile(self.feat_file) if self.caption_file and op.isfile(self.caption_file): with open(self.caption_file, 'r') as f: self.captions = json.load(f) self.tokenizer = tokenizer self.tensorizer = CaptionTensorizer(self.tokenizer, max_img_seq_length, max_seq_length, max_seq_a_length, mask_prob, max_masked_tokens, is_train=is_train) self.add_od_labels = add_od_labels self.is_train = is_train self.kwargs = kwargs self.image_keys = self.prepare_image_keys() self.key2index = self.prepare_image_key_to_index() self.key2captions = self.prepare_image_key_to_captions() self.add_conf = add_conf
def __init__( self, yaml_file, nms_threshold=0.85, max_given_constraints=3, **kwargs ): super().__init__(yaml_file, **kwargs) boxes_tsvpath = find_file_path_in_yaml(self.cfg['cbs_box'], self.root) constraint2tokens_tsvpath = find_file_path_in_yaml(self.cfg['cbs_constraint'], self.root) tokenforms_tsvpath = find_file_path_in_yaml(self.cfg['cbs_tokenforms'], self.root) hierarchy_jsonpath = find_file_path_in_yaml(self.cfg['cbs_hierarchy'], self.root) self._boxes_reader = ConstraintBoxesReader(boxes_tsvpath) self._constraint_filter = ConstraintFilter( hierarchy_jsonpath, nms_threshold, max_given_constraints ) self._fsm_builder = FiniteStateMachineBuilder(self.tokenizer, constraint2tokens_tsvpath, tokenforms_tsvpath, max_given_constraints)