Esempio n. 1
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print('args:', args)

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Val data loading
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(args.input_dim + 32),
            transforms.CenterCrop(args.input_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    num_classes = len(val_dataset.classes)
    print('Total classes: ', num_classes)

    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch == 'peleenet':
        model = PeleeNet(num_classes=num_classes)
    else:
        print(
            "=> unsupported model '{}'. creating PeleeNet by default.".format(
                args.arch))
        model = PeleeNet(num_classes=num_classes)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        # DataParallel will divide
        model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.pretrained:
        if os.path.isfile(args.weights):
            checkpoint = torch.load(args.weights,
                                    map_location=torch.device('cpu'))
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {}, acc@1 {})".format(
                args.pretrained, checkpoint['epoch'], checkpoint['best_acc1']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    if args.tune:
        model.eval()
        model.module.fuse_model()
        import ilit
        tuner = ilit.Tuner("./conf.yaml")
        q_model = tuner.tune(model)
        exit(0)

    # Training data loading
    traindir = os.path.join(args.data, 'train')

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.input_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion)

        # remember best Acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Esempio n. 2
0
        if epoch < 2:
            count = train_warmup(train_set, val_set, model, optimizer_conv,
                                 criteria1, epoch, writer, count)
        else:
            count, train_loss = train(train_set, val_set, model,
                                      optimizer_conv, criteria1, writer, count,
                                      epoch)
            print("train_loss : {0}, lr : {1}".format(
                train_loss, optimizer_conv.param_groups[0]['lr']))
            schedule.step(train_loss)
            #schedule.step()

        val_top1, val_top3, val_top5, val_loss = validate(
            val_set, model, criteria1)
        writer = add_summery(writer, 'val', val_loss, val_top1, val_top5,
                             count)
        if_best_model = (val_top1 > best_acc)
        best_acc = max(val_top1, best_acc)

        filepath = 'weights/epoch_' + str(epoch) + 'checkpoint.pth.tar'
        save_step(
            {
                'epoch': epoch + 1,
                'arch': 'peleenet',
                'state_dict': model.state_dict(),
                'acc': val_top1,
            },
            if_best_model,
            'test',
            filename=filepath)