예제 #1
0
파일: models.py 프로젝트: peria1/trainer
    def __init__(self, problem):
        super().__init__()

        from utils.augmentations import FastBaseTransform
        self.FastBaseTransform = FastBaseTransform

        import cv2
        self.cv2 = cv2

        import matplotlib.pyplot as plt
        self.plt = plt

        from layers.output_utils import postprocess, undo_image_transformation
        self.postprocess = postprocess
        self.undo_image_transformation = undo_image_transformation

        from utils import timer
        self.timer = timer

        import sys
        syspathsave = None
        if not 'yolact' in sys.path[1]:
            import copy
            syspathsave = copy.copy(sys.path)
            sys.path.insert(1, '../yolact/')

        from yolact import Yolact
        from train import MultiBoxLoss
        import data as D
        self.D = D

        from collections import defaultdict
        self.color_cache = defaultdict(lambda: {})

        net = Yolact()
        net.train()
        net.init_weights(backbone_path='../yolact/weights/' +
                         D.cfg.backbone.path)

        criterion = MultiBoxLoss(num_classes=D.cfg.num_classes,
                                 pos_threshold=D.cfg.positive_iou_threshold,
                                 neg_threshold=D.cfg.negative_iou_threshold,
                                 negpos_ratio=D.cfg.ohem_negpos_ratio)

        self.net = net
        self.criterion = criterion

        if syspathsave:
            sys.path = syspathsave
예제 #2
0
            img, gt, gt_masks, h, w, num_crowd = dataset.pull_item(i_img)
            batch = img.unsqueeze(0).cuda()
            print(type(batch), batch.size())
            preds = net(batch)
            img_numpy = local_prep_display(preds, img, h, w)
            plt.imshow(img_numpy)
            plt.pause(0.5)

            #-----------------------
            try:
                datum = next(data_loader_iterator)
            except StopIteration:
                break

            images, targets, masks, num_crowds = local_prepare_data(datum)
            net.train()
            predsT = net(images[0])
            losses = criterion(net, predsT, targets[0], masks[0],
                               num_crowds[0])
            loss = sum([losses[k] for k in losses])
            print(loss)
            # no_inf_mean removes some components from the loss, so make sure to backward through all of it
            # all_loss = sum([v.mean() for v in losses.values()])

            # Backprop
            loss.backward(
            )  # Do this to free up vram even if loss is not finite
#-----------------------

#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#    if mode == 'train':