num_workers = bps.size() # Build model model = conv_nets() model.cast(args.dtype) # Initialize parameters model.initialize(mx.init.MSRAPrelu(), ctx=context) # if bps.rank() == 0: model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank()))) model.hybridize() # BytePS: fetch and broadcast parameters params = model.collect_params() if params is not None: bps.broadcast_parameters(params, root_rank=0) # BytePS: create DistributedTrainer, a subclass of gluon.Trainer optimizer_params = { 'momentum': args.momentum, 'learning_rate': args.lr * num_workers } trainer = bps.DistributedTrainer(params, "sgd", optimizer_params) # Create loss function and train metric loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() metric = mx.metric.Accuracy() # Train model for epoch in range(args.epochs): tic = time.time()
def fit(args, network, data_loader, **kwargs): """ train a model args : argparse returns network : the symbol definition of the nerual network data_loader : function that returns the train and val data iterators """ # kvstore # kv = mx.kvstore.create(args.kv_store) # if args.gc_type != 'none': # kv.set_gradient_compression({'type': args.gc_type, # 'threshold': args.gc_threshold}) # logging head = '%(asctime)-15s Node[' + str(bps.rank()) + '] %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) # data iterators (train, val) = data_loader(args, (bps.rank(), bps.size(), bps.local_rank())) if args.test_io: tic = time.time() for i, batch in enumerate(train): for j in batch.data: j.wait_to_read() if (i + 1) % args.disp_batches == 0: logging.info( 'Batch [%d]\tSpeed: %.2f samples/sec', i, args.disp_batches * args.batch_size / (time.time() - tic)) tic = time.time() return # load model if 'arg_params' in kwargs and 'aux_params' in kwargs: arg_params = kwargs['arg_params'] aux_params = kwargs['aux_params'] else: sym, arg_params, aux_params = _load_model(args, bps.rank()) if sym is not None: assert sym.tojson() == network.tojson() # save model checkpoint = _save_model(args, bps.rank()) # devices for training if args.cpu_train: devs = [mx.cpu(bps.local_rank())] else: logging.info('Launch BytePS process on GPU-%d', bps.local_rank()) devs = [mx.gpu(bps.local_rank())] # learning rate lr, lr_scheduler = _get_lr_scheduler(args) # create model model = mx.mod.Module(context=devs, symbol=network) lr_scheduler = lr_scheduler optimizer_params = { 'learning_rate': lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, 'multi_precision': True } # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None # A limited number of optimizers have a warmup period has_warmup = {'lbsgd', 'lbnag'} if args.optimizer in has_warmup: if bps.size() > 1: nworkers = bps.size() else: nworkers = 1 epoch_size = args.num_examples / args.batch_size / nworkers if epoch_size < 1: epoch_size = 1 macrobatch_size = args.macrobatch_size if macrobatch_size < args.batch_size * nworkers: macrobatch_size = args.batch_size * nworkers #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) batch_scale = math.ceil( float(macrobatch_size) / args.batch_size / nworkers) optimizer_params['updates_per_epoch'] = epoch_size optimizer_params[ 'begin_epoch'] = args.load_epoch if args.load_epoch else 0 optimizer_params['batch_scale'] = batch_scale optimizer_params['warmup_strategy'] = args.warmup_strategy optimizer_params['warmup_epochs'] = args.warmup_epochs optimizer_params['num_epochs'] = args.num_epochs if args.initializer == 'default': if args.network == 'alexnet': # AlexNet will not converge using Xavier initializer = mx.init.Normal() # VGG will not trend to converge using Xavier-Gaussian elif 'vgg' in args.network: initializer = mx.init.Xavier() else: initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), elif args.initializer == 'xavier': initializer = mx.init.Xavier() elif args.initializer == 'msra': initializer = mx.init.MSRAPrelu() elif args.initializer == 'orthogonal': initializer = mx.init.Orthogonal() elif args.initializer == 'normal': initializer = mx.init.Normal() elif args.initializer == 'uniform': initializer = mx.init.Uniform() elif args.initializer == 'one': initializer = mx.init.One() elif args.initializer == 'zero': initializer = mx.init.Zero() # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: eval_metrics.append( mx.metric.create('top_k_accuracy', top_k=args.top_k)) supported_loss = ['ce', 'nll_loss'] if len(args.loss) > 0: # ce or nll loss is only applicable to softmax output loss_type_list = args.loss.split(',') if 'softmax_output' in network.list_outputs(): for loss_type in loss_type_list: loss_type = loss_type.strip() if loss_type == 'nll': loss_type = 'nll_loss' if loss_type not in supported_loss: logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \ 'negative likelihood loss is supported!') else: eval_metrics.append(mx.metric.create(loss_type)) else: logging.warning( "The output is not softmax_output, loss argument will be skipped!" ) # callbacks that run after each batch batch_end_callbacks = [ mx.callback.Speedometer(args.batch_size, args.disp_batches) ] if 'batch_end_callback' in kwargs: cbs = kwargs['batch_end_callback'] batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] # BytePS wrapper opt = mx.optimizer.create(args.optimizer, sym=network, **optimizer_params) # opt = bps.DistributedOptimizer(opt) print(str(os.environ) + "=============" + str(bps.rank())) # else: opt = bps.DistributedOptimizer(opt) # BytePS: better to explicitly init model.bind(data_shapes=train.provide_data, label_shapes=train.provide_label) if arg_params is None and aux_params is None: model.init_params(initializer) (arg_params, aux_params) = model.get_params() if arg_params is not None: bps.broadcast_parameters(arg_params, root_rank=0) if aux_params is not None: bps.broadcast_parameters(aux_params, root_rank=0) model.set_params(arg_params=arg_params, aux_params=aux_params) # run model.fit(train, begin_epoch=args.load_epoch if args.load_epoch else 0, num_epoch=args.num_epochs, eval_data=val, eval_metric=eval_metrics, kvstore=None, optimizer=opt, optimizer_params=optimizer_params, batch_end_callback=batch_end_callbacks, epoch_end_callback=checkpoint, allow_missing=True, monitor=monitor)