Пример #1
0
def train():

    cfg = voc_refinedet["exp"]
    dataset = ExpVOCDetection(root=args.dataset_root, transform=None)

    # im_names = "000069.jpg"
    # image_file = '/home/yiling/data/VOCdevkit/VOC2007/JPEGImages/' + im_names
    # image = cv2.imread(image_file, cv2.IMREAD_COLOR)  # uncomment if dataset not download

    refinedet_net = build_multitridentrefinedet('train', cfg['min_dim'],
                                                cfg['num_classes'])
    net = refinedet_net
    print(net)
    #input()

    if args.cuda:
        net = torch.nn.DataParallel(refinedet_net)
        cudnn.benchmark = True

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        refinedet_net.load_weights(args.resume)
    else:
        if args.withBN:
            vgg_bn_weights = torch.load(args.basenetBN)
            print('Loading base network...')
            model_dict = refinedet_net.vgg.state_dict()
            pretrained_dict = {
                k: v
                for k, v in vgg_bn_weights.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            refinedet_net.vgg.load_state_dict(model_dict)
        else:
            # vgg_weights = torch.load(args.save_folder + args.basenet)
            vgg_weights = torch.load(args.basenet)
            print('Loading base network...')
            refinedet_net.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    if not args.resume:
        print('Initializing weights...')
        # initialize newly added layers' weights with xavier method
        refinedet_net.extras.apply(weights_init)
        refinedet_net.arm_loc.apply(weights_init)
        refinedet_net.arm_conf.apply(weights_init)
        refinedet_net.trm_loc.apply(weights_init)
        refinedet_net.trm_conf.apply(weights_init)
        refinedet_net.branch_for_arm0.apply(bottleneck_init)
        refinedet_net.branch_for_arm1.apply(bottleneck_init)
        refinedet_net.branch_for_arm2.apply(bottleneck_init)
        refinedet_net.branch_for_arm3.apply(bottleneck_init)
        refinedet_net.tcb0.apply(weights_init)
        refinedet_net.tcb1.apply(weights_init)
        refinedet_net.tcb2.apply(weights_init)

        refinedet_net.se0.apply(weights_init)
        refinedet_net.se1.apply(weights_init)
        refinedet_net.se2.apply(weights_init)
        refinedet_net.se3.apply(weights_init)
        # refinedet_net.decov.apply(weights_init)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    arm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5, False,
                                          args.cuda)
    trm_criterion = multitridentMultiBoxLoss(cfg['num_classes'],
                                             0.5,
                                             True,
                                             0,
                                             True,
                                             3,
                                             0.5,
                                             False,
                                             args.cuda,
                                             use_ARM=True,
                                             use_multiscale=True)

    net.train()
    # loss counters
    arm_loc_loss = 0
    arm_conf_loss = 0
    trm_loc_s_loss = 0
    trm_loc_m_loss = 0
    trm_loc_b_loss = 0
    trm_conf_s_loss = 0
    trm_conf_m_loss = 0
    trm_conf_b_loss = 0
    epoch = 0
    print('Loading the dataset...')

    # epoch_size = len(dataset) // args.batch_size
    # print('Training RefineDet on:', dataset.name)
    print('Using the specified args:')
    print(args)

    step_index = 0

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    # create batch iterator
    num_all = np.array(0)
    num_small = np.array(0)
    num_middle = np.array(0)
    num_big = np.array(0)
    batch_iterator = iter(data_loader)

    for iteration in range(args.start_iter, cfg['max_iter']):

        # reset epoch loss counters
        arm_loc_loss = 0
        arm_conf_loss = 0
        trm_loc_s_loss = 0
        trm_loc_m_loss = 0
        trm_loc_b_loss = 0
        trm_conf_s_loss = 0
        trm_conf_m_loss = 0
        trm_conf_b_loss = 0
        epoch += 1

        if iteration in cfg['lr_steps']:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index, iteration)

        # load train data
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        # if dataset.getmyimg() != []:
        #     plt.imshow(dataset.getmyimg())
        #     plt.show()
        img = np.array(images)[0].transpose(1, 2, 0)
        # cv2.imshow("image",img)
        # cv2.waitKey(0)

        images = images.type(torch.FloatTensor)
        if args.cuda:
            images = images.cuda()
            targets = [ann.cuda() for ann in targets]
        else:
            images = images
            targets = [ann for ann in targets]

        # forward
        t0 = time.time()
        out = net(images)

        arm_loc_data, arm_conf_data, trm_loc_data1, trm_conf_data1, trm_loc_data2, trm_conf_data2, trm_loc_data3, trm_conf_data3, priors = out
        use_ARM = False
        threshold = 0.5
        pos_for_small = torch.ByteTensor(1, 6375)
        pos_for_middle = torch.ByteTensor(1, 6375)
        pos_for_big = torch.ByteTensor(1, 6375)
        loc_t = torch.Tensor(1, 6375, 4)
        conf_t = torch.LongTensor(1, 6375)
        matches_list = torch.Tensor(1, 6375, 4)
        defaults_list = torch.Tensor(1, 6375, 4)
        for idx in range(1):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            if True:
                labels = labels >= 0
            defaults = priors.data
            if use_ARM:
                matches, best_pri_overlap, best_pri_idx = refine_match_return_matches(
                    threshold, truths, defaults, cfg['variance'], labels,
                    loc_t, conf_t, idx, arm_loc_data[idx].data)
            else:
                matches, best_pri_overlap, best_pri_idx = refine_match_return_matches(
                    threshold, truths, defaults, cfg['variance'], labels,
                    loc_t, conf_t, idx)
            matches_list[idx] = matches
            defaults_list[idx] = defaults
            pos_for_small[idx], pos_for_middle[idx], pos_for_big[
                idx] = scaleAssign(matches, conf_t, idx)  # matc

        # cv2.destroyAllWindows()
        small_gt_set = set(matches_list[pos_for_small])
        middle_gt_set = set(matches_list[pos_for_middle])
        big_gt_set = set(matches_list[pos_for_big])

        small_anchs = defaults_list[pos_for_small]
        middle_anchs = defaults_list[pos_for_middle]
        big_anchs = defaults_list[pos_for_big]

        img_copy = img.copy()
        for rect in small_gt_set:
            cv2.rectangle(img_copy, (rect[0] * 320, rect[1] * 320),
                          (rect[2] * 320, rect[3] * 320), (255, 255, 255), 2)
        for rect in middle_gt_set:
            cv2.rectangle(img_copy, (rect[0] * 320, rect[1] * 320),
                          (rect[2] * 320, rect[3] * 320), (255, 255, 255), 2)
        for rect in big_gt_set:
            cv2.rectangle(img_copy, (rect[0] * 320, rect[1] * 320),
                          (rect[2] * 320, rect[3] * 320), (255, 255, 255), 2)
        for rect in small_anchs:
            x1 = (rect[0] - rect[2] / 2) * 320
            y1 = (rect[1] - rect[3] / 2) * 320
            x2 = (rect[0] + rect[2] / 2) * 320
            y2 = (rect[1] + rect[3] / 2) * 320
            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (0, 255, 0))  # green
        cv2.imshow("image", img_copy)
        cv2.waitKey(1000 * 2)
        for rect in middle_anchs:
            x1 = (rect[0] - rect[2] / 2) * 320
            y1 = (rect[1] - rect[3] / 2) * 320
            x2 = (rect[0] + rect[2] / 2) * 320
            y2 = (rect[1] + rect[3] / 2) * 320
            cv2.rectangle(img_copy, (x1, y1), (x2, y2),
                          color=(255, 0, 0))  # blue
        cv2.imshow("image", img_copy)
        cv2.waitKey(1000 * 2)
        for rect in big_anchs:
            x1 = (rect[0] - rect[2] / 2) * 320
            y1 = (rect[1] - rect[3] / 2) * 320
            x2 = (rect[0] + rect[2] / 2) * 320
            y2 = (rect[1] + rect[3] / 2) * 320
            cv2.rectangle(img_copy, (x1, y1), (x2, y2),
                          color=(0, 0, 255))  # red
        cv2.imshow("image", img_copy)
        cv2.waitKey(1000 * 2)

        # backprop
        optimizer.zero_grad()
        arm_loss_l, arm_loss_c = arm_criterion(out, targets)
        trm_loss_s_l, trm_loss_m_l, trm_loss_b_l, trm_loss_s_c, trm_loss_m_c, trm_loss_b_c, n_all, n_small, n_middle, n_big = trm_criterion(
            out, targets)

        #input()
        arm_loss = arm_loss_l + arm_loss_c
        trm_loss = trm_loss_s_l + trm_loss_m_l + trm_loss_b_l + trm_loss_s_c + trm_loss_m_c + trm_loss_b_c
        loss = arm_loss + trm_loss
        loss.backward()
        # trm_loss.backward()
        optimizer.step()
        t1 = time.time()
        # arm_loc_loss += arm_loss_l.item()
        # arm_conf_loss += arm_loss_c.item()
        # trm_loc_s_loss += trm_loss_s_l.item()
        # trm_loc_m_loss += trm_loss_m_l.item()
        # trm_loc_b_loss += trm_loss_b_l.item()
        # trm_conf_s_loss += trm_loss_s_c.item()
        # trm_conf_m_loss += trm_loss_m_c.item()
        # trm_conf_b_loss += trm_loss_b_c.item()
        num_all = np.append(num_all, n_all)
        num_small = np.append(num_small, n_small)
        num_middle = np.append(num_middle, n_middle)
        num_big = np.append(num_big, n_big)

        if type(trm_loss_s_l) != float:
            trm_loss_s_l_value = trm_loss_s_l.item()
        else:
            trm_loss_s_l_value = trm_loss_s_l
        if type(trm_loss_m_l) != float:
            trm_loss_m_l_value = trm_loss_m_l.item()
        else:
            trm_loss_m_l_value = trm_loss_m_l
        if type(trm_loss_b_l) != float:
            trm_loss_b_l_value = trm_loss_b_l.item()
        else:
            trm_loss_b_l_value = trm_loss_b_l
        if type(trm_loss_s_c) != float:
            trm_loss_s_c_value = trm_loss_s_c.item()
        else:
            trm_loss_s_c_value = trm_loss_s_c
        if type(trm_loss_m_c) != float:
            trm_loss_m_c_value = trm_loss_m_c.item()
        else:
            trm_loss_m_c_value = trm_loss_m_c
        if type(trm_loss_b_c) != float:
            trm_loss_b_c_value = trm_loss_b_c.item()
        else:
            trm_loss_b_c_value = trm_loss_b_c

        if iteration % 10 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || ARM_L: %.4f ARM_C: %.4f TRM_s_L: %.4f TRM_s_C: %.4f TRM_m_L: %.4f TRM_m_C: %.4f TRM_b_L: %.4f TRM_b_C: %.4f ||' \
            % (arm_loss_l.item(), arm_loss_c.item(), trm_loss_s_l_value, trm_loss_s_c_value, trm_loss_m_l_value, trm_loss_m_c_value, trm_loss_b_l_value, trm_loss_b_c_value), end=' ')
            print('\n' + 'all:{}  small:{}  middle:{}  big:{} lr:{}'.format(
                num_all.mean(), num_small.mean(), num_middle.mean(),
                num_big.mean(), optimizer.param_groups[0]["lr"]))
            num_all = np.array(0)
            num_small = np.array(0)
            num_middle = np.array(0)
            num_big = np.array(0)

        if iteration != 0 and iteration % 5000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(
                refinedet_net.state_dict(),
                args.save_folder + '/RefineDet{}_{}_{}.pth'.format(
                    args.input_size, args.dataset, repr(iteration)))
    torch.save(
        refinedet_net.state_dict(), args.save_folder +
        '/RefineDet{}_{}_final.pth'.format(args.input_size, args.dataset))
def train():
    if args.dataset == 'COCO':
        '''if args.dataset_root == VOC_ROOT:
            if not os.path.exists(COCO_ROOT):
                parser.error('Must specify dataset_root if specifying dataset')
            print("WARNING: Using default COCO dataset_root because " +
                  "--dataset_root was not specified.")
            args.dataset_root = COCO_ROOT
        cfg = coco
        dataset = COCODetection(root=args.dataset_root,
                                transform=SSDAugmentation(cfg['min_dim'],
                                                          MEANS))'''
    elif args.dataset == 'VOC':
        '''if args.dataset_root == COCO_ROOT:
            parser.error('Must specify dataset if specifying dataset_root')'''
        cfg = voc_refinedet[args.input_size]
        dataset = VOCDetection(root=args.dataset_root,
                               transform=SSDAugmentation(cfg['min_dim'],
                                                         MEANS))

    if args.visdom:
        import visdom
        viz = visdom.Visdom()

    refinedet_net = build_multitridentrefinedet('train', cfg['min_dim'], cfg['num_classes'])
    net = refinedet_net
    print(net)
    #input()

    if args.cuda:
        net = torch.nn.DataParallel(refinedet_net)
        cudnn.benchmark = True

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        refinedet_net.load_weights(args.resume)
    else:
        if args.withBN:
            vgg_bn_weights = torch.load(args.basenetBN)
            print('Loading base network...')
            model_dict = refinedet_net.vgg.state_dict()
            pretrained_dict = {k: v for k, v in vgg_bn_weights.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            refinedet_net.vgg.load_state_dict(model_dict)
        else:
            # vgg_weights = torch.load(args.save_folder + args.basenet)
            vgg_weights = torch.load(args.basenet)
            print('Loading base network...')
            refinedet_net.vgg.load_state_dict(vgg_weights)


    if args.cuda:
        net = net.cuda()

    if not args.resume:
        print('Initializing weights...')
        # initialize newly added layers' weights with xavier method
        refinedet_net.extras.apply(weights_init)
        refinedet_net.arm_loc.apply(weights_init)
        refinedet_net.arm_conf.apply(weights_init)
        refinedet_net.trm_loc.apply(weights_init)
        refinedet_net.trm_conf.apply(weights_init)
        refinedet_net.branch_for_arm0.apply(bottleneck_init)
        refinedet_net.branch_for_arm1.apply(bottleneck_init)
        refinedet_net.branch_for_arm2.apply(bottleneck_init)
        refinedet_net.branch_for_arm3.apply(bottleneck_init)
        # refinedet_net.tcb0.apply(weights_init)
        # refinedet_net.tcb1.apply(weights_init)
        # refinedet_net.tcb2.apply(weights_init)

        # refinedet_net.se0.apply(weights_init)
        # refinedet_net.se1.apply(weights_init)
        # refinedet_net.se2.apply(weights_init)
        # refinedet_net.se3.apply(weights_init)
        # refinedet_net.decov.apply(weights_init)

    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                          weight_decay=args.weight_decay)
    arm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda)
    trm_criterion = multitridentMultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda, use_ARM=True, use_multiscale=True)

    net.train()
    # loss counters
    arm_loc_loss = 0
    arm_conf_loss = 0
    trm_loc_s_loss = 0
    trm_loc_m_loss = 0
    trm_loc_b_loss = 0
    trm_conf_s_loss = 0
    trm_conf_m_loss = 0
    trm_conf_b_loss = 0
    epoch = 0
    print('Loading the dataset...')

    epoch_size = len(dataset) // args.batch_size
    print('Training RefineDet on:', dataset.name)
    print('Using the specified args:')
    print(args)

    step_index = 0

    if args.visdom:
        vis_title = 'RefineDet.PyTorch on ' + dataset.name
        vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
        iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
        epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)

    data_loader = data.DataLoader(dataset, args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate,
                                  pin_memory=True)
    # create batch iterator
    num_all = np.array(0)
    num_small = np.array(0)
    num_middle = np.array(0)
    num_big = np.array(0)
    batch_iterator = iter(data_loader)

    for iteration in range(args.start_iter, cfg['max_iter']):
        if args.visdom and iteration != 0 and (iteration % epoch_size == 0):
            update_vis_plot(epoch, arm_loc_loss, arm_conf_loss, epoch_plot, None,
                            'append', epoch_size)
            # reset epoch loss counters
            arm_loc_loss = 0
            arm_conf_loss = 0
            trm_loc_s_loss = 0
            trm_loc_m_loss = 0
            trm_loc_b_loss = 0
            trm_conf_s_loss = 0
            trm_conf_m_loss = 0
            trm_conf_b_loss = 0
            epoch += 1

        if iteration in cfg['lr_steps']:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index, iteration)

        # load train data
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        # if dataset.getmyimg() != []:
        #     plt.imshow(dataset.getmyimg())
        #     plt.show()

        if args.cuda:
            images = images.cuda()
            targets = [ann.cuda() for ann in targets]
        else:
            images = images
            targets = [ann for ann in targets]



        # forward
        t0 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        arm_loss_l, arm_loss_c = arm_criterion(out, targets)
        trm_loss_s_l, trm_loss_m_l, trm_loss_b_l, trm_loss_s_c, trm_loss_m_c, trm_loss_b_c, n_all, n_small, n_middle, n_big = trm_criterion(out, targets)

        #input()
        arm_loss = arm_loss_l + arm_loss_c
        trm_loss = trm_loss_s_l+ trm_loss_m_l+ trm_loss_b_l+trm_loss_s_c+ trm_loss_m_c+trm_loss_b_c
        loss = arm_loss + trm_loss
        loss.backward()
        # trm_loss.backward()
        optimizer.step()
        t1 = time.time()
        # arm_loc_loss += arm_loss_l.item()
        # arm_conf_loss += arm_loss_c.item()
        # trm_loc_s_loss += trm_loss_s_l.item()
        # trm_loc_m_loss += trm_loss_m_l.item()
        # trm_loc_b_loss += trm_loss_b_l.item()
        # trm_conf_s_loss += trm_loss_s_c.item()
        # trm_conf_m_loss += trm_loss_m_c.item()
        # trm_conf_b_loss += trm_loss_b_c.item()
        num_all = np.append(num_all, n_all)
        num_small = np.append(num_small, n_small)
        num_middle = np.append(num_middle, n_middle)
        num_big = np.append(num_big, n_big)

        if type(trm_loss_s_l) != float:
            trm_loss_s_l_value = trm_loss_s_l.item()
        else:
            trm_loss_s_l_value = trm_loss_s_l
        if type(trm_loss_m_l) != float:
            trm_loss_m_l_value = trm_loss_m_l.item()
        else:
            trm_loss_m_l_value = trm_loss_m_l
        if type(trm_loss_b_l) != float:
            trm_loss_b_l_value = trm_loss_b_l.item()
        else:
            trm_loss_b_l_value = trm_loss_b_l
        if type(trm_loss_s_c) != float:
            trm_loss_s_c_value = trm_loss_s_c.item()
        else:
            trm_loss_s_c_value = trm_loss_s_c
        if type(trm_loss_m_c) != float:
            trm_loss_m_c_value = trm_loss_m_c.item()
        else:
            trm_loss_m_c_value = trm_loss_m_c
        if type(trm_loss_b_c) != float:
            trm_loss_b_c_value = trm_loss_b_c.item()
        else:
            trm_loss_b_c_value = trm_loss_b_c

        if iteration % 10 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || ARM_L: %.4f ARM_C: %.4f TRM_s_L: %.4f TRM_s_C: %.4f TRM_m_L: %.4f TRM_m_C: %.4f TRM_b_L: %.4f TRM_b_C: %.4f ||' \
            % (arm_loss_l.item(), arm_loss_c.item(), trm_loss_s_l_value, trm_loss_s_c_value, trm_loss_m_l_value, trm_loss_m_c_value, trm_loss_b_l_value, trm_loss_b_c_value), end=' ')
            print('\n'+'all:{}  small:{}  middle:{}  big:{} lr:{}'.format(num_all.mean(), num_small.mean(), num_middle.mean(), num_big.mean(), optimizer.param_groups[0]["lr"]))
            num_all = np.array(0)
            num_small = np.array(0)
            num_middle = np.array(0)
            num_big = np.array(0)

        if args.visdom:
            update_vis_plot(iteration, arm_loss_l.data[0], arm_loss_c.data[0],
                            iter_plot, epoch_plot, 'append')

        if iteration != 0 and iteration % 5000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(refinedet_net.state_dict(), args.save_folder
            + '/RefineDet{}_{}_{}.pth'.format(args.input_size, args.dataset,
            repr(iteration)))
    torch.save(refinedet_net.state_dict(), args.save_folder
            + '/RefineDet{}_{}_final.pth'.format(args.input_size, args.dataset))
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import cv2
import time

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
size = 320
# from refinedet import build_refinedet
# from models.multitrident_refinedet_v2 import  build_multitridentrefinedet
from models.multitrident_refinedet import build_multitridentrefinedet
net = build_multitridentrefinedet('test', size, 21)  # initialize SSD
# net = build_refinedet('test', 512, 21)
# net.load_weights('../weights/RefineDet512_VOC_final.pth')
# net.load_weights('../weights/experiment/320*320/exp_4_[256relufpn][0.3_0.6][mAP_0.77][dilate:11111-12333-12555]/RefineDet320_VOC_275000.pth')
net.load_weights('../weights/experiment/320*320/RefineDet320_VOC_315000.pth')
"""000210 000111 000144 009539 009589 000069 009539 001275 002333 002338 002341 
002695 002713 003681 003874 003673 003740"""
im_names = "002695.jpg"

image_file = '/home/amax/data/VOCdevkit/VOC2007/JPEGImages/' + im_names
image = cv2.imread(image_file,
                   cv2.IMREAD_COLOR)  # uncomment if dataset not download
#%matplotlib inline
from matplotlib import pyplot as plt
from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform
# here we specify year (07 or 12) and dataset ('test', 'val', 'train')
Пример #4
0
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    evaluate_detections(all_boxes, output_dir, dataset)


def evaluate_detections(box_list, output_dir, dataset):
    write_voc_results_file(box_list, dataset)
    do_python_eval(output_dir)


if __name__ == '__main__':
    # load net
    num_classes = len(labelmap) + 1  # +1 for background
    net = build_multitridentrefinedet('test', int(args.input_size),
                                      num_classes)  # initialize SSD
    net.load_state_dict(torch.load(args.trained_model))
    net.eval()
    print('Finished loading model!')
    # load data
    dataset = VOCDetection(args.voc_root, [('2007', set_type)],
                           BaseTransform(int(args.input_size), dataset_mean),
                           VOCAnnotationTransform())
    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    # evaluation
    test_net(args.save_folder,
             net,
             args.cuda,
             dataset,