コード例 #1
0
    def relabel(self, json_dir, method='ssd'):
        submission_file = os.path.join(json_dir, 'person_instances_val2017_{}_results.json'.format(method))
        img_id_list = list()
        object_list = list()

        for json_file in os.listdir(json_dir):
            json_path = os.path.join(json_dir, json_file)
            shotname, extensions = os.path.splitext(json_file)
            shotname = shotname.rstrip().split('_')[-1]
            try:
                img_id = int(shotname)
            except ValueError:
                Log.info('Invalid Json file: {}'.format(json_file))
                continue

            img_id_list.append(img_id)
            with open(json_path, 'r') as json_stream:
                info_tree = json.load(json_stream)
                for object in info_tree['objects']:
                    object_dict = dict()
                    object_dict['image_id'] = img_id
                    object_dict['category_id'] = int(self.configer.get('details', 'coco_cat_seq')[object['label']])
                    object_dict['score'] = object['score']
                    object_dict['bbox'] = [object['bbox'][0], object['bbox'][1],
                                           object['bbox'][2] - object['bbox'][0],
                                           object['bbox'][3] - object['bbox'][1]]

                    object_list.append(object_dict)

        with open(submission_file, 'w') as write_stream:
            write_stream.write(json.dumps(object_list))

        Log.info('Evaluate {} images...'.format(len(img_id_list)))
        return submission_file, img_id_list
コード例 #2
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset') == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                            aug_transform=self.aug_val_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                             aug_transform=self.aug_val_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(root_dir=self.configer.get('data', 'data_dir'),
                                            dataset=dataset, tag=self.configer.get('data', 'tag'),
                                            aug_transform=self.aug_val_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        else:
            Log.error('{} val loader is invalid.'.format(self.configer.get('val', 'loader')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )

        return valloader
コード例 #3
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        image_dir = os.path.join(root_dir, 'leftImg8bit', dataset)
        label_dir = os.path.join(root_dir, 'gtFine', dataset)

        for image_file in FileHelper.list_dir(image_dir):
            image_name = '_'.join(image_file.split('_')[:-1])
            label_file = '{}_gtFine_labelIds.png'.format(image_name)
            img_path = os.path.join(image_dir, image_file)
            label_path = os.path.join(label_dir, label_file)
            if not (os.path.exists(label_path) and os.path.exists(img_path)):
                Log.warn('Image/Label Path: {} not exists.'.format(image_name))
                continue

            img_list.append(img_path)
            label_list.append(label_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'leftImg8bit/val')
            label_dir = os.path.join(root_dir, 'gtFine/val')

            for image_file in FileHelper.list_dir(image_dir):
                image_name = '_'.join(image_file.split('_')[:-1])
                label_file = '{}_gtFine_labelIds.png'.format(image_name)
                img_path = os.path.join(image_dir, image_file)
                label_path = os.path.join(label_dir, label_file)
                if not (os.path.exists(label_path) and os.path.exists(img_path)):
                    Log.warn('Image/Label Path: {} not exists.'.format(image_name))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)

        return img_list, label_list
コード例 #4
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                            aug_transform=self.aug_train_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                             aug_transform=self.aug_train_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(root_dir=self.configer.get('data', 'data_dir'),
                                            dataset='train', tag=self.configer.get('data', 'tag'),
                                            aug_transform=self.aug_train_transform,
                                            img_transform=self.img_transform,
                                            configer=self.configer)

        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)

        trainloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('train', 'data_transformer')
            )
        )

        return trainloader
コード例 #5
0
ファイル: default_dataset.py プロジェクト: zihua/torchcv
    def __read_json_file(self, root_dir, dataset):
        img_list = list()
        label_list = list()

        with open(os.path.join(root_dir, '{}.json'.format(dataset)),
                  'r') as file_stream:
            items = json.load(file_stream)
            for item in items:
                img_path = os.path.join(root_dir, 'dataset',
                                        item['image_path'])
                if not os.path.exists(img_path):
                    Log.warn('Image Path: {} not exists.'.format(img_path))
                    continue

                img_list.append(img_path)
                label_list.append(item['label'])

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            with open(os.path.join(root_dir, 'val.json'), 'r') as file_stream:
                items = json.load(file_stream)
                for item in items:
                    img_path = os.path.join(root_dir, 'dataset',
                                            item['image_path'])
                    if not os.path.exists(img_path):
                        Log.warn('Image Path: {} not exists.'.format(img_path))
                        continue

                    img_list.append(img_path)
                    label_list.append(item['label'])

        return img_list, label_list
コード例 #6
0
    def val(self):
        """
          Validation function during the train phase.
        """
        self.gan_net.eval()
        start_time = time.time()

        for j, data_dict in enumerate(self.val_loader):
            with torch.no_grad():
                # Forward pass.
                out_dict = self.gan_net(data_dict)
                # Compute the loss of the val batch.

            self.val_losses.update(
                out_dict['loss_G'].mean().item() +
                out_dict['loss_D'].mean().item(),
                len(DCHelper.tolist(data_dict['meta'])))
            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        RunnerHelper.save_net(self, self.gan_net, val_loss=self.val_losses.avg)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        self.batch_time.reset()
        self.val_losses.reset()
        self.gan_net.train()
コード例 #7
0
    def vis_rois(self, inputs, indices_and_rois, rois_labels=None, name='default', sub_dir='rois'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR, sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        for i in range(inputs.size(0)):
            rois = indices_and_rois[indices_and_rois[:, 0] == i][:, 1:]
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(inputs[i])
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
            color_num = len(self.configer.get('details', 'color_list'))

            for j in range(len(rois)):
                label = 1 if rois_labels is None else rois_labels[j]
                if label == 0:
                    continue

                class_name = self.configer.get('details', 'name_seq')[label - 1]
                cv2.rectangle(ori_img,
                              (int(rois[j][0]), int(rois[j][1])),
                              (int(rois[j][2]), int(rois[j][3])),
                              color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=3)
                cv2.putText(ori_img, class_name,
                            (int(rois[j][0]) + 5, int(rois[j][3]) - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5,
                            color=self.configer.get('details', 'color_list')[(label - 1) % color_num], thickness=2)

            img_path = os.path.join(base_dir, '{}_{}_{}.jpg'.format(name, i, time.time()))

            cv2.imwrite(img_path, ori_img)
コード例 #8
0
    def __call__(self):
        # the 30th layer of features is relu of conv5_3
        model = vgg16(pretrained=False)
        if self.configer.get('network', 'pretrained') is not None:
            Log.info('Loading pretrained model: {}'.format(
                self.configer.get('network', 'pretrained')))
            model.load_state_dict(
                torch.load(self.configer.get('network', 'pretrained')))

        features = list(model.features)[:30]
        classifier = model.classifier

        classifier = list(classifier)
        del classifier[6]
        if not self.configer.get('network', 'use_drop'):
            del classifier[5]
            del classifier[2]

        classifier = nn.Sequential(*classifier)

        # freeze top4 conv
        for layer in features[:10]:
            for p in layer.parameters():
                p.requires_grad = False

        return nn.Sequential(*features), classifier
コード例 #9
0
ファイル: data_loader.py プロジェクト: lizhaoliu-Lec/torchcv
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='train',
                                     aug_transform=self.aug_train_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        sampler = None
        if self.configer.get('network.distributed'):
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)

        trainloader = data.DataLoader(
            dataset,
            sampler=sampler,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=(sampler is None),
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))

        return trainloader
コード例 #10
0
    def vis_peaks(self, heatmap_in, ori_img_in, name='default', sub_dir='peaks'):
        base_dir = os.path.join(self.configer.get('project_dir'), POSE_DIR, sub_dir)
        if not os.path.exists(base_dir):
            Log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(heatmap_in, np.ndarray):
            if len(heatmap_in.size()) != 3:
                Log.error('Heatmap size is not valid.')
                exit(1)

            heatmap = heatmap_in.clone().data.cpu().numpy().transpose(1, 2, 0)
        else:
            heatmap = heatmap_in.copy()

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(div_value=self.configer.get('normalize', 'div_value'),
                                  mean=self.configer.get('normalize', 'mean'),
                                  std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        for j in range(self.configer.get('data', 'num_kpts')):
            peaks = self.__get_peaks(heatmap[:, :, j])

            for peak in peaks:
                ori_img = cv2.circle(ori_img, (peak[0], peak[1]),
                                     self.configer.get('vis', 'circle_radius'),
                                     self.configer.get('details', 'color_list')[j], thickness=-1)

            cv2.imwrite(os.path.join(base_dir, '{}_{}.jpg'.format(name, j)), ori_img)
コード例 #11
0
    def __test_img(self, image_path, save_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(
            image_path,
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        ori_width, ori_height = ImageHelper.get_size(ori_image)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image,
                                              mode=self.configer.get(
                                                  'data', 'input_mode'))
        heatmap_avg = np.zeros(
            (ori_height, ori_width, self.configer.get('network',
                                                      'heatmap_out')))
        for i, scale in enumerate(self.configer.get('test', 'scale_search')):
            image = self.blob_helper.make_input(ori_image,
                                                input_size=self.configer.get(
                                                    'test', 'input_size'),
                                                scale=scale)
            with torch.no_grad():
                heatmap_out_list = self.pose_net(image)
                heatmap_out = heatmap_out_list[-1]

                # extract outputs, resize, and remove padding
                heatmap = heatmap_out.squeeze(0).cpu().numpy().transpose(
                    1, 2, 0)
                heatmap = cv2.resize(heatmap, (ori_width, ori_height),
                                     interpolation=cv2.INTER_CUBIC)

                heatmap_avg = heatmap_avg + heatmap / len(
                    self.configer.get('test', 'scale_search'))

        all_peaks = self.__extract_heatmap_info(heatmap_avg)
        image_canvas = self.__draw_key_point(all_peaks, ori_img_bgr)
        ImageHelper.save(image_canvas, save_path)
コード例 #12
0
    def forward(self, loc_preds, conf_preds, loc_targets, conf_targets):
        """Compute loss between (loc_preds, loc_targets) and (conf_preds, conf_targets).

        Args:
          loc_preds(tensor): predicted locations, sized [batch_size, 8732, 4]
          loc_targets(tensor): encoded target locations, sized [batch_size, 8732, 4]
          conf_preds(tensor): predicted class confidences, sized [batch_size, 8732, num_classes]
          conf_targets:(tensor): encoded target classes, sized [batch_size, 8732]
          is_print: whether print loss
          img: using for visualization

        loss:
          (tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(conf_preds, conf_targets)
          loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
          conf_loss = CrossEntropyLoss(pos_conf_preds, pos_conf_targets)
                    + CrossEntropyLoss(neg_conf_preds, neg_conf_targets)

        """
        # loc_targets, conf_targets = self.ssd_target_generator(feat_list, data_dict)
        batch_size, num_boxes, _ = loc_preds.size()

        pos = conf_targets > 0  # [N,8732], pos means the box matched.
        num_matched_boxes = pos.data.float().sum()
        if num_matched_boxes == 0:
            print("No matched boxes")

        # loc_loss.
        pos_mask = pos.unsqueeze(2).expand_as(loc_preds)  # [N, 8732, 4]
        pos_loc_preds = loc_preds[pos_mask].view(-1, 4)  # [pos,4]
        pos_loc_targets = loc_targets[pos_mask].view(-1, 4)  # [pos,4]
        loc_loss = self.smooth_l1_loss(
            pos_loc_preds, pos_loc_targets
        )  # F.smooth_l1_loss(pos_loc_preds, pos_loc_targets, reduction='sum')

        # conf_loss.
        conf_loss = self._cross_entropy_loss(
            conf_preds.view(-1, self.num_classes),
            conf_targets.view(-1))  # [N*8732,]
        neg = self._hard_negative_mining(conf_loss, pos)  # [N,8732]
        pos_mask = pos.unsqueeze(2).expand_as(conf_preds)  # [N,8732,21]
        neg_mask = neg.unsqueeze(2).expand_as(conf_preds)  # [N,8732,21]
        mask = (pos_mask + neg_mask).gt(0)
        pos_and_neg = (pos + neg).gt(0)
        preds = conf_preds[mask].view(-1, self.num_classes)  # [pos + neg,21]
        targets = conf_targets[pos_and_neg]  # [pos + neg,]
        conf_loss = F.cross_entropy(preds,
                                    targets,
                                    reduction='sum',
                                    ignore_index=-1)

        if num_matched_boxes > 0:
            loc_loss = loc_loss / num_matched_boxes
            conf_loss = conf_loss / num_matched_boxes
        else:
            return conf_loss + loc_loss

        Log.debug("loc_loss: %f, cls_loss: %f" %
                  (float(loc_loss.item()), float(conf_loss.item())))

        return loc_loss + conf_loss
コード例 #13
0
    def get_valloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='val',
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     label_transform=self.label_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'cityscapes':
            dataset = CityscapesDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='val',
                                        aug_transform=self.aug_val_transform,
                                        img_transform=self.img_transform,
                                        label_transform=self.label_transform,
                                        configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))

        return valloader
コード例 #14
0
    def _make_parallel(runner, net):
        if runner.configer.get('network.distributed', default=False):
            local_rank = runner.configer.get('local_rank')
            torch.cuda.set_device(local_rank)
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')
            if runner.configer.get('network.syncbn', default=False):
                Log.info('Converting syncbn model...')
                net = nn.SyncBatchNorm.convert_sync_batchnorm(net)

            net = nn.parallel.DistributedDataParallel(
                net.cuda(),
                find_unused_parameters=True,
                device_ids=[local_rank],
                output_device=local_rank)
            # if runner.configer.get('network.syncbn', default=False):
            #     Log.info('Converting syncbn model...')
            #     from apex.parallel import convert_syncbn_model
            #     net = convert_syncbn_model(net)
            # from apex.parallel import DistributedDataParallel
            # net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
            return net

        net = net.to(
            torch.device(
                'cpu' if runner.configer.get('gpu') is None else 'cuda'))
        from lib.parallel.data_parallel import ParallelModel
        return ParallelModel(net,
                             gather_=runner.configer.get('network', 'gather'))
コード例 #15
0
    def vis_default_bboxes(self,
                           ori_img_in,
                           default_bboxes,
                           labels,
                           name='default',
                           sub_dir='encode'):
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        if not isinstance(ori_img_in, np.ndarray):
            ori_img = DeNormalize(
                div_value=self.configer.get('normalize', 'div_value'),
                mean=self.configer.get('normalize', 'mean'),
                std=self.configer.get('normalize', 'std'))(ori_img_in.clone())
            ori_img = ori_img.data.cpu().squeeze().numpy().transpose(
                1, 2, 0).astype(np.uint8)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
        else:
            ori_img = ori_img_in.copy()

        assert labels.size(0) == default_bboxes.size(0)

        bboxes = torch.cat([
            default_bboxes[:, :2] - default_bboxes[:, 2:] / 2,
            default_bboxes[:, :2] + default_bboxes[:, 2:] / 2
        ], 1)
        height, width, _ = ori_img.shape
        for i in range(labels.size(0)):
            if labels[i] == 0:
                continue

            class_name = self.configer.get('details',
                                           'name_seq')[labels[i] - 1]
            color_num = len(self.configer.get('details', 'color_list'))

            cv2.rectangle(
                ori_img,
                (int(bboxes[i][0] * width), int(bboxes[i][1] * height)),
                (int(bboxes[i][2] * width), int(bboxes[i][3] * height)),
                color=self.configer.get(
                    'details', 'color_list')[(labels[i] - 1) % color_num],
                thickness=3)

            cv2.putText(ori_img,
                        class_name, (int(bboxes[i][0] * width) + 5,
                                     int(bboxes[i][3] * height) - 5),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=0.5,
                        color=self.configer.get('details',
                                                'color_list')[(labels[i] - 1) %
                                                              color_num],
                        thickness=2)

        img_path = os.path.join(base_dir, '{}.jpg'.format(name))

        cv2.imwrite(img_path, ori_img)
コード例 #16
0
    def vis_bboxes(self,
                   image_in,
                   bboxes_list,
                   name='default',
                   sub_dir='bbox'):
        """
          Show the diff bbox of individuals.
        """
        base_dir = os.path.join(self.configer.get('project_dir'), DET_DIR,
                                sub_dir)

        if isinstance(image_in, Image.Image):
            image = ImageHelper.rgb2bgr(ImageHelper.to_np(image_in))

        else:
            image = image_in.copy()

        if not os.path.exists(base_dir):
            log.error('Dir:{} not exists!'.format(base_dir))
            os.makedirs(base_dir)

        img_path = os.path.join(
            base_dir,
            name if ImageHelper.is_img(name) else '{}.jpg'.format(name))

        for bbox in bboxes_list:
            image = cv2.rectangle(image, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]), (0, 255, 0), 2)

        cv2.imwrite(img_path, image)
コード例 #17
0
ファイル: configer.py プロジェクト: lizhaoliu-Lec/torchcv
    def update(self, key, value, append=False):
        if key not in self.params_root:
            Log.error('{} Key: {} not existed!!!'.format(
                self._get_caller(), key))
            exit(1)

        self.params_root.put(key, value, append)
コード例 #18
0
    def val(self):
        """
          Validation function during the train phase.
        """
        self.pose_net.eval()
        start_time = time.time()

        with torch.no_grad():
            for i, data_dict in enumerate(self.val_loader):
                # Forward pass.
                out = self.pose_net(data_dict)
                # Compute the loss of the val batch.
                loss_dict = self.pose_loss(out)

                self.val_losses.update({key: loss.item() for key, loss in loss_dict.items()}, data_dict['img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            self.runner_state['val_loss'] = self.val_losses.avg['loss']
            RunnerHelper.save_net(self, self.pose_net, val_loss=self.val_losses.avg['loss'])
            # Print the log info & reset the states.
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {0}\n'.format(self.val_losses.info(), batch_time=self.batch_time))
            self.batch_time.reset()
            self.val_losses.reset()
            self.pose_net.train()
コード例 #19
0
    def train(runner):
        Log.info('Training start...')
        if runner.configer.get('network',
                               'resume') is not None and runner.configer.get(
                                   'network', 'resume_val'):
            runner.val()

        if runner.configer.get('solver', 'lr')['metric'] == 'epoch':
            while runner.runner_state['epoch'] < runner.configer.get(
                    'solver', 'max_epoch'):
                if runner.configer.get('network.distributed'):
                    runner.train_loader.sampler.set_epoch(
                        runner.runner_state['epoch'])

                runner.train()
                if runner.runner_state['epoch'] == runner.configer.get(
                        'solver', 'max_epoch'):
                    runner.val()
                    break
        else:
            while runner.runner_state['iters'] < runner.configer.get(
                    'solver', 'max_iters'):
                if runner.configer.get('network.distributed'):
                    runner.train_loader.sampler.set_epoch(
                        runner.runner_state['epoch'])

                runner.train()
                if runner.runner_state['iters'] == runner.configer.get(
                        'solver', 'max_iters'):
                    runner.val()
                    break

        Log.info('Training end...')
コード例 #20
0
ファイル: data_loader.py プロジェクト: zihua/torchcv
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultCPMDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                             aug_transform=self.aug_val_transform,
                                             img_transform=self.img_transform,
                                             configer=self.configer),

        else:
            Log.error('{} dataset is invalid.'.format(self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
            num_workers=self.configer.get('data', 'workers'), pin_memory=True,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )
        return valloader
コード例 #21
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        label_dir = os.path.join(root_dir, dataset, 'label')

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            label_path = os.path.join(label_dir, file_name)
            img_path = ImageHelper.imgpath(image_dir, image_name)
            if not os.path.exists(label_path) or img_path is None:
                Log.warn('Label Path: {} not exists.'.format(label_path))
                continue

            img_list.append(img_path)
            label_list.append(label_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            label_dir = os.path.join(root_dir, 'val/label')
            for file_name in os.listdir(label_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                label_path = os.path.join(label_dir, file_name)
                img_path = ImageHelper.imgpath(image_dir, image_name)
                if not os.path.exists(label_path) or img_path is None:
                    Log.warn('Label Path: {} not exists.'.format(label_path))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)

        return img_list, label_list
コード例 #22
0
    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.det_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1

        # data_tuple: (inputs, heatmap, maskmap, vecmap)
        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self,
                           warm_list=(0, ),
                           warm_lr_list=(self.configer.get('solver',
                                                           'lr')['base_lr'], ),
                           solver_dict=self.configer.get('solver'))

            self.data_time.update(time.time() - start_time)
            # Forward pass.
            out_dict = self.det_net(data_dict)
            # Compute the loss of the train batch & backward.
            loss = out_dict['loss'].mean()
            self.train_losses.update(loss.item(),
                                     len(DCHelper.tolist(data_dict['meta'])))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.runner_state['epoch'],
                            self.runner_state['iters'],
                            self.configer.get('solver', 'display_iter'),
                            RunnerHelper.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.val()
コード例 #23
0
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find(
                'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    Log.info('initialize network with {}'.format(init_type))
    net.apply(init_func)  # apply the initialization function <init_func>
コード例 #24
0
    def load_net(runner, net, model_path=None, map_location='cpu'):
        if model_path is not None or runner.configer.get('network', 'resume') is not None:
            resume_path = runner.configer.get('network', 'resume')
            resume_path = model_path if model_path is not None else resume_path
            Log.info('Resuming from {}'.format(resume_path))
            resume_dict = torch.load(resume_path, map_location=map_location)
            if 'state_dict' in resume_dict:
                checkpoint_dict = resume_dict['state_dict']

            elif 'model' in resume_dict:
                checkpoint_dict = resume_dict['model']

            elif isinstance(resume_dict, OrderedDict):
                checkpoint_dict = resume_dict

            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(runner.configer.get('network', 'resume')))

            # load state_dict
            if hasattr(net, 'module'):
                RunnerHelper.load_state_dict(net.module, checkpoint_dict,
                                             runner.configer.get('network', 'resume_strict'))
            else:
                RunnerHelper.load_state_dict(net, checkpoint_dict, runner.configer.get('network', 'resume_strict'))

            if runner.configer.get('network', 'resume_continue'):
                # runner.configer.resume(resume_dict['config_dict'])
                runner.runner_state = resume_dict['runner_state']

        net = RunnerHelper._make_parallel(runner, net)
        return net
コード例 #25
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(**self.configer.get('data', 'normalize')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
コード例 #26
0
    def test(self, test_dir, out_dir):
        for _, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=test_dir)):
            data_dict['testing'] = True
            out_dict = self.det_net(data_dict)
            meta_list = DCHelper.tolist(data_dict['meta'])
            batch_detections = self.decode(out_dict['loc'], out_dict['conf'],
                                           self.configer, meta_list)
            for i in range(len(meta_list)):
                ori_img_bgr = ImageHelper.read_image(meta_list[i]['img_path'],
                                                     tool='cv2',
                                                     mode='BGR')
                json_dict = self.__get_info_tree(batch_detections[i])
                image_canvas = self.det_parser.draw_bboxes(
                    ori_img_bgr.copy(),
                    json_dict,
                    conf_threshold=self.configer.get('res', 'vis_conf_thre'))
                ImageHelper.save(image_canvas,
                                 save_path=os.path.join(
                                     out_dir, 'vis/{}.png'.format(
                                         meta_list[i]['filename'])))

                Log.info('Json Path: {}'.format(
                    os.path.join(
                        out_dir,
                        'json/{}.json'.format(meta_list[i]['filename']))))
                JsonHelper.save_file(json_dict,
                                     save_path=os.path.join(
                                         out_dir, 'json/{}.json'.format(
                                             meta_list[i]['filename'])))
コード例 #27
0
    def __list_dirs(self, root_dir, dataset):
        imgA_list = list()
        imgB_list = list()

        imageA_dir = os.path.join(root_dir, dataset, 'imageA')
        imageB_dir = os.path.join(root_dir, dataset, 'imageB')

        for file_name in os.listdir(imageA_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
            imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
            if not os.path.exists(imgA_path) or not os.path.exists(imgB_path):
                Log.warn('Img Path: {} not exists.'.format(imgA_path))
                continue

            imgA_list.append(imgA_path)
            imgB_list.append(imgB_path)

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            imageA_dir = os.path.join(root_dir, 'val/imageA')
            imageB_dir = os.path.join(root_dir, 'val/imageB')
            for file_name in os.listdir(imageA_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                imgA_path = ImageHelper.imgpath(imageA_dir, image_name)
                imgB_path = ImageHelper.imgpath(imageB_dir, image_name)
                if not os.path.exists(imgA_path) or not os.path.exists(
                        imgB_path):
                    Log.warn('Img Path: {} not exists.'.format(imgA_path))
                    continue

                imgA_list.append(imgA_path)
                imgB_list.append(imgB_path)

        return imgA_list, imgB_list
コード例 #28
0
    def save_file(json_dict, save_path):
        dir_name = os.path.dirname(save_path)
        if not os.path.exists(dir_name):
            Log.info('Json Dir: {} not exists.'.format(dir_name))
            os.makedirs(dir_name)

        with open(save_path, 'w') as write_stream:
            write_stream.write(json.dumps(json_dict))
コード例 #29
0
ファイル: configer.py プロジェクト: lizhaoliu-Lec/torchcv
    def get(self, *key, **kwargs):
        key = '.'.join(key)
        if key in self.params_root or 'default' in kwargs:
            return self.params_root.get(key, **kwargs)

        else:
            Log.error('{} KeyError: {}.'.format(self._get_caller(), key))
            exit(1)
コード例 #30
0
ファイル: image_helper.py プロジェクト: lizhaoliu-Lec/torchcv
 def read_image(image_path, tool='pil', mode='RGB'):
     if tool == 'pil':
         return ImageHelper.pil_read_image(image_path, mode=mode)
     elif tool == 'cv2':
         return ImageHelper.cv2_read_image(image_path, mode=mode)
     else:
         Log.error('Not support mode {}'.format(mode))
         exit(1)