def init_net(is_train, imdb_classes, args):
    """
    initilize the network here.
    """
    if args.net == 'alexnet':
        fasterRCNN = alexnet(imdb_classes,
                             pretrained=is_train,
                             class_agnostic=args.class_agnostic)
    elif args.net == 'vgg11':
        fasterRCNN = vgg(imdb_classes,
                         11,
                         pretrained=is_train,
                         class_agnostic=args.class_agnostic)
    elif args.net == 'vgg13':
        fasterRCNN = vgg(imdb_classes,
                         13,
                         pretrained=is_train,
                         class_agnostic=args.class_agnostic)
    elif args.net == 'vgg16':
        # fasterRCNN = vgg16(imdb_classes, pretrained=is_train, class_agnostic=args.class_agnostic)
        fasterRCNN = vgg(imdb_classes,
                         16,
                         pretrained=is_train,
                         class_agnostic=args.class_agnostic)
    elif args.net == 'vgg19':
        fasterRCNN = vgg(imdb_classes,
                         19,
                         pretrained=is_train,
                         class_agnostic=args.class_agnostic)
    elif args.net == 'res101':
        fasterRCNN = resnet(imdb_classes,
                            101,
                            pretrained=is_train,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res50':
        fasterRCNN = resnet(imdb_classes,
                            50,
                            pretrained=is_train,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res152':
        fasterRCNN = resnet(imdb_classes,
                            152,
                            pretrained=is_train,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res18':
        fasterRCNN = resnet(imdb_classes,
                            18,
                            pretrained=is_train,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res34':
        fasterRCNN = resnet(imdb_classes,
                            34,
                            pretrained=is_train,
                            class_agnostic=args.class_agnostic)
    else:
        raise Exception("network is not defined")

    fasterRCNN.create_architecture()
    return fasterRCNN
示例#2
0
    load_name = 'models/res101/saba_20171219_train/faster_rcnn_1_100_833.pth'
    result_dir = 'result_res101'

  elif args.net == 'res50':
    fasterRCNN = resnet(pascal_classes, 50, pretrained=False, class_agnostic=args.class_agnostic)
    
  elif args.net == 'res152':
    fasterRCNN = resnet(pascal_classes, 152, pretrained=False, class_agnostic=args.class_agnostic)

  elif args.net == 'res18':
    fasterRCNN = Resnet18(pascal_classes, 18, pretrained=False, class_agnostic=args.class_agnostic)
    load_name = 'models/res18/saba_20171219_train/faster_rcnn_1_100_833.pth'
    result_dir = 'result_res18'

  elif args.net == 'alexnet':
    fasterRCNN = alexnet(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic)
    load_name = 'models/alexnet/saba_20171219_train/faster_rcnn_1_100_833.pth'
    result_dir = 'result_alexnet'

  elif args.net == 'inceptionv3':
    fasterRCNN = Inceptionv3(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic)
    load_name = 'models/inceptionv3/saba_20171219_train/faster_rcnn_1_100_833.pth'
    result_dir = 'result_inceptionv3'
    cfg.TEST.MAX_SIZE = 600

  elif args.net == 'dense121':
    fasterRCNN = Dense121(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic)
    load_name = 'models/dense121/saba_20171219_train/faster_rcnn_1_100_833.pth'
    result_dir = 'result_dense121'

  else:
        gt_boxes = gt_boxes.cuda()

    #  make variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    num_boxes = Variable(num_boxes)
    gt_boxes = Variable(gt_boxes)

    if args.cuda:
        cfg.CUDA = True

    # initilize the network here.

    if args.s_net == 'alexnet':
        student_net = alexnet(imdb.classes,
                              pretrained=True,
                              class_agnostic=args.class_agnostic)
    else:
        print("student network is not defined")
        pdb.set_trace()

    if args.t_net == 'vgg16':
        teacher_net = vgg16(imdb.classes,
                            pretrained=False,
                            class_agnostic=args.class_agnostic,
                            teaching=True)
    elif args.t_net == 'res101':
        teacher_net = resnet(imdb.classes,
                             101,
                             pretrained=False,
                             class_agnostic=args.class_agnostic,