def test_feature(test_loader, mol, cuda, name): test_time = time.time() print('#### Start %s testing with %d batches ####' % (name, len(test_loader))) feature, labels, loss = get_feature_loss(test_loader, mol, cuda) similar_mat = cal_cos(feature) accuracy, _ = cal_accuracy(similar_mat, labels, topk=1) top5accuracy, _ = cal_accuracy(similar_mat, labels, topk=5) print( '[Testing] Feature accuracy = %.5f%%; top5 accuracy = %.5f%%; Decoder loss = %.6f; time cost %.2fs' % (np.mean(accuracy) * 100, np.mean(top5accuracy) * 100, loss, time.time() - test_time)) return accuracy, top5accuracy, loss
def evaluate(): ''' Evaluate the AE model on feature extraction work: 1. Extract features from encode result in AE. 2. Find top k similar pictures and compare their labels ''' model_name = '%s_%s%s_model-%s' % (args.main_model, args.model, '' if args.fea_c is None else args.fea_c, args.dataset) output_file = 'feature/%s.json' % model_name if os.path.exists(output_file) and args.load_feature: print('Loading features...') with open(output_file) as f: res = json.load(f) else: print('Generating features ...') res = generate_feature(output_file) similar_mat = cal_distance(normalize_feature(res['features'])) accuracy, similar_pic = cal_accuracy(similar_mat, res['labels'], model_name, args.top_k)
def train(model): optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_in_epoch, gamma=opt.gamma) # metrics=cfg.metrics best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 txt_log = os.path.join(log_dir, 'log.txt') f_out = open(txt_log, 'w') param = 'lr: {}\tprob_thresh: {}\ttop_k: {}\tlr_decay_in_epoch: {}\toctave_test_thresh: {}\n'.format( opt.lr, cfg.octave_prob_thresh, cfg.octave_top_k, opt.lr_decay_in_epoch, cfg.octave_test_thresh) best_Wscore = 0 print(param) f_out.write(param) for epoch in range(opt.epochs): print('Epoch {}/{}'.format(epoch, opt.epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() start_time = time.time() running_loss = 0.0 running_accuracy = 0 running_precision = 0 running_recall = 0 running_pos_keys = 0 running_total_keys = 0 running_true_keys = 0 pressed_keys = [] for batch_i, (img_paths, imgs, targets) in enumerate(data_loader[phase]): batches_done = len(data_loader[phase]) * epoch + batch_i # for i in range(len(imgs)): # a = imgs[i][0] # img_path = img_paths[i][0] # print(img_path) # test_img=np.array(a*255,dtype=np.uint8).transpose((1,2,0)) # opencv_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR) # cv2.imwrite('./{}.jpg'.format(i), opencv_img) if isinstance(imgs, list): img_lists = [] for img in imgs: img = img.to(device) img_lists.append(img) imgs = img_lists[:] else: imgs = imgs.to(device) #---由于数据集大小限制的原因,不是每个batch刚好都是取batch_size大小的,一般在最后一个iteration,不能够整除 if isinstance(targets, list): press_label = targets[0].to(device) pos_label = targets[1].to(device) batch_shape = targets[0].shape[0] else: press_label = targets.to(device) batch_shape = targets.shape[0] # targets = targets.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(imgs).type(torch.double) #--batch,12 correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy( outputs, press_label, cfg.octave_top_k, cfg.octave_prob_thresh, pressed_keys, epoch) func = nn.Sigmoid() loss = criterion(func(outputs), press_label) # print(loss) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * press_label.size(0) running_accuracy += (correct + TN) running_total_keys += total_keys running_precision += correct running_pos_keys += pos_keys running_recall += recall running_true_keys += true_keys if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_accuracy / running_total_keys epoch_prec = running_precision / running_pos_keys epoch_recall = running_recall / running_true_keys F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec + 1e-6) # torch.save(model.state_dict(),'./checkpoints/epoch_{}.pth'.format(epoch)) data = '{}\tLoss: {:.4f}\tAcc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format( phase, epoch_loss, epoch_acc, epoch_prec, epoch_recall, F) print(data) if phase == 'val': average_Wscore = eval_video(model) print("average_Wscore= : {}".format(average_Wscore)) if phase == 'val' and average_Wscore > best_Wscore: best_acc = F best_model_wts = copy.deepcopy(model.state_dict()) best_epoch = epoch best_Wscore = average_Wscore f_out.write('epoch: {}:\n'.format(epoch)) f_out.write(data) f_out.write('\n') f_out.write('\n') time_elapsed = time.time() - start_time print('current epoch time cost {:.2f} minutes'.format( (time_elapsed) / 60)) print('\n') print('Epoch {} has the best Acc is {} '.format(best_epoch, best_acc)) f_out.close() if opt.with_Time: if opt.Data_type: #--包含位置和时间信息 torch.save( best_model_wts, "checkpoints/wb_{}_octave_time_with_pos_keys_epoch_{}_Acc_{:.3f}.pth" .format(opt.mode, best_epoch, best_acc)) else: #--只包含时间信息 torch.save( best_model_wts, "checkpoints/wb_{}_octave_time_keys_epoch_{}_Acc_{:.3f}.pth". format(opt.mode, best_epoch, best_acc)) elif opt.Data_type: #---包含位置信息不包含时间信息 torch.save( best_model_wts, "checkpoints/wb_{}_octave_with_pos_keys_epoch_{}_Acc_{:.3f}.pth". format(opt.mode, best_epoch, best_acc)) else: #---什么都不包含 torch.save( best_model_wts, "checkpoints/wb_{}_octave_keys_epoch_{}_Acc_{:.3f}.pth".format( opt.mode, best_epoch, best_acc))
def train(model): # modify learning rate of last layer finetune_params = modify_last_layer_lr(model.named_parameters(), opt.lr, cfg.lr_mult_w, cfg.lr_mult_b) optimizer = optim.SGD(finetune_params, opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_in_epoch, gamma=opt.gamma) # metrics=cfg.metrics best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 txt_log = os.path.join(log_dir, 'log.txt') f_out = open(txt_log, 'w') param = 'prob_thresh: {}\ttop_k: {}\n'.format(cfg.prob_thresh, cfg.top_k) print(param) f_out.write(param) for epoch in range(opt.epochs): print('Epoch {}/{}'.format(epoch, opt.epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() start_time = time.time() running_loss = 0.0 running_accuracy = 0 running_precision = 0 running_recall = 0 running_pos_keys = 0 running_total_keys = 0 running_true_keys = 0 pressed_keys = [] for batch_i, (img_paths, imgs, targets) in enumerate(data_loader[phase]): # test_img=np.array(imgs[0]*255,dtype=np.uint8).transpose((1,2,0)) # opencv_img=cv2.cvtColor(test_img,cv2.COLOR_RGB2BGR) # cv2.imwrite('./test_img.jpg',opencv_img) # index=torch.where(targets[0]==1)[0] # print(targets[0]) # embed() imgs = imgs.to(device) targets = targets.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(imgs).type(torch.double) #--batch,88 # print(outputs.shape) correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy( outputs, targets, cfg.top_k, cfg.prob_thresh, pressed_keys) func = nn.Sigmoid() loss = criterion(func(outputs), targets) # print(loss) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * imgs.size(0) running_accuracy += (correct + TN) running_total_keys += total_keys running_precision += correct running_pos_keys += pos_keys running_recall += recall running_true_keys += true_keys if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_accuracy / running_total_keys epoch_prec = running_precision / running_pos_keys epoch_recall = running_recall / running_true_keys F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec + 1e-6) # torch.save(model.state_dict(),'./checkpoints/epoch_{}.pth'.format(epoch)) if phase == 'val' and F > best_acc: best_acc = F best_model_wts = copy.deepcopy(model.state_dict()) best_epoch = epoch data = '{}\tLoss: {:.4f}\tAcc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format( phase, epoch_loss, epoch_acc, epoch_prec, epoch_recall, F) print(data) f_out.write('epoch: {}:\n'.format(epoch)) f_out.write(data) f_out.write('\n') f_out.write('\n') end_time = time.time() print('current epoch time cost {:.2f} minutes'.format( (end_time - start_time) / 60)) print('\n') print('Epoch {} has the best Fscore is {} '.format(best_epoch, F)) torch.save( best_model_wts, "checkpoints/keys_epoch_{}_Fscore_{:.3f}.pth".format( best_epoch, best_acc))
def main(img_lists, val_dataloader): model = resnet18(pretrained=False, num_classes=cfg.num_classes).to(device) model.load_state_dict(torch.load(cfg.ckpt_path)) model.eval() model.layer4[1].conv2.register_forward_hook(farward_hook) model.layer4[1].conv2.register_backward_hook(backward_hook) print('the ckpt path is {}'.format(cfg.ckpt_path)) # test_lines=select_img1 out_dir = './output' if opt.img_lists: for img_path in img_lists: # if not 'test_0013' in os.path.basename(img_path):continue print(img_path) # ori_img=cv2.imread(img_path,1) thresh = 0.4 img = Image.open(img_path) # file_seq=os.path.basename(os.path.split(img_path)[0]) # if file_seq in cfg.crop_file_seq: # rect=cfg.EVALUATE_MAP[file_seq]['rect'] # img=img.crop((rect)) #--img.size-> w,h img = img.resize((cfg.input_size)) img = transforms.ToTensor()(img) if len(img.shape) != 3: img = img.unsqueeze(0) img = img.expand((3, img.shape[1:])) img = img.unsqueeze(0).to(device) output = model(img) func = nn.Sigmoid() prob = func(output) value, index = prob.topk(cfg.top_k) final_index = (index[value > thresh]).cpu().numpy() if len(final_index) > 0: img_draw = get_key_nums1(img_path, out_dir, final_index) cam_img = (get_cam_img(cfg.input_size, img_draw, output, final_index, grad_block, fmap_block, cfg.num_classes)).astype(np.uint8) path_cam_img = os.path.join(out_dir, os.path.basename(img_path)) cv2.imwrite(path_cam_img, cam_img) for index in final_index: print('the predict press key is {}'.format(cfg.labels[index] + 1)) print('\n') else: running_accuracy = 0 running_precision = 0 running_recall = 0 running_pos_keys = 0 running_total_keys = 0 running_true_keys = 0 fout = open(cfg.res_txt_path, 'w') for batch_i, (paths, imgs, targets) in enumerate(val_dataloader): imgs = imgs.to(device) targets = targets.to(device) outputs = model(imgs) pressed_keys = [] correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy( outputs, targets, cfg.top_k, cfg.prob_thresh, pressed_keys) running_accuracy += (correct + TN) running_total_keys += total_keys running_precision += correct running_pos_keys += pos_keys running_recall += recall running_true_keys += true_keys epoch_acc = running_accuracy / running_total_keys epoch_prec = running_precision / running_pos_keys epoch_recall = running_recall / running_true_keys F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec + 1e-6) data = 'Acc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format( epoch_acc, epoch_prec, epoch_recall, F) print(data) for i, path in enumerate(paths): fout.write('{} '.format(path)) for key in pressed_keys[i]: fout.write('{} '.format(cfg.labels[key] + 1)) fout.write('\n') print('one batch is down') # embed() fout.close() break