def get_model(args): # initiate model input_channels = int(args.use_multiview) * 128 + int( args.use_normal) * 3 + int( args.use_color) * 3 + int(not args.no_height) model = RefNet(num_class=DC.num_class, num_heading_bin=DC.num_heading_bin, num_size_cluster=DC.num_size_cluster, mean_size_arr=DC.mean_size_arr, input_feature_dim=input_channels, num_proposal=args.num_proposals, use_lang_classifier=(not args.no_lang_cls), use_bidir=args.use_bidir, no_reference=args.no_reference, batch_size=args.batch_size) # trainable model if args.use_pretrained: # load model print("loading pretrained VoteNet...") pretrained_model = RefNet(num_class=DC.num_class, num_heading_bin=DC.num_heading_bin, num_size_cluster=DC.num_size_cluster, mean_size_arr=DC.mean_size_arr, num_proposal=args.num_proposals, input_feature_dim=input_channels, use_bidir=args.use_bidir, no_reference=True) pretrained_path = os.path.join(CONF.PATH.OUTPUT, args.use_pretrained, "model_last.pth") pretrained_model.load_state_dict(torch.load(pretrained_path), strict=False) # mount model.backbone_net = pretrained_model.backbone_net model.vgen = pretrained_model.vgen model.proposal = pretrained_model.proposal if args.no_detection: # freeze pointnet++ backbone for param in model.backbone_net.parameters(): param.requires_grad = False # freeze voting for param in model.vgen.parameters(): param.requires_grad = False # freeze detector for param in model.proposal.parameters(): param.requires_grad = False # to CUDA model = model.cuda() return model
def get_model(args): # initiate model input_channels = int(args.use_multiview) * 128 + int( args.use_normal) * 3 + int( args.use_color) * 3 + int(not args.no_height) model_fn = model_fn_decorator() model = RefNet(num_class=DC.num_class, num_heading_bin=DC.num_heading_bin, num_size_cluster=DC.num_size_cluster, mean_size_arr=DC.mean_size_arr, input_feature_dim=input_channels, num_proposal=args.num_proposals, use_lang_classifier=(not args.no_lang_cls), use_bidir=args.use_bidir, no_reference=args.no_reference, batch_size=args.batch_size, fix_match_module_input=args.fix_match_module_input, model_fn=model_fn) # trainable model if args.use_pretrained: # load model print("loading pretrained PointGroup...") pretrained_model = RefNet(num_class=DC.num_class, num_heading_bin=DC.num_heading_bin, num_size_cluster=DC.num_size_cluster, mean_size_arr=DC.mean_size_arr, num_proposal=args.num_proposals, input_feature_dim=input_channels, use_bidir=args.use_bidir, no_reference=True) if args.use_pretrained[-4:] != ".pth": pretrained_path = os.path.join(CONF.PATH.OUTPUT, args.use_pretrained, "model_last.pth") else: pretrained_path = os.path.join(CONF.PATH.BASE, args.use_pretrained) pretrained_model.load_state_dict(torch.load(pretrained_path), strict=False) # mount model.pointgroup = pretrained_model.pointgroup print("loaded pretrained PG model: ", pretrained_path) if args.no_pg: # freeze PG for param in model.pointgroup.parameters(): param.requires_grad = False print("freezed pg params") # to CUDA model = model.cuda() return model