コード例 #1
0
def parse_args():
    train_parser = train.get_args_parser()
    parser = argparse.ArgumentParser("Submitit for detection", parents=[train_parser])
    parser.add_argument("--ngpus", default=1, type=int,
                        help="Number of gpus to request on each node")
    parser.add_argument("--vram", default="12GB", type=str)
    parser.add_argument("--num_gpus", default=1, type=int)
    parser.add_argument("--mem_per_gpu", default=20, type=int)
    parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
    parser.add_argument("--timeout", default=60, type=int, help="Duration of the job")
    parser.add_argument("--job_dir", default="", type=str,
                        help="Job dir. Leave empty for automatic.")
    parser.add_argument("--cluster", default=None, type=str,
                        help="Use to run jobs locally.")
    parser.add_argument("--slurm_partition", default="NORMAL", type=str,
                        help="Partition. Leave empty for automatic.")
    parser.add_argument("--slurm_constraint", default="", type=str,
                        help="Constraint. Leave empty for automatic.")
    parser.add_argument("--slurm_comment", default="", type=str)
    parser.add_argument("--slurm_gres", default="", type=str)
    parser.add_argument("--slurm_exclude", default="", type=str)
    parser.add_argument("--checkpoint_name", default="last.ckpt", type=str)
    return parser.parse_args()
コード例 #2
0
        model.zero_grad()

        output1 = model(image.to(device))
        output = F.softmax(output1, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)

        pred_label_idx.squeeze_()
        predicted_label = str(pred_label_idx.item() + 1)
        print('Predicted:', predicted_label, '(',
              prediction_score.squeeze().item(), ')')

        make_grad(extractor, output1, image_orl, args.grad_min_level,
                  cam_extractors_names[idx])
        extractor.clear_hooks()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('model training and evaluation script',
                                     parents=[get_args_parser()])
    args = parser.parse_args()

    args_dict = vars(args)
    args_for_evaluation = [
        'num_classes', 'lambda_value', 'power', 'slots_per_class'
    ]
    args_type = [int, float, int, int]
    for arg_id, arg in enumerate(args_for_evaluation):
        args_dict[arg] = args_type[arg_id](args_dict[arg])

    for_vis(args)
コード例 #3
0
    label = cv2.resize(label, (1024, 512), interpolation=cv2.INTER_NEAREST)
    images = torch.from_numpy(np.array(np.transpose(images, (0, 3, 1, 2)), dtype="float32")/255)
    labels = torch.from_numpy(np.array([label], dtype="int64"))
    show_single(args, ColorTransition().recover(labels[0]), "origin")
    inputs = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.int64)
    predict = for_test(args, model, inputs, labels, iou, lstm=args.lstm, need=True)
    print(predict.size())
    color_img = ColorTransition().recover(predict[0])
    show_single(args, color_img, "ls_noise")
    epoch_iou = iou.iou_demo()
    print(epoch_iou)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    args.use_pre = False
    device = args.device
    model_name = "../saved_model/PSPNet_lstm_noise.pt"
    init_model = load_model(args)
    init_model = convert_model(init_model)
    init_model = init_model.to(device)
    if args.multi_gpu:
        init_model = nn.DataParallel(init_model, device_ids=[0])
    init_model.load_state_dict(torch.load(model_name), strict=True)
    init_model.eval()
    print("load param over")

    # for lstm model inference
    make_image_lstm(init_model, device)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser('model training and evaluation script',
                                     parents=[get_args_parser()])
    args = parser.parse_args()

    args_dict = vars(args)
    args_for_evaluation = [
        'num_classes', 'lambda_value', 'power', 'slots_per_class'
    ]
    args_type = [int, float, int, int]
    for arg_id, arg in enumerate(args_for_evaluation):
        args_dict[arg] = args_type[arg_id](args_dict[arg])

    model_name = f"{args.dataset}_" + f"{'use_slot_' if args.use_slot else 'no_slot_'}"\
                + f"{'negative_' if args.use_slot and args.loss_status != 1 else ''}"\
                + f"{'for_area_size_'+str(args.lambda_value) + '_'+ str(args.slots_per_class) + '_' if args.cal_area_size else ''}" + 'checkpoint.pth'
    args.use_pre = False

    device = torch.device(args.device)

    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
    ])
    # Con-text
    if args.dataset == 'ConText':
        train, val = MakeList(args).get_data()
        dataset_val = ConText(val, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        data = iter(data_loader_val).next()
        image = data["image"][0]
        label = data["label"][0].item()
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    elif args.dataset == 'ImageNet':
        train, val = MakeListImage(args).get_data()
        dataset_val = ConText(val, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        iter_loader = iter(data_loader_val)
        for i in range(0, 1):
            data = iter_loader.next()
        image = data["image"][0]
        label = data["label"][0].item()
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    # MNIST
    elif args.dataset == 'MNIST':
        dataset_val = datasets.MNIST('./data/mnist',
                                     train=False,
                                     transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        image = iter(data_loader_val).next()[0][0]
        label = ''
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8)[0], mode='L')
        image = transform(image_orl)
        transform = transforms.Compose(
            [transforms.Normalize((0.1307, ), (0.3081, ))])
    # CUB
    elif args.dataset == 'CUB200':
        dataset_val = CUB_200(args, train=False, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        data = iter(data_loader_val).next()
        image = data["image"][0]
        label = data["label"][0].item()
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    image = transform(image)

    print("label\t", label)
    model = SlotModel(args)
    # Map model to be loaded to specified single gpu.
    checkpoint = torch.load(f"{args.output_dir}/" + model_name,
                            map_location=args.device)
    for k, v in checkpoint.items():
        print(k)
    model.load_state_dict(checkpoint["model"])

    test(args, model, device, image_orl, image, label, vis_id=args.vis_id)