示例#1
0
    def load_model(self):
        # load the model
        self.sst = build_sst('test', 900)

        self.sst.load_state_dict(
            torch.load(config['resume'], map_location='cpu'))
        self.sst.eval()
示例#2
0
    def init(img1_path, img2_path, model_path, cuda):
        print('start init >>>>>>>>>>>>>>')
        if not os.path.exists(img1_path) or not os.path.exists(
                img2_path) or not os.path.exists(model_path):
            raise ValueError("input parameter not right")

        CompareTools.cuda = cuda

        print('load image...')
        # load image
        CompareTools.img1 = cv2.imread(img1_path)
        CompareTools.img2 = cv2.imread(img2_path)
        CompareTools.img1_convert = CompareTools.convert_image(
            CompareTools.img1, CompareTools.cuda)
        CompareTools.img2_convert = CompareTools.convert_image(
            CompareTools.img2, CompareTools.cuda)
        CompareTools.img = np.concatenate(
            [CompareTools.img1, CompareTools.img2], axis=0)
        CompareTools.img_org = np.copy(CompareTools.img)

        print('load model...')
        # load net
        CompareTools.sst = build_sst('test', 900, CompareTools.cuda)
        if cuda:
            cudnn.benchmark = True
            CompareTools.sst.load_state_dict(torch.load(model_path))
            CompareTools.sst = CompareTools.sst.cuda()
        else:
            CompareTools.sst.load_state_dict(torch.load(model_path))

        print('finish init <<<<<<<<<<<<')
示例#3
0
 def load_model(self):
     # load the model
     self.sst = build_sst('test', 900)
     if self.cuda:
         cudnn.benchmark = True
         self.sst.load_state_dict(torch.load(config['resume']))
         self.sst = self.sst.cuda()
     else:
         self.sst.load_state_dict(
             torch.load(config['resume'], map_location='cpu'))
     self.sst.eval()
示例#4
0
 def load_model(self):
     # load the model
     self.sst = build_sst('test', 900)
     if self.cuda:
         cudnn.benchmark = True
         self.sst.load_state_dict(torch.load(config['resume']))
         self.sst = self.sst.cuda()
     else:
         self.sst.load_state_dict(torch.load(config['resume'], map_location='cpu'))
     for param in self.sst.parameters():
         param.requires_grad = False
示例#5
0
    def __init__(self):
        self.tracks = list()
        self.max_drawing_track = TrackerConfig.max_draw_track_node
        self.cuda = TrackerConfig.cuda
        self.recorder = FeatureRecorder()
        self.frame_index = 0

        # load the model
        self.sst = build_sst('test', 900)
        if self.cuda:
            cudnn.benchmark = True
            self.sst.load_state_dict(torch.load(TrackerConfig.sst_model_path))
            self.sst = self.sst.cuda()
        else:
            self.sst.load_state_dict(
                torch.load(config['resume'], map_location='cpu'))
        self.sst.eval()
示例#6
0
from torch.autograd import Variable
import torch.utils.data as data
import numpy as np
import argparse

from data.ua import UATrainDataset
from config.config import config
from layer.sst import build_sst
from layer.sst_loss import SSTLoss
from utils.augmentations import SSJEvalAugment, collate_fn
import time
from utils.operation import show_circle, show_batch_circle_image
import cv2

# build the model
sst = build_sst('test', 900)
if config['cuda']:
    cudnn.benchmark = True
    sst.load_state_dict(torch.load(config['resume']))
    sst = sst.cuda()
else:
    sst.load_state_dict(torch.load(config['resume'], map_location='cpu'))
sst.eval()

dataset = UATrainDataset(
    config['ua_image_root'], config['ua_detection_root'],
    config['ua_ignore_root'],
    SSJEvalAugment(config['sst_dim'], config['mean_pixel']))

data_loader = data.DataLoader(dataset,
                              config['batch_size'],
示例#7
0
    save_weights_iteration = config['save_weight_every_epoch_num'] * config['epoch_size']
else:
    stepvalues = (90000, 95000)
    save_weights_iteration = 5000

gamma = args.gamma
momentum = args.momentum


if args.tensorboard:
    from tensorboardX import SummaryWriter
    if not os.path.exists(config['log_folder']):
        os.mkdir(config['log_folder'])
    writer = SummaryWriter(log_dir=config['log_folder'])

sst_net = build_sst('train')
net = sst_net

if args.cuda:
    net = torch.nn.DataParallel(sst_net)
    cudnn.benchmark = True

if args.resume:
    print('Resuming training, loading {}...'.format(args.resume))
    sst_net.load_weights(args.resume)
else:
    vgg_weights = torch.load(args.basenet)

    print('Loading the base network...')
    sst_net.vgg.load_state_dict(vgg_weights)