示例#1
0
def main():
    if args.gpu is not None:
        print('Using GPU %d' % args.gpu)
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    else:
        print('CPU mode')

    ## DataLoader initialize ILSVRC2012_train_processed
    #train_loader = DataLoader(args.data+'/ILSVRC2012_img_train',
    #args.data+'/ilsvrc12_train.txt', batchsize=args.batch,
    #classes=args.classes, n_cores = 10)
    #N = train_loader.N

    train_data = DataLoader(args.data + '/ILSVRC2012_img_train',
                            args.data + '/ilsvrc12_train.txt',
                            classes=args.classes)
    train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                               batch_size=args.batch,
                                               shuffle=True,
                                               num_workers=16)
    N = train_data.N

    iter_per_epoch = N / args.batch
    print 'Images: %d' % (N)

    # Network initialize
    net = Network(args.classes)
    if args.gpu is not None:
        net.cuda()

    if os.path.exists(args.checkpoint):
        files = [f for f in os.listdir(args.checkpoint) if 'pth' in f]
        if len(files) > 0:
            files.sort()
            #print files
            ckp = files[-1]
            net.load_state_dict(torch.load(args.checkpoint + ckp))
            args.iter_start = int(ckp.split(".")[-3].split("_")[-1])
            print 'Starting from: ', ckp
        else:
            if args.model is not None:
                net.load(args.model)
    else:
        if args.model is not None:
            net.load(args.model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    logger = Logger(args.checkpoint + '/train')
    logger_test = Logger(args.checkpoint + '/test')

    ############## TRAINING ###############
    print('Start training: lr %f, batch size %d, classes %d' %
          (args.lr, args.batch, args.classes))
    print('Checkpoint: ' + args.checkpoint)

    # Train the Model
    batch_time, net_time = [], []
    steps = args.iter_start
    for epoch in range(int(args.iter_start / iter_per_epoch), args.epochs):
        lr = adjust_learning_rate(optimizer,
                                  epoch,
                                  init_lr=args.lr,
                                  step=20,
                                  decay=0.1)

        #for i, (images, labels, _) in enumerate(train_loader):
        it = iter(train_loader)
        for i in range(int((float(N) / args.batch)) - 1):
            t = time()
            images, labels, _ = it.next()
            batch_time.append(time() - t)
            if len(batch_time) > 100:
                del batch_time[0]

            images = Variable(images)
            labels = Variable(labels)
            if args.gpu is not None:
                images = images.cuda()
                labels = labels.cuda()

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            t = time()
            outputs = net(images)
            net_time.append(time() - t)
            if len(net_time) > 100:
                del net_time[0]

            prec1, prec5 = compute_accuracy(outputs.cpu().data,
                                            labels.cpu().data,
                                            topk=(1, 5))
            acc = prec1[0]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss = float(loss.cpu().data.numpy())

            if steps % 1 == 0:
                print(
                    '[%2d/%2d] %5d) [batch load % 2.2fsec, net %1.2fsec], LR %.5f, Loss: % 1.3f, Accuracy % 2.1f%%'
                    % (epoch + 1, args.epochs, steps, np.mean(batch_time),
                       np.mean(net_time), lr, loss, acc))

            if steps % 20 == 0:
                logger.scalar_summary('accuracy', acc, steps)
                logger.scalar_summary('loss', loss, steps)
                #data = original.numpy()
                #logger.image_summary('input', data[:10], steps)

            steps += 1

            if steps % 1000 == 0:
                filename = '%s/jps_%03i_%06d.pth.tar' % (args.checkpoint,
                                                         epoch, steps)
                net.save(filename)
                print 'Saved: ' + args.checkpoint

        if os.path.exists(args.checkpoint + '/stop.txt'):
            # break without using CTRL+C
            break
def main():
    if args.gpu is not None:
        print(('Using GPU %d' % args.gpu))
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    else:
        print('CPU mode')

    print('Process number: %d' % (os.getpid()))

    ## DataLoader initialize ILSVRC2012_train_processed
    trainpath = args.data + '/ILSVRC2012_img_train'
    if os.path.exists(trainpath + '_255x255'):
        trainpath += '_255x255'
    train_data = DataLoader(trainpath,
                            args.data + '/ilsvrc12_train.txt',
                            classes=args.classes)
    train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                               batch_size=args.batch,
                                               shuffle=True,
                                               num_workers=args.cores)

    valpath = args.data + '/ILSVRC2012_img_val'
    if os.path.exists(valpath + '_255x255'):
        valpath += '_255x255'
    val_data = DataLoader(valpath,
                          args.data + '/ilsvrc12_val.txt',
                          classes=args.classes)
    val_loader = torch.utils.data.DataLoader(dataset=val_data,
                                             batch_size=args.batch,
                                             shuffle=True,
                                             num_workers=args.cores)
    N = train_data.N

    iter_per_epoch = train_data.N / args.batch
    print('Images: train %d, validation %d' % (train_data.N, val_data.N))

    # Network initialize
    net = Network(args.classes)
    if args.gpu is not None:
        net.cuda()

    ############## Load from checkpoint if exists, otherwise from model ###############
    if os.path.exists(args.checkpoint):
        files = [f for f in os.listdir(args.checkpoint) if 'pth' in f]
        if len(files) > 0:
            files.sort()
            #print files
            ckp = files[-1]
            net.load_state_dict(torch.load(args.checkpoint + '/' + ckp))
            args.iter_start = int(ckp.split(".")[-3].split("_")[-1])
            print('Starting from: ', ckp)
        else:
            if args.model is not None:
                net.load(args.model)
    else:
        if args.model is not None:
            net.load(args.model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    logger = Logger(args.checkpoint + '/train')
    logger_test = Logger(args.checkpoint + '/test')

    ############## TESTING ###############
    if args.evaluate:
        test(net, criterion, None, val_loader, 0)
        return

    ############## TRAINING ###############
    print(('Start training: lr %f, batch size %d, classes %d' %
           (args.lr, args.batch, args.classes)))
    print(('Checkpoint: ' + args.checkpoint))

    # Train the Model
    batch_time, net_time = [], []
    steps = args.iter_start
    for epoch in range(int(args.iter_start / iter_per_epoch), args.epochs):
        if epoch % 10 == 0 and epoch > 0:
            test(net, criterion, logger_test, val_loader, steps)
        lr = adjust_learning_rate(optimizer,
                                  epoch,
                                  init_lr=args.lr,
                                  step=20,
                                  decay=0.1)

        end = time()
        for i, (images, labels, original) in enumerate(train_loader):
            batch_time.append(time() - end)
            if len(batch_time) > 100:
                del batch_time[0]

            images = Variable(images)
            labels = Variable(labels)
            if args.gpu is not None:
                images = images.cuda()
                labels = labels.cuda()

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            t = time()
            outputs = net(images)
            net_time.append(time() - t)
            if len(net_time) > 100:
                del net_time[0]

            prec1, prec5 = compute_accuracy(outputs.cpu().data,
                                            labels.cpu().data,
                                            topk=(1, 5))
            acc = prec1[0]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss = float(loss.cpu().data.numpy())

            if steps % 20 == 0:
                print((
                    '[%2d/%2d] %5d) [batch load % 2.3fsec, net %1.2fsec], LR %.5f, Loss: % 1.3f, Accuracy % 2.2f%%'
                    % (epoch + 1, args.epochs, steps, np.mean(batch_time),
                       np.mean(net_time), lr, loss, acc)))

            if steps % 20 == 0:
                logger.scalar_summary('accuracy', acc, steps)
                logger.scalar_summary('loss', loss, steps)

                original = [im[0] for im in original]
                imgs = np.zeros([9, 75, 75, 3])
                for ti, img in enumerate(original):
                    img = img.numpy()
                    imgs[ti] = np.stack([(im - im.min()) /
                                         (im.max() - im.min()) for im in img],
                                        axis=2)

                logger.image_summary('input', imgs, steps)

            steps += 1

            if steps % 1000 == 0:
                filename = '%s/jps_%03i_%06d.pth.tar' % (args.checkpoint,
                                                         epoch, steps)
                net.save(filename)
                print('Saved: ' + args.checkpoint)

            end = time()

        if os.path.exists(args.checkpoint + '/stop.txt'):
            # break without using CTRL+C
            break