コード例 #1
0
def train():
    filename ='Flick_8k.trainImages.txt'
    train=utils.load_ids(filename)
    train_captions=utils.load_clean_captions('descriptions.txt',train)
    train_features=utils.load_photos_features('features.pkl',train)
    tokenizer = load(open('tokenizer.pkl','rb'))
    vocab_size = len(tokenizer.word_index)+1
    max_len = utils.get_max_length(train_captions)

    model = caption_model(vocab_size,max_len)
    epochs=20
    steps = len(train_captions)
コード例 #2
0
ファイル: main.py プロジェクト: raihan2108/NActSeer
def main(args):
    with open(args.param_file) as read_file:
        configs = yaml.load(read_file)
    # pprint(configs)
    user_id, reverse_user_id, item_id, reverse_item_id = \
        load_ids(configs['data']['dataset_dir'], configs['data']['ids_file_name'])
    adj_mat = load_graph(
        join(configs['data']['dataset_dir'],
             configs['data']['graph_file_name']), len(user_id))

    if configs['model_name'] == 'netact':
        mc = NetActController(adj_mat, **configs)
        mc.run_train()
コード例 #3
0
def preprocess_annotations(annot_folder):
    '''
    Preprocessing of the raw .xml annotation data in <root_path>/VOCdevkit/VOC2007/Annotations

    Creates a single csv file containing all ground truth data
    csv columns:
    - img_name (000012)
    - class_id (see class_id.csv for matching)
    - shape ([width, height, depth])
    - bounding_box ([xmin, ymin, xmax, ymax])
    - difficult (0 or 1)
    - truncated (0 or 1)
    - segmented (0 or 1)

    TODO: segmentation mask not loaded for now
    '''

    output_path = os.path.join(annot_folder, 'annotations.csv')
    class_ids = load_ids('class_id.csv')

    count_xml = 0
    count_ignored = 0

    with open(output_path, 'w+') as f_out:

        for filename in os.listdir(annot_folder):

            # ignore non-xml file
            if not filename.endswith('.xml'):
                print('Ignoring file %s' % filename)
                count_ignored += 1
                continue

            count_xml += 1
            parsed = parse_xml_annot(os.path.join(annot_folder, filename))

            # yes we want it to fail if the key is not here
            row = ';'.join([parsed[key] for key in csv_keys]) + '\n'
            f_out.write(row)

            break
コード例 #4
0
def preprocess(dataset_folder, dataset_type):

    output_file = os.path.join(dataset_folder, 'VOCdevkit/VOC2007/Annotations',
                               'annotations_multilabel_%s.csv' % dataset_type)
    classif_annot_folder = os.path.join(dataset_folder,
                                        'VOCdevkit/VOC2007/ImageSets/Main')

    class_ids = load_ids()
    data = defaultdict(lambda: [None] * len(class_ids))

    # read
    for filename in os.listdir(classif_annot_folder):
        if not filename.endswith('_%s.txt' % dataset_type):
            continue

        class_name = filename.split('_')[0]
        class_id = class_ids[class_name]['id']
        print('doing class %s (%s)' % (class_name, class_id))

        with open(os.path.join(classif_annot_folder, filename), 'r') as f_in:
            for line in f_in:
                img_id, class_gt = line.strip().split()
                data[img_id][class_id] = int(class_gt)

    # write data in output csv
    with open(output_file, 'w+') as f_out:
        for img_id, img_gt in data.items():
            assert all(
                [x is not None]
                for x in img_gt), 'missing class value for img_id %s (%s)' % (
                    img_id, str(img_gt))
            write_line = img_id + ',' + ','.join([str(x)
                                                  for x in img_gt]) + '\n'
            f_out.write(write_line)

    print('multilabel annotations compiled in %s' % output_file)
コード例 #5
0
    def __init__(self, adj_mx, **kwargs):
        self._kwargs = kwargs
        self._data_kwargs = kwargs.get('data')
        self._model_kwargs = kwargs.get('model')
        self._train_kwargs = kwargs.get('train')
        self.dataset_name = self._data_kwargs['dataset_dir'].split('/')[-1]
        self.adj_mx = adj_mx
        self.model_params = dict()
        self.model_params['seq_len'] = 30
        self.K = [1, 5, 10, 20, 50, 100]

        model_name = 'net_act_orig'  # self._kwargs['model_name']
        self.log_file_name = utils.get_log_dir(log_dir=self._kwargs['log_dir'],
                                               model_name=model_name,
                                               dataset_name=self.dataset_name)
        if not os.path.exists(self._kwargs['save_dir']):
            os.makedirs(self._kwargs['save_dir'])
        if not os.path.exists(
                os.path.join(self._kwargs['save_dir'], self.dataset_name)):
            os.makedirs(
                os.path.join(self._kwargs['save_dir'], self.dataset_name))
        if not os.path.exists(
                os.path.join(self._kwargs['save_dir'], self.dataset_name,
                             self._kwargs['model_name'])):
            os.makedirs(
                os.path.join(self._kwargs['save_dir'], self.dataset_name,
                             self._kwargs['model_name']))

        log_level = self._kwargs.get('log_level', 'INFO')
        self._logger = utils.get_logger(self.log_file_name,
                                        name=__name__,
                                        level=log_level)
        self._writer = tf.summary.FileWriter(self.log_file_name)
        self._logger.info(json.dumps(kwargs, indent=2))
        self._saved_file_name = 'best_model.ckpt'

        user_id, reverse_user_id, item_id, reverse_item_id = \
            utils.load_ids(self._data_kwargs['dataset_dir'], self._data_kwargs['ids_file_name'])
        print(len(user_id), len(reverse_user_id), len(item_id),
              len(reverse_item_id))

        self.n_users = len(user_id)
        self.n_context = self._model_kwargs['context_size']

        data_examples, self.user_history, num_bins = utils.load_dataset_timestamp(
            self._data_kwargs['dataset_dir'],
            self._data_kwargs['dataset_name'], self.n_users, self.n_context,
            self.model_params['seq_len'])
        self.num_bins = num_bins

        self.model_params['batch_size'] = self._data_kwargs['batch_size']
        self.model_params['user_size'] = self.n_users
        self.model_params['item_size'] = len(item_id)
        self.model_params['state_size'] = self._model_kwargs['state_size']
        self.model_params['emb_size'] = self._model_kwargs['emb_size']
        self.model_params['lr'] = self._train_kwargs['base_lr']
        self.model_params['n_bins'] = self.num_bins
        self.model_params['context_size'] = self.n_context
        self.model_params['start_lr'] = len(
            data_examples) // self._data_kwargs['batch_size']
        self.model_params['min_lr'] = self._train_kwargs['min_learning_rate']
        self.model_params['use_attn'] = self._model_kwargs['use_attn']
        self.model_params['normalize'] = self._model_kwargs['normalize']
        self.model_params['max_diff'] = self._model_kwargs['max_diff']
        if self._model_kwargs['n_samples'] == -1:
            self.model_params['n_samples'] = len(item_id)
        else:
            self.model_params['n_samples'] = self._model_kwargs['n_samples']
        self.model_params['comb'] = self._model_kwargs['comb']

        self.data_iterator = utils.Loader(data_examples,
                                          options=self.model_params)
コード例 #6
0
def assign_instances_for_scan(pred_file, gt_file, pred_path):
    try:
        pred_info = utils.read_instance_prediction_file(pred_file, pred_path)
    except Exception as e:
        utils.print_error('unable to load ' + pred_file + ': ' + str(e))
    try:
        gt_ids = utils.load_ids(gt_file)
    except Exception as e:
        utils.print_error('unable to load ' + gt_file + ': ' + str(e))

    # get gt instances
    gt_instances = utils.get_instances(gt_ids, VALID_CLASS_IDS, CLASS_LABELS,
                                       ID_TO_LABEL)
    # associate
    gt2pred = deepcopy(gt_instances)
    for label in gt2pred:
        for gt in gt2pred[label]:
            gt['matched_pred'] = []
    pred2gt = {}
    for label in CLASS_LABELS:
        pred2gt[label] = []
    num_pred_instances = 0
    # mask of void labels in the groundtruth
    bool_void = np.logical_not(np.in1d(gt_ids // 1000, VALID_CLASS_IDS))
    #bool_void = np.in1d(gt_ids, VALID_CLASS_IDS)
    # go thru all prediction masks
    for pred_mask_file in pred_info:
        label_id = int(pred_info[pred_mask_file]['label_id'])
        conf = pred_info[pred_mask_file]['conf']
        if not label_id in ID_TO_LABEL:
            continue
        label_name = ID_TO_LABEL[label_id]
        # read the mask
        pred_mask = utils.load_ids(pred_mask_file)
        if len(pred_mask) != len(gt_ids):
            utils.print_error(
                'wrong number of lines in ' + pred_mask_file +
                '(%d) vs #mesh vertices (%d), please double check and/or re-download the mesh'
                % (len(pred_mask), len(gt_ids)))
        # convert to binary
        pred_mask = np.not_equal(pred_mask, 0)
        num = np.count_nonzero(pred_mask)
        if num < opt.min_region_sizes[0]:
            continue  # skip if empty

        pred_instance = {}
        pred_instance['filename'] = pred_mask_file
        pred_instance['pred_id'] = num_pred_instances
        pred_instance['label_id'] = label_id
        pred_instance['vert_count'] = num
        pred_instance['confidence'] = conf
        pred_instance['void_intersection'] = np.count_nonzero(
            np.logical_and(bool_void, pred_mask))

        # matched gt instances
        matched_gt = []
        # go thru all gt instances with matching label
        for (gt_num, gt_inst) in enumerate(gt2pred[label_name]):
            intersection = np.count_nonzero(
                np.logical_and(gt_ids == gt_inst['instance_id'], pred_mask))
            if intersection > 0:
                gt_copy = gt_inst.copy()
                pred_copy = pred_instance.copy()
                gt_copy['intersection'] = intersection
                pred_copy['intersection'] = intersection
                matched_gt.append(gt_copy)
                gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy)
        pred_instance['matched_gt'] = matched_gt
        num_pred_instances += 1
        pred2gt[label_name].append(pred_instance)

    return gt2pred, pred2gt