示例#1
0
    # store pruning models
    if not os.path.exists(args.prune_folder):
        os.mkdir(args.prune_folder)

    # ------------------------------------------- 1st prune: load model from state_dict
    model = build_refine('train', cfg['min_dim'], cfg['num_classes'], use_refine = True, use_tcb = True).cuda()
    state_dict = torch.load(args.trained_model)
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7] # head = k[:4]
        if head == 'module.': # head == 'vgg.'
            name = k[7:]  # name = 'base.' + k[4:]
        else:
            name = k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    #model.load_state_dict(torch.load(args.trained_model))
    # ------------------------------------------- >= 2nd prune: load model from previous pruning
    # model = torch.load(args.trained_model).cuda()

    print('Finished loading model!')

    testset = VOCDetection(root=args.dataset_root, image_sets=[('2007', 'test')],
                                transform=BaseTransform(cfg['min_dim'], cfg['testset_mean']))
    arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, 0, args.cuda)
    odm_criterion = RefineMultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, 0.01, args.cuda)# 0.01 -> 0.99 negative confidence threshold

    prunner = Prunner_refineDet(testset, arm_criterion, odm_criterion, model)
    prunner.prune(cut_ratio = args.cut_ratio)
示例#2
0
if args.cuda:
    net = torch.nn.DataParallel(net, device_ids=[args.device])

if args.cuda:
    net.cuda()
    cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
#optimizer = optim.RMSprop(net.parameters(), lr=args.lr,alpha = 0.9, eps=1e-08,
#                      momentum=args.momentum, weight_decay=args.weight_decay)

criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False,
                               0.9)
priorbox = PriorBox(cfg)
#with torch.no_grad():
priors = Variable(priorbox.forward(), volatile=True)
if args.cuda:
    priors = priors.cuda()


def train():
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0
    epoch = 0 + args.resume_epoch
    print('Loading Dataset...')
示例#3
0
if args.gpu_id:
    net = torch.nn.DataParallel(net, device_ids=args.gpu_id)

if args.cuda:
    net.cuda()
    cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
#optimizer = optim.RMSprop(net.parameters(), lr=args.lr,alpha = 0.9, eps=1e-08,
#                     momentum=args.momentum, weight_decay=args.weight_decay)

arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False)
odm_criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5,
                                   False, 0.01)
priorbox = PriorBox(cfg)
detector = Detect(num_classes, 0, cfg, object_score=0.01)
priors = Variable(priorbox.forward(), volatile=True)
#dataset
print('Loading Dataset...')
if args.dataset == 'VOC':
    testset = VOCDetection(VOCroot, [('2007', 'test')], None,
                           AnnotationTransform())
    train_dataset = VOCDetection(VOCroot, train_sets,
                                 preproc(img_dim, rgb_means, p),
                                 AnnotationTransform())
elif args.dataset == 'COCO':
    testset = COCODetection(COCOroot, [('2014', 'minival')], None)
def main():
    global args
    args = arg_parse()
    cfg_from_file(args.cfg_file)
    save_folder = args.save_folder
    batch_size = cfg.TRAIN.BATCH_SIZE
    bgr_means = cfg.TRAIN.BGR_MEAN
    p = 0.6
    gamma = cfg.SOLVER.GAMMA
    momentum = cfg.SOLVER.MOMENTUM
    weight_decay = cfg.SOLVER.WEIGHT_DECAY
    size = cfg.MODEL.SIZE
    thresh = cfg.TEST.CONFIDENCE_THRESH
    if cfg.DATASETS.DATA_TYPE == 'VOC':
        trainvalDataset = VOCDetection
        top_k = 1000
    else:
        trainvalDataset = COCODetection
        top_k = 1000
    dataset_name = cfg.DATASETS.DATA_TYPE
    dataroot = cfg.DATASETS.DATAROOT
    trainSet = cfg.DATASETS.TRAIN_TYPE
    valSet = cfg.DATASETS.VAL_TYPE
    num_classes = cfg.MODEL.NUM_CLASSES
    start_epoch = args.resume_epoch
    epoch_step = cfg.SOLVER.EPOCH_STEPS
    end_epoch = cfg.SOLVER.END_EPOCH
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    net = SSD(cfg)
    print(net)
    if cfg.MODEL.SIZE == '300':
        size_cfg = cfg.SMALL
    else:
        size_cfg = cfg.BIG
    optimizer = optim.SGD(
        net.parameters(),
        lr=cfg.SOLVER.BASE_LR,
        momentum=momentum,
        weight_decay=weight_decay)
    if args.resume_net != None:
        checkpoint = torch.load(args.resume_net)
        state_dict = checkpoint['model']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('Loading resume network...')
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net)
    net.cuda()
    cudnn.benchmark = True

    criterion = list()
    if cfg.MODEL.REFINE:
        detector = Detect(cfg)
        arm_criterion = RefineMultiBoxLoss(cfg, 2)
        odm_criterion = RefineMultiBoxLoss(cfg, cfg.MODEL.NUM_CLASSES)
        criterion.append(arm_criterion)
        criterion.append(odm_criterion)
    else:
        detector = Detect(cfg)
        ssd_criterion = MultiBoxLoss(cfg)
        criterion.append(ssd_criterion)
    
    TrainTransform = preproc(size_cfg.IMG_WH, bgr_means, p)
    ValTransform = BaseTransform(size_cfg.IMG_WH, bgr_means, (2, 0, 1))

    val_dataset = trainvalDataset(dataroot, valSet, ValTransform, dataset_name)
    val_loader = data.DataLoader(
        val_dataset,
        batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=detection_collate)

    for epoch in range(start_epoch + 1, end_epoch + 1):
        train_dataset = trainvalDataset(dataroot, trainSet, TrainTransform,
                                        dataset_name)
        epoch_size = len(train_dataset)
        train_loader = data.DataLoader(
            train_dataset,
            batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=detection_collate)
        train(train_loader, net, criterion, optimizer, epoch, epoch_step,
              gamma, end_epoch, cfg)
        if (epoch % 5 == 0) or (epoch % 2 == 0 and epoch >= 60):
            save_checkpoint(net, epoch, size, optimizer)
        if (epoch >= 2 and epoch % 2 == 0):
            eval_net(
                val_dataset,
                val_loader,
                net,
                detector,
                cfg,
                ValTransform,
                top_k,
                thresh=thresh,
                batch_size=batch_size)
    save_checkpoint(net, end_epoch, size, optimizer)
示例#5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_map
    ## deal with args
    args.gpu = gpu
    cfg_from_file(args.cfg_file)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # distributed cfgs
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    torch.cuda.set_device(args.gpu)
    net = SSD(cfg)
    # print(net)
    if args.resume_net != None:
        checkpoint = torch.load(args.resume_net)
        state_dict = checkpoint['model']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)

        print('Loading resume network...')

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            # print(args.gpu)
            torch.cuda.set_device(args.gpu)
            net.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            net = torch.nn.parallel.DistributedDataParallel(
                net, device_ids=[args.gpu])
        else:
            net.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        # torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    # args = arg_parse()

    batch_size = args.batch_size
    print("batch_size = ", batch_size)
    bgr_means = cfg.TRAIN.BGR_MEAN
    p = 0.6

    gamma = cfg.SOLVER.GAMMA
    momentum = cfg.SOLVER.MOMENTUM
    weight_decay = cfg.SOLVER.WEIGHT_DECAY
    size = cfg.MODEL.SIZE  # size =300
    thresh = cfg.TEST.CONFIDENCE_THRESH
    if cfg.DATASETS.DATA_TYPE == 'VOC':
        trainvalDataset = VOCDetection
        top_k = 1000
    else:
        trainvalDataset = COCODetection
        top_k = 1000
    dataset_name = cfg.DATASETS.DATA_TYPE
    dataroot = cfg.DATASETS.DATAROOT
    trainSet = cfg.DATASETS.TRAIN_TYPE
    valSet = cfg.DATASETS.VAL_TYPE
    num_classes = cfg.MODEL.NUM_CLASSES
    start_epoch = args.resume_epoch
    epoch_step = cfg.SOLVER.EPOCH_STEPS
    end_epoch = cfg.SOLVER.END_EPOCH
    args.num_workers = args.workers

    # optimizer

    optimizer = optim.SGD(net.parameters(),
                          lr=cfg.SOLVER.BASE_LR,
                          momentum=momentum,
                          weight_decay=weight_decay)

    if cfg.MODEL.SIZE == '300':
        size_cfg = cfg.SMALL
    else:
        size_cfg = cfg.BIG
    # if args.resume_net != None:
    #	 checkpoint = torch.load(args.resume_net)
    #    optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # deal with criterion
    criterion = list()
    if cfg.MODEL.REFINE:
        detector = Detect(cfg)
        arm_criterion = RefineMultiBoxLoss(cfg, 2)
        odm_criterion = RefineMultiBoxLoss(cfg, cfg.MODEL.NUM_CLASSES)
        arm_criterion.cuda(args.gpu)
        odm_criterion.cuda(args.gpu)
        criterion.append(arm_criterion)
        criterion.append(odm_criterion)
    else:
        detector = Detect(cfg)
        ssd_criterion = MultiBoxLoss(cfg)
        criterion.append(ssd_criterion)

    # deal with dataset
    TrainTransform = preproc(size_cfg.IMG_WH, bgr_means, p)
    ValTransform = BaseTransform(size_cfg.IMG_WH, bgr_means, (2, 0, 1))

    val_dataset = trainvalDataset(dataroot, valSet, ValTransform, dataset_name)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers * ngpus_per_node,
                                 collate_fn=detection_collate)
    # deal with training dataset
    train_dataset = trainvalDataset(dataroot, trainSet, TrainTransform,
                                    dataset_name)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               collate_fn=detection_collate,
                                               pin_memory=True,
                                               sampler=train_sampler)
    ## set net in training phase
    net.train()

    for epoch in range(start_epoch + 1, end_epoch + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=args.num_workers,
        #                               collate_fn=detection_collate)

        # Training
        train(train_loader, net, criterion, optimizer, epoch, epoch_step,
              gamma, end_epoch, cfg, args)

        if (epoch >= 0 and epoch % 10 == 0):
            #print("here",args.rank % ngpus_per_node)
            ## validation the model
            eval_net(val_dataset,
                     val_loader,
                     net,
                     detector,
                     cfg,
                     ValTransform,
                     args,
                     top_k,
                     thresh=thresh,
                     batch_size=cfg.TEST.BATCH_SIZE)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if (epoch % 10 == 0) or (epoch % 5 == 0 and epoch >= 60):
                save_name = os.path.join(
                    args.save_folder,
                    cfg.MODEL.TYPE + "_epoch_{}_rank_{}_{}".format(
                        str(epoch), str(args.rank), str(size)) + '.pth')
                save_checkpoint(net, epoch, size, optimizer, batch_size,
                                save_name)
示例#6
0
if not args.resume:
    from model.networks import net_init
    net_init(ssd_net, args.backbone, logging, refine=args.refine, deform=args.deform, multihead=args.multihead)

if args.augm_type == 'ssd':
    data_transform = SSDAugmentation
else:
    data_transform = BaseTransform

optimizer = optim.SGD(net.parameters(),
                      lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# criterion
if 'RefineDet' in args.backbone and args.refine:
    use_refine = True
    arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, device=device, only_loc=True)
    criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, device=device)
else:
    use_refine = False
    criterion = MultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False, device=device)

priorbox = PriorBox(cfg)
with torch.no_grad():
    priors = priorbox.forward().to(device)

def train():
    net.train()
    epoch = args.start_iter
    if args.dataset_name == 'COCO':
        dataset = COCODetection(COCOroot, year='trainval2014', image_sets=train_sets, transform=data_transform(ssd_dim, means), phase='train')
    else:
def train():
    # network set-up
    ssd_net = build_refine('train',
                           cfg['min_dim'],
                           cfg['num_classes'],
                           use_refine=True,
                           use_tcb=True)
    net = ssd_net

    if args.cuda:
        net = torch.nn.DataParallel(
            ssd_net)  # state_dict will have .module. prefix
        cudnn.benchmark = True

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        ssd_net.load_weights(args.resume)
    else:
        print('Using preloaded base network...')  # Preloaded.
        print('Initializing other weights...')
        # initialize newly added layers' weights with xavier method
        ssd_net.extras.apply(weights_init)
        ssd_net.trans_layers.apply(weights_init)
        ssd_net.latent_layrs.apply(weights_init)
        ssd_net.up_layers.apply(weights_init)
        ssd_net.arm_loc.apply(weights_init)
        ssd_net.arm_conf.apply(weights_init)
        ssd_net.odm_loc.apply(weights_init)
        ssd_net.odm_conf.apply(weights_init)

    if args.cuda:
        net = net.cuda()

    # otimizer and loss set-up
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False, 0,
                                       args.cuda)
    odm_criterion = RefineMultiBoxLoss(
        cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, 0.01,
        args.cuda)  # 0.01 -> 0.99 negative confidence threshold

    # different from normal ssd, where the PriorBox is stored inside SSD object
    priorbox = PriorBox(cfg)
    priors = Variable(priorbox.forward(), volatile=True)
    # detector used in test_net for testing
    detector = RefineDetect(cfg['num_classes'], 0, cfg, object_score=0.01)

    net.train()
    # loss counters
    loc_loss = 0
    conf_loss = 0
    epoch = 0
    print('Loading the dataset...')

    epoch_size = len(dataset) // args.batch_size
    print('Training refineDet on:', dataset.name)
    print('Using the specified args:')
    print(args)

    if args.visdom:
        import visdom
        viz = visdom.Visdom()
        # initialize visdom loss plot
        vis_title = 'SSD.PyTorch on ' + dataset.name
        vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
        iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
        epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)

    # adjust learning rate based on epoch
    stepvalues_VOC = (150 * epoch_size, 200 * epoch_size, 250 * epoch_size)
    stepvalues_COCO = (90 * epoch_size, 120 * epoch_size, 140 * epoch_size)
    stepvalues = (stepvalues_VOC, stepvalues_COCO)[args.dataset == 'COCO']
    step_index = 0

    # training data loader
    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    # create batch iterator
    batch_iterator = iter(data_loader)
    #    batch_iterator = None
    mean_odm_loss_c = 0
    mean_odm_loss_l = 0
    mean_arm_loss_c = 0
    mean_arm_loss_l = 0
    # max_iter = cfg['max_epoch'] * epoch_size
    for iteration in range(args.start_iter,
                           cfg['max_epoch'] * epoch_size + 10):
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(
                data_loader)  # the dataloader cannot re-initilize
            images, targets = next(batch_iterator)

        if args.visdom and iteration != 0 and (iteration % epoch_size == 0):
            # update visdom loss plot
            update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None,
                            'append', epoch_size)
            # reset epoch loss counters
            loc_loss = 0
            conf_loss = 0

        if iteration != 0 and (iteration % epoch_size == 0):
            #        adjust_learning_rate(optimizer, args.gamma, epoch)
            # evaluation
            if args.evaluate == True:
                # load net
                net.eval()
                APs, mAP = test_net(args.eval_folder,
                                    net,
                                    detector,
                                    priors,
                                    args.cuda,
                                    val_dataset,
                                    BaseTransform(net.module.size,
                                                  cfg['testset_mean']),
                                    args.max_per_image,
                                    thresh=args.confidence_threshold
                                    )  # 320 originally for cfg['min_dim']
                net.train()
            epoch += 1

        # update learning rate
        if iteration in stepvalues:
            step_index = stepvalues.index(iteration) + 1
        lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index,
                                  iteration, epoch_size)

        if args.cuda:
            images = Variable(images.cuda())
            targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
        else:
            images = Variable(images)
            targets = [Variable(ann, volatile=True) for ann in targets]
        # forward
        t0 = time.time()
        out = net(images)
        arm_loc, arm_conf, odm_loc, odm_conf = out
        # backprop
        optimizer.zero_grad()
        #arm branch loss
        #priors = priors.type(type(images.data)) #convert to same datatype
        arm_loss_l, arm_loss_c = arm_criterion((arm_loc, arm_conf), priors,
                                               targets)
        #odm branch loss
        odm_loss_l, odm_loss_c = odm_criterion(
            (odm_loc, odm_conf), priors, targets, (arm_loc, arm_conf), False)

        mean_arm_loss_c += arm_loss_c.data[0]
        mean_arm_loss_l += arm_loss_l.data[0]
        mean_odm_loss_c += odm_loss_c.data[0]
        mean_odm_loss_l += odm_loss_l.data[0]

        loss = arm_loss_l + arm_loss_c + odm_loss_l + odm_loss_c
        loss.backward()
        optimizer.step()
        t1 = time.time()

        if iteration % 10 == 0:
            print('Epoch:' + repr(epoch) + ' || epochiter: ' +
                  repr(iteration % epoch_size) + '/' + repr(epoch_size) +
                  '|| Total iter ' + repr(iteration) +
                  ' || AL: %.4f AC: %.4f OL: %.4f OC: %.4f||' %
                  (mean_arm_loss_l / 10, mean_arm_loss_c / 10,
                   mean_odm_loss_l / 10, mean_odm_loss_c / 10) +
                  'Timer: %.4f sec. ||' % (t1 - t0) + 'Loss: %.4f ||' %
                  (loss.data[0]) + 'LR: %.8f' % (lr))

            mean_odm_loss_c = 0
            mean_odm_loss_l = 0
            mean_arm_loss_c = 0
            mean_arm_loss_l = 0


#        if args.visdom:
#            update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
#                            iter_plot, epoch_plot, 'append')

        if iteration != 0 and iteration % 5000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(ssd_net.state_dict(),
                       'weights/ssd300_refineDet_' + repr(iteration) + '.pth')

    torch.save(ssd_net.state_dict(),
               args.save_folder + '' + args.dataset + '.pth')
示例#8
0
def main():
    global args
    args = arg_parse()
    save_folder = args.save_folder
    batch_size = args.batch_size
    bgr_means = (104, 117, 123)
    weight_decay = 0.0005
    p = 0.6
    gamma = 0.1
    momentum = 0.9
    dataset_name = args.dataset
    size = args.size
    channel_size = args.channel_size
    thresh = args.confidence_threshold
    use_refine = False
    if args.version.split("_")[0] == "refine":
        use_refine = True
    if dataset_name[0] == "V":
        cfg = cfg_dict["VOC"][args.version][str(size)]
        trainvalDataset = VOCDetection
        dataroot = VOCroot
        targetTransform = AnnotationTransform()
        valSet = datasets_dict["VOC2007"]
        top_k = 200
    else:
        cfg = cfg_dict["COCO"][args.version][str(size)]
        trainvalDataset = COCODetection
        dataroot = COCOroot
        targetTransform = None
        valSet = datasets_dict["COCOval"]
        top_k = 300
    num_classes = cfg['num_classes']
    start_epoch = args.resume_epoch
    epoch_step = cfg["epoch_step"]
    end_epoch = cfg["end_epoch"]
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)
    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')
    net = model_builder(args.version, cfg, "train", int(size), num_classes,
                        args.channel_size)
    print(net)

    if args.resume_net == None:
        net.load_weights(pretrained_model[args.version])
    else:
        state_dict = torch.load(args.resume_net)
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)
        print('Loading resume network...')
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net)
    if args.cuda:
        net.cuda()
        cudnn.benchmark = True
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
    criterion = list()
    if use_refine:
        detector = Detect(num_classes, 0, cfg, use_arm=use_refine)
        arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5,
                                           False, args.cuda)

        odm_criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3,
                                           0.5, False, 0.01, args.cuda)
        criterion.append(arm_criterion)
        criterion.append(odm_criterion)
    else:
        detector = Detect(num_classes, 0, cfg, use_arm=use_refine)
        ssd_criterion = MultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5,
                                     False, args.cuda)
        criterion.append(ssd_criterion)
    TrainTransform = preproc(cfg["img_wh"], bgr_means, p)
    ValTransform = BaseTransform(cfg["img_wh"], bgr_means, (2, 0, 1))

    val_dataset = trainvalDataset(dataroot, valSet, ValTransform,
                                  targetTransform, dataset_name)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=detection_collate)

    for epoch in range(start_epoch + 1, end_epoch + 1):
        train_dataset = trainvalDataset(dataroot, datasets_dict[dataset_name],
                                        TrainTransform, targetTransform,
                                        dataset_name)
        epoch_size = len(train_dataset)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       collate_fn=detection_collate)
        train(train_loader, net, criterion, optimizer, epoch, epoch_step,
              gamma, use_refine)
        if (epoch % 10 == 0) or (epoch % 5 == 0 and epoch >= 200):
            save_checkpoint(net, epoch, size)
        if (epoch >= 150 and epoch % 10 == 0):
            eval_net(val_dataset,
                     val_loader,
                     net,
                     detector,
                     cfg,
                     ValTransform,
                     top_k,
                     thresh=thresh,
                     batch_size=batch_size)
    eval_net(val_dataset,
             val_loader,
             net,
             detector,
             cfg,
             ValTransform,
             top_k,
             thresh=thresh,
             batch_size=batch_size)
    save_checkpoint(net, end_epoch, size)
示例#9
0
    print('Loading resume network...')

if args.ngpu > 1:
    # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    net = torch.nn.DataParallel(net)

if args.cuda:
    net.cuda()
    cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)

arm_criterion = RefineMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False,
                                   args.cuda)

criterion = RefineMultiBoxLoss(num_classes, 0.5, True, 0, True, 3, 0.5, False,
                               0.01, args.cuda)


def train():
    net.train()
    # loss counters
    loc_loss = 0  # epoch
    conf_loss = 0
    epoch = 0 + args.resume_epoch
    print('Loading Dataset...')

    epoch_size = len(train_dataset) // args.batch_size
    max_iter = args.max_epoch * epoch_size