return top1.avg, top3.avg, top5.avg, losses.avg


if __name__ == '__main__':
    # load data
    train_set = get_data_sets(train_cfg['data_set_dir'], 'train')
    val_set = get_data_sets(train_cfg['data_set_dir'], 'val')
    # load model
    if model_cfg['if_pre_train']:
        model = load_model(model_cfg['pre_trained_path'],
                           data_classes=model_cfg['data_classes'])
    else:
        model = PeleeNet(num_classes=model_cfg['model_classes'])
    if train_cfg['cuda']:
        model = model.cuda()
    #print(model.state_dict())

    # init the summery
    writer = init_summery('runs', 'train_val')
    #global optimizer_conv

    optimizer_conv = optim.Adam(model.parameters(),
                                lr=train_cfg['init_lr'],
                                weight_decay=0.0001)
    #schedule = optim.lr_scheduler.ReduceLROnPlateau(optimizer_conv, factor=0.1, patience=2, verbose=False, threshold=0.1, cooldown=0, min_lr=train_cfg['min_lr'])
    schedule = optim.lr_scheduler.ReduceLROnPlateau(optimizer_conv,
                                                    factor=0.9,
                                                    patience=2,
                                                    threshold=0.01,
                                                    min_lr=train_cfg['min_lr'])
Beispiel #2
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print( 'args:',args)

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    # Val data loading
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(args.input_dim+32),
            transforms.CenterCrop(args.input_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    val_loader = torch.utils.data.DataLoader(val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    num_classes = len(val_dataset.classes)
    print('Total classes: ',num_classes)

    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch == 'peleenet':
        model = PeleeNet(num_classes=num_classes)          
    else:
        print("=> unsupported model '{}'. creating PeleeNet by default.".format(args.arch))
        model = PeleeNet(num_classes=num_classes)

    if args.distributed:
        model.cuda()
        # DistributedDataParallel will divide and allocate batch_size to all
        # available GPUs if device_ids are not set
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained)
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {}, acc@1 {})"
                  .format(args.pretrained, checkpoint['epoch'], checkpoint['best_acc1']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True


    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    # Training data loading
    traindir = os.path.join(args.data, 'train')

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.input_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)


    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion)

        # remember best Acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)