Exemple #1
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size
    checkpoint_dir = args.checkpoint_dir

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    # print(dla_up.__dict__.get(args.arch))
    single_model = dla_up.__dict__.get(args.arch)(classes=args.classes,
                                                  down_ratio=args.down)

    single_model = convert_model(single_model)

    model = torch.nn.DataParallel(single_model).cuda()
    print('model_created')
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
        criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight)
    else:
        # criterion = nn.NLLLoss2d(ignore_index=255)
        criterion = nn.NLLLoss2d(ignore_index=-1)

    criterion.cuda()

    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))  #TODO
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend([transforms.RandomHorizontalFlip()])  #TODO

    t_val = []
    t_val.append(transforms.RandomCrop(crop_size))

    dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/'
    dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/'
    my_train = BasicDataset(dir_img,
                            dir_mask,
                            transforms.Compose(t),
                            is_train=True)

    val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/'
    val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/'
    my_val = BasicDataset(val_dir_img,
                          val_dir_mask,
                          transforms.Compose(t_val),
                          is_train=True)

    train_loader = torch.utils.data.DataLoader(my_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        my_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)  #TODO  batch_size
    print("loader created")
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = None  #TODO

    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    confusion_labels = np.arange(0, 5)
    val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                  ignore_label=-1)

    if args.evaluate:
        confusion_labels = np.arange(0, 2)
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1,
                                                      reduce=True)
        validate(val_loader,
                 model,
                 criterion,
                 confusion_matrix=val_confusion_matrix)
        return
    writer = SummaryWriter(comment=args.log)

    # TODO test val
    # print("test val")
    # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix)

    for epoch in range(start_epoch, args.epochs):
        train_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                        ignore_label=-1)
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch

        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              lr_scheduler,
              confusion_matrix=train_confusion_matrix,
              writer=writer)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict()
            },
            is_best=False,
            filename=checkpoint_path)

        # evaluate on validation set
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1)
        prec1, loss_val = validate(val_loader,
                                   model,
                                   criterion,
                                   confusion_matrix=val_confusion_matrix)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        writer.add_scalar('mIoU/epoch', prec1, epoch + 1)
        writer.add_scalar('loss/epoch', loss_val, epoch + 1)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)

    writer.close()
Exemple #2
0
def train_seg(args):
    writer = SummaryWriter(comment=args.log)
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size
    checkpoint_dir = args.checkpoint_dir

    print(' '.join(sys.argv))
    # logger.info(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    # print(dla_up.__dict__.get(args.arch))
    single_model = dla_up.__dict__.get(args.arch)(classes=args.classes,
                                                  down_ratio=args.down)
    model = torch.nn.DataParallel(single_model).cuda()
    print('model_created')
    if args.bg_weight > 0:
        weight_array = np.ones(args.classes, dtype=np.float32)
        weight_array[0] = args.bg_weight
        weight = torch.from_numpy(weight_array)
        # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
        criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
    else:
        # criterion = nn.NLLLoss2d(ignore_index=255)
        criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))  #TODO
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend([transforms.RandomHorizontalFlip()])  #TODO

    t_val = []
    t_val.append(transforms.RandomCrop(crop_size))

    train_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_train2017.json'
    train_root = '/shared/xudongliu/COCO/train2017/train2017'
    my_train = COCOSeg(train_root,
                       train_json,
                       transforms.Compose(t),
                       is_train=True)

    val_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_val2017.json'
    val_root = '/shared/xudongliu/COCO/2017val/val2017'
    my_val = COCOSeg(val_root,
                     val_json,
                     transforms.Compose(t_val),
                     is_train=True)

    train_loader = torch.utils.data.DataLoader(my_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        my_val,
        batch_size=20,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)  #TODO  batch_size
    print("loader created")

    # optimizer = torch.optim.Adam(single_model.optim_parameters(),
    #                             args.lr,
    #                              weight_decay=args.weight_decay) #TODO adam optimizer
    optimizer = torch.optim.SGD(
        single_model.optim_parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)  #TODO adam optimizer

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              T_max=32)  #TODO
    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    # TODO test val
    # print("test val")
    # prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch

        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              lr_scheduler,
              eval_score=accuracy,
              writer=writer)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict()
            },
            is_best=False,
            filename=checkpoint_path)

        # evaluate on validation set
        prec1, loss_val, recall_val = validate(val_loader,
                                               model,
                                               criterion,
                                               eval_score=accuracy)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        writer.add_scalar('accuracy/epoch', prec1, epoch + 1)
        writer.add_scalar('loss/epoch', loss_val, epoch + 1)
        writer.add_scalar('recall/epoch', recall_val, epoch + 1)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)

    writer.close()
Exemple #3
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    single_model = dla_up.__dict__.get(args.arch)(args.classes,
                                                  pretrained_base,
                                                  down_ratio=args.down)
    model = torch.nn.DataParallel(single_model).cuda()
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
    else:
        criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    data_dir = args.data_dir
    info = dataset.load_dataset_info(data_dir)
    normalize = transforms.Normalize(mean=info.mean, std=info.std)
    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor(), normalize])
    train_loader = torch.utils.data.DataLoader(SegList(
        data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        SegList(
            data_dir,
            'val',
            transforms.Compose([
                transforms.RandomCrop(crop_size),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            binary=(args.classes == 2)),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)
        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)
Exemple #4
0
def run_training(args):
    model = dla.__dict__[args.arch](
        pretrained=args.pretrained, num_classes=args.classes,
        pool_size=args.crop_size // 32)
    model = torch.nn.DataParallel(model)

    best_prec1 = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    data = dataset.get_data(args.data_name)
    if data is None:
        data = dataset.load_dataset_info(args.data, data_name=args.data_name)
    if data is None:
        raise ValueError('{} is not pre-defined in dataset.py and info.json '
                         'does not exist in {}', args.data_name, args.data)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = data_transforms.Normalize(mean=data.mean, std=data.std)
    tt = [data_transforms.RandomResizedCrop(
        args.crop_size, min_area_ratio=args.min_area_ratio,
        aspect_ratio=args.aspect_ratio)]
    if data.eigval is not None and data.eigvec is not None \
            and args.random_color:
        ligiting = data_transforms.Lighting(0.1, data.eigval, data.eigvec)
        jitter = data_transforms.RandomJitter(0.4, 0.4, 0.4)
        tt.extend([jitter, ligiting])
    tt.extend([data_transforms.RandomHorizontalFlip(),
               data_transforms.ToTensor(),
               normalize])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, data_transforms.Compose(tt)),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(args.scale_size),
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            normalize
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.evaluate:
        validate(args, val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train(args, train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(args, val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best, filename=checkpoint_path)
        if (epoch + 1) % args.check_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)