예제 #1
0
def get_model():

    # Get model from config
    if config.model == "resnet18":
        model = models.resnet18(pretrained=config.pretrained)
    elif config.model == "resnet34":
        model = models.resnet34(pretrained=config.pretrained)
    elif config.model == 'resnet50':
        model = models.resnet50(pretrained=config.pretrained)
    elif config.model == "resnet101":
        model = models.resnet101(pretrained=config.pretrained)
    elif config.model == "resnet152":
        model = models.resnet152(pretrained=config.pretrained)
    elif config.model == "resnext50_32x4d":
        model = models.resnet34(pretrained=config.pretrained)
    elif config.model == 'resnext101_32x8d':
        model = models.resnet50(pretrained=config.pretrained)
    elif config.model == "wide_resnet50_2":
        model = models.resnet101(pretrained=config.pretrained)
    elif config.model == "wide_resnet101_2":
        model = models.resnet152(pretrained=config.pretrained)
    else:
        raise ValueError('%s not supported'.format(config.model))

    # Initialize fc layer
    (in_features, out_features) = model.fc.in_features, model.fc.out_features
    model.fc = torch.nn.Linear(in_features, out_features)
    return model
예제 #2
0
def create_model(name, num_classes):
    if name == 'resnet34':
        model = models.resnet34(True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        nn.init.xavier_uniform(model.fc.weight)
        nn.init.constant(model.fc.bias, 0)
    elif name == 'resnet152':
        model = models.resnet152(True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        nn.init.xavier_uniform(model.fc.weight)
        nn.init.constant(model.fc.bias, 0)
    elif name == 'densenet121':
        model = models.densenet121(True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        nn.init.xavier_uniform(model.classifier.weight)
        nn.init.constant(model.classifier.bias, 0)
    elif name == 'vgg11_bn':
        model = models.vgg11_bn(False, num_classes)
    elif name == 'vgg19_bn':
        model = models.vgg19_bn(True)
        model.classifier._modules['6'] = nn.Linear(model.classifier._modules['6'].in_features, num_classes)
        nn.init.xavier_uniform(model.classifier._modules['6'].weight)
        nn.init.constant(model.classifier._modules['6'].bias, 0)
    elif name == 'alexnet':
        model = models.alexnet(True)
        model.classifier._modules['6'] = nn.Linear(model.classifier._modules['6'].in_features, num_classes)
        nn.init.xavier_uniform(model.classifier._modules['6'].weight)
        nn.init.constant(model.classifier._modules['6'].bias, 0)
    else:
        model = Net(num_classes)

    return model
예제 #3
0
def create_model(name, num_classes):
    if name == 'resnet34':
        model = models.resnet34(True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        nn.init.xavier_uniform(model.fc.weight)
        nn.init.constant(model.fc.bias, 0)
    elif name == 'resnet50':
        model = models.resnet50(True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        nn.init.xavier_uniform(model.fc.weight)
        nn.init.constant(model.fc.bias, 0)
    elif name == 'resnet152':
        model = models.resnet152(True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        nn.init.xavier_uniform(model.fc.weight)
        nn.init.constant(model.fc.bias, 0)
    elif name == 'seresnet50':
        model = models.se_resnet50()
        model.last_linear = nn.Linear(model.last_linear.in_features,
                                      num_classes,
                                      bias=True)
    elif name == 'seresnet152':
        model = models.se_resnet152()
        model.last_linear = nn.Linear(model.last_linear.in_features,
                                      num_classes,
                                      bias=True)
    elif name == 'dpn131':
        model = models.dpn131()
        model.classifier = nn.Conv2d(2688,
                                     num_classes,
                                     kernel_size=1,
                                     bias=True)
    elif name == 'densenet121':
        model = models.densenet121(True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        nn.init.xavier_uniform(model.classifier.weight)
        nn.init.constant(model.classifier.bias, 0)
    elif name == 'vgg11_bn':
        model = models.vgg11_bn(False, num_classes)
    elif name == 'vgg19_bn':
        model = models.vgg19_bn(True)
        model.classifier._modules['6'] = nn.Linear(
            model.classifier._modules['6'].in_features, num_classes)
        nn.init.xavier_uniform(model.classifier._modules['6'].weight)
        nn.init.constant(model.classifier._modules['6'].bias, 0)
    elif name == 'alexnet':
        model = models.alexnet(True)
        model.classifier._modules['6'] = nn.Linear(
            model.classifier._modules['6'].in_features, num_classes)
        nn.init.xavier_uniform(model.classifier._modules['6'].weight)
        nn.init.constant(model.classifier._modules['6'].bias, 0)
    else:
        model = Net(num_classes)

    return model
예제 #4
0
def demo2(image_paths, output_dir, cuda):
    """
    Generate Grad-CAM at different layers of ResNet-152
    """

    device = get_device(cuda)

    # Synset words
    classes = get_classtable()

    # Model
    model = models.resnet152(pretrained=True)
    model.to(device)
    model.eval()

    # The four residual layers
    target_layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
    target_class = 243  # "bull mastif"

    # Images
    images = []
    raw_images = []
    print("Images:")
    for i, image_path in enumerate(image_paths):
        print("\t#{}: {}".format(i, image_path))
        image, raw_image = preprocess(image_path)
        images.append(image)
        raw_images.append(raw_image)
    images = torch.stack(images).to(device)

    gcam = GradCAM(model=model)
    probs, ids = gcam.forward(images)
    ids_ = torch.LongTensor([[target_class]] * len(images)).to(device)
    gcam.backward(ids=ids_)

    for target_layer in target_layers:
        print("Generating Grad-CAM @{}".format(target_layer))

        # Grad-CAM
        regions = gcam.generate(target_layer=target_layer)

        for j in range(len(images)):
            print("\t#{}: {} ({:.5f})".format(
                j, classes[target_class], float(probs[ids == target_class])))

            save_gradcam(
                filename=osp.join(
                    output_dir,
                    "{}-{}-gradcam-{}-{}.png".format(j, "resnet152",
                                                     target_layer,
                                                     classes[target_class]),
                ),
                gcam=regions[j, 0],
                raw_image=raw_images[j],
            )
예제 #5
0
def initModel(args):
    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=18,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=18,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=18,
                                 scale=args.scale)
    elif args.arch == "vgg16":
        model = models.vgg16(pretrained=True, num_classes=18)
    elif args.arch == "googlenet":
        model = models.googlenet(pretrained=True, num_classes=18)

    for param in model.parameters():
        param.requires_grad = False

    #model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print(("Loading model and optimizer from checkpoint '{}'".format(
                args.resume)))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in list(checkpoint['state_dict'].items()):
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print(("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
            sys.stdout.flush()
        else:
            print(("No checkpoint found at '{}'".format(args.resume)))
            sys.stdout.flush()

    model.eval()

    summary(model, (3, 640, 640))

    return model
예제 #6
0
def test(args):
    data_loader = IC15TestLoader(long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        drop_last=True)

    submit_path = 'outputs'
    if os.path.exists(submit_path):
        shutil.rmtree(submit_path)

    os.mkdir(submit_path)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=7, scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=7, scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=7, scale=args.scale)
    
    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()
    
    if args.resume is not None:                                         
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            
            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()
    state_dict = model.state_dict() #if data_parallel else model.state_dict()
    torch.save(state_dict,'./ctpn_pse_batch_new_pre_train_ic15.pth')
def build_model(model_name, num_classes, pretrain):
    if model_name == 'resnet50':
        net = resnet50(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet18':
        net = resnet18(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet34':
        net = resnet34(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet101':
        net = resnet101(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet152':
        net = resnet152(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet50se':
        net = resnet50se(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet50dilated':
        net = resnet50_dilated(num_classes=num_classes, pretrain=pretrain)
    elif model_name == 'resnet50dcse':
        net = resnet50_dcse(num_classes=num_classes, pretrain=pretrain)
    else:
        print('wait a minute')
    return net
예제 #8
0
파일: test_base.py 프로젝트: cnzeki/PSENet
def load_model(args):
    epoch = 0
    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=7, scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=7, scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=7, scale=args.scale)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print(("Loading model and optimizer from checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in list(checkpoint['state_dict'].items()):
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
            epoch = checkpoint['epoch']
            print(("Loaded checkpoint '{}' (epoch {})"
                   .format(args.resume, checkpoint['epoch'])))
            sys.stdout.flush()
        else:
            print(("No checkpoint found at '{}'".format(args.resume)))
            sys.stdout.flush()

    model.eval()
    return model, epoch
예제 #9
0
def test(args):
    import torch
    data_loader = IC15TestLoader(root_dir=args.root_dir, long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        drop_last=True)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=False, num_classes=1, scale=args.scale, train_mode=False)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=1, scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=1, scale=args.scale)
    
    for param in model.parameters():
        param.requires_grad = False

    if args.gpus > 0:
        model = model.cuda()
    
    if args.resume is not None:                                         
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            device = torch.device('cpu') if args.gpus < 0 else None
            checkpoint = torch.load(args.resume, map_location=device)
            
            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()
    model.eval()
    if args.onnx:
        import torch.onnx.symbolic_opset9 as onnx_symbolic
        def upsample_nearest_2d(g, input, output_size):
            scales = g.op('Constant', value_t=torch.Tensor([1., 1., 2., 2.]))
            return g.op("Upsample", input, scales, mode_s='nearest')
        onnx_symbolic = upsample_nearest_2d
        dummy_input = torch.autograd.Variable(torch.randn(1, 3, 640, 640)).cpu()
        torch.onnx.export(model, dummy_input, 'sanet.onnx', verbose=False,
        input_names=["input"],	
        output_names=["gaussian_map", 'border_map'],
        dynamic_axes = {'input':{0:'b', 2:'h', 3:'w'}, 'gaussian_map':{0:'b', 2:'h', 3:'w'}, 'border_map':{0:'b', 2:'h', 3:'w'}}
        )
        return 0
    total_frame = 0.0
    total_time = 0.0
    
    for idx, (org_img, img, scale_val) in enumerate(test_loader):
        print('progress: %d / %d'%(idx, len(test_loader)))
        sys.stdout.flush()
        if args.gpus > 0:
            img = Variable(img.cuda(), volatile=True)
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()
        resize_img = img.cpu().numpy().astype('uint8')[0].transpose((1, 2, 0)).copy()
        if args.gpus > 0:
            torch.cuda.synchronize()
        start = time.time()

        outputs = model(img)
        infer_time = time.time()
        probability_map, border_map = outputs[0].sigmoid(), outputs[1].sigmoid()
#         print(probability_map.max(), probability_map.min())
        score = probability_map[0, 0]
        border_score = border_map[0, 0]
#         prediction_map = textfill(score.cpu().numpy(), border_score, top_threshold=0.7, end_thershold=0.2)
        post_time = time.time()
        center_text = torch.where(score > 0.7, torch.ones_like(score), torch.zeros_like(score))
        center_text = center_text.data.cpu().numpy().astype(np.uint8)

        text_region = torch.where(score > 0.5, torch.ones_like(score), torch.zeros_like(score))
        border_region = torch.where(border_score > 0.9, torch.ones_like(border_score), torch.zeros_like(border_score))
        prediction_map = text_region.data.cpu().numpy()
        border_region = border_region.data.cpu().numpy()
        prediction_map[border_region==1] = 0
        
        prob_map = probability_map.cpu().numpy()[0, 0] * 255
        bord_map = border_map[0, 0].cpu().numpy() * 255
        out_path = 'outputs/vis_ic15/'
        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
#         cv2.imwrite(out_path + image_name + '_prob.png', prob_map.astype(np.uint8))
#         cv2.imwrite(out_path + image_name + '_bd.png', bord_map.astype(np.uint8))
#         cv2.imwrite(out_path + image_name + '_tr.png', text_region.astype(np.uint8) * 255)
#         cv2.imwrite(out_path + image_name + '_fl.png', prediction_map.astype(np.uint8) * 255)
    
        scale = (org_img.shape[1] * 1.0 / img.shape[1], org_img.shape[0] * 1.0 / img.shape[0])
        bboxes = []
        scale_val = scale_val.cpu().numpy()
        nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(prediction_map.astype(np.uint8), connectivity=4)
        t5 = time.time()
#         nLabels = prediction_map.max()
#         print("nLabels:", nLabels)
        img_h, img_w = prediction_map.shape[:2]
        for k in range(1, nLabels):
        
            size = stats[k, cv2.CC_STAT_AREA]
#             if size < 10: continue
            # make segmentation map
            segmap = np.zeros(prediction_map.shape, dtype=np.uint8)
            segmap[labels==k] = 255
#             segmap[np.logical_and(border_score > 0.7, score.cpu().numpy() < 0.05)] = 0   # remove link area
            x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
            w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
#             print("xywh:", x, y, w, h, " size:", size)
            niter = int(math.sqrt(size * min(w, h) / (w * h)) * 4.3)
            sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
#             print("info:", sy, ey, sx, ex)
            # boundary check
            if sx < 0 : sx = 0
            if sy < 0 : sy = 0
            if ex >= img_w: ex = img_w
            if ey >= img_h: ey = img_h
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))

            segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
            ############### original postprocess ################
#             # make segmentation map
#             segmap = np.zeros(score.shape, dtype=np.uint8)
#             segmap[prediction_map==k] = 255

#             # contourexpand
#             text_area = np.sum(segmap)
#             kernel = dilated_kernel(text_area)
            ############## original postprocess ################
            
#             segmap = cv2.dilate(segmap, kernel, iterations=1)
            np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
            rectangle = cv2.minAreaRect(np_contours)
            box = cv2.boxPoints(rectangle) * 4
            box = box / scale_val
            box = box.astype('int32')
            bboxes.append(box)
        t6 = time.time()
        print("infer_time:{}, post_time:{}, expand_time:{}".format(infer_time-start, post_time-infer_time, t6-t5))
        # find contours
        bboxes = np.array(bboxes)
        num_box = bboxes.shape[0]
        if args.gpus > 0:
            torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0), 2)
        image_name = ".".join(data_loader.img_paths[idx].split('/')[-1].split('.')[:-1])
        write_result_as_txt(image_name, bboxes.reshape((-1, 8)), 'outputs/submit_ic15/')
        
        debug(idx, data_loader.img_paths, [[text_box]], 'outputs/vis_ic15/')
예제 #10
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNetV2()
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        # print(model)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

    scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_acc = train(train_loader=train_loader,
                                      model=model,
                                      metric_fc=metric_fc,
                                      criterion=criterion,
                                      optimizer=optimizer,
                                      epoch=epoch,
                                      logger=logger)

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_acc', train_acc, epoch)

        # One epoch's validation
        lfw_acc, threshold = lfw_test(model)
        writer.add_scalar('model/valid_acc', lfw_acc, epoch)
        writer.add_scalar('model/valid_thres', threshold, epoch)

        # Check if there was an improvement
        is_best = lfw_acc > best_acc
        best_acc = max(lfw_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)

        scheduler.step(epoch)
예제 #11
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        elif args.network == 'mr18':
            print("mr18")
            model = myResnet18()
        else:
            model = resnet_face18(args.use_se)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        scheduler.step()

        if args.full_log:
            lfw_acc, threshold = lfw_test(model)
            writer.add_scalar('LFW_Accuracy', lfw_acc, epoch)
            full_log(epoch)

        start = datetime.now()
        # One epoch's training
        train_loss, train_top5_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger,
                                            writer=writer)

        writer.add_scalar('Train_Loss', train_loss, epoch)
        writer.add_scalar('Train_Top5_Accuracy', train_top5_accs, epoch)

        end = datetime.now()
        delta = end - start
        print('{} seconds'.format(delta.seconds))

        # One epoch's validation
        lfw_acc, threshold = lfw_test(model)
        writer.add_scalar('LFW Accuracy', lfw_acc, epoch)

        # Check if there was an improvement
        is_best = lfw_acc > best_acc
        best_acc = max(lfw_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)
예제 #12
0
                                          shuffle=True,
                                          num_workers=4)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False,
                                         num_workers=4)

net = None
if args.depth == 18:
    net = models.resnet18()
if args.depth == 50:
    net = models.resnet50()
if args.depth == 101:
    net = models.resnet101()
if args.depth == 152:
    net = models.resnet152()

net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LR, weight_decay=5e-4, momentum=0.9)

if __name__ == "__main__":
    best_acc = 0
    for epoch in range(args.epoch):
        if epoch in [60, 140, 180]:
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10
        net.train()
        sum_loss = 0.0
        correct = 0.0
        total = 0.0
예제 #13
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        if args.pretrained:
            model.load_state_dict(torch.load('insight-face-v3.pt'))

        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        nesterov=True,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma)
    else:
        criterion = nn.CrossEntropyLoss()

    # Custom dataloaders
    # train_dataset = ArcFaceDataset('train')
    # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
    #                                            num_workers=num_workers)
    train_dataset = ArcFaceDatasetBatched('train', img_batch_size)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size //
                                               img_batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               collate_fn=batched_collate_fn)

    scheduler = MultiStepLR(optimizer, milestones=[8, 16, 24, 32], gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        lr = optimizer.param_groups[0]['lr']
        logger.info('\nCurrent effective learning rate: {}\n'.format(lr))
        # print('Step num: {}\n'.format(optimizer.step_num))
        writer.add_scalar('model/learning_rate', lr, epoch)

        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch)

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)

        scheduler.step(epoch)

        if args.eval_ds == "LFW":
            from lfw_eval import lfw_test

            # One epochs's validata
            accuracy, threshold = lfw_test(model)

        elif args.eval_ds == "Megaface":
            from megaface_eval import megaface_test

            accuracy = megaface_test(model)

        else:
            accuracy = -1

        writer.add_scalar('model/evaluation_accuracy', accuracy, epoch)

        # Check if there was an improvement
        is_best = accuracy > best_acc
        best_acc = max(accuracy, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("\nEpochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best, scheduler)
예제 #14
0
def main():
    # parser = argparse.ArgumentParser(description='Hyperparams')
    # parser.add_argument('--arch', nargs='?', type=str, default='resnet50')
    # parser.add_argument('--img_size', nargs='?', type=int, default=640,
    #                     help='Height of the input image')
    # parser.add_argument('--n_epoch', nargs='?', type=int, default=600,
    #                     help='# of the epochs')
    # parser.add_argument('--schedule', type=int, nargs='+', default=[200, 400],
    #                     help='Decrease learning rate at these epochs.')
    # parser.add_argument('--batch_size', nargs='?', type=int, default=1,
    #                     help='Batch Size')
    # parser.add_argument('--lr', nargs='?', type=float, default=1e-3,
    #                     help='Learning Rate')
    # parser.add_argument('--resume', nargs='?', type=str, default=None,
    #                     help='Path to previous saved model to restart from')
    # parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
    #                     help='path to save checkpoint (default: checkpoint)')
    # args = parser.parse_args()

    # lr = args.lr
    # schedule = args.schedule
    # batch_size = args.batch_size
    # n_epoch = args.n_epoch
    # image_size = args.img_size
    # resume = args.resume
    # checkpoint_path = args.checkpoint
    # arch = args.arch

    lr = 1e-3
    schedule = [200, 400]
    batch_size = 16
    # batch_size = 1
    n_epoch = 100
    image_size = 640
    checkpoint_path = ''
    # arch = 'resnet50'
    arch = 'mobilenetV2'
    resume = "checkpoints/ReCTS_%s_bs_%d_ep_%d" % (arch, batch_size, 5)
    # resume = None

    if checkpoint_path == '':
        checkpoint_path = "checkpoints/ReCTS_%s_bs_%d_ep_%d" % (
            arch, batch_size, n_epoch)

    print('checkpoint path: %s' % checkpoint_path)
    print('init lr: %.8f' % lr)
    print('schedule: ', schedule)
    sys.stdout.flush()

    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = ReCTSDataLoader(
        need_transform=True,
        img_size=image_size,
        kernel_num=kernel_num,
        min_scale=min_scale,
        train_data_dir='../ocr_data/ReCTS/img/',
        train_gt_dir='../ocr_data/ReCTS/gt/'
        # train_data_dir='/kaggle/input/rects-ocr/img/',
        # train_gt_dir='/kaggle/input/rects-ocr/gt/'
    )

    ctw_root_dir = 'data/'

    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if arch == "resnet50":
        model = models.resnet50(pretrained=False, num_classes=kernel_num)
    elif arch == "resnet101":
        model = models.resnet101(pretrained=False, num_classes=kernel_num)
    elif arch == "resnet152":
        model = models.resnet152(pretrained=False, num_classes=kernel_num)
    elif arch == "mobilenetV2":
        model = PSENet(backbone="mobilenetv2",
                       pretrained=False,
                       result_num=kernel_num,
                       scale=1)

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()
        device = 'cuda'
    else:
        model = torch.nn.DataParallel(model)
        device = 'cpu'

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.99,
                                weight_decay=5e-4)

    title = 'ReCTS'
    if resume:
        print('Resuming from checkpoint.')
        checkpoint_file_path = os.path.join(resume, "checkpoint.pth.tar")
        assert os.path.isfile(
            checkpoint_file_path
        ), 'Error: no checkpoint directory: %s found!' % checkpoint_file_path

        checkpoint = torch.load(checkpoint_file_path,
                                map_location=torch.device(device))
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        shutil.copy(os.path.join(resume, 'log.txt'),
                    os.path.join(checkpoint_path, 'log.txt'))
        logger = Logger(os.path.join(checkpoint_path, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(checkpoint_path, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, n_epoch):
        lr = adjust_learning_rate(schedule, lr, optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, n_epoch, optimizer.param_groups[0]['lr']))

        stat(model, (3, image_size, image_size))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch, lr,
            checkpoint_path)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=checkpoint_path)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()
예제 #15
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            from mobilenet_v2 import MobileNetV2
            model = MobileNetV2()
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        metric_fc = ArcMarginModel(args)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    model = nn.DataParallel(model)
    metric_fc = nn.DataParallel(metric_fc)

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # Decay learning rate if there is no improvement for 2 consecutive epochs, and terminate training after 10
        if epochs_since_improvement == 10:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            checkpoint = 'BEST_checkpoint.tar'
            checkpoint = torch.load(checkpoint)
            model = checkpoint['model']
            metric_fc = checkpoint['metric_fc']
            optimizer = checkpoint['optimizer']

            adjust_learning_rate(optimizer, 0.5)

        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch)
        lr = optimizer.param_groups[0]['lr']
        print('\nCurrent effective learning rate: {}\n'.format(lr))
        # print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', lr, epoch)

        if epoch % 5 == 0:
            # One epoch's validation
            megaface_acc = megaface_test(model)
            writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch)

            # Check if there was an improvement
            is_best = megaface_acc > best_acc
            best_acc = max(megaface_acc, best_acc)
            if not is_best:
                epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" %
                      (epochs_since_improvement, ))
            else:
                epochs_since_improvement = 0

            # Save checkpoint
            save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                            optimizer, best_acc, is_best)
예제 #16
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ctw1500_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print(('checkpoint path: %s' % args.checkpoint))
    print(('init lr: %.8f' % args.lr))
    print(('schedule: ', args.schedule))
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = CTW1500Loader(is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    title = 'CTW1500'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' %
               (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()
예제 #17
0
def test(args):
    data_loader = IC15TestLoader(long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=6,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    bboxs = []
    bboxes = []
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        img = Variable(img.cuda())
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        with torch.no_grad():
            outputs = model(img)

        torch.cuda.synchronize()
        start = time.time()

        similarity_vector = outputs[0, 2:, :, :]
        similarity_vector_ori = similarity_vector.permute((1, 2, 0))

        score = torch.sigmoid(outputs[:, 0, :, :])
        score = score.data.cpu().numpy()[0].astype(np.float32)

        outputs = (torch.sign(outputs - 1.0) + 1) / 2

        text = outputs[0, 0, :, :]
        kernel = outputs[0, 1, :, :] * text

        tag_cat, label_kernel, label_text = get_cat_tag(text, kernel)
        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        #         cv2.imwrite('./test_result/image/text_'+image_name+'.jpg',label_text*255)
        #         cv2.imwrite('./test_result/image/kernel_'+image_name+'.jpg',label_kernel*255)

        label_text = torch.Tensor(label_text).cuda()
        label_kernel = torch.Tensor(label_kernel).cuda()

        w, h, _ = similarity_vector_ori.shape
        similarity_vector = similarity_vector.permute(
            (1, 2, 0)).data.cpu().numpy()
        bboxs = []
        bboxes = []
        scale = (org_img.shape[1] * 1.0 / text.shape[1],
                 org_img.shape[0] * 1.0 / text.shape[0])

        for item in tag_cat:

            similarity_vector_ori1 = similarity_vector_ori.clone()
            #             mask = torch.zeros((w,h)).cuda()
            index_k = (label_kernel == item[0])
            index_t = (label_text == item[1])
            similarity_vector_k = torch.sum(
                similarity_vector_ori1[index_k],
                0) / similarity_vector_ori1[index_k].shape[0]
            #             similarity_vector_t = similarity_vector_ori1[index_t]

            #             similarity_vector_t = similarity_vector_ori1[index_t]
            similarity_vector_ori1[~index_t] = similarity_vector_k
            similarity_vector_ori1 = similarity_vector_ori1.reshape(-1, 4)
            out = torch.norm((similarity_vector_ori1 - similarity_vector_k), 2,
                             1)

            #             out = torch.norm((similarity_vector_t-similarity_vector_k),2,1)
            #             print(out.shape)
            #             mask[index_t] = out
            out = out.reshape(w, h)

            out = out * ((text > 0).float())
            #             out = mask*((text>0).float())

            out[out > 0.8] = 0
            out[out > 0] = 1
            out_im = (text * out).data.cpu().numpy()
            #             cv2.imwrite('./test_result/image/out_'+image_name+'.jpg',out_im*255)
            points = np.array(np.where(out_im == out_im.max())).transpose(
                (1, 0))[:, ::-1]

            if points.shape[0] < 800:
                continue

            score_i = np.mean(score[out_im == out_im.max()])
            if score_i < 0.93:
                continue

            rect = cv2.minAreaRect(points)
            bbox = cv2.boxPoints(rect) * scale
            bbox = bbox.astype('int32')
            bboxs.append(bbox)
            bboxes.append(bbox.reshape(-1))


#         text_box = scale(text_box, long_size=2240)
        torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))

        for bbox in bboxs:
            text_box = cv2.line(text_box, (bbox[0, 0], bbox[0, 1]),
                                (bbox[1, 0], bbox[1, 1]), (0, 0, 255), 2)
            text_box = cv2.line(text_box, (bbox[1, 0], bbox[1, 1]),
                                (bbox[2, 0], bbox[2, 1]), (0, 0, 255), 2)
            text_box = cv2.line(text_box, (bbox[2, 0], bbox[2, 1]),
                                (bbox[3, 0], bbox[3, 1]), (0, 0, 255), 2)
            text_box = cv2.line(text_box, (bbox[3, 0], bbox[3, 1]),
                                (bbox[0, 0], bbox[0, 1]), (0, 0, 255), 2)
        write_result_as_txt(image_name, bboxes, 'test_result/submit_ic15/')
        cv2.imwrite('./test_result/image/' + image_name + '.jpg', text_box)
예제 #18
0
def main():
    args = parser.parse_args()

    data_dir = args.data_dir
    val_file = args.list_files
    ext_batch_sz = int(args.ext_batch_sz)
    int_batch_sz = int(args.int_batch_sz)
    start_instance = int(args.start_instance)
    end_instance = int(args.end_instance)
    checkpoint = args.checkpoint_path

    model_start_time = time.time()
    if args.architecture == "inception_v3":
        new_size = 299
        num_categories = 3528, 3468, 2048
        spatial_net = models.inception_v3(pretrained=(checkpoint == ""),
                                          num_outputs=len(num_categories))
    else:  #resnet
        new_size = 224
        num_categories = 8192, 4096, 2048
        spatial_net = models.resnet152(pretrained=(checkpoint == ""),
                                       num_outputs=len(num_categories))

    if os.path.isfile(checkpoint):
        print('loading checkpoint {} ...'.format(checkpoint))
        params = torch.load(checkpoint)
        model_dict = spatial_net.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in params['state_dict'].items() if k in model_dict
        }

        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        spatial_net.load_state_dict(model_dict)
        print('loaded')
    else:
        print(checkpoint)
        print('ERROR: No checkpoint found')

    spatial_net.cuda()
    spatial_net.eval()
    model_end_time = time.time()
    model_time = model_end_time - model_start_time
    print("Action recognition model is loaded in %4.4f seconds." %
          (model_time))

    f_val = open(val_file, "r")
    val_list = f_val.readlines()[start_instance:end_instance]
    print("we got %d test videos" % len(val_list))

    line_id = 1
    match_count = 0

    for line_id, line in enumerate(val_list):
        print("sample %d/%d" % (line_id + 1, len(val_list)))
        line_info = line.split(" ")
        clip_path = os.path.join(data_dir, line_info[0])
        num_frames = int(line_info[1])
        input_video_label = int(line_info[2])

        spatial_prediction = VideoSpatialPrediction(clip_path, spatial_net,
                                                    num_categories, num_frames,
                                                    ext_batch_sz, int_batch_sz,
                                                    new_size)

        for ii in range(len(spatial_prediction)):
            for vr_ind, vr in enumerate(spatial_prediction[ii]):
                folder_name = args.architecture + "_" + args.dataset + "_VR" + str(
                    ii)
                if not os.path.isdir(folder_name + '/' + line_info[0]):
                    print("creating folder: " + folder_name + "/" +
                          line_info[0])
                    os.makedirs(folder_name + "/" + line_info[0])
                vr_name = folder_name + '/' + line_info[
                    0] + '/vr_{0:02d}.png'.format(vr_ind)
                vr_gray = normalize_maxmin(vr.transpose()).transpose() * 255.
                cv2.imwrite(vr_name, vr_gray)
예제 #19
0
def test(args):
    data_loader = IC15TestLoader(long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=18,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=18,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=18,
                                 scale=args.scale)
    elif args.arch == "vgg16":
        model = models.vgg16(pretrained=True, num_classes=18)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print(("Loading model and optimizer from checkpoint '{}'".format(
                args.resume)))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in list(checkpoint['state_dict'].items()):
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print(("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
            sys.stdout.flush()
        else:
            print(("No checkpoint found at '{}'".format(args.resume)))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print(('progress: %d / %d' % (idx, len(test_loader))))
        sys.stdout.flush()

        img = img.cuda()
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        torch.cuda.synchronize()
        start = time.time()

        cls_logits, link_logits = model(img)

        outputs = torch.cat((cls_logits, link_logits), dim=1)
        shape = outputs.shape
        pixel_pos_scores = F.softmax(outputs[:, 0:2, :, :], dim=1)[:, 1, :, :]
        # pixel_pos_scores=torch.sigmoid(outputs[:,1,:,:])
        # FIXME the dimention should be changed
        link_scores = outputs[:, 2:, :, :].view(shape[0], 2, 8, shape[2],
                                                shape[3])
        link_pos_scores = F.softmax(link_scores, dim=1)[:, 1, :, :, :]

        mask, bboxes = to_bboxes(org_img,
                                 pixel_pos_scores.cpu().numpy(),
                                 link_pos_scores.cpu().numpy())

        score = pixel_pos_scores[0, :, :]
        score = score.data.cpu().numpy().astype(np.float32)

        torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print(('fps: %.2f' % (total_frame / total_time)))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0),
                             2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes, 'outputs/submit_ic15/')

        text_box = cv2.resize(text_box, (org_img.shape[1], org_img.shape[0]))
        score_s = cv2.resize(
            np.repeat(score[:, :, np.newaxis] * 255, 3, 2).astype(np.uint8),
            (org_img.shape[1], org_img.shape[0]))
        mask = cv2.resize(
            np.repeat(mask[:, :, np.newaxis], 3, 2).astype(np.uint8),
            (org_img.shape[1], org_img.shape[0]))

        link_score = (link_pos_scores[0, 0, :, :]).cpu().numpy() * (
            score > 0.5).astype(np.float)
        link_score = cv2.resize(
            np.repeat(link_score[:, :, np.newaxis] * 255, 3,
                      2).astype(np.uint8),
            (org_img.shape[1], org_img.shape[0]))
        debug(idx, data_loader.img_paths,
              [[text_box, score_s], [link_score, mask]], 'outputs/vis_ic15/')

    cmd = 'cd %s;zip -j %s %s/*' % ('./outputs/', 'submit_ic15.zip',
                                    'submit_ic15')
    print(cmd)
    sys.stdout.flush()
    util.cmd.cmd(cmd)
    cmd_eval = 'cd eval;sh eval_ic15.sh'
    sys.stdout.flush()
    util.cmd.cmd(cmd_eval)
def train_net(args):
    torch.manual_seed(7)  #torch的随机种子,在torch.randn使用
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()  #tensorboard
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        else:
            model = resnet_face18(args.use_se)

        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
            #                             lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)
            optimizer = InsightFaceOptimizer(
                torch.optim.SGD([{
                    'params': model.parameters()
                }, {
                    'params': metric_fc.parameters()
                }],
                                lr=args.lr,
                                momentum=args.mom,
                                weight_decay=args.weight_decay))
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        #这里还需要自己加载进去
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = Dataset(root=args.train_path,
                            phase='train',
                            input_shape=(3, 112, 112))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        # 这里写一个训练函数十分简练,值得学习
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger)
        print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', optimizer.lr, epoch)

        # Save checkpoint
        if epoch % 10 == 0:
            save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                            optimizer, best_acc)
예제 #21
0
def test(args):
    data_loader = IC15TestLoader(long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    submit_path = 'outputs'
    if os.path.exists(submit_path):
        shutil.rmtree(submit_path)

    os.mkdir(submit_path)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=7,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        img = Variable(img.cuda(), volatile=True)
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        torch.cuda.synchronize()
        start = time.time()
        # with torch.no_grad():
        #     outputs = model(img)
        #     outputs = torch.sigmoid(outputs)
        #     score = outputs[:, 0, :, :]
        #     outputs = outputs > args.threshold # torch.uint8
        #     text = outputs[:, 0, :, :]
        #     kernels = outputs[:, 0:args.kernel_num, :, :] * text
        # score = score.squeeze(0).cpu().numpy()
        # text = text.squeeze(0).cpu().numpy()
        # kernels = kernels.squeeze(0).cpu().numpy()
        # print(img.shape)
        outputs = model(img)

        score = torch.sigmoid(outputs[:, 0, :, :])
        outputs = (torch.sign(outputs - args.binary_th) + 1) / 2

        text = outputs[:, 0, :, :]
        kernels = outputs[:, 0:args.kernel_num, :, :] * text
        # print(score.shape)
        # score = score.data.cpu().numpy()[0].astype(np.float32)
        # text = text.data.cpu().numpy()[0].astype(np.uint8)
        # kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)
        score = score.squeeze(0).cpu().numpy().astype(np.float32)
        text = text.squeeze(0).cpu().numpy().astype(np.uint8)
        kernels = kernels.squeeze(0).cpu().numpy().astype(np.uint8)
        tmp_marker = kernels[-1, :, :]
        # for i in range(args.kernel_num-2, -1, -1):
        # sure_fg = tmp_marker
        # sure_bg = kernels[i, :, :]
        # watershed_source = cv2.cvtColor(sure_bg, cv2.COLOR_GRAY2BGR)
        # unknown = cv2.subtract(sure_bg,sure_fg)
        # ret, marker = cv2.connectedComponents(sure_fg)
        # label_num = np.max(marker)
        # marker += 1
        # marker[unknown==1] = 0
        # marker = cv2.watershed(watershed_source, marker)
        # marker[marker==-1] = 1
        # marker -= 1
        # tmp_marker = np.asarray(marker, np.uint8)
        sure_fg = kernels[-1, :, :]
        sure_bg = text
        watershed_source = cv2.cvtColor(sure_bg, cv2.COLOR_GRAY2BGR)
        unknown = cv2.subtract(sure_bg, sure_fg)
        ret, marker = cv2.connectedComponents(sure_fg)
        label_num = np.max(marker)
        marker += 1
        marker[unknown == 1] = 0
        marker = cv2.watershed(watershed_source, marker)
        marker -= 1
        label = marker

        # label = tmp_marker
        # scale = (w / marker.shape[1], h / marker.shape[0])
        scale = (org_img.shape[1] * 1.0 / marker.shape[1],
                 org_img.shape[0] * 1.0 / marker.shape[0])
        bboxes = []
        # print(label_num)
        for i in range(1, label_num + 1):
            # get [x,y] pair, points.shape=[n, 2]
            points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]
            # similar to pixellink's min_area when post-processing
            if points.shape[0] < args.min_area / (args.scale * args.scale):
                continue
            #this filter op is very important, f-score=68.0(without) vs 69.1(with)
            score_i = np.mean(score[label == i])
            if score_i < args.min_score:
                continue
            rect = cv2.minAreaRect(points)
            bbox = cv2.boxPoints(rect) * scale
            # bbox=.tolist()
            # bbox.append(score_i)
            bbox = bbox.astype('int32')
            bboxes.append(bbox.reshape(-1))

        torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0),
                             2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes, 'outputs/submit_ic15/')

        text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
        # if idx % 200 == 0:
        #     debug(idx, data_loader.img_paths, [[text_box]], 'outputs/vis_ic15/')

    cmd = 'cd %s;zip -j %s %s/*' % ('./outputs/', 'submit_ic15.zip',
                                    'submit_ic15')
    print(cmd)
    sys.stdout.flush()
    util.cmd.Cmd(cmd)
예제 #22
0
def test(args):
    data_loader = CTW1500TestLoader(long_size=args.long_size)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)
    submit_path = 'outputs_shape'
    if os.path.exists(submit_path):
        shutil.rmtree(submit_path)

    os.mkdir(submit_path)
    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=7,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        img = Variable(img.cuda(), volatile=True)
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        torch.cuda.synchronize()
        start = time.time()

        outputs = model(img)

        score = torch.sigmoid(outputs[:, 0, :, :])
        outputs = (torch.sign(outputs - args.binary_th) + 1) / 2

        text = outputs[:, 0, :, :]
        kernels = outputs[:, 0:args.kernel_num, :, :] * text

        score = score.data.cpu().numpy()[0].astype(np.float32)
        text = text.data.cpu().numpy()[0].astype(np.uint8)
        kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)

        # # c++ version pse
        # # pred = pse(kernels, args.min_kernel_area / (args.scale * args.scale))
        # # python version pse
        # pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale))

        # # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1])
        # scale = (org_img.shape[1] * 1.0 / pred.shape[1], org_img.shape[0] * 1.0 / pred.shape[0])
        # label = pred
        # label_num = np.max(label) + 1
        # bboxes = []

        tmp_marker = kernels[-1, :, :]
        # for i in range(args.kernel_num-2, -1, -1):
        # sure_fg = tmp_marker
        # sure_bg = kernels[i, :, :]
        # watershed_source = cv2.cvtColor(sure_bg, cv2.COLOR_GRAY2BGR)
        # unknown = cv2.subtract(sure_bg,sure_fg)
        # ret, marker = cv2.connectedComponents(sure_fg)
        # label_num = np.max(marker)
        # marker += 1
        # marker[unknown==1] = 0
        # marker = cv2.watershed(watershed_source, marker)
        # marker[marker==-1] = 1
        # marker -= 1
        # tmp_marker = np.asarray(marker, np.uint8)
        sure_fg = kernels[-1, :, :]
        sure_bg = text
        watershed_source = cv2.cvtColor(sure_bg, cv2.COLOR_GRAY2BGR)
        unknown = cv2.subtract(sure_bg, sure_fg)
        ret, marker = cv2.connectedComponents(sure_fg)
        label_num = np.max(marker)
        marker += 1
        marker[unknown == 1] = 0
        marker = cv2.watershed(watershed_source, marker)
        marker -= 1
        label = marker

        # label = tmp_marker
        # scale = (w / marker.shape[1], h / marker.shape[0])
        scale = (org_img.shape[1] * 1.0 / marker.shape[1],
                 org_img.shape[0] * 1.0 / marker.shape[0])
        bboxes = []
        for i in range(1, label_num):
            points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]

            if points.shape[0] < args.min_area / (args.scale * args.scale):
                continue

            score_i = np.mean(score[label == i])
            if score_i < args.min_score:
                continue

            # rect = cv2.minAreaRect(points)1
            binary = np.zeros(label.shape, dtype='uint8')
            binary[label == i] = 1
            # a=cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            # print(a)
            contours, _ = cv2.findContours(binary, cv2.RETR_TREE,
                                           cv2.CHAIN_APPROX_SIMPLE)
            contour = contours[0]
            # epsilon = 0.01 * cv2.arcLength(contour, True)
            # bbox = cv2.approxPolyDP(contour, epsilon, True)
            bbox = contour

            if bbox.shape[0] <= 2:
                continue

            bbox = bbox * scale
            bbox = bbox.astype('int32')
            bboxes.append(bbox.reshape(-1))

        torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(bbox.shape[0] // 2, 2)],
                             -1, (0, 255, 0), 2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes,
                            'outputs_shape/submit_ctw1500/')

        text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
        debug(idx, data_loader.img_paths, [[text_box]],
              'outputs_shape/vis_ctw1500/')
예제 #23
0
def test(args, file=None):
    result = []
    data_loader = DataLoader(long_size=args.long_size, file=file)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    slice = 0
    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=7,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "mobilenet":
        model = models.Mobilenet(pretrained=True,
                                 num_classes=6,
                                 scale=args.scale)
        slice = -1

    for param in model.parameters():
        param.requires_grad = False

    # model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value

            try:
                model.load_state_dict(d)
            except:
                model.load_state_dict(checkpoint['state_dict'])

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        # img = Variable(img.cuda(), volatile=True)
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        # torch.cuda.synchronize()
        start = time.time()

        # angle detection
        # org_img, angle = detect_angle(org_img)
        outputs = model(img)

        score = torch.sigmoid(outputs[:, slice, :, :])
        outputs = (torch.sign(outputs - args.binary_th) + 1) / 2

        text = outputs[:, slice, :, :]
        kernels = outputs
        # kernels = outputs[:, 0:args.kernel_num, :, :] * text

        score = score.data.cpu().numpy()[0].astype(np.float32)
        text = text.data.cpu().numpy()[0].astype(np.uint8)
        kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)

        if args.arch == 'mobilenet':
            pred = pse2(kernels,
                        args.min_kernel_area / (args.scale * args.scale))
        else:
            # c++ version pse
            pred = pse(kernels,
                       args.min_kernel_area / (args.scale * args.scale))
            # python version pse
            # pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale))

        # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1])
        scale = (org_img.shape[1] * 1.0 / pred.shape[1],
                 org_img.shape[0] * 1.0 / pred.shape[0])
        label = pred
        label_num = np.max(label) + 1
        bboxes = []
        rects = []
        for i in range(1, label_num):
            points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]

            if points.shape[0] < args.min_area / (args.scale * args.scale):
                continue

            score_i = np.mean(score[label == i])
            if score_i < args.min_score:
                continue

            rect = cv2.minAreaRect(points)
            bbox = cv2.boxPoints(rect) * scale
            bbox = bbox.astype('int32')
            bbox = order_point(bbox)
            # bbox = np.array([bbox[1], bbox[2], bbox[3], bbox[0]])
            bboxes.append(bbox.reshape(-1))

            rec = []
            rec.append(rect[-1])
            rec.append(rect[1][1] * scale[1])
            rec.append(rect[1][0] * scale[0])
            rec.append(rect[0][0] * scale[0])
            rec.append(rect[0][1] * scale[1])
            rects.append(rec)

        # torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0),
                             2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes, 'outputs/submit_invoice/')

        text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
        debug(idx, data_loader.img_paths, [[text_box]], 'data/images/tmp/')

        result = crnnRec(cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB), rects)
        result = formatResult(result)

    # cmd = 'cd %s;zip -j %s %s/*' % ('./outputs/', 'submit_invoice.zip', 'submit_invoice')
    # print(cmd)
    # sys.stdout.flush()
    # util.cmd.Cmd(cmd)
    return result
예제 #24
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        else:
            model = resnet_face18(args.use_se)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
            #                             lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)
            optimizer = InsightFaceOptimizer(
                torch.optim.SGD([{
                    'params': model.parameters()
                }, {
                    'params': metric_fc.parameters()
                }],
                                lr=args.lr,
                                momentum=args.mom,
                                weight_decay=args.weight_decay))
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger)
        print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', optimizer.lr, epoch)

        # One epoch's validation
        megaface_acc = megaface_test(model)
        writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch)

        # Check if there was an improvement
        is_best = megaface_acc > best_acc
        best_acc = max(megaface_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)
예제 #25
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ctw1500_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = CTW1500Loader(is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    #train_loader = ctw_train_loader(data_loader, batch_size=args.batch_size)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    #resnet18 and 34 didn't inplement pretrained
    elif args.arch == "resnet18":
        model = models.resnet18(pretrained=False, num_classes=kernel_num)
    elif args.arch == "resnet34":
        model = models.resnet34(pretrained=False, num_classes=kernel_num)

    elif args.arch == "mobilenetv2":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "mobilenetv3large":
        model = models.mobilenetv3_large(pretrained=False,
                                         num_classes=kernel_num)

    elif args.arch == "mobilenetv3small":
        model = models.mobilenetv3_small(pretrained=False,
                                         num_classes=kernel_num)

    optimizer = tf.keras.optimizers.SGD(learning_rate=args.lr,
                                        momentum=0.99,
                                        decay=5e-4)

    title = 'CTW1500'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'

        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')

        model.load_weights(args.resume)

        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, args.n_epoch):
        optimizer = get_new_optimizer(args, optimizer, epoch)
        print(
            '\nEpoch: [%d | %d] LR: %f' %
            (epoch + 1, args.n_epoch, optimizer.get_config()['learning_rate']))

        train_loader = ctw_train_loader(data_loader,
                                        batch_size=args.batch_size)

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(train_loader, model, dice_loss,\
                                                                                   optimizer, epoch)

        model.save_weights('%s%s' % (args.checkpoint, '/model_tf/weights'))

        logger.append([
            optimizer.get_config()['learning_rate'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
    log = open(
        os.path.join(
            args.save_path,
            'cluster_seed_{}_{}.txt'.format(args.manualSeed, time_for_file())),
        'w')
    print_log('save path : {}'.format(args.save_path), log)
    print_log('------------ Options -------------', log)
    for k, v in sorted(vars(args).items()):
        print_log('Parameter : {:20} = {:}'.format(k, v), log)
    print_log('-------------- End ----------------', log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Pillow version : {}".format(PIL.__version__), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])

    args.downsample = 8  # By default
    args.sigma = args.sigma * args.scale_eval

    data = datasets.GeneralDataset(transform, args.sigma, args.downsample,
                                   args.heatmap_type, args.dataset_name)
    data.load_list(args.train_list, args.num_pts, True)
    loader = torch.utils.data.DataLoader(data,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

    # Load all lists
    all_lines = {}
    for file_path in args.train_list:
        listfile = open(file_path, 'r')
        listdata = listfile.read().splitlines()
        listfile.close()
        for line in listdata:
            temp = line.split(' ')
            assert len(temp) == 6 or len(
                temp) == 7, 'This line has the wrong format : {}'.format(line)
            image_path = temp[0]
            all_lines[image_path] = line

    assert args.n_clusters >= 2, 'The cluster number must be greater than 2'
    resnet = models.resnet152(True).cuda()
    all_features = []
    for i, (inputs, target, mask, points, image_index, label_sign,
            ori_size) in enumerate(loader):
        input_vars = torch.autograd.Variable(inputs.cuda(), volatile=True)
        features, classifications = resnet(input_vars)
        features = features.cpu().data.numpy()
        all_features.append(features)
        if i % args.print_freq == 0:
            print_log(
                '{} {}/{} extract features'.format(time_string(), i,
                                                   len(loader)), log)
    all_features = np.concatenate(all_features, axis=0)
    kmeans_result = KMeans(n_clusters=args.n_clusters,
                           n_jobs=args.workers).fit(all_features)
    print_log('kmeans [{}] calculate done'.format(args.n_clusters), log)
    labels = kmeans_result.labels_.copy()

    cluster_idx = []
    for iL in range(args.n_clusters):
        indexes = np.where(labels == iL)[0]
        cluster_idx.append(len(indexes))
    cluster_idx = np.argsort(cluster_idx)

    for iL in range(args.n_clusters):
        ilabel = cluster_idx[iL]
        indexes = np.where(labels == ilabel)
        if isinstance(indexes, tuple) or isinstance(indexes, list):
            indexes = indexes[0]
        cluster_features = all_features[indexes, :].copy()
        filtered_index = filter_cluster(indexes.copy(), cluster_features, 0.8)

        print_log(
            '{:} [{:2d} / {:2d}] has {:4d} / {:4d} -> {:4d} = {:.2f} images '.
            format(time_string(), iL, args.n_clusters, indexes.size, len(data),
                   len(filtered_index), indexes.size * 1. / len(data)), log)
        indexes = filtered_index.copy()
        save_dir = osp.join(
            args.save_path,
            'cluster-{:02d}-{:02d}'.format(iL, args.n_clusters))
        save_path = save_dir + '.lst'
        #if not osp.isdir(save_path): os.makedirs( save_path )
        print_log('save into {}'.format(save_path), log)
        txtfile = open(save_path, 'w')
        for idx in indexes:
            image_path = data.datas[idx]
            assert image_path in all_lines, 'Not find {}'.format(image_path)
            txtfile.write('{}\n'.format(all_lines[image_path]))
            #basename = osp.basename( image_path )
            #os.system( 'cp {} {}'.format(image_path, save_dir) )
        txtfile.close()
예제 #27
0
def test(args):
    data_loader = IC19TestLoader(long_size=args.long_size,
                                 indic=True,
                                 part_num=5)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=7,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)

    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        img = Variable(img.cuda())
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        torch.cuda.synchronize()
        start = time.time()

        outputs = model(img)

        score = torch.sigmoid(outputs[:, 0, :, :])
        outputs = (torch.sign(outputs - args.binary_th) + 1) / 2

        text = outputs[:, 0, :, :]
        kernels = outputs[:, 0:args.kernel_num, :, :] * text

        score = score.data.cpu().numpy()[0].astype(np.float32)
        text = text.data.cpu().numpy()[0].astype(np.uint8)
        kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)

        # c++ version pse
        #pred = pse(kernels, args.min_kernel_area / (args.scale * args.scale))
        # python version pse
        pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale))

        # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1])
        scale = (org_img.shape[1] * 1.0 / pred.shape[1],
                 org_img.shape[0] * 1.0 / pred.shape[0])
        label = pred
        label_num = np.max(label) + 1
        bboxes = []
        for i in range(1, label_num):
            points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]

            if points.shape[0] < args.min_area / (args.scale * args.scale):
                continue

            score_i = np.mean(score[label == i])
            if score_i < args.min_score:
                continue

            rect = cv2.minAreaRect(points)
            bbox = cv2.boxPoints(rect) * scale
            bbox = bbox.astype('int32')
            bboxes.append(bbox.reshape(-1))

        torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0),
                             2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes, 'outputs/BoundingBoxCords/')

        text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
        debug(idx, data_loader.img_paths, [[text_box]], 'outputs/Detections/')
예제 #28
0
import os
import PIL
import torch
import models
import argparse
import numpy as np

from torchvision import transforms

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_num', default='cuda:0', help='GPU device')
parser.add_argument('--session', default='session1')
opt = parser.parse_args()
device = torch.device(opt.gpu_num)

model = models.resnet152(pretrained=True, progress=False, opt=opt)
for param in model.parameters():
    param.requires_grad = False
model.to(device)

feature_extractor = models.FeatureExtractor(model)
for param in feature_extractor.parameters():
    param.requires_grad = False
feature_extractor.to(device)

transform_list = [
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
예제 #29
0
def main(args):

    #initial setup
    if args.checkpoint == '':
        args.checkpoint = "checkpoints1/ic19val_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print(('checkpoint path: %s' % args.checkpoint))
    print(('init lr: %.8f' % args.lr))
    print(('schedule: ', args.schedule))
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0
    validation_split = 0.1
    random_seed = 42
    prev_val_loss = -1
    val_loss_list = []
    loggertf = tfLogger('./log/' + args.arch)
    #end

    #setup data loaders
    data_loader = IC19Loader(is_transform=True,
                             img_size=args.img_size,
                             kernel_num=kernel_num,
                             min_scale=min_scale)
    dataset_size = len(data_loader)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    train_incidies, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_incidies)
    validate_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True,
                                               sampler=train_sampler)

    validate_loader = torch.utils.data.DataLoader(data_loader,
                                                  batch_size=args.batch_size,
                                                  num_workers=3,
                                                  drop_last=True,
                                                  pin_memory=True,
                                                  sampler=validate_sampler)
    #end

    #Setup architecture and optimizer
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet50":
        model = models.resPAnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet101":
        model = models.resPAnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet152":
        model = models.resPAnet152(pretrained=True, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    #end

    #options to resume/use pretrained model/train from scratch
    title = 'icdar2019MLT'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.',
            'Validate Loss', 'Validate Acc', 'Validate IOU'
        ])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.',
            'Validate Loss', 'Validate Acc', 'Validate IOU'
        ])
    #end

    #start training model
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' %
               (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch, loggertf)
        val_loss, val_te_acc, val_ke_acc, val_te_iou, val_ke_iou = validate(
            validate_loader, model, dice_loss)

        #logging on tensorboard
        loggertf.scalar_summary('Training/Accuracy', train_te_acc, epoch + 1)
        loggertf.scalar_summary('Training/Loss', train_loss, epoch + 1)
        loggertf.scalar_summary('Training/IoU', train_te_iou, epoch + 1)
        loggertf.scalar_summary('Validation/Accuracy', val_te_acc, epoch + 1)
        loggertf.scalar_summary('Validation/Loss', val_loss, epoch + 1)
        loggertf.scalar_summary('Validation/IoU', val_te_iou, epoch + 1)
        #end

        #Boring Book Keeping
        print(("End of Epoch %d", epoch + 1))
        print((
            "Train Loss: {loss:.4f} | Train Acc: {acc: .4f} | Train IOU: {iou_t: .4f}"
            .format(loss=train_loss, acc=train_te_acc, iou_t=train_te_iou)))
        print((
            "Validation Loss: {loss:.4f} | Validation Acc: {acc: .4f} | Validation IOU: {iou_t: .4f}"
            .format(loss=val_loss, acc=val_te_acc, iou_t=val_te_iou)))
        #end

        #Saving improving and Best Models
        val_loss_list.append(val_loss)
        if (val_loss < prev_val_loss or prev_val_loss == -1):
            checkpointname = "{loss:.3f}".format(
                loss=val_loss) + "_epoch" + str(epoch +
                                                1) + "_checkpoint.pth.tar"
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer': optimizer.state_dict(),
                },
                checkpoint=args.checkpoint,
                filename=checkpointname)
        if (val_loss < min(val_loss_list)):
            checkpointname = "best_checkpoint.pth.tar"
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer': optimizer.state_dict(),
                },
                checkpoint=args.checkpoint,
                filename=checkpointname)
        #end
        prev_val_loss = val_loss
        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou, val_loss, val_te_acc, val_te_iou
        ])
    #end traing model
    logger.close()
예제 #30
0
import models
import torch
num_classes = 18
inputs = torch.rand([1, 3, 224, 224])
test = models.resnet34(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.resnet50(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.resnet101(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.resnet152(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.alexnet(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.densenet121(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.densenet169(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.densenet201(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')
test = models.densenet201(num_classes=num_classes, pretrained='imagenet')
assert test(inputs).size()[1] == num_classes
print('ok')