Esempio n. 1
0
 def forward(self, mid, ref):
     B, C, H, W = mid.shape
     mid = F.normalize(mid, p=2, axis=1)
     ref = F.normalize(ref, p=2, axis=1)
     cost_volume, ref = compute_cost_volume(
         mid, ref, max_displacement=self.d)  # [B, (2d+1)**2, H, W]
     cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1))
     cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2))
     # argmax
     indices = F.top_k(cost_volume, k=self.K,
                       descending=True)[1]  # [B*H*W, K]
     del cost_volume
     ref_list = []  # [B, C, H, W]
     origin_i_j = F.arange(0, H * W, 1)  # float32
     origin_i = F.floor(origin_i_j / W)  # (H*W, )
     origin_j = F.mod(origin_i_j, W)  # (H*W, )
     del origin_i_j
     # reshape ref
     ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d)))
     for i in range(self.K):
         index = indices[:, i]  # [B*H*W, ]
         index = index.reshape((-1, H * W))
         index_i = F.floor(index / (2 * self.d + 1)) + origin_i  # [B, H*W]
         index_j = F.mod(index, (2 * self.d + 1)) + origin_j  # [B, H*W]
         # 根据每个pixel的i,j 算出index
         index = index_i * W + index_j  # [B, H*W]
         index = index.astype('int32')
         # add axis
         index = F.add_axis(index, axis=1)  # [B, 1, H*W]
         # broadcast
         index = F.broadcast_to(index, (B, C, H * W))
         # gather
         output = F.gather(ref, axis=2, index=index)  # [B, C, H*W]
         ref_list.append(output.reshape((B, C, H, W)))
     return self.conv(F.concat(ref_list, axis=1))
Esempio n. 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-a",
                        "--arch",
                        default="shufflenet_v1_x0_5_g3",
                        type=str)
    parser.add_argument("-m", "--model", default=None, type=str)
    parser.add_argument("-i", "--image", default=None, type=str)
    parser.add_argument("--quantized",
                        action="store_true",
                        help="inference by quantized model, cpu only")
    args = parser.parse_args()

    model = getattr(M, args.arch)(pretrained=(args.model is None))

    if args.model:
        state_dict = mge.load(args.model)
        model.load_state_dict(state_dict, strict=False)

    if args.quantized:
        quantize(model)

    if args.image is None:
        path = "../../../assets/cat.jpg"
    else:
        path = args.image
    image = cv2.imread(path, cv2.IMREAD_COLOR)

    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.Normalize(mean=[128.0, 128.0, 128.0], std=[1.0, 1.0, 1.0]),  # BGR
        T.ToMode("CHW"),
    ])

    @jit.trace(symbolic=False)
    def infer_func(processed_img):
        model.eval()
        logits = model(processed_img)
        probs = F.softmax(logits)
        return probs

    processed_img = transform.apply(image)[np.newaxis, :]
    probs = infer_func(processed_img)

    top_probs, classes = F.top_k(probs, k=5, descending=True)

    with open("../../../assets/imagenet_class_info.json") as fp:
        imagenet_class_index = json.load(fp)

    for rank, (prob, classid) in enumerate(
            zip(top_probs.numpy().reshape(-1),
                classes.numpy().reshape(-1))):
        print("{}: class = {:20s} with probability = {:4.1f} %".format(
            rank, imagenet_class_index[str(classid)][1], 100 * prob))
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--arch", default="resnet50_frelu", type=str)
    parser.add_argument("-m", "--model", default=None, type=str)
    parser.add_argument("-i", "--image", default=None, type=str)
    args = parser.parse_args()

    model = getattr(M, args.arch)(pretrained=(args.model is None))
    if args.model:
        state_dict = mge.load(args.model)
        model.load_state_dict(state_dict)

    if args.image is None:
        path = "../../../assets/cat.jpg"  # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets
    else:
        path = args.image
    image = cv2.imread(path, cv2.IMREAD_COLOR)

    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.Normalize(mean=[103.530, 116.280, 123.675],
                    std=[57.375, 57.120, 58.395]),  # BGR
        T.ToMode("CHW"),
    ])

    @jit.trace(symbolic=True)
    def infer_func(processed_img):
        model.eval()
        logits = model(processed_img)
        probs = F.softmax(logits)
        return probs

    processed_img = transform.apply(image)[np.newaxis, :]
    probs = infer_func(processed_img)

    top_probs, classes = F.top_k(probs, k=5, descending=True)

    with open(
            "../../../assets/imagenet_class_info.json"
    ) as fp:  # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets
        imagenet_class_index = json.load(fp)

    for rank, (prob, classid) in enumerate(
            zip(top_probs.numpy().reshape(-1),
                classes.numpy().reshape(-1))):
        print("{}: class = {:20s} with probability = {:4.1f} %".format(
            rank, imagenet_class_index[str(classid)][1], 100 * prob))
Esempio n. 4
0
    def calc_loss(self):
        outs = self.forward(self.inputs["image"])

        loss = 0
        for stage_out in outs:
            for ind, scale_out in enumerate(stage_out[:-1]):
                label = (self.inputs["heatmap"][:, ind] *
                         (self.inputs["heat_valid"] > 1.1)[:, :, None, None])
                tmp = F.square_loss(scale_out, label)
                loss += tmp / 4 / len(outs)

            # OHKM loss for the largest heatmap
            tmp = ((stage_out[-1] - self.inputs["heatmap"][:, -1])**
                   2).mean(3).mean(2) * (self.inputs["heat_valid"] > 0.1)
            ohkm_loss = 0
            for i in range(tmp.shape[0]):
                selected_loss, _ = F.top_k(tmp[i],
                                           self.keypoint_num // 2,
                                           descending=True)
                ohkm_loss += selected_loss.mean()
            ohkm_loss /= tmp.shape[0]
            loss += ohkm_loss
        return loss
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--arch", default="resnet18", type=str)
    parser.add_argument("-c", "--checkpoint", default=None, type=str)
    parser.add_argument("-i", "--image", default=None, type=str)

    parser.add_argument(
        "-m",
        "--mode",
        default="quantized",
        type=str,
        choices=["normal", "qat", "quantized"],
        help="Quantization Mode\n"
        "normal: no quantization, using float32\n"
        "qat: quantization aware training, simulate int8\n"
        "quantized: convert mode to int8 quantized, inference only")
    parser.add_argument("--dump",
                        action="store_true",
                        help="Dump quantized model")
    args = parser.parse_args()

    if args.mode == "quantized":
        mge.set_default_device("cpux")

    model = models.__dict__[args.arch]()

    if args.mode != "normal":
        Q.quantize_qat(model, Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        Q.quantize(model)

    if args.image is None:
        path = "../assets/cat.jpg"
    else:
        path = args.image
    image = cv2.imread(path, cv2.IMREAD_COLOR)

    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.Normalize(mean=128),
        T.ToMode("CHW"),
    ])

    @jit.trace(symbolic=True)
    def infer_func(processed_img):
        model.eval()
        logits = model(processed_img)
        probs = F.softmax(logits)
        return probs

    processed_img = transform.apply(image)[np.newaxis, :]

    if args.mode == "normal":
        processed_img = processed_img.astype("float32")
    elif args.mode == "quantized":
        processed_img = processed_img.astype("int8")

    probs = infer_func(processed_img)

    top_probs, classes = F.top_k(probs, k=5, descending=True)

    if args.dump:
        output_file = ".".join([args.arch, args.mode, "megengine"])
        logger.info("Dump to {}".format(output_file))
        infer_func.dump(output_file, arg_names=["data"])
        mge.save(model.state_dict(), output_file.replace("megengine", "pkl"))

    with open("../assets/imagenet_class_info.json") as fp:
        imagenet_class_index = json.load(fp)

    for rank, (prob, classid) in enumerate(
            zip(top_probs.numpy().reshape(-1),
                classes.numpy().reshape(-1))):
        print("{}: class = {:20s} with probability = {:4.1f} %".format(
            rank, imagenet_class_index[str(classid)][1], 100 * prob))