def forward(self, mid, ref): B, C, H, W = mid.shape mid = F.normalize(mid, p=2, axis=1) ref = F.normalize(ref, p=2, axis=1) cost_volume, ref = compute_cost_volume( mid, ref, max_displacement=self.d) # [B, (2d+1)**2, H, W] cost_volume = F.dimshuffle(cost_volume, (0, 2, 3, 1)) cost_volume = cost_volume.reshape((-1, (2 * self.d + 1)**2)) # argmax indices = F.top_k(cost_volume, k=self.K, descending=True)[1] # [B*H*W, K] del cost_volume ref_list = [] # [B, C, H, W] origin_i_j = F.arange(0, H * W, 1) # float32 origin_i = F.floor(origin_i_j / W) # (H*W, ) origin_j = F.mod(origin_i_j, W) # (H*W, ) del origin_i_j # reshape ref ref = ref.reshape((B, C, (H + 2 * self.d) * (W + 2 * self.d))) for i in range(self.K): index = indices[:, i] # [B*H*W, ] index = index.reshape((-1, H * W)) index_i = F.floor(index / (2 * self.d + 1)) + origin_i # [B, H*W] index_j = F.mod(index, (2 * self.d + 1)) + origin_j # [B, H*W] # 根据每个pixel的i,j 算出index index = index_i * W + index_j # [B, H*W] index = index.astype('int32') # add axis index = F.add_axis(index, axis=1) # [B, 1, H*W] # broadcast index = F.broadcast_to(index, (B, C, H * W)) # gather output = F.gather(ref, axis=2, index=index) # [B, C, H*W] ref_list.append(output.reshape((B, C, H, W))) return self.conv(F.concat(ref_list, axis=1))
def main(): parser = argparse.ArgumentParser() parser.add_argument("-a", "--arch", default="shufflenet_v1_x0_5_g3", type=str) parser.add_argument("-m", "--model", default=None, type=str) parser.add_argument("-i", "--image", default=None, type=str) parser.add_argument("--quantized", action="store_true", help="inference by quantized model, cpu only") args = parser.parse_args() model = getattr(M, args.arch)(pretrained=(args.model is None)) if args.model: state_dict = mge.load(args.model) model.load_state_dict(state_dict, strict=False) if args.quantized: quantize(model) if args.image is None: path = "../../../assets/cat.jpg" else: path = args.image image = cv2.imread(path, cv2.IMREAD_COLOR) transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=[128.0, 128.0, 128.0], std=[1.0, 1.0, 1.0]), # BGR T.ToMode("CHW"), ]) @jit.trace(symbolic=False) def infer_func(processed_img): model.eval() logits = model(processed_img) probs = F.softmax(logits) return probs processed_img = transform.apply(image)[np.newaxis, :] probs = infer_func(processed_img) top_probs, classes = F.top_k(probs, k=5, descending=True) with open("../../../assets/imagenet_class_info.json") as fp: imagenet_class_index = json.load(fp) for rank, (prob, classid) in enumerate( zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))): print("{}: class = {:20s} with probability = {:4.1f} %".format( rank, imagenet_class_index[str(classid)][1], 100 * prob))
def main(): parser = argparse.ArgumentParser() parser.add_argument("-a", "--arch", default="resnet50_frelu", type=str) parser.add_argument("-m", "--model", default=None, type=str) parser.add_argument("-i", "--image", default=None, type=str) args = parser.parse_args() model = getattr(M, args.arch)(pretrained=(args.model is None)) if args.model: state_dict = mge.load(args.model) model.load_state_dict(state_dict) if args.image is None: path = "../../../assets/cat.jpg" # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets else: path = args.image image = cv2.imread(path, cv2.IMREAD_COLOR) transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=[103.530, 116.280, 123.675], std=[57.375, 57.120, 58.395]), # BGR T.ToMode("CHW"), ]) @jit.trace(symbolic=True) def infer_func(processed_img): model.eval() logits = model(processed_img) probs = F.softmax(logits) return probs processed_img = transform.apply(image)[np.newaxis, :] probs = infer_func(processed_img) top_probs, classes = F.top_k(probs, k=5, descending=True) with open( "../../../assets/imagenet_class_info.json" ) as fp: # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets imagenet_class_index = json.load(fp) for rank, (prob, classid) in enumerate( zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))): print("{}: class = {:20s} with probability = {:4.1f} %".format( rank, imagenet_class_index[str(classid)][1], 100 * prob))
def calc_loss(self): outs = self.forward(self.inputs["image"]) loss = 0 for stage_out in outs: for ind, scale_out in enumerate(stage_out[:-1]): label = (self.inputs["heatmap"][:, ind] * (self.inputs["heat_valid"] > 1.1)[:, :, None, None]) tmp = F.square_loss(scale_out, label) loss += tmp / 4 / len(outs) # OHKM loss for the largest heatmap tmp = ((stage_out[-1] - self.inputs["heatmap"][:, -1])** 2).mean(3).mean(2) * (self.inputs["heat_valid"] > 0.1) ohkm_loss = 0 for i in range(tmp.shape[0]): selected_loss, _ = F.top_k(tmp[i], self.keypoint_num // 2, descending=True) ohkm_loss += selected_loss.mean() ohkm_loss /= tmp.shape[0] loss += ohkm_loss return loss
def main(): parser = argparse.ArgumentParser() parser.add_argument("-a", "--arch", default="resnet18", type=str) parser.add_argument("-c", "--checkpoint", default=None, type=str) parser.add_argument("-i", "--image", default=None, type=str) parser.add_argument( "-m", "--mode", default="quantized", type=str, choices=["normal", "qat", "quantized"], help="Quantization Mode\n" "normal: no quantization, using float32\n" "qat: quantization aware training, simulate int8\n" "quantized: convert mode to int8 quantized, inference only") parser.add_argument("--dump", action="store_true", help="Dump quantized model") args = parser.parse_args() if args.mode == "quantized": mge.set_default_device("cpux") model = models.__dict__[args.arch]() if args.mode != "normal": Q.quantize_qat(model, Q.ema_fakequant_qconfig) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) if args.mode == "quantized": Q.quantize(model) if args.image is None: path = "../assets/cat.jpg" else: path = args.image image = cv2.imread(path, cv2.IMREAD_COLOR) transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW"), ]) @jit.trace(symbolic=True) def infer_func(processed_img): model.eval() logits = model(processed_img) probs = F.softmax(logits) return probs processed_img = transform.apply(image)[np.newaxis, :] if args.mode == "normal": processed_img = processed_img.astype("float32") elif args.mode == "quantized": processed_img = processed_img.astype("int8") probs = infer_func(processed_img) top_probs, classes = F.top_k(probs, k=5, descending=True) if args.dump: output_file = ".".join([args.arch, args.mode, "megengine"]) logger.info("Dump to {}".format(output_file)) infer_func.dump(output_file, arg_names=["data"]) mge.save(model.state_dict(), output_file.replace("megengine", "pkl")) with open("../assets/imagenet_class_info.json") as fp: imagenet_class_index = json.load(fp) for rank, (prob, classid) in enumerate( zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))): print("{}: class = {:20s} with probability = {:4.1f} %".format( rank, imagenet_class_index[str(classid)][1], 100 * prob))