示例#1
0
def main():
    global args, best_prec, train_loss_list, val_acc_list
    args = parser.parse_args()

    # create model
    print("=> creating model '{}'".format(args.arch))
    net = models.__dict__[args.arch]()

    in_features = net.fc.in_features
    new_fc = nn.Linear(in_features, 6)
    net.fc = new_fc

    net.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            train_loss_list = checkpoint['train_loss_list']
            val_acc_list = checkpoint['val_acc_list']
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    params = net.parameters()
    snapshot_fname = "snapshots/%s.pth.tar" % args.arch
    snapshot_best_fname = "snapshots/%s_best.pth.tar" % args.arch

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(data_utils.ImageFolderCounting(
        args.data, '../dataset/counting_train.json',
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(data_utils.ImageFolderCounting(
        args.data, '../dataset/counting_val.json',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # evaluate on validation set
    if args.evaluate == True:
        validate(val_loader, net, criterion, True)
        return

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, net, criterion, optimizer, epoch)

        # evaluate on validation set
        prec = validate(val_loader, net, criterion, False)

        # remember best prec@1 and save checkpoint
        is_best = prec > best_prec
        best_prec = max(prec, best_prec)
        filename = "snapshots/%s.pth.tar" % args.arch
        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec': best_prec,
                'train_loss_list': train_loss_list,
                'val_acc_list': val_acc_list,
            }, snapshot_fname)
        if is_best:
            shutil.copyfile(snapshot_fname, snapshot_best_fname)
示例#2
0
def main():
    global args, best_prec, train_loss_list, val_acc_list
    args = parser.parse_args()

    # create model
    print("=> creating model '{}'".format(args.arch))
    net = models.__dict__[args.arch]()

    in_features = net.fc.in_features

    print(in_features)
    new_fc = nn.Linear(in_features, 2)
    net.fc = new_fc

    params = net.parameters()
    snapshot_fname = "snapshots/%s.pth.tar" % args.arch
    snapshot_best_fname = "snapshots/%s_best.pth.tar" % args.arch

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(data_utils.ImageFolderCounting(
        args.data, './train.json',
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(data_utils.ImageFolderCounting(
        args.data, './val.json',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(params, 0.1, momentum=0.9, weight_decay=1e-4)

    if args.evaluate == True:
        validate(val_loader, net, criterion, True)
        return

    for epoch in range(0, args.epochs):
        train(train_loader, net, criterion, optimizer, epoch)
        prec = validate(val_loader, net, criterion, False)
        is_best = prec > best_prec
        best_prec = max(prec, best_prec)
        filename = "snapshots/%s.pth.tar" % args.arch
        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec': best_prec,
                'train_loss_list': train_loss_list,
                'val_acc_list': val_acc_list,
            }, snapshot_fname)
        if is_best:
            shutil.copyfile(snapshot_fname, snapshot_best_fname)