def predict(image): image = image_loader(image=image) BACKBONE = IR_50(INPUT_SIZE) HEAD = ArcFace(in_features=EMBEDDING_SIZE, out_features=1000, device_id=GPU_ID) BACKBONE = BACKBONE.to(DEVICE) HEAD = HEAD.to(DEVICE) BACKBONE.load_state_dict( torch.load('./trained_model/Backbone_IR_50_ArcFace_30.pth')) HEAD.load_state_dict( torch.load('./trained_model/Head_IR_50_ArcFace_30.pth')) BACKBONE.eval() HEAD.eval() image = image.to(DEVICE) bs, ncrops, c, h, w = image.size() inputs = image.view(-1, c, h, w) features = BACKBONE(inputs) outputs = HEAD(features, None) outputs = outputs.view(bs, ncrops, -1).mean(1) top_probs, top_labs = outputs.data.topk(1) top_labs = top_labs.cpu().numpy() top_probs = top_probs.cpu().numpy() return int(top_labs), float(top_probs)
print("=" * 60) if os.path.isfile(BACKBONE_RESUME_ROOT) and os.path.isfile(HEAD_RESUME_ROOT): print("Loading Backbone Checkpoint '{}'".format(BACKBONE_RESUME_ROOT)) BACKBONE.load_state_dict(torch.load(BACKBONE_RESUME_ROOT)) print("Loading Head Checkpoint '{}'".format(HEAD_RESUME_ROOT)) HEAD.load_state_dict(torch.load(HEAD_RESUME_ROOT)) else: print("No Checkpoint Found at '{}' and '{}'. Please Have a Check or Continue to Train from Scratch".format(BACKBONE_RESUME_ROOT, HEAD_RESUME_ROOT)) print("=" * 60) if MULTI_GPU: # multi-GPU setting BACKBONE = nn.DataParallel(BACKBONE, device_ids = GPU_ID) BACKBONE = BACKBONE.to(DEVICE) HEAD = nn.DataParallel(HEAD, device_ids = GPU_ID) HEAD = HEAD.to(DEVICE) else: # single-GPU setting BACKBONE = BACKBONE.to(DEVICE) HEAD = HEAD.to(DEVICE) #======= train & validation & save checkpoint =======# DISP_FREQ = len(train_loader) // 100 # frequency to display training loss & acc NUM_EPOCH_WARM_UP = NUM_EPOCH // 25 # use the first 1/25 epochs to warm up NUM_BATCH_WARM_UP = len(train_loader) * NUM_EPOCH_WARM_UP # use the first 1/25 epochs to warm up batch = 0 # batch index for epoch in range(NUM_EPOCH): # start training process