コード例 #1
0
def init_htcn_multi_adv_model(LA_ATT,
                              MID_ATT,
                              class_agnostic,
                              device,
                              gc,
                              imdb,
                              lc,
                              load_name,
                              net,
                              pretrained=True,
                              strict=True,
                              target_num=1):

    if net == 'vgg16':
        fasterRCNN = vgg16(imdb.classes,
                           pretrained=pretrained,
                           class_agnostic=class_agnostic,
                           lc=lc,
                           gc=gc,
                           la_attention=LA_ATT,
                           mid_attention=MID_ATT,
                           target_num=target_num)
    elif net == 'res101':
        fasterRCNN = m_resnet(imdb.classes,
                              101,
                              pretrained=pretrained,
                              class_agnostic=class_agnostic,
                              lc=lc,
                              gc=gc,
                              la_attention=LA_ATT,
                              mid_attention=MID_ATT,
                              target_num=target_num)
    elif net == 'res50':
        fasterRCNN = m_resnet(imdb.classes,
                              50,
                              pretrained=pretrained,
                              class_agnostic=class_agnostic,
                              lc=lc,
                              gc=gc,
                              la_attention=LA_ATT,
                              mid_attention=MID_ATT,
                              target_num=target_num)
    else:
        raise NotImplementedError("Not implemented for other architecture")
    fasterRCNN.create_architecture()
    fasterRCNN.to(device)
    if load_name != "":
        checkpoint = torch.load(load_name)
        fasterRCNN.load_state_dict(checkpoint['model'], strict=strict)
        if 'pooling_mode' in checkpoint.keys():
            cfg.POOLING_MODE = checkpoint['pooling_mode']
        print('Loading pretrained weight from {}'.format(load_name))
    return fasterRCNN
コード例 #2
0
def init_htcn_model_optimizer_with_od(alr,
                                      LA_ATT,
                                      MID_ATT,
                                      class_agnostic,
                                      device,
                                      gc,
                                      imdb,
                                      lc,
                                      load_name,
                                      net,
                                      optimizer,
                                      resume,
                                      session,
                                      start_epoch,
                                      teacher,
                                      distiller_fn,
                                      is_all_params=False):

    optimizer_wd = None
    if net == 'vgg16':
        fasterRCNN = vgg16(imdb.classes,
                           pretrained=True,
                           class_agnostic=class_agnostic,
                           lc=lc,
                           gc=gc,
                           la_attention=LA_ATT,
                           mid_attention=MID_ATT,
                           target_num=1)
    elif net == 'res101':
        fasterRCNN = resnet(imdb.classes,
                            101,
                            pretrained=True,
                            class_agnostic=class_agnostic,
                            lc=lc,
                            gc=gc,
                            la_attention=LA_ATT,
                            mid_attention=MID_ATT)
    elif net == 'res50':
        fasterRCNN = resnet(imdb.classes,
                            50,
                            pretrained=True,
                            class_agnostic=class_agnostic,
                            lc=lc,
                            gc=gc,
                            la_attention=LA_ATT,
                            mid_attention=MID_ATT)
    elif net == 'res152':
        fasterRCNN = resnet(imdb.classes,
                            152,
                            pretrained=True,
                            class_agnostic=class_agnostic,
                            lc=lc,
                            gc=gc,
                            la_attention=LA_ATT,
                            mid_attention=MID_ATT)

    else:
        raise NotImplementedError("Not implemented for other architecture")
    fasterRCNN.create_architecture()
    distill = distiller_fn(teacher, fasterRCNN)

    lr = cfg.TRAIN.LEARNING_RATE
    lr = alr
    params = []
    for key, value in dict(fasterRCNN.named_parameters()).items():
        if value.requires_grad or is_all_params:
            if 'bias' in key:
                params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \
                            'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                }]

    params += [{'params': distill.get_parameters(), 'lr': lr , \
                            'weight_decay': cfg.TRAIN.BIAS_DECAY}]

    if optimizer == "adam":
        lr = lr * 0.1
        optimizer = torch.optim.Adam(params)

    elif optimizer == "sgd":
        optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)

    fasterRCNN.to(device)
    distill.to(device)
    if resume:
        checkpoint = torch.load(load_name)
        session = checkpoint['session']
        start_epoch = checkpoint['epoch']
        fasterRCNN.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        if 'pooling_mode' in checkpoint.keys():
            cfg.POOLING_MODE = checkpoint['pooling_mode']
        print("loaded checkpoint %s" % (load_name))
    return fasterRCNN, lr, optimizer, session, start_epoch, distill
コード例 #3
0
ファイル: trainval_net_HTCN.py プロジェクト: yuzhouzhili/HTCN
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    num_boxes = Variable(num_boxes)
    gt_boxes = Variable(gt_boxes)
    if args.cuda:
        cfg.CUDA = True

    # initilize the network here.
    from model.faster_rcnn.vgg16_HTCN import vgg16
    from model.faster_rcnn.resnet_HTCN import resnet

    if args.net == 'vgg16':
        fasterRCNN = vgg16(imdb.classes,
                           pretrained=True,
                           class_agnostic=args.class_agnostic,
                           lc=args.lc,
                           gc=args.gc,
                           la_attention=args.LA_ATT,
                           mid_attention=args.MID_ATT)
    elif args.net == 'res101':
        fasterRCNN = resnet(imdb.classes,
                            101,
                            pretrained=True,
                            class_agnostic=args.class_agnostic,
                            lc=args.lc,
                            gc=args.gc,
                            la_attention=args.LA_ATT,
                            mid_attention=args.MID_ATT)
    # elif args.net == 'res50':
    #     fasterRCNN = resnet(imdb.classes, 50, pretrained=True, class_agnostic=args.class_agnostic, context=args.context)