Exemple #1
0
def test_model(args):
    # create model
    model = dla.__dict__[args.arch](pretrained=args.pretrained,
                                    pool_size=args.crop_size // 32)
    model = torch.nn.DataParallel(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {} prec {:.03f}) "
                  .format(args.resume, checkpoint['epoch'], best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    data = dataset.get_data(args.data_name)
    if data is None:
        data = dataset.load_dataset_info(args.data, data_name=args.data_name)
    if data is None:
        raise ValueError('{} is not pre-defined in dataset.py and info.json '
                         'does not exist in {}', args.data_name, args.data)
    # Data loading code
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=data.mean, std=data.std)

    if args.crop_10:
        t = transforms.Compose([
            transforms.Resize(args.scale_size),
            transforms.ToTensor(),
            normalize])
    else:
        t = transforms.Compose([
            transforms.Resize(args.scale_size),
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            normalize])
    val_loader = torch.utils.data.DataLoader(
        ImageFolder(valdir, t, out_name=args.crop_10),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss()

    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.crop_10:
        validate_10(args, val_loader, model,
                    '{}_i_{}_c_10.txt'.format(args.arch, args.start_epoch))
    else:
        validate(args, val_loader, model, criterion)
Exemple #2
0
def test_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    phase = args.phase

    for k, v in args.__dict__.items():
        print(k, ':', v)

    single_model = dla_up.__dict__.get(args.arch)(args.classes,
                                                  down_ratio=args.down)

    model = torch.nn.DataParallel(single_model).cuda()

    data_dir = args.data_dir
    info = dataset.load_dataset_info(data_dir)
    normalize = transforms.Normalize(mean=info.mean, std=info.std)
    # scales = [0.5, 0.75, 1.25, 1.5, 1.75]
    scales = [0.5, 0.75, 1.25, 1.5]
    t = []
    if args.crop_size > 0:
        t.append(transforms.PadToSize(args.crop_size))
    t.extend([transforms.ToTensor(), normalize])
    if args.ms:
        data = SegListMS(data_dir, phase, transforms.Compose(t), scales)
    else:
        data = SegList(data_dir,
                       phase,
                       transforms.Compose(t),
                       out_name=True,
                       out_size=True,
                       binary=args.classes == 2)
    test_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              pin_memory=False)

    cudnn.benchmark = True

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    out_dir = '{}_{:03d}_{}'.format(args.arch, start_epoch, phase)
    if len(args.test_suffix) > 0:
        out_dir += '_' + args.test_suffix

    if args.ms:
        out_dir += '_ms'

    if args.ms:
        mAP = test_ms(test_loader,
                      model,
                      args.classes,
                      save_vis=True,
                      has_gt=phase != 'test' or args.with_gt,
                      output_dir=out_dir,
                      scales=scales)
    else:
        mAP = test(test_loader,
                   model,
                   args.classes,
                   save_vis=True,
                   has_gt=phase != 'test' or args.with_gt,
                   output_dir=out_dir)
    print('mAP: ', mAP)
Exemple #3
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    single_model = dla_up.__dict__.get(args.arch)(args.classes,
                                                  pretrained_base,
                                                  down_ratio=args.down)
    model = torch.nn.DataParallel(single_model).cuda()
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
    else:
        criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    data_dir = args.data_dir
    info = dataset.load_dataset_info(data_dir)
    normalize = transforms.Normalize(mean=info.mean, std=info.std)
    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor(), normalize])
    train_loader = torch.utils.data.DataLoader(SegList(
        data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        SegList(
            data_dir,
            'val',
            transforms.Compose([
                transforms.RandomCrop(crop_size),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            binary=(args.classes == 2)),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)
        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)
def train_net(net, paras):
    # parameters
    img_dir = paras.image_dir
    anno_path = paras.anno_path
    checkpoint_dir = paras.model_save_dir
    val_percent = 0.1
    epochs = paras.epochs
    batch_size = paras.batch_size
    lr = paras.learning_rate
    num_workers = 2

    # torch model saver
    saver = ModelSaver(max_save_num=5)

    # load dataset info
    dataset = load_dataset_info(img_dir, anno_path)
    train_set_info, valid_set_info = split_dataset_info(dataset, val_percent)

    # build dataloader
    building_trainset = Building_Dataset(train_set_info)
    building_validset = Building_Dataset(valid_set_info)
    train_dataloader = torch.utils.data.DataLoader(building_trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    valid_dataloader = torch.utils.data.DataLoader(building_validset,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   num_workers=num_workers)

    # optimizer
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    # loss function
    #criterion = nn.L1Loss(reduce=True, size_average=True)
    criterion = nn.BCELoss()

    train_num = len(building_trainset)
    valid_num = len(building_validset)
    print('''
    Starting training:
        Total Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints save dir: {}
    '''.format(epochs, batch_size, lr, train_num, valid_num, checkpoint_dir))

    # ------------------------
    # start training...
    # ------------------------
    best_valid_loss = 1000
    for epoch in range(1, epochs + 1):
        print('Starting epoch {}/{}.'.format(epoch, epochs))

        # training
        net.train()
        epoch_loss = 0
        for idx, data in enumerate(train_dataloader):
            imgs, true_masks = data

            imgs = imgs.cuda()
            true_masks = true_masks.cuda()

            pred_masks = net(imgs)

            # compute loss
            loss = criterion(pred_masks, true_masks)
            epoch_loss += loss.item()

            if idx % 10 == 0:
                print(f'{idx}/{len(train_dataloader)}, loss: {loss.item()}')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss = epoch_loss / len(train_dataloader)
        print('Epoch finished ! Loss: {}\n'.format(epoch_loss))

        # validation
        net.eval()
        valid_loss = 0
        with torch.no_grad():
            for idx, data in enumerate(valid_dataloader):
                if idx % 10 == 0:
                    print(idx, '/', len(valid_dataloader))
                imgs, true_masks = data

                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

                # inference
                pred_masks = net(imgs)
                # compute loss
                loss = criterion(pred_masks, true_masks)
                valid_loss += loss.item()
        valid_loss = valid_loss / len(valid_dataloader)
        print('Validation finished ! Loss:{}  Best Loss before:{}\n'.format(
            valid_loss, best_valid_loss))

        # save check_point
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            print('New best model find, Checkpoint {} saving...'.format(epoch))
            model_save_path = os.path.join(
                checkpoint_dir, '{}_CP{}.pth'.format(best_valid_loss, epoch))
            #torch.save(net.state_dict(), model_save_path)
            saver.save_new_model(net, model_save_path)
Exemple #5
0
def run_training(args):
    model = dla.__dict__[args.arch](
        pretrained=args.pretrained, num_classes=args.classes,
        pool_size=args.crop_size // 32)
    model = torch.nn.DataParallel(model)

    best_prec1 = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    data = dataset.get_data(args.data_name)
    if data is None:
        data = dataset.load_dataset_info(args.data, data_name=args.data_name)
    if data is None:
        raise ValueError('{} is not pre-defined in dataset.py and info.json '
                         'does not exist in {}', args.data_name, args.data)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = data_transforms.Normalize(mean=data.mean, std=data.std)
    tt = [data_transforms.RandomResizedCrop(
        args.crop_size, min_area_ratio=args.min_area_ratio,
        aspect_ratio=args.aspect_ratio)]
    if data.eigval is not None and data.eigvec is not None \
            and args.random_color:
        ligiting = data_transforms.Lighting(0.1, data.eigval, data.eigvec)
        jitter = data_transforms.RandomJitter(0.4, 0.4, 0.4)
        tt.extend([jitter, ligiting])
    tt.extend([data_transforms.RandomHorizontalFlip(),
               data_transforms.ToTensor(),
               normalize])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, data_transforms.Compose(tt)),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(args.scale_size),
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            normalize
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.evaluate:
        validate(args, val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train(args, train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(args, val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best, filename=checkpoint_path)
        if (epoch + 1) % args.check_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)