Ejemplo n.º 1
0
def create_model(base_model, ckpt_file=None, from_measurement=None):
    # prepare model
    if base_model == 'modelnet_x3_l4':
        net = PointCNN.modelnet_x3_l4()
    elif base_model == 'shapenet_x8_2048_fps':
        net = PointCNN.shapenet_x8_2048_fps()
    else:
        raise NotImplementedError
    if ckpt_file is not None:
        ModelRecorder.resume_model(net,
                                   ckpt_file,
                                   from_measurement=from_measurement)
    return net
Ejemplo n.º 2
0
def visualization_process():
    # prepare data
    image_num = 2
    image_file = "assets/1/000{}.jpg".format(image_num)
    pred_img = preprocess_image(Image.open(image_file))
    # prepare model
    print("evaluation on : {}".format(config.base_model))
    net = create_model(config.base_model)
    print("load pretained model from {}".format(config.test.pretrained_model))
    ModelRecorder.resume_model(net,
                               config.test.pretrained_model,
                               from_measurement="acc")
    show_first_conv_feature(net, pred_img, "{}_C".format(image_num))
    visualize(net, pred_img, 1, "{}_C".format(image_num))
Ejemplo n.º 3
0
def evaluation_process(detail=False):
    valid_set = ds.ClipArt(config.validation.data_set, data_aug=False)

    valid_loader = DataLoader(valid_set,
                              batch_size=config.validation.batch_size,
                              shuffle=False,
                              num_workers=config.num_workers,
                              drop_last=False)
    print('valid set size: {}'.format(len(valid_set)))

    # prepare model
    print("evaluation on : {}".format(config.base_model))
    net = create_model(config.base_model)
    print("load pretained model from {}".format(config.test.pretrained_model))
    ModelRecorder.resume_model(net,
                               config.test.pretrained_model,
                               from_measurement="acc")

    net = nn.DataParallel(net)
    if not detail:
        with torch.no_grad():
            acc, conf_matrix = evaluate(valid_loader, net, True)
        plot_conf_matrix(valid_set.get_categories(),
                         conf_matrix,
                         save_file='{}/{}.pdf'.format(config.result_sub_folder,
                                                      config.comment))
    else:
        with torch.no_grad():
            acc, conf_matrix, features, labels = evaluate(
                valid_loader, net, True, True)
        plot_conf_matrix(valid_set.get_categories(),
                         conf_matrix,
                         save_file='{}/{}.pdf'.format(config.result_sub_folder,
                                                      config.comment))
        array2tsv(features, '{}/{}.tsv'.format(config.result_sub_folder,
                                               config.comment))
        labels_file = open(
            '{}/{}_label.tsv'.format(config.result_sub_folder, config.comment),
            'w')
        labels_txt = ""
        for label in labels:
            labels_txt += '{}\n'.format(int(label))
        labels_file.write(labels_txt[:-1])
        labels_file.close()
    print("Finished!")
Ejemplo n.º 4
0
def extract_test_rst_process():
    test_set = ds.ClipArtTest(config.test.data_set)

    test_loader = DataLoader(test_set,
                             batch_size=config.test.batch_size,
                             shuffle=False,
                             num_workers=config.num_workers,
                             drop_last=False)
    print('test set size: {}'.format(len(test_set)))

    # prepare model
    print("evaluation on : {}".format(config.base_model))
    net = create_model(config.base_model)
    print("load pretained model from {}".format(config.test.pretrained_model))
    ModelRecorder.resume_model(net,
                               config.test.pretrained_model,
                               from_measurement="acc")
    net = nn.DataParallel(net)

    net.eval()
    rst = "id,label\n"

    with torch.no_grad():
        for i, (batch_data, img_names) in enumerate(test_loader):
            batch_data = batch_data.to(config.device)
            batch_label = batch_label.to(config.device)
            raw_out = net(batch_data)
            pred = torch.argmax(raw_out.detach(), dim=1)
            pred = list(pred.cpu().numpy())
            assert len(pred) == len(img_names)
            for j, p in enumerate(pred):
                rst += '{}, {}\n'.format(img_names[j], ds.CONVERT_TABLE[p])
    save_file_name = 's_test_rst.txt'
    save_file = open(save_file_name, 'w')
    save_file.write(rst)
    save_file.close()
    print("Finished!")
Ejemplo n.º 5
0
def evaluation_process(detail=False):

    if config.dataset == "ModelNet40":
        valid_set = ModelNet40(partition='test')
    elif config.dataset == "ModelNet10":
        valid_set = ModelNet10(partition='test')
    elif config.dataset == "ShapeNetParts":
        valid_set = ShapeNetPart(partition='test')

    valid_loader = DataLoader(valid_set,
                              batch_size=config.validation.batch_size,
                              shuffle=False,
                              num_workers=config.num_workers,
                              drop_last=False)
    print('valid set size: {}'.format(len(valid_set)))

    # prepare model
    print("evaluation on : {}".format(config.base_model))
    net = create_model(config.base_model).to(config.device)
    print("load pretained model from {}".format(config.test.pretrained_model))
    ModelRecorder.resume_model(net,
                               config.test.pretrained_model,
                               from_measurement="acc")

    net = nn.DataParallel(net)
    if not detail:
        with torch.no_grad():
            if config.task == "seg":
                validation_loss, validation_acc, avg_per_class_acc, val_ious = evaluate(
                    valid_loader, net, html_path="validation_output")
            else:
                validation_loss, acc, conf_matrix = evaluate(
                    valid_loader, net, True)
        if not config.task == "seg":
            plot_conf_matrix(valid_set.get_categories(),
                             conf_matrix,
                             save_file='{}conf_matrix.pdf'.format(
                                 config.result_sub_folder))
    else:
        if not config.task == "seg":
            with torch.no_grad():
                validation_loss, acc, conf_matrix, features, labels = evaluate(
                    valid_loader, net, True, True)
            plot_conf_matrix(valid_set.get_categories(),
                             conf_matrix,
                             save_file='{}conf_matrix.pdf'.format(
                                 config.result_sub_folder))
            array2tsv(features,
                      '{}features.tsv'.format(config.result_sub_folder))

            np.save('{}features.npy'.format(config.result_sub_folder),
                    features)
            np.save('{}labels.npy'.format(config.result_sub_folder), labels)
            labels_file = open('{}labels.tsv'.format(config.result_sub_folder),
                               'w')
            labels_txt = ""
            for label in labels:
                labels_txt += '{}\n'.format(int(label))
            labels_file.write(labels_txt[:-1])
            labels_file.close()
        else:
            raise NotImplementedError
    print("Finished!")