Ejemplo n.º 1
0
    def __init__(self,
                 optimizer,
                 model,
                 training_dataloader,
                 validation_dataloader,
                 log_dir=False,
                 max_epoch=100,
                 resume=False,
                 persist_stride=1,
                 verbose=False):

        self.start_epoch = 1
        self.current_epoch = 1

        self.verbose = verbose
        self.max_epoch = max_epoch
        self.persist_stride = persist_stride

        # initialize log
        self.log_dir = log_dir
        log_file = os.path.join(self.log_dir, 'log.txt')
        logging.basicConfig(filename=log_file, level=logging.DEBUG)
        if not self.log_dir:
            self.log_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), 'logs')
        if not os.path.isdir(self.log_dir):
            os.mkdir(self.log_dir)

        # initialize model
        self.optimizer = optimizer
        self.model = model.float().to(device)
        self.model.load_state_dict(model_zoo.load_url(
            Config.VGG16_PRETRAINED_WEIGHTS),
                                   strict=False)
        self.resume = str(resume) if resume else False

        self.training_dataloader = training_dataloader
        self.validation_dataloader = validation_dataloader

        # initialize anchors
        self.anchors = np.vstack(
            list(
                map(
                    lambda x: np.array(x),
                    generate_anchors(Config.ANCHOR_STRIDE, Config.ANCHOR_SIZE,
                                     Config.IMAGE_SIZE))))
        self.anchors_coord_changed = change_coordinate(self.anchors)
        self.len_anchors = len(self.anchors)

        # resume from some model
        if self.resume:
            state_file = seek_model(self.resume)

            print("loading checkpoint {}".format(state_file))
            checkpoint = torch.load(state_file)
            self.start_epoch = self.current_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'], strict=True)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("loaded checkpoint {} (epoch {})".format(
                state_file, self.current_epoch))
Ejemplo n.º 2
0
 def __init__(self,
              model,
              image_size=Config.IMAGE_SIZE,
              threshold=Config.PREDICTION_THRESHOLD):
     checkpoint = torch.load(seek_model(model))
     self.model = Net().to(device)
     self.model.load_state_dict(checkpoint['state_dict'], strict=True)
     self.threshold = threshold
     self.image_size = image_size
Ejemplo n.º 3
0
    def __init__(self, model, image_size=Config.IMAGE_SIZE, threshold=Config.PREDICTION_THRESHOLD):
        if type(model) == str:
            checkpoint = torch.load(seek_model(model))
            self.model = Net().to(device)
            self.model.load_state_dict(checkpoint['state_dict'], strict=True)
        else:
            self.model = model
        self.model.eval()
        self.threshold = threshold
        self.image_size = image_size

        anchor_configs = (
            Config.ANCHOR_STRIDE,
            Config.ANCHOR_SIZE,
        )
Ejemplo n.º 4
0
    def __init__(self,
                 model,
                 image_size=Config.IMAGE_SIZE,
                 threshold=Config.PREDICTION_THRESHOLD):
        checkpoint = torch.load(seek_model(model))
        self.model = Net().to(device)
        self.model.load_state_dict(checkpoint['state_dict'], strict=True)
        self.threshold = threshold
        self.image_size = image_size

        anchor_configs = (Config.ANCHOR_STRIDE, Config.ANCHOR_SIZE,
                          Config.IMAGE_SIZE)
        self.anchors = torch.tensor(
            change_coordinate(
                np.vstack(
                    list(
                        map(lambda x: np.array(x),
                            generate_anchors(*anchor_configs)))))).float()
Ejemplo n.º 5
0
 def __init__(self, model, image_size=Config.IMAGE_SIZE, keep=200):
     checkpoint = torch.load(seek_model(model))
     self.model = Net().to(device)
     self.model.load_state_dict(checkpoint['state_dict'], strict=True)
     self.keep = keep
     self.image_size = image_size