示例#1
0
def main():
    parser = argparse.ArgumentParser(description='Learning MobileNet')
    parser.add_argument('train', help='Path to training image-label list file')
    # parser.add_argument('val', help='Path to validation image-label list file')
    parser.add_argument('--batchsize',
                        '-B',
                        type=int,
                        default=50,
                        help='Learning minibatch size')
    parser.add_argument('--epoch',
                        '-E',
                        type=int,
                        default=10,
                        help='Number of epochs to train')
    parser.add_argument('--gpu', '-g', action='store_true', help='Use GPU')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob',
                        '-j',
                        type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean',
                        '-m',
                        default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--root',
                        '-R',
                        default='.',
                        help='Root directory path of image files')
    parser.add_argument('--val_batchsize',
                        '-b',
                        type=int,
                        default=250,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--noplot',
                        dest='plot',
                        action='store_false',
                        help='Disable PlotReport extension')
    parser.set_defaults(test=False)
    args = parser.parse_args()

    if args.gpu:
        if args.communicator == 'naive':
            print("Error: 'naive' communicator does not support GPU.\n")
            exit(-1)
        comm = chainermn.create_communicator(args.communicator)
        device = comm.intra_rank
    else:
        if args.communicator != 'naive':
            print('Warning: using naive communicator '
                  'because only naive supports CPU-only execution')
        comm = chainermn.create_communicator('naive')
        device = -1

    if comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(comm.size))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    model = MobileNet()

    model = L.Classifier(model)

    # model = L.Classifier(MobileNet())
    if args.initmodel:
        print('Load model from {}'.format(args.initmodel))
        chainer.serializers.load_npz(args.initmodel, 224)
    if device >= 0:
        chainer.backends.cuda.get_device_from_id(
            device).use()  # Make the GPU current
        model.to_gpu()

    # Create a multi node optimizer from a standard Chainer optimizer.
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9), comm)
    optimizer.setup(model)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        mean = np.load(args.mean)
        train_val = PreprocessedDataset(args.train, args.root, mean, 224)
        train, valid = split_dataset_random(train_val,
                                            int(len(train_val) * 0.9),
                                            seed=0)
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm, shuffle=True)

    train_iter = chainer.iterators.MultiprocessIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    # val_iter = chainer.iterators.MultiprocessIterator(
    #     valid, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)

    val_interval = (1 if args.test else 1000), 'iteration'
    log_interval = (1 if args.test else 1000), 'iteration'
    if comm.rank == 0:
        # print("val_interval:{}".format(val_interval))

        # trainer.extend(extensions.Evaluator(val_iter, model, device=device),
        #               trigger=val_interval)
        trainer.extend(extensions.dump_graph('main/loss'))
        trainer.extend(extensions.snapshot(), trigger=val_interval)
        trainer.extend(extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}'),
                       trigger=val_interval)
        # Be careful to pass the interval directly to LogReport
        # (it determines when to emit log rather than when to read observations)

        # Save two plot images to the result dir
        if extensions.PlotReport.available():
            print("Plot available")
            trainer.extend(
                extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                      'iteration',
                                      file_name='loss.png'))
            trainer.extend(
                extensions.PlotReport(
                    ['main/accuracy', 'validation/main/accuracy'],
                    'iteration',
                    file_name='accuracy.png'))
        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.observe_lr(), trigger=log_interval)
        trainer.extend(
            extensions.PrintReport([
                'epoch', 'iteration', 'main/loss', 'validation/main/loss',
                'main/accuracy', 'validation/main/accuracy', 'lr'
            ]))

        trainer.extend(extensions.ProgressBar())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
def main():
    parser = argparse.ArgumentParser(
        description='Learning convnet from ILSVRC2012 dataset')
    parser.add_argument('train', help='Path to training image-label list file')
    # parser.add_argument('val', help='Path to validation image-label list file')
    parser.add_argument('--batchsize', '-B', type=int, default=50,
                        help='Learning minibatch size')
    parser.add_argument('--epoch', '-E', type=int, default=10,
                        help='Number of epochs to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob', '-j', type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean', '-m', default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--resume', '-r', default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out', '-o', default='result',
                        help='Output directory')
    parser.add_argument('--root', '-R', default='.',
                        help='Root directory path of image files')
    parser.add_argument('--val_batchsize', '-b', type=int, default=250,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    parser.set_defaults(test=False)
    args = parser.parse_args()

    # Initialize the model to trains
    # model = MyRegressor(MobileNet())

    model = MobileNet()

    model = L.Classifier(model)

    # model = L.Classifier(MobileNet())
    if args.initmodel:
        print('Load model from {}'.format(args.initmodel))
        chainer.serializers.load_npz(args.initmodel, 224)
    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(
            args.gpu).use()  # Make the GPU current
        model.to_gpu()

    # Load the datasets and mean file
    mean = np.load(args.mean)
    train_val = PreprocessedDataset(args.train, args.root, mean,224)

    # val = PreprocessedDataset(args.val, args.root, mean, model.insize, False)
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.

    train, valid = split_dataset_random(train_val, int(len(train_val)*0.9), seed=0)
    # train, valid = split_dataset_random(train_val, 10, seed=0)
    # np_train = np.asarray(train)
    # print np_train[0][0].shape
    print("train:{} / valid:{}".format(len(train), len(valid)))
    # print(train[0][0])
    # print(train[0][1])
    train_iter = chainer.iterators.MultiprocessIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = chainer.iterators.MultiprocessIterator(
        valid, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up an optimizer
    # optimizer = chainer.optimizers.RMSprop(lr=0.01, alpha=0.9)

    optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    optimizer.setup(model)

    # Set up a trainer
    updater = training.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)

    val_interval = (1 if args.test else 10000), 'iteration'
    log_interval = (1 if args.test else 1000), 'iteration'

    print("val_interval:{}".format(val_interval))

    trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu),
                   trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
    # Be careful to pass the interval directly to LogReport
    # (it determines when to emit log rather than when to read observations)

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'lr'
    ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()