def training_process(self):
        if self.main_proc_flag:
            logger = Logger(self.cfg)

        self.model.train()

        total_loss = 0
        total_hm_loss = 0
        total_wh_loss = 0
        total_off_loss = 0

        epoch = 0
        self.training_loader.sampler.set_epoch(epoch)
        training_loader = iter(self.training_loader)

        for step in range(self.cfg.Train.iter_num):
            self.lr_sch.step()
            self.optimizer.zero_grad()

            try:
                imgs, annos, hms, whs, inds, offsets, reg_masks, names = next(
                    training_loader)
            except StopIteration:
                epoch += 1
                self.training_loader.sampler.set_epoch(epoch)
                training_loader = iter(self.training_loader)
                imgs, annos, hms, whs, inds, offsets, reg_masks, names = next(
                    training_loader)

            imgs = imgs.cuda(self.cfg.Distributed.gpu_id)
            hms = hms.cuda(self.cfg.Distributed.gpu_id)
            whs = whs.cuda(self.cfg.Distributed.gpu_id)
            inds = inds.cuda(self.cfg.Distributed.gpu_id)
            offsets = offsets.cuda(self.cfg.Distributed.gpu_id)
            reg_masks = reg_masks.cuda(self.cfg.Distributed.gpu_id)

            targets = hms, whs, inds, offsets, reg_masks

            outs = self.model(imgs)

            hm_loss, wh_loss, off_loss = self.criterion(outs, targets)

            loss = hm_loss + (0.1 * wh_loss) + off_loss
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            total_hm_loss += hm_loss.item()
            total_wh_loss += wh_loss.item()
            total_off_loss += off_loss.item()

            if self.main_proc_flag:
                if step % self.cfg.Train.print_interval == self.cfg.Train.print_interval - 1:
                    # Loss
                    for param_group in self.optimizer.param_groups:
                        lr = param_group['lr']
                    log_data = {
                        'scalar': {
                            'train/total_loss':
                            total_loss / self.cfg.Train.print_interval,
                            'train/hm_loss':
                            total_hm_loss / self.cfg.Train.print_interval,
                            'train/wh_loss':
                            total_wh_loss / self.cfg.Train.print_interval,
                            'train/off_loss':
                            total_off_loss / self.cfg.Train.print_interval,
                            'train/lr': lr
                        }
                    }

                    # Visualization
                    img = (denormalize(imgs[0].cpu()).permute(
                        1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

                    hm, wh, offset = outs[0][0], outs[1][0], outs[2][0]
                    pred_bbox0 = self.transform_bbox(
                        hm,
                        wh,
                        offset,
                        scale_factor=self.cfg.Train.scale_factor).cpu()

                    hm, wh, offset = outs[0][1], outs[1][1], outs[2][1]
                    pred_bbox1 = self.transform_bbox(
                        hm,
                        wh,
                        offset,
                        scale_factor=self.cfg.Train.scale_factor).cpu()

                    # Do nms
                    pred_bbox0 = self._ext_nms(pred_bbox0)
                    pred_bbox1 = self._ext_nms(pred_bbox1)

                    pred0_on_img = visualize(img.copy(),
                                             pred_bbox0,
                                             xywh=False,
                                             with_score=True)
                    pred1_on_img = visualize(img.copy(),
                                             pred_bbox1,
                                             xywh=False,
                                             with_score=True)
                    gt_on_img = visualize(img, annos[0])
                    pred0_on_img = torch.from_numpy(pred0_on_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.
                    pred1_on_img = torch.from_numpy(pred1_on_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.
                    gt_on_img = torch.from_numpy(gt_on_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.

                    log_data['imgs'] = {
                        'train': [pred0_on_img, pred1_on_img, gt_on_img]
                    }
                    logger.log(log_data, step)

                    total_loss = 0
                    total_hm_loss = 0
                    total_wh_loss = 0
                    total_off_loss = 0

                if step % self.cfg.Train.checkpoint_interval == self.cfg.Train.checkpoint_interval - 1 or \
                        step == self.cfg.Train.iter_num - 1:
                    self.save_ckp(self.model.module, step, logger.log_dir)
Пример #2
0
    def training_process(self):
        if self.main_proc_flag:
            logger = Logger(self.cfg)

        self.model.train()

        total_loss = 0
        total_hm_loss = 0
        total_wh_loss = 0
        total_off_loss = 0
        total_s2_reg_loss = 0

        for step in range(self.cfg.Train.iter_num):
            self.lr_sch.step()
            self.optimizer.zero_grad()

            try:
                imgs, annos, gt_hms, gt_whs, gt_inds, gt_offsets, gt_reg_masks, names = self.training_loader.get_batch(
                )
                targets = gt_hms, gt_whs, gt_inds, gt_offsets, gt_reg_masks, annos
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        "WARNING: ran out of memory with exception at step {}."
                        .format(step))
                continue

            outs = self.model(imgs)
            targets = gt_hms, gt_whs, gt_inds, gt_offsets, gt_reg_masks, annos
            hm_loss, wh_loss, offset_loss, s2_reg_loss = self.criterion(
                outs, targets)

            if step < 2000:
                s2_factor = 0
            else:
                s2_factor = 1
            loss = hm_loss + (0.1 *
                              wh_loss) + offset_loss + s2_reg_loss * s2_factor
            loss.backward()
            self.optimizer.step()

            total_loss += float(loss)
            total_hm_loss += float(hm_loss)
            total_wh_loss += float(wh_loss)
            total_off_loss += float(offset_loss)
            total_s2_reg_loss += float(s2_reg_loss)

            if self.main_proc_flag:
                if step % self.cfg.Train.print_interval == self.cfg.Train.print_interval - 1:
                    # Loss
                    for param_group in self.optimizer.param_groups:
                        lr = param_group['lr']
                    log_data = {
                        'scalar': {
                            'train/total_loss':
                            total_loss / self.cfg.Train.print_interval,
                            'train/hm_loss':
                            total_hm_loss / self.cfg.Train.print_interval,
                            'train/wh_loss':
                            total_wh_loss / self.cfg.Train.print_interval,
                            'train/off_loss':
                            total_off_loss / self.cfg.Train.print_interval,
                            'train/s2_reg_loss':
                            total_s2_reg_loss / self.cfg.Train.print_interval,
                            'train/lr':
                            lr
                        }
                    }

                    # Generate bboxs
                    s1_pred_bbox, s2_pred_bbox = self.generate_bbox(
                        outs, batch_idx=0)

                    # Visualization
                    img = (denormalize(imgs[0].cpu()).permute(
                        1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    # Do nms
                    s2_pred_bbox = self._ext_nms(s2_pred_bbox)

                    #
                    s1_pred_on_img = visualize(img.copy(),
                                               s1_pred_bbox,
                                               xywh=True,
                                               with_score=True)
                    s2_pred_on_img = visualize(img.copy(),
                                               s2_pred_bbox,
                                               xywh=True,
                                               with_score=True)
                    gt_img = visualize(img.copy(), annos[0, :, :6], xywh=False)

                    s1_pred_on_img = torch.from_numpy(s1_pred_on_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.
                    s2_pred_on_img = torch.from_numpy(s2_pred_on_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.
                    gt_on_img = torch.from_numpy(gt_img).permute(
                        2, 0, 1).unsqueeze(0).float() / 255.
                    log_data['imgs'] = {
                        'Train': [s1_pred_on_img, s2_pred_on_img, gt_on_img]
                    }
                    logger.log(log_data, step)

                    total_loss = 0
                    total_hm_loss = 0
                    total_wh_loss = 0
                    total_off_loss = 0
                    total_s2_reg_loss = 0

                if step % self.cfg.Train.checkpoint_interval == self.cfg.Train.checkpoint_interval - 1 or \
                        step == self.cfg.Train.iter_num - 1:
                    self.save_ckp(self.model.module, step, logger.log_dir)
Пример #3
0
    def training_process(self):
        logger = Logger(self.cfg, self.main_proc_flag)

        self.model.train()

        total_loss = 0
        total_cls_loss = 0
        total_loc_loss = 0

        epoch = 0
        self.training_loader.sampler.set_epoch(epoch)
        training_loader = iter(self.training_loader)

        for step in range(self.cfg.Train.iter_num):
            self.lr_sch.step()
            self.optimizer.zero_grad()

            try:
                imgs, annos, names = next(training_loader)
            except StopIteration:
                epoch += 1
                self.training_loader.sampler.set_epoch(epoch)
                training_loader = iter(self.training_loader)
                imgs, annos, names = next(training_loader)
            imgs = imgs.cuda(self.cfg.Distributed.gpu_id)
            annos = annos.cuda(self.cfg.Distributed.gpu_id)
            outs = self.model(imgs)
            cls_loss, loc_loss = self.criterion(outs, annos.clone())
            loss = cls_loss + loc_loss
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            total_cls_loss += cls_loss.item()
            total_loc_loss += loc_loss.item()

            if step % self.cfg.Train.print_interval == self.cfg.Train.print_interval - 1:
                # Loss
                log_data = {
                    'scalar': {
                        'train/total_loss':
                        total_loss / self.cfg.Train.print_interval,
                        'train/cls_loss':
                        total_cls_loss / self.cfg.Train.print_interval,
                        'train/loc_loss':
                        total_loc_loss / self.cfg.Train.print_interval,
                    }
                }

                img = (
                    denormalize(imgs[0].cpu()).permute(1, 2, 0).cpu().numpy() *
                    255).astype(np.uint8)
                pred_bbox = self.transform_bbox(outs[1][0], outs[0][0]).cpu()
                vis_img = visualize(img, pred_bbox)
                vis_gt_img = visualize(img, annos[0])
                vis_img = torch.from_numpy(vis_img).permute(
                    2, 0, 1).unsqueeze(0).float() / 255.
                vis_gt_img = torch.from_numpy(vis_gt_img).permute(
                    2, 0, 1).unsqueeze(0).float() / 255.

                log_data['imgs'] = {'train': [vis_img, vis_gt_img]}

                logger.log(log_data, step)

                total_loss = 0
                total_cls_loss = 0
                total_loc_loss = 0

            if self.main_proc_flag and (
                    step % self.cfg.Train.checkpoint_interval == self.cfg.Train.checkpoint_interval - 1 or\
                    step == self.cfg.Train.iter_num - 1):
                self.save_ckp(self.model.module, step, logger.log_dir)