Пример #1
0
def test_feature(test_loader, mol, cuda, name):
    test_time = time.time()
    print('#### Start %s testing with %d batches ####' %
          (name, len(test_loader)))

    feature, labels, loss = get_feature_loss(test_loader, mol, cuda)
    similar_mat = cal_cos(feature)
    accuracy, _ = cal_accuracy(similar_mat, labels, topk=1)
    top5accuracy, _ = cal_accuracy(similar_mat, labels, topk=5)
    print(
        '[Testing] Feature accuracy = %.5f%%; top5 accuracy = %.5f%%; Decoder loss = %.6f; time cost %.2fs'
        % (np.mean(accuracy) * 100, np.mean(top5accuracy) * 100, loss,
           time.time() - test_time))
    return accuracy, top5accuracy, loss
Пример #2
0
def evaluate():
    '''
    Evaluate the AE model on feature extraction work:
        1. Extract features from encode result in AE.
        2. Find top k similar pictures and compare their labels
    '''
    model_name = '%s_%s%s_model-%s' % (args.main_model, args.model, ''
                                       if args.fea_c is None else args.fea_c,
                                       args.dataset)
    output_file = 'feature/%s.json' % model_name

    if os.path.exists(output_file) and args.load_feature:
        print('Loading features...')
        with open(output_file) as f:
            res = json.load(f)
    else:
        print('Generating features ...')
        res = generate_feature(output_file)
    similar_mat = cal_distance(normalize_feature(res['features']))
    accuracy, similar_pic = cal_accuracy(similar_mat, res['labels'],
                                         model_name, args.top_k)
Пример #3
0
def train(model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=opt.lr_decay_in_epoch,
                                          gamma=opt.gamma)

    # metrics=cfg.metrics
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    txt_log = os.path.join(log_dir, 'log.txt')
    f_out = open(txt_log, 'w')

    param = 'lr: {}\tprob_thresh: {}\ttop_k: {}\tlr_decay_in_epoch: {}\toctave_test_thresh: {}\n'.format(
        opt.lr, cfg.octave_prob_thresh, cfg.octave_top_k,
        opt.lr_decay_in_epoch, cfg.octave_test_thresh)

    best_Wscore = 0
    print(param)
    f_out.write(param)
    for epoch in range(opt.epochs):
        print('Epoch {}/{}'.format(epoch, opt.epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            start_time = time.time()
            running_loss = 0.0

            running_accuracy = 0
            running_precision = 0
            running_recall = 0
            running_pos_keys = 0
            running_total_keys = 0
            running_true_keys = 0
            pressed_keys = []
            for batch_i, (img_paths, imgs,
                          targets) in enumerate(data_loader[phase]):
                batches_done = len(data_loader[phase]) * epoch + batch_i
                # for i in range(len(imgs)):
                #     a = imgs[i][0]
                #     img_path = img_paths[i][0]
                #     print(img_path)
                #     test_img=np.array(a*255,dtype=np.uint8).transpose((1,2,0))
                #     opencv_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
                #     cv2.imwrite('./{}.jpg'.format(i), opencv_img)
                if isinstance(imgs, list):
                    img_lists = []
                    for img in imgs:
                        img = img.to(device)
                        img_lists.append(img)
                    imgs = img_lists[:]
                else:
                    imgs = imgs.to(device)

                #---由于数据集大小限制的原因,不是每个batch刚好都是取batch_size大小的,一般在最后一个iteration,不能够整除
                if isinstance(targets, list):
                    press_label = targets[0].to(device)
                    pos_label = targets[1].to(device)
                    batch_shape = targets[0].shape[0]
                else:
                    press_label = targets.to(device)
                    batch_shape = targets.shape[0]

                # targets = targets.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(imgs).type(torch.double)  #--batch,12

                    correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy(
                        outputs, press_label, cfg.octave_top_k,
                        cfg.octave_prob_thresh, pressed_keys, epoch)

                    func = nn.Sigmoid()
                    loss = criterion(func(outputs), press_label)
                    # print(loss)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * press_label.size(0)
                running_accuracy += (correct + TN)
                running_total_keys += total_keys
                running_precision += correct
                running_pos_keys += pos_keys
                running_recall += recall
                running_true_keys += true_keys

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_accuracy / running_total_keys
            epoch_prec = running_precision / running_pos_keys
            epoch_recall = running_recall / running_true_keys
            F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec +
                                                   1e-6)
            # torch.save(model.state_dict(),'./checkpoints/epoch_{}.pth'.format(epoch))

            data = '{}\tLoss: {:.4f}\tAcc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, epoch_prec, epoch_recall, F)
            print(data)

            if phase == 'val':
                average_Wscore = eval_video(model)
                print("average_Wscore= : {}".format(average_Wscore))
            if phase == 'val' and average_Wscore > best_Wscore:
                best_acc = F
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch
                best_Wscore = average_Wscore

            f_out.write('epoch: {}:\n'.format(epoch))
            f_out.write(data)
            f_out.write('\n')
            f_out.write('\n')

            time_elapsed = time.time() - start_time
            print('current epoch time cost {:.2f} minutes'.format(
                (time_elapsed) / 60))
            print('\n')

    print('Epoch {} has the best Acc is {} '.format(best_epoch, best_acc))

    f_out.close()
    if opt.with_Time:
        if opt.Data_type:
            #--包含位置和时间信息
            torch.save(
                best_model_wts,
                "checkpoints/wb_{}_octave_time_with_pos_keys_epoch_{}_Acc_{:.3f}.pth"
                .format(opt.mode, best_epoch, best_acc))
        else:
            #--只包含时间信息
            torch.save(
                best_model_wts,
                "checkpoints/wb_{}_octave_time_keys_epoch_{}_Acc_{:.3f}.pth".
                format(opt.mode, best_epoch, best_acc))
    elif opt.Data_type:
        #---包含位置信息不包含时间信息
        torch.save(
            best_model_wts,
            "checkpoints/wb_{}_octave_with_pos_keys_epoch_{}_Acc_{:.3f}.pth".
            format(opt.mode, best_epoch, best_acc))
    else:
        #---什么都不包含
        torch.save(
            best_model_wts,
            "checkpoints/wb_{}_octave_keys_epoch_{}_Acc_{:.3f}.pth".format(
                opt.mode, best_epoch, best_acc))
Пример #4
0
def train(model):
    # modify learning rate of last layer
    finetune_params = modify_last_layer_lr(model.named_parameters(), opt.lr,
                                           cfg.lr_mult_w, cfg.lr_mult_b)

    optimizer = optim.SGD(finetune_params,
                          opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=opt.lr_decay_in_epoch,
                                          gamma=opt.gamma)

    # metrics=cfg.metrics
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    txt_log = os.path.join(log_dir, 'log.txt')
    f_out = open(txt_log, 'w')
    param = 'prob_thresh: {}\ttop_k: {}\n'.format(cfg.prob_thresh, cfg.top_k)
    print(param)
    f_out.write(param)
    for epoch in range(opt.epochs):
        print('Epoch {}/{}'.format(epoch, opt.epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            start_time = time.time()
            running_loss = 0.0

            running_accuracy = 0
            running_precision = 0
            running_recall = 0
            running_pos_keys = 0
            running_total_keys = 0
            running_true_keys = 0
            pressed_keys = []

            for batch_i, (img_paths, imgs,
                          targets) in enumerate(data_loader[phase]):

                # test_img=np.array(imgs[0]*255,dtype=np.uint8).transpose((1,2,0))
                # opencv_img=cv2.cvtColor(test_img,cv2.COLOR_RGB2BGR)
                # cv2.imwrite('./test_img.jpg',opencv_img)
                # index=torch.where(targets[0]==1)[0]
                # print(targets[0])
                # embed()

                imgs = imgs.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(imgs).type(torch.double)  #--batch,88
                    # print(outputs.shape)

                    correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy(
                        outputs, targets, cfg.top_k, cfg.prob_thresh,
                        pressed_keys)

                    func = nn.Sigmoid()
                    loss = criterion(func(outputs), targets)
                    # print(loss)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * imgs.size(0)
                running_accuracy += (correct + TN)
                running_total_keys += total_keys
                running_precision += correct
                running_pos_keys += pos_keys
                running_recall += recall
                running_true_keys += true_keys

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_accuracy / running_total_keys
            epoch_prec = running_precision / running_pos_keys
            epoch_recall = running_recall / running_true_keys
            F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec +
                                                   1e-6)
            # torch.save(model.state_dict(),'./checkpoints/epoch_{}.pth'.format(epoch))

            if phase == 'val' and F > best_acc:
                best_acc = F
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch

            data = '{}\tLoss: {:.4f}\tAcc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, epoch_prec, epoch_recall, F)
            print(data)
            f_out.write('epoch: {}:\n'.format(epoch))
            f_out.write(data)
            f_out.write('\n')
            f_out.write('\n')

            end_time = time.time()
            print('current epoch time cost {:.2f} minutes'.format(
                (end_time - start_time) / 60))
            print('\n')

    print('Epoch {} has the best Fscore is {} '.format(best_epoch, F))
    torch.save(
        best_model_wts, "checkpoints/keys_epoch_{}_Fscore_{:.3f}.pth".format(
            best_epoch, best_acc))
Пример #5
0
def main(img_lists, val_dataloader):
    model = resnet18(pretrained=False, num_classes=cfg.num_classes).to(device)
    model.load_state_dict(torch.load(cfg.ckpt_path))
    model.eval()

    model.layer4[1].conv2.register_forward_hook(farward_hook)
    model.layer4[1].conv2.register_backward_hook(backward_hook)

    print('the ckpt path is {}'.format(cfg.ckpt_path))
    # test_lines=select_img1

    out_dir = './output'

    if opt.img_lists:
        for img_path in img_lists:
            # if not 'test_0013' in os.path.basename(img_path):continue
            print(img_path)
            # ori_img=cv2.imread(img_path,1)
            thresh = 0.4
            img = Image.open(img_path)
            # file_seq=os.path.basename(os.path.split(img_path)[0])
            # if file_seq in cfg.crop_file_seq:
            #     rect=cfg.EVALUATE_MAP[file_seq]['rect']
            #     img=img.crop((rect))  #--img.size-> w,h

            img = img.resize((cfg.input_size))

            img = transforms.ToTensor()(img)
            if len(img.shape) != 3:
                img = img.unsqueeze(0)
                img = img.expand((3, img.shape[1:]))
            img = img.unsqueeze(0).to(device)

            output = model(img)
            func = nn.Sigmoid()
            prob = func(output)
            value, index = prob.topk(cfg.top_k)
            final_index = (index[value > thresh]).cpu().numpy()

            if len(final_index) > 0:
                img_draw = get_key_nums1(img_path, out_dir, final_index)
                cam_img = (get_cam_img(cfg.input_size, img_draw, output,
                                       final_index, grad_block, fmap_block,
                                       cfg.num_classes)).astype(np.uint8)
                path_cam_img = os.path.join(out_dir,
                                            os.path.basename(img_path))
                cv2.imwrite(path_cam_img, cam_img)

            for index in final_index:
                print('the predict press key is {}'.format(cfg.labels[index] +
                                                           1))
            print('\n')
    else:
        running_accuracy = 0
        running_precision = 0
        running_recall = 0
        running_pos_keys = 0
        running_total_keys = 0
        running_true_keys = 0
        fout = open(cfg.res_txt_path, 'w')
        for batch_i, (paths, imgs, targets) in enumerate(val_dataloader):
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = model(imgs)
            pressed_keys = []
            correct, pos_keys, recall, true_keys, TN, total_keys, pressed_keys = cal_accuracy(
                outputs, targets, cfg.top_k, cfg.prob_thresh, pressed_keys)

            running_accuracy += (correct + TN)
            running_total_keys += total_keys
            running_precision += correct
            running_pos_keys += pos_keys
            running_recall += recall
            running_true_keys += true_keys

            epoch_acc = running_accuracy / running_total_keys
            epoch_prec = running_precision / running_pos_keys
            epoch_recall = running_recall / running_true_keys
            F = 2.0 * epoch_recall * epoch_prec / (epoch_recall + epoch_prec +
                                                   1e-6)
            data = 'Acc: {:.4f}\tPrec: {:.4f}\tRecall: {:.4f}\t Fscore: {:.4f}'.format(
                epoch_acc, epoch_prec, epoch_recall, F)
            print(data)

            for i, path in enumerate(paths):
                fout.write('{} '.format(path))
                for key in pressed_keys[i]:
                    fout.write('{} '.format(cfg.labels[key] + 1))
                fout.write('\n')
            print('one batch is down')
            # embed()
            fout.close()
            break