def test(**kwargs): conf.parse(kwargs) model = Network().eval() if conf.LOAD_MODEL_PATH: print(conf.LOAD_MODEL_PATH) model.load_state_dict(torch.load(conf.CHECKPOINTS_ROOT + conf.LOAD_MODEL_PATH)) device = torch.device('cuda:0' if conf.USE_GPU else 'cpu') model.to(device) test_set = ImageFolder(conf.TEST_DATA_ROOT, transform) test_loader = DataLoader(test_set, conf.BATCH_SIZE, shuffle=False, num_workers=conf.NUM_WORKERS) results = list() with torch.no_grad(): for step, (inputs, targets) in enumerate(test_loader): inputs, targets = inputs.to(device), targets.to(device) outs = model(inputs) pred = torch.max(outs, 1)[1] # print((targets == pred).float()) # (prob_top_k, idxs_top_k) = probability.topk(3, dim=1) acc = (pred == targets).float().sum() / len(targets) results += ((pred == targets).float().to('cpu').numpy().tolist()) print('[%5d] acc: %.3f' % (step + 1, acc)) results = np.array(results) print('Top 1 acc: %.3f' % (np.sum(results) / len(results)))
def train(**kwargs): conf.parse(kwargs) # train_set = DataSet(cfg, train=True, test=False) train_set = ImageFolder(conf.TRAIN_DATA_ROOT, transform) train_loader = DataLoader(train_set, conf.BATCH_SIZE, shuffle=True, num_workers=conf.NUM_WORKERS) model = Network() if conf.LOAD_MODEL_PATH: print(conf.LOAD_MODEL_PATH) model.load_state_dict(torch.load(conf.CHECKPOINTS_ROOT + conf.LOAD_MODEL_PATH)) device = torch.device('cuda:0' if conf.USE_GPU else 'cpu') criterion = nn.CrossEntropyLoss().to(device) lr = conf.LEARNING_RATE optim = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=conf.WEIGHT_DECAY) model.to(device) for epoch in range(conf.MAX_EPOCH): model.train() running_loss = 0 for step, (inputs, targets) in tqdm(enumerate(train_loader)): inputs, targets = inputs.to(device), targets.to(device) optim.zero_grad() outs = model(inputs) loss = criterion(outs, targets) loss.backward() optim.step() running_loss += loss.item() if step % conf.PRINT_FREQ == conf.PRINT_FREQ - 1: running_loss = running_loss / conf.PRINT_FREQ print('[%d, %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss)) # vis.plot('loss', running_loss) running_loss = 0 torch.save(model.state_dict(), conf.CHECKPOINTS_ROOT + time.strftime('%Y-%m-%d-%H-%M-%S.pth')) for param_group in optim.param_groups: lr *= conf.LEARNING_RATE_DECAY param_group['lr'] = lr
cfg.MODEL.KSIZE, cfg.MODEL.padding, cfg.MODEL.dilation, cfg.MODEL.scale_list, ) des = HardNetNeiMask(cfg.HARDNET.MARGIN, cfg.MODEL.COO_THRSH) model = Network(det, des, cfg.LOSS.SCORE, cfg.LOSS.PAIR, cfg.PATCH.SIZE, cfg.TRAIN.TOPK) print(f"{gct()} : to device") device = torch.device("cuda") model = model.to(device) resume = args.resume print(f"{gct()} : in {resume}") checkpoint = torch.load(resume) model.load_state_dict(checkpoint["state_dict"]) ############################################################################### # detect and compute ############################################################################### img1_path, img2_path = args.imgpath.split("@") kp1, des1, img1, _, _ = model.detectAndCompute(img1_path, device, (600, 460)) kp2, des2, img2, _, _ = model.detectAndCompute(img2_path, device, (460, 600)) #(460, 600) # kp2, des2, img2,_,_ = model.detectAndCompute(img2_path, device, (600, 460)) #(460, 600) predict_label, nn_kp2 = nearest_neighbor_distance_ratio_match( des1, des2, kp2, 0.9) idx = predict_label.nonzero().view(-1) mkp1 = kp1.index_select(dim=0,
det_lr=cfg.TRAIN.DET_LR, des_lr=cfg.TRAIN.DES_LR, det_wd=cfg.TRAIN.DET_WD, des_wd=cfg.TRAIN.DES_WD, mgpu=mgpu, ) ############################################################################### # resume model if exists ############################################################################### if args.resume: if os.path.isfile(args.resume): print(f"{gct()} : Loading checkpoint {args.resume}") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint["state_dict"]) # det_optim.load_state_dict(checkpoint["det_optim"]) # des_optim.load_state_dict(checkpoint["des_optim"]) else: print(f"{gct()} : Cannot found checkpoint {args.resume}") else: args.start_epoch = 0 ############################################################################### # Visualization ############################################################################### train_writer = SummaryWriter(f"{args.save}/log/train") test_writer = SummaryWriter(f"{args.save}/log/test") ############################################################################### # Training function
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('image', help='Load image to classify') parser.add_argument('model', help='Load trained model') return parser.parse_args() def to_img(tensor: torch.Tensor): return ToPILImage()(tensor.cpu().view(1, 28, 28)) if __name__ == '__main__': args = parse_args() net = Network(1, 64, 5, 10).to(device) net.load_state_dict(torch.load(args.model, map_location=device)) image = Image.open(args.image).convert('L').resize((28, 28)) image = ToTensor()(image).to(device) image = image.view(1, *image.shape) pred, feat = net.predict_with_feature(image) pred = F.softmax(pred, dim=1) grid = plt.GridSpec(3, 2, wspace=0.2, hspace=0.2) orig_img_plot = plt.subplot(grid[0, 0]) orig_img_plot.set_title('input image') orig_img_plot.xaxis.set_major_locator(plt.NullLocator()) orig_img_plot.yaxis.set_major_locator(plt.NullLocator())