def main():
    global args, conf
    args = parser.parse_args()

    # Load configuration
    conf = config.load_config(args.config)

    # Create model
    model_params = utils.get_model_params(conf["network"])
    model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)
    model = torch.nn.DataParallel(model).cuda()

    # Resume from checkpoint
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    # Data loading code
    valdir = os.path.join(args.data, 'val')
    val_transforms = get_transforms(conf["input"])

    batch_size = conf["optimizer"]["batch_size"] if not args.ten_crops else conf["optimizer"]["batch_size"] // 10
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose(val_transforms)),
        batch_size=batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    criterion = nn.CrossEntropyLoss().cuda()
    validate(val_loader, model, criterion)
示例#2
0
def main():
    global args, conf
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)

    try:
        world_size = int(os.environ["WORLD_SIZE"])
        distributed = world_size > 1
    except:
        distributed = False
        world_size = 1

    if distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method="env://")

    # Load configuration
    conf = config.load_config(args.config)

    # Create model
    model_params = utils.get_model_params(conf["network"])
    model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)
    model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    else:
        model = SingleGPU(model)

    # Resume from checkpoint
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint["state_dict"])

    # Data loading code
    valdir = os.path.join(args.data, "val")
    val_transforms = utils.create_test_transforms(conf["input"], args.crop,
                                                  args.scale, args.ten_crops)

    batch_size = (conf["optimizer"]["batch_size"] if not args.ten_crops else
                  conf["optimizer"]["batch_size"] // 10)
    dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms))
    val_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size // world_size,
        shuffle=False,
        sampler=TestDistributedSampler(dataset),
        num_workers=args.workers,
        pin_memory=True,
    )

    criterion = nn.CrossEntropyLoss().cuda()
    utils.validate(val_loader, model, criterion, args.ten_crops,
                   args.print_freq)
示例#3
0
def main():
    global args, best_prec1, logger, conf
    args = parser.parse_args()

    args.distributed = args.world_size > 1
    logger = SummaryWriter(args.log_dir)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Load configuration
    conf = config.load_config(args.config)

    # Create model
    model_params = utils.get_model_params(conf["network"])
    model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)

    if not args.distributed:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        init_weights(model)
        args.start_epoch = 0

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    train_transforms, val_transforms = utils.create_transforms(conf["input"])
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_transforms))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=conf["optimizer"]["batch_size"],
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose(val_transforms)),
        batch_size=conf["optimizer"]["batch_size"],
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch,
                       conf["optimizer"]["schedule"]["epochs"]):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, scheduler, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader,
                         model,
                         criterion,
                         it=epoch * len(train_loader))

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': conf["network"]["arch"],
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
示例#4
0
def main():
    global args, best_prec1, logger, conf, tb
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)

    try:
        world_size = int(os.environ['WORLD_SIZE'])
        distributed = world_size > 1
    except:
        distributed = False
        world_size = 1

    if distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://')

    rank = 0 if not distributed else dist.get_rank()
    init_logger(rank, args.log_dir)
    tb = SummaryWriter(args.log_dir) if rank == 0 else None

    # Load configuration
    conf = config.load_config(args.config)

    # Create model
    model_params = utils.get_model_params(conf["network"])
    model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)

    model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    else:
        model = SingleGPU(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.warning("=> no checkpoint found at '{}'".format(
                args.resume))
    else:
        init_weights(model)
        args.start_epoch = 0

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    train_transforms, val_transforms = utils.create_transforms(conf["input"])
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_transforms))

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=conf["optimizer"]["batch_size"] // world_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler)

    val_dataset = datasets.ImageFolder(valdir,
                                       transforms.Compose(val_transforms))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=conf["optimizer"]["batch_size"] // world_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        sampler=TestDistributedSampler(val_dataset))

    if args.evaluate:
        utils.validate(val_loader,
                       model,
                       criterion,
                       print_freq=args.print_freq,
                       tb=tb,
                       logger=logger.info)
        return

    for epoch in range(args.start_epoch,
                       conf["optimizer"]["schedule"]["epochs"]):
        if distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, scheduler, epoch)

        # evaluate on validation set
        prec1 = utils.validate(val_loader,
                               model,
                               criterion,
                               it=epoch * len(train_loader),
                               print_freq=args.print_freq,
                               tb=tb,
                               logger=logger.info)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': conf["network"]["arch"],
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)