예제 #1
0
def main(NetClass, key_name, dataset_type='test'):
    torch.set_grad_enabled(False)

    model_id = NetClass.model_id

    test_dataset_path = '{}/{}/{}'.format(dataset_path, key_name, dataset_type)
    test_dataset = DatasetReader(test_dataset_path)

    ck_name = '{}/model_{}_{}.pt'.format(seg_net_save_dir, model_id, key_name)
    cm_test_name = '{}/cm_{}_{}_{}.png'.format(seg_net_save_dir, dataset_type,
                                               model_id, key_name)

    net = NetClass(in_dim)
    net.load_state_dict(torch.load(ck_name))
    net = net.to(device)
    net.eval()

    all_pred = []
    all_label = []

    for i in range(len(test_dataset)):
        im, cm, cls = test_dataset.get_im_patch_list_to_combind_predict(i)

        batch_im = torch.tensor([im], dtype=torch.float) / 65535
        batch_im = batch_im.permute(0, 3, 1, 2)
        batch_im = batch_im.to(device)

        net_out = net(batch_im)
        out = torch.argmax(net_out, 1)
        all_label.append(cls)

        cls1_pixel_num = (out == 1).sum().item()
        cls2_pixel_num = (out == 2).sum().item()

        if cls1_pixel_num + cls2_pixel_num == 0:
            all_pred.append(1)
        else:
            if cls2_pixel_num / (cls1_pixel_num + cls2_pixel_num) > seg_thresh:
                all_pred.append(2)
            else:
                all_pred.append(1)

    _accuracy = accuracy_score(all_label, all_pred)
    _malignant_precision, _malignant_recall, _malignant_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=2, average='binary')

    _benign_precision, _benign_recall, _benign_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=1, average='binary')

    _accuracy = float(_accuracy)
    _malignant_precision = float(_malignant_precision)
    _malignant_recall = float(_malignant_recall)
    _malignant_f1 = float(_malignant_f1)
    _benign_precision = float(_benign_precision)
    _benign_recall = float(_benign_recall)
    _benign_f1 = float(_benign_f1)

    out_line = '{} acc: {:.3f} m_prec: {:.3f} m_rec: {:.3f} m_f1: {:.3f} '\
               'b_prec: {:.3f} b_rec: {:.3f} b_f1: {:.3f} model {}_{}'.format(dataset_type, _accuracy,
                                                    _malignant_precision, _malignant_recall, _malignant_f1,
                                                    _benign_precision, _benign_recall, _benign_f1, model_id, key_name)

    print(out_line)
    test_out.append(out_line)

    cm = confusion_matrix(all_label, all_pred)
    draw_confusion_matrix(cm,
                          list(test_dataset.class2id.keys())[1:], cm_test_name)
def main(NetClass, key_name, scale=32):
    torch.set_grad_enabled(True)

    assert scale in [32, 64, 128]
    model_id = NetClass.model_id

    save_dir = '{}.{}'.format(simple_net_save_dir_prefix, scale)
    os.makedirs(save_dir, exist_ok=True)

    train_dataset_path = '{}/{}/train'.format(dataset_path, key_name)
    eval_dataset_path = '{}/{}/eval'.format(dataset_path, key_name)

    ck_name = '{}/model_{}_{}.pt'.format(save_dir, model_id, key_name)
    ck_extra_name = '{}/extra_{}_{}.yml'.format(save_dir, model_id, key_name)
    cm_name = '{}/cm_valid_{}_{}.png'.format(save_dir, model_id, key_name)

    logdir = '{}_{}_{}.{}'.format(simple_net_train_logs_dir_prefix, model_id,
                                  key_name, scale)
    sw = SummaryWriter(logdir)

    train_dataset = DatasetReader(train_dataset_path,
                                  is_require_cls_blance=True,
                                  target_hw=(scale, scale))
    eval_dataset = DatasetReader(eval_dataset_path, target_hw=(scale, scale))

    net = NetClass(in_dim)

    net = net.to(device)

    batch_count = train_dataset.get_batch_count(batch_size)

    optim = torch.optim.Adam(net.parameters(), 1e-3, eps=1e-8)
    optim_adjust = torch.optim.lr_scheduler.MultiStepLR(optim, [90, 180, 270],
                                                        gamma=0.1)

    max_valid_value = 0.

    class_weight_for_loss = torch.tensor([1, 1],
                                         dtype=torch.float,
                                         device=device)

    for e in range(epoch):
        net.train()
        optim_adjust.step(e)
        train_acc = 0
        train_loss = 0
        for b in range(batch_count):

            batch_im, batch_cls = train_dataset.get_batch(batch_size)

            batch_im = torch.tensor(batch_im.astype(np.int32),
                                    dtype=torch.float) / 65535
            # batch_im += (torch.rand_like(batch_im) * 0.1 - 0.05)
            batch_cls = torch.tensor(batch_cls, dtype=torch.long)

            batch_im = batch_im.permute(0, 3, 1, 2)

            batch_im = batch_im.to(device)
            batch_cls = batch_cls.to(device)

            net_out = net(batch_im)
            # net_out = net_train(batch_im)

            with torch.no_grad():
                out = torch.argmax(net_out, 1)
                acc = torch.eq(out,
                               batch_cls).sum(dtype=torch.float) / len(out)

            loss = F.cross_entropy(net_out, batch_cls, class_weight_for_loss)

            train_acc += acc.item()
            train_loss += loss.item()

            print('epoch: {} train acc: {:.3f} loss: {:.3f}'.format(
                e, acc.item(), loss.item()))
            optim.zero_grad()
            loss.backward()
            optim.step()

        train_acc = train_acc / batch_count
        train_loss = train_loss / batch_count

        sw.add_scalar('train_acc', train_acc, global_step=e)
        sw.add_scalar('train_loss', train_loss, global_step=e)

        # here to check eval
        if (e + 1) % 3 == 0:
            with torch.no_grad():
                net.eval()

                all_pred = []
                all_label = []

                for i in range(len(eval_dataset)):
                    ims, cls = eval_dataset.get_im_patch_list_to_combind_predict(
                        i, one_im=False)

                    batch_im = torch.tensor(ims.astype(np.int32),
                                            dtype=torch.float) / 65535
                    # batch_cls = torch.tensor([cls]).repeat(len(batch_im))

                    batch_im = batch_im.permute(0, 3, 1, 2)

                    batch_im = batch_im.to(device)
                    # batch_cls = batch_cls.to(device)

                    net_out = net(batch_im)
                    out = torch.argmax(net_out, 1)
                    all_label.append(cls)

                    if out.sum(dtype=torch.float).item(
                    ) > out.shape[0] * simple_thresh:
                        all_pred.append(1)
                    else:
                        all_pred.append(0)

                _accuracy = accuracy_score(all_label, all_pred)
                _malignant_precision, _malignant_recall, _malignant_f1, _ =\
                    precision_recall_fscore_support(all_label, all_pred, pos_label=1, average='binary')

                _benign_precision, _benign_recall, _benign_f1, _ =\
                    precision_recall_fscore_support(all_label, all_pred, pos_label=0, average='binary')

                _accuracy = float(_accuracy)
                _malignant_precision = float(_malignant_precision)
                _malignant_recall = float(_malignant_recall)
                _malignant_f1 = float(_malignant_f1)
                _benign_precision = float(_benign_precision)
                _benign_recall = float(_benign_recall)
                _benign_f1 = float(_benign_f1)

                sw.add_scalar('eval_acc', _accuracy, global_step=e)
                sw.add_scalar('eval_m_prec',
                              _malignant_precision,
                              global_step=e)
                sw.add_scalar('eval_m_recall',
                              _malignant_recall,
                              global_step=e)
                sw.add_scalar('eval_m_f1', _malignant_f1, global_step=e)
                sw.add_scalar('eval_b_prec', _benign_precision, global_step=e)
                sw.add_scalar('eval_b_recall', _benign_recall, global_step=e)
                sw.add_scalar('eval_b_f1', _benign_f1, global_step=e)

                print(
                    'epoch: {} eval acc: {:.3f} m_prec: {:.3f} m_rec: {:.3f} m_f1: {:.3f} '
                    'b_prec: {:.3f} b_rec: {:.3f} b_f1: {:.3f}'.format(
                        e, _accuracy, _malignant_precision, _malignant_recall,
                        _malignant_f1, _benign_precision, _benign_recall,
                        _benign_f1))

                avg_f1 = (_malignant_f1 + _benign_f1) / 2

                #if _benign_precision - _malignant_precision > 0.2:
                #    class_weight_for_loss[1] += 0.1

                if avg_f1 >= max_valid_value:
                    max_valid_value = avg_f1
                    torch.save(net.state_dict(), ck_name)
                    extra = {
                        'accuracy': _accuracy,
                        'm_precision': _malignant_precision,
                        'm_recall': _malignant_recall,
                        'm_f1': _malignant_f1,
                        'b_precision': _benign_precision,
                        'b_recall': _benign_recall,
                        'b_f1': _benign_f1,
                    }
                    yaml.safe_dump(extra, open(ck_extra_name, 'w'))
                    cm = confusion_matrix(all_label, all_pred)
                    draw_confusion_matrix(cm,
                                          list(eval_dataset.class2id.keys()),
                                          cm_name)

            # early exit
            if _accuracy == 1.:
                print('found valid acc == 1. , early exit')
                break

    sw.close()
def main(NetClass, key_name, scale):
    assert scale in [32, 64, 128]
    torch.set_grad_enabled(False)

    model_id = NetClass.model_id
    
    save_dir = '{}.{}'.format(simple_net_save_dir_prefix, scale)
    os.makedirs(save_dir, exist_ok=True)

    test_dataset_path = '{}/{}/test'.format(dataset_path, key_name)

    ck_name = '{}/model_{}_{}.pt'.format(save_dir, model_id, key_name)
    cm_test_name = '{}/cm_test_{}_{}.png'.format(save_dir, model_id, key_name)

    test_dataset = DatasetReader(test_dataset_path, target_hw=(scale, scale))

    net = NetClass(in_dim)

    net.load_state_dict(torch.load(ck_name, map_location='cpu'))

    net = net.to(device)

    net.eval()

    all_pred = []
    all_label = []

    for i in range(len(test_dataset)):
        ims, cls = test_dataset.get_im_patch_list_to_combind_predict(i, one_im=False)

        batch_im = torch.tensor(ims.astype(np.int32), dtype=torch.float) / 65535
        # batch_cls = torch.tensor([cls]).repeat(len(batch_im))

        batch_im = batch_im.permute(0, 3, 1, 2)

        batch_im = batch_im.to(device)
        # batch_cls = batch_cls.to(device)

        net_out = net(batch_im)
        out = torch.argmax(net_out, 1)
        all_label.append(cls)

        if out.sum(dtype=torch.float).item() > out.shape[0] * simple_thresh:
            all_pred.append(1)
        else:
            all_pred.append(0)

    _accuracy = accuracy_score(all_label, all_pred)
    _malignant_precision, _malignant_recall, _malignant_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=1, average='binary')

    _benign_precision, _benign_recall, _benign_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=0, average='binary')

    _accuracy = float(_accuracy)
    _malignant_precision = float(_malignant_precision)
    _malignant_recall = float(_malignant_recall)
    _malignant_f1 = float(_malignant_f1)
    _benign_precision = float(_benign_precision)
    _benign_recall = float(_benign_recall)
    _benign_f1 = float(_benign_f1)

    out_line = 'test acc: {:.3f} m_prec: {:.3f} m_rec: {:.3f} m_f1: {:.3f} '\
               'b_prec: {:.3f} b_rec: {:.3f} b_f1: {:.3f} model {}_{} x{}'.format(_accuracy,
                                            _malignant_precision, _malignant_recall, _malignant_f1,
                                            _benign_precision, _benign_recall, _benign_f1, model_id, key_name, scale)

    print(out_line)
    test_out.append(out_line)

    cm = confusion_matrix(all_label, all_pred)
    draw_confusion_matrix(cm, list(test_dataset.class2id.keys()), cm_test_name)
def main(NetClass, key_name):
    torch.set_grad_enabled(False)

    model_id = NetClass.model_id

    test_dataset_path = '{}/{}/test'.format(dataset_path, key_name)

    ck_32_name = '{}.32/model_{}_{}.pt'.format(simple_net_save_dir_prefix,
                                               model_id, key_name)
    ck_64_name = '{}.64/model_{}_{}.pt'.format(simple_net_save_dir_prefix,
                                               model_id, key_name)
    ck_128_name = '{}.128/model_{}_{}.pt'.format(simple_net_save_dir_prefix,
                                                 model_id, key_name)

    cm_net3_test_name = '{}_{}_{}.png'.format(
        simple_net_3_merge_test_cm_prefix, model_id, key_name)
    os.makedirs(os.path.split(cm_net3_test_name)[0], exist_ok=True)

    test_dataset_32 = DatasetReader(test_dataset_path, target_hw=(32, 32))
    test_dataset_64 = DatasetReader(test_dataset_path, target_hw=(64, 64))
    test_dataset_128 = DatasetReader(test_dataset_path, target_hw=(128, 128))

    net_32 = NetClass(in_dim)
    net_64 = NetClass(in_dim)
    net_128 = NetClass(in_dim)

    net_32.load_state_dict(
        torch.load(ck_32_name, map_location=torch.device('cpu')))
    net_64.load_state_dict(
        torch.load(ck_64_name, map_location=torch.device('cpu')))
    net_128.load_state_dict(
        torch.load(ck_128_name, map_location=torch.device('cpu')))

    net_32 = net_32.to(device)
    net_64 = net_64.to(device)
    net_128 = net_128.to(device)

    net_32.eval()
    net_64.eval()
    net_128.eval()

    all_pred = []
    all_label = []

    for i in range(len(test_dataset_32)):
        ims_32, cls_32 = test_dataset_32.get_im_patch_list_to_combind_predict(
            i, one_im=False)
        ims_64, cls_64 = test_dataset_64.get_im_patch_list_to_combind_predict(
            i, one_im=False)
        ims_128, cls_128 = test_dataset_128.get_im_patch_list_to_combind_predict(
            i, one_im=False)

        assert cls_32 == cls_64 == cls_128

        tmp_x = [[net_32, ims_32, cls_32], [net_64, ims_64, cls_64],
                 [net_128, ims_128, cls_128]]
        tmp_y = []

        for net, ims, cls in tmp_x:
            batch_im = torch.tensor(ims.astype(np.int32),
                                    dtype=torch.float) / 65535
            # batch_cls = torch.tensor([cls]).repeat(len(batch_im))

            batch_im = batch_im.permute(0, 3, 1, 2)

            batch_im = batch_im.to(device)
            # batch_cls = batch_cls.to(device)

            net_out = net(batch_im)
            out = torch.argmax(net_out, 1)

            if out.sum(
                    dtype=torch.float).item() > out.shape[0] * simple_thresh:
                tmp_y.append(1)
            else:
                tmp_y.append(0)

        all_label.append(tmp_x[0][-1])

        if np.sum(tmp_y) > simple_merge_thresh:
            all_pred.append(1)
        else:
            all_pred.append(0)

    _accuracy = accuracy_score(all_label, all_pred)
    _malignant_precision, _malignant_recall, _malignant_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=1, average='binary')

    _benign_precision, _benign_recall, _benign_f1, _ = \
        precision_recall_fscore_support(all_label, all_pred, pos_label=0, average='binary')

    _accuracy = float(_accuracy)
    _malignant_precision = float(_malignant_precision)
    _malignant_recall = float(_malignant_recall)
    _malignant_f1 = float(_malignant_f1)
    _benign_precision = float(_benign_precision)
    _benign_recall = float(_benign_recall)
    _benign_f1 = float(_benign_f1)

    out_line = 'test acc: {:.3f} m_prec: {:.3f} m_rec: {:.3f} m_f1: {:.3f} '\
               'b_prec: {:.3f} b_rec: {:.3f} b_f1: {:.3f} model {}_{}'.format(_accuracy,
                                                    _malignant_precision, _malignant_recall, _malignant_f1,
                                                    _benign_precision, _benign_recall, _benign_f1, model_id, key_name)

    print(out_line)
    test_out.append(out_line)

    cm = confusion_matrix(all_label, all_pred)
    draw_confusion_matrix(cm, list(test_dataset_32.class2id.keys()),
                          cm_net3_test_name)