コード例 #1
0
def main(args):
    print("git:\n  {}\n".format(utils.get_sha()))

    print(args)

    valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18,
                     19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35,
                     36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50,
                     51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
                     65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82,
                     84, 85, 86, 87, 88, 89, 90)

    verb_classes = [
        'hold_obj', 'stand', 'sit_instr', 'ride_instr', 'walk', 'look_obj',
        'hit_instr', 'hit_obj', 'eat_obj', 'eat_instr', 'jump_instr',
        'lay_instr', 'talk_on_phone_instr', 'carry_obj', 'throw_obj',
        'catch_obj', 'cut_instr', 'cut_obj', 'run', 'work_on_computer_instr',
        'ski_instr', 'surf_instr', 'skateboard_instr', 'smile', 'drink_instr',
        'kick_obj', 'point_instr', 'read_obj', 'snowboard_instr'
    ]

    device = torch.device(args.device)

    dataset_val = build_dataset(image_set='val', args=args)

    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    args.lr_backbone = 0
    args.masks = False
    backbone = build_backbone(args)
    transformer = build_transformer(args)
    model = DETRHOI(backbone, transformer,
                    len(valid_obj_ids) + 1, len(verb_classes),
                    args.num_queries)
    post_processor = PostProcessHOI(args.num_queries, args.subject_category_id,
                                    dataset_val.correct_mat)
    model.to(device)
    post_processor.to(device)

    checkpoint = torch.load(args.param_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])

    detections = generate(model, post_processor, data_loader_val, device,
                          verb_classes, args.missing_category_id)

    with open(args.save_path, 'wb') as f:
        pickle.dump(detections, f, protocol=2)
コード例 #2
0
ファイル: detr.py プロジェクト: whq-hqw/detr_change
def build(args):
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        num_classes = 250
    device = torch.device(args.device)

    backbone = build_backbone(args)

    transformer = build_transformer(args)

    model = DETR(
        args,
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
    )
    if args.masks:
        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
    matcher = build_matcher(args)
    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef
    # TODO this is a hack
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update(
                {k + f'_{i}': v
                 for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    losses = ['labels', 'boxes', 'cardinality']
    if args.masks:
        losses += ["masks"]
    criterion = SetCriterion(num_classes,
                             matcher=matcher,
                             weight_dict=weight_dict,
                             eos_coef=args.eos_coef,
                             losses=losses)
    criterion.to(device)
    postprocessors = {'bbox': PostProcess()}
    if args.masks:
        postprocessors['segm'] = PostProcessSegm()
        if args.dataset_file == "coco_panoptic":
            is_thing_map = {i: i <= 90 for i in range(201)}
            postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map,
                                                             threshold=0.85)

    return model, criterion, postprocessors
コード例 #3
0
ファイル: pose_transformer.py プロジェクト: sheeranshan/PRTR
def get_pose_net(cfg, is_train, **kwargs):
    extra = cfg.MODEL.EXTRA

    transformer = build_transformer(hidden_dim=extra.HIDDEN_DIM,
                                    dropout=extra.DROPOUT,
                                    nheads=extra.NHEADS,
                                    dim_feedforward=extra.DIM_FEEDFORWARD,
                                    enc_layers=extra.ENC_LAYERS,
                                    dec_layers=extra.DEC_LAYERS,
                                    pre_norm=extra.PRE_NORM)
    pretrained = is_train and cfg.MODEL.INIT_WEIGHTS
    backbone = build_backbone(cfg, pretrained)
    model = PoseTransformer(cfg, backbone, transformer, **kwargs)

    return model
コード例 #4
0
        mask = dg.to_variable(np.zeros([4, 512, 512], dtype=np.bool))
        fake_data = NestedTensor(fake_image, mask)

        for k, v in backbone.state_dict().items():
            print(k + ': ' + str(v.shape))

        out, pos = backbone(fake_data)

        for feature_map in out:
            print(feature_map.tensors.shape)  # [4, 2048, 16, 16]
            print(feature_map.mask.shape)  # [4, 16, 16]

        for pos_tensor in pos:
            print(pos_tensor.shape)  # [4, 256, 16, 16]

        transformer = build_transformer(args)
        features = dg.to_variable(np.zeros([4, 256, 16, 16], dtype="float32"))
        mask = dg.to_variable(np.zeros([4, 16, 16], dtype="bool"))
        query_embed = dg.to_variable(np.zeros([100, 256], dtype="float32"))
        pos_embed = dg.to_variable(np.zeros([4, 256, 16, 16], dtype="float32"))

        hs, memory = transformer(features, mask, query_embed, pos_embed)
        print(hs.shape)  # [6, 4, 100, 256]
        print(memory.shape)  # [4, 256, 16, 16]

        detr, criterion, postprocessors = build(args)
        out = detr(fake_data)
        for name, tensor in out.items():
            if isinstance(tensor, list):
                print(name)
                print()