Пример #1
0
class Test(object):
    def __init__(self, args, work_dir):
        self.args = args
        self.cfg = cfg
        self.time = Timer()
        self.input_size = cfg.input_size
        self.num_classes = 21
        self.class_name = ('__background__',  # always index 0
                            'aeroplane', 'bicycle', 'bird', 'boat',
                            'bottle', 'bus', 'car', 'cat', 'chair',
                            'cow', 'diningtable', 'dog', 'horse',
                            'motorbike', 'person', 'pottedplant',
                            'sheep', 'sofa', 'train', 'tvmonitor')
        # Define Network
        # initilize the network here.
        if args.net == 'resnet':
            model = DSSD(args=args,
                         cfg=cfg,
                         net=args.net,
                         output_stride=32,
                         num_classes=self.num_classes,
                         img_size=self.input_size,
                         pretrained=True)
        else:
            NotImplementedError
        checkpoint = torch.load(work_dir)
        model.load_state_dict(checkpoint['state_dict'])
        self.model = model.to(self.args.device)

    def test(self, img_path):
        self.time.batch()
        self.model.eval()
        image = cv2.imread(img_path)
        img, ratio, left, top = self.letterbox(image, self.input_size)
        img2 = img
        # Normalize
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB
        img = np.ascontiguousarray(img, dtype=np.float32)
        img = self.normalize(img)

        # image transform to input form of network
        img = torch.from_numpy(img).unsqueeze(0)
        input = img.to(self.args.device)

        output = self.model(input)

        output = output.squeeze(0).cpu()
        output = output[output[:, 4].gt(0)]
        output[:, :4] *= self.input_size
        for ii, name in enumerate(self.class_name):
            print(ii, ':', name, end='. ')
        print(output)
        # output[:, [0, 2]] = output[:, [0, 2]] / ratio - left
        # output[:, [1, 3]] = output[:, [1, 3]] / ratio - top
        # self.show_image(image, output)
        self.show_image(img2, output)

        print("Time cost: %7.3gs" % self.time.batch())

    def letterbox(self, img, height=512, color=(127.5, 127.5, 127.5)):
        """resize a rectangular image to a padded square
        """
        shape = img.shape[:2]  # shape = [height, width]
        ratio = float(height) / max(shape)  # ratio  = old / new
        dw = (max(shape) - shape[1]) / 2  # width padding
        dh = (max(shape) - shape[0]) / 2  # height padding
        left, right = round(dw - 0.1), round(dw + 0.1)
        top, bottom = round(dh - 0.1), round(dh + 0.1)

        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # padded square
        interp = np.random.randint(0, 5)
        img = cv2.resize(img, (height, height), interpolation=interp)  # resized, no border

        return img, ratio, left, top

    def normalize(self, img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        """norm = (x - mean) / std
        """
        img = img / 255.0
        mean = np.array(mean)
        std = np.array(std)
        img = (img - mean[:, np.newaxis, np.newaxis]) / std[:, np.newaxis, np.newaxis]
        return img.astype(np.float32)

    def show_image(self, img, labels):
        import matplotlib.pyplot as plt
        plt.figure(figsize=(10, 10))
        plt.subplot(1, 1, 1).imshow(img[:, :, ::-1])
        plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '-')
        plt.show()
        pass
Пример #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.cfg = cfg
        self.time = Timer()

        # Define Saver
        self.saver = Saver(args, cfg)
        self.saver.save_experiment_config()

        # Define Dataloader
        train_dataset = Detection_Dataset(args, cfg, cfg.train_split, 'train')
        self.num_classes = train_dataset.num_classes
        self.input_size = train_dataset.input_size
        self.train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            num_workers=self.args.workers,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            collate_fn=train_dataset.collate_fn)
        self.num_batch = len(self.train_loader)

        # Define Network
        # initilize the network here.
        if args.net == 'resnet':
            model = DSSD(args=args,
                         cfg=cfg,
                         net=args.net,
                         output_stride=32,
                         num_classes=self.num_classes,
                         img_size=self.input_size,
                         pretrained=True)
        else:
            NotImplementedError

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': cfg.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': cfg.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=cfg.lr,
                                    weight_decay=cfg.weight_decay,
                                    nesterov=False)

        # Define Criterion
        # Whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset))
            if osp.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  cfg.num_classes)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = MultiBoxLoss(args, cfg, self.num_classes, weight)
        self.model, self.optimizer = model, optimizer

        # Define lr scherduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, cfg.lr, args.epochs,
                                      len(self.train_loader))

        # Resuming Checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Using cuda
        # self.optimizer = self.model.to(self.args.device)
        self.model = self.model.to(self.args.device)
        if args.ng > 1 and args.use_multi_gpu:
            print("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=args.gpu_ids)
        # Clear start epoch if fine-tuning
        if args.ft:
            self.start_epoch = 0
        else:
            self.start_epoch = args.start_epoch

        # Visdom
        if args.visdom:
            vis = visdom.Visdom()
            vis_legend = ['Loss_local', 'Loss_confidence', 'mAP', 'mF1']
            self.epoch_plot = create_vis_plot(vis, 'Epoch', 'Loss',
                                              'train loss', vis_legend[0:2])
            self.batch_plot = create_vis_plot(vis, 'Batch', 'Loss',
                                              'batch loss', vis_legend[0:2])
            self.val_plot = create_vis_plot(vis, 'Epoch', 'result', 'val loss',
                                            vis_legend[2:4])
            self.vis = vis
            self.vis_legend = vis_legend
        model_info(self.model)

    def training(self, epoch):
        self.time.epoch()
        self.model.train()
        ave_loss_l = 0.
        ave_loss_c = 0.
        for ii, (images, targets, _, _) in enumerate(self.train_loader):
            num_target = [len(ann) for ann in targets]
            # continue if exist image no target.
            if 0 in num_target:
                continue
            self.time.batch()
            images = images.to(self.args.device)
            targets = [ann.to(self.args.device) for ann in targets]
            self.scheduler(self.optimizer, ii, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output = self.model(images)

            loss_l, loss_c = self.criterion(output, targets)
            loss = loss_l + loss_c
            ave_loss_c += (loss_c - ave_loss_c) / (ii + 1)
            ave_loss_l += (loss_l - ave_loss_l) / (ii + 1)
            assert not torch.isnan(
                loss), 'WARNING: nan loss detected, ending training'
            loss.backward()
            self.optimizer.step()

            # visdom
            if self.args.visdom:
                update_vis_plot(self.vis, ii, [loss_l, loss_c],
                                self.batch_plot, 'append')

            show_info = '[mode: train' +\
                'Epoch: [%d][%d/%d], ' % (epoch, ii, self.num_batch) +\
                'lr: %5.4g, ' % self.optimizer.param_groups[0]['lr'] +\
                'loc_loss: %5.3g, conf_loss: %5.3g, time: %5.2gs]' %\
                (loss_l, loss_c, self.time.batch())
            if (ii + 1) % 50 == 0:
                print(show_info)

            # Save log info
            self.saver.save_log(show_info)

        epoch_show_info = '[mode: train, ' +\
            'Epoch: [%d], ' % epoch +\
            'lr: %5.4g, ' % self.optimizer.param_groups[0]['lr'] +\
            'average_loc_loss: %5.3g, ' % ave_loss_l +\
            'average_conf_loss: %5.3g, ' % ave_loss_c +\
            'time: %5.2gm]' % self.time.epoch()
        print(epoch_show_info)

        # Save log info
        self.saver.save_log(epoch_show_info)

        # visdom
        if self.args.visdom:
            update_vis_plot(self.vis, epoch, [ave_loss_l, ave_loss_c],
                            self.epoch_plot, 'append')