예제 #1
0
def get_model(args):
    # initiate model
    input_channels = int(args.use_multiview) * 128 + int(
        args.use_normal) * 3 + int(
            args.use_color) * 3 + int(not args.no_height)
    model = RefNet(num_class=DC.num_class,
                   num_heading_bin=DC.num_heading_bin,
                   num_size_cluster=DC.num_size_cluster,
                   mean_size_arr=DC.mean_size_arr,
                   input_feature_dim=input_channels,
                   num_proposal=args.num_proposals,
                   use_lang_classifier=(not args.no_lang_cls),
                   use_bidir=args.use_bidir,
                   no_reference=args.no_reference,
                   batch_size=args.batch_size)

    # trainable model
    if args.use_pretrained:
        # load model
        print("loading pretrained VoteNet...")
        pretrained_model = RefNet(num_class=DC.num_class,
                                  num_heading_bin=DC.num_heading_bin,
                                  num_size_cluster=DC.num_size_cluster,
                                  mean_size_arr=DC.mean_size_arr,
                                  num_proposal=args.num_proposals,
                                  input_feature_dim=input_channels,
                                  use_bidir=args.use_bidir,
                                  no_reference=True)

        pretrained_path = os.path.join(CONF.PATH.OUTPUT, args.use_pretrained,
                                       "model_last.pth")
        pretrained_model.load_state_dict(torch.load(pretrained_path),
                                         strict=False)

        # mount
        model.backbone_net = pretrained_model.backbone_net
        model.vgen = pretrained_model.vgen
        model.proposal = pretrained_model.proposal

        if args.no_detection:
            # freeze pointnet++ backbone
            for param in model.backbone_net.parameters():
                param.requires_grad = False

            # freeze voting
            for param in model.vgen.parameters():
                param.requires_grad = False

            # freeze detector
            for param in model.proposal.parameters():
                param.requires_grad = False

    # to CUDA
    model = model.cuda()

    return model
예제 #2
0
def get_model(args):
    # initiate model
    input_channels = int(args.use_multiview) * 128 + int(
        args.use_normal) * 3 + int(
            args.use_color) * 3 + int(not args.no_height)
    model_fn = model_fn_decorator()
    model = RefNet(num_class=DC.num_class,
                   num_heading_bin=DC.num_heading_bin,
                   num_size_cluster=DC.num_size_cluster,
                   mean_size_arr=DC.mean_size_arr,
                   input_feature_dim=input_channels,
                   num_proposal=args.num_proposals,
                   use_lang_classifier=(not args.no_lang_cls),
                   use_bidir=args.use_bidir,
                   no_reference=args.no_reference,
                   batch_size=args.batch_size,
                   fix_match_module_input=args.fix_match_module_input,
                   model_fn=model_fn)
    # trainable model
    if args.use_pretrained:
        # load model
        print("loading pretrained PointGroup...")
        pretrained_model = RefNet(num_class=DC.num_class,
                                  num_heading_bin=DC.num_heading_bin,
                                  num_size_cluster=DC.num_size_cluster,
                                  mean_size_arr=DC.mean_size_arr,
                                  num_proposal=args.num_proposals,
                                  input_feature_dim=input_channels,
                                  use_bidir=args.use_bidir,
                                  no_reference=True)
        if args.use_pretrained[-4:] != ".pth":
            pretrained_path = os.path.join(CONF.PATH.OUTPUT,
                                           args.use_pretrained,
                                           "model_last.pth")
        else:
            pretrained_path = os.path.join(CONF.PATH.BASE, args.use_pretrained)
        pretrained_model.load_state_dict(torch.load(pretrained_path),
                                         strict=False)

        # mount
        model.pointgroup = pretrained_model.pointgroup
        print("loaded pretrained PG model: ", pretrained_path)
        if args.no_pg:
            # freeze PG
            for param in model.pointgroup.parameters():
                param.requires_grad = False
            print("freezed pg params")

    # to CUDA
    model = model.cuda()

    return model
예제 #3
0
def get_model(args, config):
    # load model
    input_channels = int(args.use_multiview) * 128 + int(args.use_normal) * 3 + int(args.use_color) * 3 + int(not args.no_height)
    model = RefNet(
        num_class=config.num_class,
        num_heading_bin=config.num_heading_bin,
        num_size_cluster=config.num_size_cluster,
        mean_size_arr=config.mean_size_arr,
        num_proposal=args.num_proposals,
        input_feature_dim=input_channels,
        use_lang_classifier=(not args.no_lang_cls),
        use_bidir=args.use_bidir,
        attn=args.self_attn,
    ).cuda()

    devices = [int(x) for x in args.devices]
    print("devices", devices, "torch.cuda.device_count()", torch.cuda.device_count())
    model = nn.DataParallel(model, device_ids=devices)

    model_name = "model_last.pth" if args.detection else "model.pth"
    path = os.path.join(CONF.PATH.BASE, args.folder, model_name)
    model.load_state_dict(torch.load(path), strict=False)
    model.eval()

    return model
예제 #4
0
def get_model(args):
    # initiate model
    input_channels = int(args.use_multiview) * 128 + int(
        args.use_normal) * 3 + int(
            args.use_color) * 3 + int(not args.no_height)
    model = RefNet(num_class=DC.num_class,
                   num_heading_bin=DC.num_heading_bin,
                   num_size_cluster=DC.num_size_cluster,
                   mean_size_arr=DC.mean_size_arr,
                   num_proposal=args.num_proposals,
                   input_feature_dim=input_channels,
                   use_lang_classifier=(not args.no_lang_cls)).cuda()

    return model
예제 #5
0
def get_model(args):
    # load model
    input_channels = int(args.use_multiview) * 128 + int(
        args.use_normal) * 3 + int(
            args.use_color) * 3 + int(not args.no_height)
    model = RefNet(num_class=DC.num_class,
                   num_heading_bin=DC.num_heading_bin,
                   num_size_cluster=DC.num_size_cluster,
                   mean_size_arr=DC.mean_size_arr,
                   num_proposal=args.num_proposals,
                   input_feature_dim=input_channels).cuda()

    path = os.path.join(CONF.PATH.OUTPUT, args.folder, "model.pth")
    model.load_state_dict(torch.load(path), strict=False)
    model.eval()

    return model
예제 #6
0
def get_model(args, config):
    # load model
    input_channels = int(args.use_multiview) * 128 + int(
        args.use_normal) * 3 + int(
            args.use_color) * 3 + int(not args.no_height)
    model = RefNet(num_class=config.num_class,
                   num_heading_bin=config.num_heading_bin,
                   num_size_cluster=config.num_size_cluster,
                   mean_size_arr=config.mean_size_arr,
                   num_proposal=args.num_proposals,
                   input_feature_dim=input_channels,
                   use_lang_classifier=(not args.no_lang_cls),
                   use_bidir=args.use_bidir).cuda()

    model_name = "model_last.pth" if args.detection else "model.pth"
    path = os.path.join(CONF.PATH.OUTPUT, args.folder, model_name)
    model.load_state_dict(torch.load(path), strict=False)
    model.eval()

    return model
예제 #7
0
def get_model(args):
    # initiate model
    input_channels = int(args.use_multiview) * 128 + int(args.use_normal) * 3 + int(args.use_color) * 3 + int(not args.no_height)
    model = RefNet(
        num_class=DC.num_class,
        num_heading_bin=DC.num_heading_bin,
        num_size_cluster=DC.num_size_cluster,
        mean_size_arr=DC.mean_size_arr,
        input_feature_dim=input_channels,
        num_proposal=args.num_proposals,
        use_lang_classifier=(not args.no_lang_cls),
        use_bidir=args.use_bidir,
        no_reference=args.no_reference,
        attn=args.self_attn
    )

    # trainable model
    if args.use_pretrained:
        # load model
        print("loading pretrained VoteNet...")
        pretrained_model = RefNet(
            num_class=DC.num_class,
            num_heading_bin=DC.num_heading_bin,
            num_size_cluster=DC.num_size_cluster,
            mean_size_arr=DC.mean_size_arr,
            num_proposal=args.num_proposals,
            input_feature_dim=input_channels,
            use_bidir=args.use_bidir,
            no_reference=True,
            attn=args.self_attn,
        )

        pretrained_path = os.path.join(args.use_pretrained, "model_last.pth")
        pretrained_model.load_state_dict(torch.load(pretrained_path), strict=False)

        # mount
        model.backbone_net = pretrained_model.backbone_net
        model.vgen = pretrained_model.vgen
        model.proposal = pretrained_model.proposal

        if args.no_detection:
            # freeze pointnet++ backbone
            for param in model.backbone_net.parameters():
                param.requires_grad = False

            # freeze voting
            for param in model.vgen.parameters():
                param.requires_grad = False
            
            # freeze detector
            for param in model.proposal.parameters():
                param.requires_grad = False
    
    # to CUDA
    is_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if is_cuda else "cpu")
    model = model.to(device)
    devices = [int(x) for x in args.devices]
    print("devices",devices, "torch.cuda.device_count()",torch.cuda.device_count())
    model = nn.DataParallel(model, device_ids=devices)

    # model = model.cuda()

    return model