예제 #1
0
파일: m2det.py 프로젝트: vuanh96/Thesis
    def init_model(self, base_model_path):
        if self.backbone == 'vgg16':
            if isinstance(base_model_path, str):
                base_weights = torch.load(base_model_path)
                print_info('Loading base network...')
                self.base.load_state_dict(base_weights)
        elif 'res' in self.backbone:
            pass  # pretrained seresnet models are initially loaded when defining them.

        def weights_init(m):
            for key in m.state_dict():
                if key.split('.')[-1] == 'weight':
                    if 'conv' in key:
                        init.kaiming_normal_(m.state_dict()[key], mode='fan_out')
                    if 'bn' in key:
                        m.state_dict()[key][...] = 1
                elif key.split('.')[-1] == 'bias':
                    m.state_dict()[key][...] = 0

        print_info('Initializing weights for [tums, reduce, up_reduce, leach, loc, conf]...')
        for i in range(self.num_levels):
            getattr(self, 'unet{}'.format(i + 1)).apply(weights_init)
        self.reduce.apply(weights_init)
        self.up_reduce.apply(weights_init)
        self.leach.apply(weights_init)
        self.loc.apply(weights_init)
        self.conf.apply(weights_init)
예제 #2
0
    def init_model(self, pretained_model):
        base_state = torch.load(pretained_model)
        self.features.load_state_dict(base_state)
        print_info('Loading base network...')

        def weights_init(m):
            '''
            for key in m.state_dict():
                if key.split('.')[-1] == 'weight':
                    if 'conv' in key:
                        init.kaiming_normal_(
                            m.state_dict()[key], mode='fan_out')
                    if 'bn' in key:
                        m.state_dict()[key][...] = 1
                elif key.split('.')[-1] == 'bias':
                    m.state_dict()[key][...] = 0
            '''
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if 'bias' in m.state_dict().keys():
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        print_info('Initializing weights for [extras, resblock,multibox]...')
        self.extras.apply(weights_init)
        self.resblock.apply(weights_init)
        self.loc.apply(weights_init)
        self.conf.apply(weights_init)
예제 #3
0
 def __init__(self, phase, size, config=None):
     super(STDN, self).__init__()
     self.phase = phase
     self.size = size
     self.init_params(config)
     print_info('===> Constructing STDN model', ['yellow', 'bold'])
     self.construct_modules()
예제 #4
0
파일: m2det.py 프로젝트: vuanh96/Thesis
 def __init__(self, phase, size, config=None):
     '''
     M2Det: Multi-level Multi-scale single-shot object Detector
     '''
     super(M2Det, self).__init__()
     self.phase = phase
     self.size = size
     self.init_params(config)
     print_info('===> Constructing M2Det model', ['yellow', 'bold'])
     self.construct_modules()
예제 #5
0
파일: m2det.py 프로젝트: vuanh96/Thesis
 def load_weights(self, base_file):
     other, ext = os.path.splitext(base_file)
     if ext == '.pkl' or '.pth':
         print_info('Loading weights into state dict...')
         self.load_state_dict(torch.load(base_file))
         print_info('Finished!')
     else:
         print_info('Sorry only .pth and .pkl files supported.')
예제 #6
0
파일: demo.py 프로젝트: sarrrrry/M2Det
import cv2
import numpy as np
import os
import torch
import torch.backends.cudnn as cudnn

from configs.CC import Config
from m2det.datasets import BaseTransform
from layers.functions import Detect, PriorBox
from m2det import build_net
from utils.core import print_info, anchors, init_net
from utils.nms_wrapper import nms

print_info(
    ' ----------------------------------------------------------------------\n'
    '|                       M2Det Demo Program                             |\n'
    ' ----------------------------------------------------------------------',
    ['yellow', 'bold'])


def get_args():
    parser = argparse.ArgumentParser(description='M2Det Testing')
    parser.add_argument('-c',
                        '--config',
                        default='configs/m2det320_vgg.py',
                        type=str)
    parser.add_argument('-f',
                        '--directory',
                        default='imgs/',
                        help='the path to demo images')
    parser.add_argument('-m',
예제 #7
0
from m2det.datasets import detection_collate
from configs.CC import Config

# from utils.core import *

parser = argparse.ArgumentParser(description='M2Det Training')
parser.add_argument('-c', '--config', default='configs/m2det320_vgg.py')
parser.add_argument('-d', '--dataset', default='COCO', help='VOC or COCO dataset')
parser.add_argument('--ngpu', default=1, type=int, help='gpus')
parser.add_argument('--resume_net', default=None, help='resume net for retraining')
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
parser.add_argument('-t', '--tensorboard', type=bool, default=False, help='Use tensorborad to show the Loss Graph')
args = parser.parse_args()

print_info('----------------------------------------------------------------------\n'
           '|                       M2Det Training Program                       |\n'
           '----------------------------------------------------------------------',['yellow','bold'])

logger = set_logger(args.tensorboard)
global cfg
cfg = Config.fromfile(args.config)
net = build_net('train', 
                size = cfg.model.input_size, # Only 320, 512, 704 and 800 are supported
                config = cfg.model.m2det_config)
init_net(net, cfg, args.resume_net) # init the network with pretrained weights or resumed weights

if args.ngpu>1:
    net = torch.nn.DataParallel(net)
if cfg.train_cfg.cuda:
    net.cuda()
    cudnn.benchmark = True