def init_htcn_multi_adv_model(LA_ATT, MID_ATT, class_agnostic, device, gc, imdb, lc, load_name, net, pretrained=True, strict=True, target_num=1): if net == 'vgg16': fasterRCNN = vgg16(imdb.classes, pretrained=pretrained, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT, target_num=target_num) elif net == 'res101': fasterRCNN = m_resnet(imdb.classes, 101, pretrained=pretrained, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT, target_num=target_num) elif net == 'res50': fasterRCNN = m_resnet(imdb.classes, 50, pretrained=pretrained, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT, target_num=target_num) else: raise NotImplementedError("Not implemented for other architecture") fasterRCNN.create_architecture() fasterRCNN.to(device) if load_name != "": checkpoint = torch.load(load_name) fasterRCNN.load_state_dict(checkpoint['model'], strict=strict) if 'pooling_mode' in checkpoint.keys(): cfg.POOLING_MODE = checkpoint['pooling_mode'] print('Loading pretrained weight from {}'.format(load_name)) return fasterRCNN
def init_htcn_model_optimizer_with_od(alr, LA_ATT, MID_ATT, class_agnostic, device, gc, imdb, lc, load_name, net, optimizer, resume, session, start_epoch, teacher, distiller_fn, is_all_params=False): optimizer_wd = None if net == 'vgg16': fasterRCNN = vgg16(imdb.classes, pretrained=True, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT, target_num=1) elif net == 'res101': fasterRCNN = resnet(imdb.classes, 101, pretrained=True, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT) elif net == 'res50': fasterRCNN = resnet(imdb.classes, 50, pretrained=True, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT) elif net == 'res152': fasterRCNN = resnet(imdb.classes, 152, pretrained=True, class_agnostic=class_agnostic, lc=lc, gc=gc, la_attention=LA_ATT, mid_attention=MID_ATT) else: raise NotImplementedError("Not implemented for other architecture") fasterRCNN.create_architecture() distill = distiller_fn(teacher, fasterRCNN) lr = cfg.TRAIN.LEARNING_RATE lr = alr params = [] for key, value in dict(fasterRCNN.named_parameters()).items(): if value.requires_grad or is_all_params: if 'bias' in key: params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \ 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}] else: params += [{ 'params': [value], 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY }] params += [{'params': distill.get_parameters(), 'lr': lr , \ 'weight_decay': cfg.TRAIN.BIAS_DECAY}] if optimizer == "adam": lr = lr * 0.1 optimizer = torch.optim.Adam(params) elif optimizer == "sgd": optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM) fasterRCNN.to(device) distill.to(device) if resume: checkpoint = torch.load(load_name) session = checkpoint['session'] start_epoch = checkpoint['epoch'] fasterRCNN.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] if 'pooling_mode' in checkpoint.keys(): cfg.POOLING_MODE = checkpoint['pooling_mode'] print("loaded checkpoint %s" % (load_name)) return fasterRCNN, lr, optimizer, session, start_epoch, distill
im_data = Variable(im_data) im_info = Variable(im_info) num_boxes = Variable(num_boxes) gt_boxes = Variable(gt_boxes) if args.cuda: cfg.CUDA = True # initilize the network here. from model.faster_rcnn.vgg16_HTCN import vgg16 from model.faster_rcnn.resnet_HTCN import resnet if args.net == 'vgg16': fasterRCNN = vgg16(imdb.classes, pretrained=True, class_agnostic=args.class_agnostic, lc=args.lc, gc=args.gc, la_attention=args.LA_ATT, mid_attention=args.MID_ATT) elif args.net == 'res101': fasterRCNN = resnet(imdb.classes, 101, pretrained=True, class_agnostic=args.class_agnostic, lc=args.lc, gc=args.gc, la_attention=args.LA_ATT, mid_attention=args.MID_ATT) # elif args.net == 'res50': # fasterRCNN = resnet(imdb.classes, 50, pretrained=True, class_agnostic=args.class_agnostic, context=args.context)