コード例 #1
0
ファイル: train_mnist_byteps.py プロジェクト: yuxihu/byteps
num_workers = bps.size()

# Build model
model = conv_nets()
model.cast(args.dtype)

# Initialize parameters
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()

# BytePS: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
    bps.broadcast_parameters(params, root_rank=0)

# BytePS: create DistributedTrainer, a subclass of gluon.Trainer
optimizer_params = {
    'momentum': args.momentum,
    'learning_rate': args.lr * num_workers
}
trainer = bps.DistributedTrainer(params, "sgd", optimizer_params)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
    tic = time.time()
コード例 #2
0
ファイル: fit_byteps.py プロジェクト: zprhhs/byteps
def fit(args, network, data_loader, **kwargs):
    """
    train a model
    args : argparse returns
    network : the symbol definition of the nerual network
    data_loader : function that returns the train and val data iterators
    """
    # kvstore
    # kv = mx.kvstore.create(args.kv_store)
    # if args.gc_type != 'none':
    #     kv.set_gradient_compression({'type': args.gc_type,
    #                                  'threshold': args.gc_threshold})

    # logging
    head = '%(asctime)-15s Node[' + str(bps.rank()) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    # data iterators
    (train, val) = data_loader(args,
                               (bps.rank(), bps.size(), bps.local_rank()))
    if args.test_io:
        tic = time.time()
        for i, batch in enumerate(train):
            for j in batch.data:
                j.wait_to_read()
            if (i + 1) % args.disp_batches == 0:
                logging.info(
                    'Batch [%d]\tSpeed: %.2f samples/sec', i,
                    args.disp_batches * args.batch_size / (time.time() - tic))
                tic = time.time()

        return

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = _load_model(args, bps.rank())
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = _save_model(args, bps.rank())

    # devices for training
    if args.cpu_train:
        devs = [mx.cpu(bps.local_rank())]
    else:
        logging.info('Launch BytePS process on GPU-%d', bps.local_rank())
        devs = [mx.gpu(bps.local_rank())]

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args)

    # create model
    model = mx.mod.Module(context=devs, symbol=network)

    lr_scheduler = lr_scheduler
    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True
    }

    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    monitor = mx.mon.Monitor(args.monitor,
                             pattern=".*") if args.monitor > 0 else None

    # A limited number of optimizers have a warmup period
    has_warmup = {'lbsgd', 'lbnag'}
    if args.optimizer in has_warmup:
        if bps.size() > 1:
            nworkers = bps.size()
        else:
            nworkers = 1
        epoch_size = args.num_examples / args.batch_size / nworkers
        if epoch_size < 1:
            epoch_size = 1
        macrobatch_size = args.macrobatch_size
        if macrobatch_size < args.batch_size * nworkers:
            macrobatch_size = args.batch_size * nworkers
        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
        batch_scale = math.ceil(
            float(macrobatch_size) / args.batch_size / nworkers)
        optimizer_params['updates_per_epoch'] = epoch_size
        optimizer_params[
            'begin_epoch'] = args.load_epoch if args.load_epoch else 0
        optimizer_params['batch_scale'] = batch_scale
        optimizer_params['warmup_strategy'] = args.warmup_strategy
        optimizer_params['warmup_epochs'] = args.warmup_epochs
        optimizer_params['num_epochs'] = args.num_epochs

    if args.initializer == 'default':
        if args.network == 'alexnet':
            # AlexNet will not converge using Xavier
            initializer = mx.init.Normal()
            # VGG will not trend to converge using Xavier-Gaussian
        elif 'vgg' in args.network:
            initializer = mx.init.Xavier()
        else:
            initializer = mx.init.Xavier(rnd_type='gaussian',
                                         factor_type="in",
                                         magnitude=2)
    # initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),
    elif args.initializer == 'xavier':
        initializer = mx.init.Xavier()
    elif args.initializer == 'msra':
        initializer = mx.init.MSRAPrelu()
    elif args.initializer == 'orthogonal':
        initializer = mx.init.Orthogonal()
    elif args.initializer == 'normal':
        initializer = mx.init.Normal()
    elif args.initializer == 'uniform':
        initializer = mx.init.Uniform()
    elif args.initializer == 'one':
        initializer = mx.init.One()
    elif args.initializer == 'zero':
        initializer = mx.init.Zero()

    # evaluation metrices
    eval_metrics = ['accuracy']
    if args.top_k > 0:
        eval_metrics.append(
            mx.metric.create('top_k_accuracy', top_k=args.top_k))

    supported_loss = ['ce', 'nll_loss']
    if len(args.loss) > 0:
        # ce or nll loss is only applicable to softmax output
        loss_type_list = args.loss.split(',')
        if 'softmax_output' in network.list_outputs():
            for loss_type in loss_type_list:
                loss_type = loss_type.strip()
                if loss_type == 'nll':
                    loss_type = 'nll_loss'
                if loss_type not in supported_loss:
                    logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \
                                    'negative likelihood loss is supported!')
                else:
                    eval_metrics.append(mx.metric.create(loss_type))
        else:
            logging.warning(
                "The output is not softmax_output, loss argument will be skipped!"
            )

    # callbacks that run after each batch
    batch_end_callbacks = [
        mx.callback.Speedometer(args.batch_size, args.disp_batches)
    ]
    if 'batch_end_callback' in kwargs:
        cbs = kwargs['batch_end_callback']
        batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]

    # BytePS wrapper
    opt = mx.optimizer.create(args.optimizer, sym=network, **optimizer_params)
    # opt = bps.DistributedOptimizer(opt)
    print(str(os.environ) + "=============" + str(bps.rank()))

    # else:
    opt = bps.DistributedOptimizer(opt)

    # BytePS: better to explicitly init

    model.bind(data_shapes=train.provide_data,
               label_shapes=train.provide_label)
    if arg_params is None and aux_params is None:
        model.init_params(initializer)
        (arg_params, aux_params) = model.get_params()
    if arg_params is not None:
        bps.broadcast_parameters(arg_params, root_rank=0)
    if aux_params is not None:
        bps.broadcast_parameters(aux_params, root_rank=0)
    model.set_params(arg_params=arg_params, aux_params=aux_params)

    # run
    model.fit(train,
              begin_epoch=args.load_epoch if args.load_epoch else 0,
              num_epoch=args.num_epochs,
              eval_data=val,
              eval_metric=eval_metrics,
              kvstore=None,
              optimizer=opt,
              optimizer_params=optimizer_params,
              batch_end_callback=batch_end_callbacks,
              epoch_end_callback=checkpoint,
              allow_missing=True,
              monitor=monitor)