def test_horovod_broadcast_deferred_init_parameters(self): """Test that the deferred initialized parameters are broadcasted.""" hvd.init() root_rank = 0 rank = hvd.rank() # This test does not apply if there is only one worker. if hvd.size() == 1: return mx.random.seed(rank) layer = mx.gluon.nn.Conv2D(10, 2) layer.initialize() hvd.broadcast_parameters(layer.collect_params(), root_rank=root_rank) x = mx.nd.ones((5, 4, 10, 10)) layer(x) tensors = [p.data() for _, p in sorted(layer.collect_params().items())] root_tensors = [] for tensor in tensors: root_tensors.append(hvd.broadcast(tensor, root_rank=root_rank)) for tensor, root_tensor in zip(tensors, root_tensors): assert same(tensor.asnumpy(), root_tensor.asnumpy()), \ 'horovod did not broadcast deferred initialized parameter correctly'
def _init_trainer(self): if self.last_train is None: raise RuntimeError( 'Cannot init trainer without knowing the size of training data' ) if isinstance(self.last_train, pd.DataFrame): train_size = len(self.last_train) elif isinstance(self.last_train, int): train_size = self.last_train else: raise ValueError("Unknown type of self.last_train: {}".format( type(self.last_train))) if self._cfg.train.lr_decay_period > 0: lr_decay_epoch = list( range(self._cfg.train.lr_decay_period, self._cfg.train.epochs, self._cfg.train.lr_decay_period)) else: lr_decay_epoch = [int(i) for i in self._cfg.train.lr_decay_epoch] lr_decay_epoch = [ e - self._cfg.train.warmup_epochs for e in lr_decay_epoch ] num_batches = train_size // self._cfg.train.batch_size lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=0, target_lr=self._cfg.train.lr, nepochs=self._cfg.train.warmup_epochs, iters_per_epoch=num_batches), LRScheduler(self._cfg.train.lr_mode, base_lr=self._cfg.train.lr, nepochs=self._cfg.train.epochs - self._cfg.train.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=self._cfg.train.lr_decay, power=2), ]) if self._cfg.horovod: hvd.broadcast_parameters(self.net.collect_params(), root_rank=0) self.trainer = hvd.DistributedTrainer( self.net.collect_params(), 'sgd', { 'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum, 'lr_scheduler': lr_scheduler }) else: self.trainer = gluon.Trainer( self.net.collect_params(), 'sgd', { 'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum, 'lr_scheduler': lr_scheduler }, kvstore='local', update_on_kvstore=(False if self._cfg.yolo3.amp else None)) if self._cfg.yolo3.amp: amp.init_trainer(self.trainer)
def _init_trainer(self): kv_store_type = 'device' if (self._cfg.faster_rcnn.amp and 'nccl' in self._cfg.kv_store) \ else self._cfg.kv_store kv = mx.kvstore.create(kv_store_type) optimizer_params = { 'learning_rate': self._cfg.train.lr, 'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum } if self._cfg.faster_rcnn.amp: optimizer_params['multi_precision'] = True if self._cfg.horovod: hvd.broadcast_parameters(self.net.collect_params(), root_rank=0) self.trainer = hvd.DistributedTrainer( self.net.collect_train_params( ), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params) else: self.trainer = gluon.Trainer( self.net.collect_train_params( ), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if self._cfg.faster_rcnn.amp else None), kvstore=kv) if self._cfg.faster_rcnn.amp: self._cfg.init_trainer(self.trainer)
def test_two_trainer(self): """Test using horovod allreduce in MXNet Gluon trainer.""" from mxnet import gluon from mxnet.gluon import Block, nn, HybridBlock hvd.init() rank = hvd.rank() ctx = mx.cpu(rank) net1 = nn.Dense(20, in_units=10) net2 = nn.Dense(30, in_units=10) net1.initialize(ctx=ctx) net2.initialize(ctx=ctx) params1 = net1.collect_params() params2 = net2.collect_params() hvd.broadcast_parameters(params1, prefix="net1") hvd.broadcast_parameters(params2, prefix="net2") trainer1 = hvd.DistributedTrainer(params1, 'sgd', {'learning_rate': 0.1}, prefix="net1") trainer2 = hvd.DistributedTrainer(params2, 'sgd', {'learning_rate': 0.1}, prefix="net2") for i in range(10): data = mx.nd.ones((5, 10), ctx=ctx) with mx.autograd.record(): pred1 = net1(data).sum() pred2 = net2(data).sum() mx.autograd.backward([pred1, pred2]) trainer1.step(1.0) trainer2.step(1.0) l = pred1.asscalar() + pred2.asscalar()
def test_horovod_broadcast_parameters(self): """Test the correctness of broadcast_parameters.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: self.skipTest("Only one worker available") dtypes = ['int32', 'int64', 'float32', 'float64'] dims = [1, 2, 3] ctx = self._current_context() count = 0 shapes = [(), (17), (17, 17), (17, 17, 17)] root_rank = 1 tensor_dict = {} root_dict = {} for dtype, dim, in itertools.product(dtypes, dims): tensor_dict[count] = mx.nd.ones(shapes[dim], ctx=ctx) * rank root_dict[count] = mx.nd.ones(shapes[dim], ctx=ctx) * root_rank tensor_dict[count] = tensor_dict[count].astype(dtype) root_dict[count] = root_dict[count].astype(dtype) count += 1 hvd.broadcast_parameters(tensor_dict, root_rank=root_rank) for i in range(count): if not same(tensor_dict[i].asnumpy(), root_dict[i].asnumpy()): print("broadcast", i, dtypes[i], dims[i]) print("broadcast_tensor", hvd.rank(), tensor_dict[i]) print("root_tensor", hvd.rank(), root_dict[i]) print("comparison", hvd.rank(), tensor_dict[i] == root_dict[i]) assert same(tensor_dict[i].asnumpy(), root_dict[i].asnumpy()), \ 'hvd.broadcast_parameters produces incorrect broadcasted tensor'
def train(net, train_data, batch_size, ctx, logging, args): if isinstance(ctx, mx.Context): ctx = [ctx] net.initialize(mx.init.Xavier(), ctx=ctx) optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum } optimizer = 'nag' hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer(net.collect_params(), optimizer, optimizer_params) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() logging.info('Training Begins') for epoch in range(args.epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) if i % 10 == 0: name, acc = train_metric.get() logging.info('[Epoch %d Batch %d] Training: %s=%f' % (epoch, i, name, acc)) train_loss /= batch_size * num_batch if hvd.rank() == 0: elapsed = time.time() - tic speed = num_batch * batch_size * hvd.size() / elapsed logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', epoch, speed, elapsed)
def run_training(self): """ Run training benchmarks. Returns: Numpy array containing batch times (string, numpy array). """ # Create data iterator and resize it to total number of iterations (no matter what input data size is) train_data = DataIteratorFactory.get( (self.worker_batch, ) + self.model.input_shape, (self.worker_batch, ) + self.model.labels_shape, self.model.labels_range, self.args, kv_store=self.kv_store) # https://github.com/apache/incubator-mxnet/blob/master/example/distributed_training-horovod/resnet50_imagenet.py optimizer_params = { 'multi_precision': True } if self.args.dtype == 'float16' else {} if self.is_horovod: optimizer_params['rescale_grad'] = 1.0 / self.worker_batch opt = mx.optimizer.create('sgd', **optimizer_params) if self.is_horovod: opt = hvd.DistributedOptimizer(opt) mod = mx.mod.Module(symbol=self.model.output, context=self.devices[0]) mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True) mod.init_params( mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)) if self.is_horovod: arg_params, aux_params = mod.get_params() if arg_params: hvd.broadcast_parameters(arg_params, root_rank=0) if aux_params: hvd.broadcast_parameters(aux_params, root_rank=0) mod.set_params(arg_params=arg_params, aux_params=aux_params) batch_end_callback = BatchEndCallback(self.args.num_warmup_batches, self.args.num_batches) # print ("Starting benchmarks.") # TODO: In current implementation, number of epochs must always equal to 1. It is iterator responsibility to # iterate the right number of batched - warm up plus benchmark batches. mod.fit(train_data, kvstore=self.kv_store, optimizer=opt, optimizer_params=optimizer_params, eval_metric=self.model.eval_metric, batch_end_callback=[batch_end_callback], begin_epoch=0, num_epoch=1) if self.is_horovod: start_time = timeit.default_timer() mx.ndarray.waitall() logging.info( "(horovod) wait time for all ndarrays is %.5f seconds", timeit.default_timer() - start_time) return batch_end_callback.batch_times
def _init_trainer(self): if self._cfg.horovod: hvd.broadcast_parameters(self.net.collect_params(), root_rank=0) self.trainer = hvd.DistributedTrainer( self.net.collect_params(), 'sgd', {'learning_rate': self._cfg.train.lr, 'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum}) else: self.trainer = gluon.Trainer( self.net.collect_params(), 'sgd', {'learning_rate': self._cfg.train.lr, 'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum}, update_on_kvstore=(False if self._cfg.ssd.amp else None)) if self._cfg.ssd.amp: amp.init_trainer(self.trainer)
# Creating the module mod = mx.mod.Module( symbol=sym, context=context, data_names=[k[0] for k in train_iter.provide_data_single], label_names=[k[0] for k in train_iter.provide_label_single], fixed_param_names=fixed_param_names) shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single) sym_inst.infer_shape(shape_dict) arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True) hvd.broadcast_parameters(arg_params, root_rank=0) hvd.broadcast_parameters(aux_params, root_rank=0) if config.TRAIN.ONLY_PROPOSAL: sym_inst.init_weight_rpn(config, arg_params, aux_params) else: sym_inst.init_weight_rcnn(config, arg_params, aux_params) # Creating the metrics eval_metric = metric.RPNAccMetric() cls_metric = metric.RPNLogLossMetric() bbox_metric = metric.RPNL1LossMetric() rceval_metric = metric.RCNNAccMetric(config) rccls_metric = metric.RCNNLogLossMetric(config) rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
def train(net, train_data, val_data, eval_metric, ctx, args): """Training pipeline""" net.collect_params().reset_ctx(ctx) if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_params(), 'sgd', {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}) else: trainer = gluon.Trainer( net.collect_params(), 'sgd', {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}, update_on_kvstore=(False if args.amp else None)) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) mbox_loss = gcv.loss.SSDMultiBoxLoss() ce_metric = mx.metric.Loss('CrossEntropy') smoothl1_metric = mx.metric.Loss('SmoothL1') # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.epochs): while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) ce_metric.reset() smoothl1_metric.reset() tic = time.time() btic = time.time() net.hybridize(static_alloc=True, static_shape=True) for i, batch in enumerate(train_data): if args.dali: # dali iterator returns a mxnet.io.DataBatch data = [d.data[0] for d in batch] box_targets = [d.label[0] for d in batch] cls_targets = [nd.cast(d.label[1], dtype='float32') for d in batch] else: data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0) with autograd.record(): cls_preds = [] box_preds = [] for x in data: cls_pred, box_pred, _ = net(x) cls_preds.append(cls_pred) box_preds.append(box_pred) sum_loss, cls_loss, box_loss = mbox_loss( cls_preds, box_preds, cls_targets, box_targets) if args.amp: with amp.scale_loss(sum_loss, trainer) as scaled_loss: autograd.backward(scaled_loss) else: autograd.backward(sum_loss) # since we have already normalized the loss, we don't want to normalize # by batch-size anymore trainer.step(1) if (not args.horovod or hvd.rank() == 0): local_batch_size = int(args.batch_size // (hvd.size() if args.horovod else 1)) ce_metric.update(0, [l * local_batch_size for l in cls_loss]) smoothl1_metric.update(0, [l * local_batch_size for l in box_loss]) if args.log_interval and not (i + 1) % args.log_interval: name1, loss1 = ce_metric.get() name2, loss2 = smoothl1_metric.get() logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format( epoch, i, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2)) btic = time.time() if (not args.horovod or hvd.rank() == 0): name1, loss1 = ce_metric.get() name2, loss2 = smoothl1_metric.get() logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format( epoch, (time.time()-tic), name1, loss1, name2, loss2)) if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0): # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric) val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) current_map = float(mean_ap[-1]) else: current_map = 0. save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
def train(args): _, num_parts, rank, local_rank, _, ctx_l = init_comm( args.comm_backend, args.gpus) if args.comm_backend == 'horovod': logging_config( args.save_dir, name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}', console=(rank == 0)) logging.info(args) else: logging_config(args.save_dir, name='train_transformer', console=True) logging.info(args) use_amp = args.fp16 if use_amp: from mxnet import amp src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) base_tgt_tokenizer = MosesTokenizer(args.tgt_lang) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, max_src_length=args.max_src_length, max_tgt_length=args.max_tgt_length, pretokenized=not args.tokenize) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, pretokenized=not args.tokenize) tgt_detok_sentences = [] tgt_raw_sentences = [] with open(args.dev_tgt_corpus, 'r') as in_f: for line in in_f: tgt_detok_sentences.append( base_tgt_tokenizer.decode( tgt_tokenizer.decode(line.split()).split())) with open(args.dev_tgt_raw_corpus, 'r') as in_f: for line in in_f: tgt_raw_sentences.append(line.strip()) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) val_samples = [ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ] if args.comm_backend == 'horovod': slice_begin = rank * (len(val_samples) // num_parts) slice_end = min((rank + 1) * (len(val_samples) // num_parts), len(val_samples)) data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end]) else: data_val = gluon.data.SimpleDataset(val_samples) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) cfg.MODEL.layout = 'TN' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 param_dict = deduplicate_param_dict(model.collect_params()) inference_model = TransformerInference(model=model) inference_model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() # Construct the beam search sampler scorer = BeamSearchScorer(alpha=args.lp_alpha, K=args.lp_k, from_logits=False) beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size, decoder=inference_model, vocab_size=len(tgt_vocab), eos_id=tgt_vocab.eos_id, scorer=scorer, stochastic=False, max_length_a=args.max_length_a, max_length_b=args.max_length_b) logging.info(beam_search_sampler) if args.comm_backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) # Construct the trainer if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) optimizer_params = { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.997, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler, 'wd': args.wd } user_provided_ptimizer_params = json.loads(args.optimizer_params) optimizer_params.update(user_provided_ptimizer_params) if args.fp16: optimizer_params.update({'multi_precision': True}) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, shuffle=True, seed=args.seed) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError num_updates_per_epoch = int( math.ceil( len(train_batch_sampler) / (num_parts * len(ctx_l) * args.num_accumulated))) # Convert the batch sampler to multiple shards if num_parts > 1: train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank, even_size=True, seed=args.seed + 1000 * rank) logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) params = [p for p in param_dict.values() if p.grad_req != 'null'] model_averager = AverageSGDTracker(param_dict) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] # Maintain the denominator of the loss. log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 if local_rank == 0: writer = SummaryWriter( logdir=os.path.join(args.save_dir, 'tensorboard')) if use_amp: amp.init_trainer(trainer) train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l)) # when args.epochs < 0, the model will keep training if args.epochs < 0: if args.max_update > 0: total_train_iters = args.max_update if args.num_averages > 0: assert args.num_averages <= total_train_iters // args.save_iterval_update avg_start_iter = ( total_train_iters // args.save_iterval_update - args.num_averages) * args.save_iterval_update else: avg_start_iter = -1 else: total_train_iters = np.inf avg_start_iter = -1 else: total_train_iters = args.epochs * num_updates_per_epoch if args.num_averages > 0: assert args.num_averages <= args.epochs avg_start_iter = (args.epochs - args.num_average) * num_updates_per_epoch else: avg_start_iter = -1 # Here, we are manually setting up the scale to 1.0 because # in horovod, the scale can be the number of workers: # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141 # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale(). # A scale that is larger than 1.0 can be problematic in this case. trainer._scale = 1.0 if args.max_num_tokens > 0: const_scale = args.max_num_tokens else: const_scale = 100 train_start_time = time.time() for train_iter in range(total_train_iters): model.zero_grad() loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] for i in range(args.num_accumulated): loss_l = [] sample_data_l = next(train_multi_data_loader) for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)): src_token_ids, tgt_token_ids, src_valid_length,\ tgt_valid_length, sample_ids = sample_data src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) src_wc, tgt_wc, bs = src_valid_length.sum(), \ tgt_valid_length.sum(), src_token_ids.shape[0] log_wc_l[j] += src_wc + tgt_wc log_tgt_wc_l[j] += tgt_wc token_count = (tgt_valid_length - 1).sum() loss_denom_l[j] += token_count / const_scale log_avg_loss_denom_l[j] += token_count / const_scale with mx.autograd.record(): if model.layout == 'NT': tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss = loss.sum() / const_scale loss_l.append(loss) elif model.layout == 'TN': tgt_pred = model(src_token_ids.T, src_valid_length, tgt_token_ids.T[:-1, :], tgt_valid_length - 1) tgt_labels = tgt_token_ids.T[1:, :] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=0) loss = loss.sum() / const_scale loss_l.append(loss) log_avg_loss_l[j] += loss if use_amp: with mx.autograd.record(): with amp.scale_loss(loss_l, trainer) as amp_loss_l: for loss in amp_loss_l: loss.backward() else: with mx.autograd.record(): for loss in loss_l: loss.backward() # Print the total number of parameters if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(param_dict) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'.format( num_params, num_fixed_params)) # All-Reduce the gradient trainer.allreduce_grads() if args.comm_backend == 'horovod': # All-Reduce the loss denominator assert len(loss_denom_l) == 1 loss_denom = hvd.allreduce(loss_denom_l[0], average=False).asnumpy() else: loss_denom = sum([ele.asnumpy() for ele in loss_denom_l]) if use_amp: # We need to first unscale the gradient and then perform allreduce. grad_scale = trainer.amp_loss_scale * loss_denom else: grad_scale = loss_denom if args.max_grad_norm is not None: total_norm, ratio, is_finite\ = clip_grad_global_norm(params, args.max_grad_norm * grad_scale) total_norm = total_norm / grad_scale else: total_norm = grad_global_norm(params) total_norm = total_norm / grad_scale log_avg_grad_norm += total_norm log_iter_num += 1 trainer.update(loss_denom, ignore_stale_grad=True) if avg_start_iter > 0 and train_iter >= avg_start_iter: model_averager.step() if ((train_iter + 1) % args.log_interval == 0 or train_iter + 1 == total_train_iters): if args.comm_backend == 'horovod': # Use allreduce to get the total number of tokens and loss log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy() log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0], average=False).asnumpy() log_avg_loss = hvd.allreduce(log_avg_loss_l[0] / log_avg_loss_denom_l[0], average=True) log_avg_loss = log_avg_loss.asnumpy() else: log_wc = sum([ele.asnumpy() for ele in log_wc_l]) log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l]) log_avg_loss =\ sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy() for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l) log_avg_grad_norm = log_avg_grad_norm / log_iter_num log_end_time = time.time() wps = log_wc / (log_end_time - log_start_time) epoch_id = train_iter // num_updates_per_epoch logging.info( '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,' ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format( epoch_id, train_iter % num_updates_per_epoch + 1, num_updates_per_epoch, train_iter + 1, total_train_iters, log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate, log_avg_grad_norm, (log_end_time - train_start_time) / (train_iter + 1) * (total_train_iters - train_iter - 1) / 3600)) if local_rank == 0: writer.add_scalar('throughput_wps', wps, train_iter) writer.add_scalar('train_loss', log_avg_loss, train_iter) writer.add_scalar('lr', trainer.learning_rate, train_iter) writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter) # Reinitialize the log variables log_start_time = time.time() log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 log_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] log_tgt_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \ or ((train_iter + 1) % num_updates_per_epoch == 0) \ or train_iter + 1 == total_train_iters: epoch_id = (train_iter + 1) // num_updates_per_epoch if local_rank == 0: if args.max_update <= 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{}.params'.format(epoch_id)), deduplicate=True) else: model.save_parameters(os.path.join( args.save_dir, 'iter{}.params'.format(train_iter + 1)), deduplicate=True) avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\ = validation(model, val_data_loader, inference_model, beam_search_sampler, tgt_tokenizer, ctx_l) if args.comm_backend == 'horovod': flatten_pred_sentences = np.concatenate(pred_sentences, axis=0) all_val_loss = hvd.allgather( mx.np.array([avg_val_loss * ntokens], dtype=np.float32, ctx=ctx_l[0])) all_ntokens = hvd.allgather( mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0])) flatten_pred_sentences = hvd.allgather( mx.np.array(flatten_pred_sentences, dtype=np.int32, ctx=ctx_l[0])) pred_lengths = hvd.allgather( mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0])) sentence_ids = hvd.allgather( mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0])) avg_val_loss = all_val_loss.asnumpy().sum( ) / all_ntokens.asnumpy().sum() flatten_pred_sentences = flatten_pred_sentences.asnumpy() pred_lengths = pred_lengths.asnumpy() sentence_ids = sentence_ids.asnumpy() pred_sentences = [None for _ in range(len(sentence_ids))] ptr = 0 assert sentence_ids.min() == 0 and sentence_ids.max( ) == len(sentence_ids) - 1 for sentence_id, length in zip(sentence_ids, pred_lengths): pred_sentences[sentence_id] = flatten_pred_sentences[ptr:( ptr + length)] ptr += length if local_rank == 0: # Perform detokenization pred_sentences_bpe_decode = [] pred_sentences_raw = [] for sentence in pred_sentences: bpe_decode_sentence = tgt_tokenizer.decode( sentence.tolist()) raw_sentence = base_tgt_tokenizer.decode( bpe_decode_sentence.split()) pred_sentences_bpe_decode.append(bpe_decode_sentence) pred_sentences_raw.append(raw_sentence) detok_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_bpe_decode, ref_streams=[tgt_detok_sentences]) raw_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_raw, ref_streams=[tgt_raw_sentences]) with open( os.path.join(args.save_dir, f'epoch{epoch_id}_dev_prediction.txt'), 'w') as of: for line in pred_sentences_raw: of.write(line + '\n') logging.info( '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, ' 'SacreBlEU={}, Detok SacreBLUE={}'.format( epoch_id, train_iter, total_train_iters, avg_val_loss, np.exp(avg_val_loss), raw_sacrebleu_out.score, detok_sacrebleu_out.score)) writer.add_scalar('valid_loss', avg_val_loss, train_iter) writer.add_scalar('valid_bleu', raw_sacrebleu_out.score, train_iter) if args.num_averages > 0: model_averager.copy_back( param_dict) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)
model.hybridize() # Create optimizer optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * hvd.size()} opt = mx.optimizer.create('sgd', **optimizer_params) # Initialize parameters initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) model.initialize(initializer, ctx=context) # Horovod: fetch and broadcast parameters params = model.collect_params() if params is not None: hvd.broadcast_parameters(params, root_rank=0) # Horovod: create DistributedTrainer, a subclass of gluon.Trainer trainer = hvd.DistributedTrainer(params, opt, gradient_predivide_factor=args.gradient_predivide_factor) # Create loss function and train metric loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() metric = mx.metric.Accuracy() # Train model for epoch in range(args.epochs): tic = time.time() train_data.reset() metric.reset() for nbatch, batch in enumerate(train_data, start=1):
def fit(args, model, data_loader): """ train a model args : argparse returns model : the the neural network model data_loader : function that returns the train and val data iterators """ start_time = time.time() # select gpu for horovod process if 'horovod' in args.kv_store: args.gpus = [args.gpus[hvd.local_rank()]] if args.amp: amp.init() if args.seed is not None: logging.info('Setting seeds to {}'.format(args.seed)) random.seed(args.seed) np.random.seed(args.seed) mx.random.seed(args.seed) # kvstore if 'horovod' in args.kv_store: kv = None rank = hvd.rank() num_workers = hvd.size() else: kv = mx.kvstore.create(args.kv_store) rank = kv.rank num_workers = kv.num_workers if args.test_io: train, val = data_loader(args, kv) if args.test_io_mode == 'train': data_iter = train else: data_iter = val tic = time.time() for i, batch in enumerate(data_iter): if isinstance(batch, list): for b in batch: for j in b.data: j.wait_to_read() else: for j in batch.data: j.wait_to_read() if (i + 1) % args.disp_batches == 0: logging.info('Batch [{}]\tSpeed: {:.2f} samples/sec'.format( i, args.disp_batches * args.batch_size / (time.time() - tic))) tic = time.time() return if not load_model(args, model): # all initializers should be specified in the model definition. # if not, this will raise an error model.initialize(mx.init.Initializer()) # devices for training devs = list(map(mx.gpu, args.gpus)) model.collect_params().reset_ctx(devs) if args.mode == 'pred': logging.info('Infering image {}'.format(args.data_pred)) model_pred(args, model, data.load_image(args, args.data_pred, devs[0])) return # learning rate lr_scheduler = get_lr_scheduler(args) optimizer_params = { 'learning_rate': 0, 'wd': args.wd, 'multi_precision': True, } # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom # evaluation metrices if not args.no_metrics: eval_metrics = ['accuracy'] eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=5)) else: eval_metrics = [] train, val = data_loader(args, kv) train = BenchmarkingDataIter(train, args.benchmark_iters) if val is not None: val = BenchmarkingDataIter(val, args.benchmark_iters) if 'horovod' in args.kv_store: # Fetch and broadcast parameters params = model.collect_params() if params is not None: hvd.broadcast_parameters(params, root_rank=0) global_metrics = CompositeMeter() if args.mode in ['train_val', 'train']: global_metrics.register_metric('train.loss', MinMeter()) global_metrics.register_metric('train.ips', AvgMeter()) if args.mode in ['train_val', 'val']: global_metrics.register_metric('val.accuracy', MaxMeter()) global_metrics.register_metric('val.top_k_accuracy_5', MaxMeter()) global_metrics.register_metric('val.ips', AvgMeter()) global_metrics.register_metric('val.latency_avg', AvgMeter()) if args.mode in ['val']: global_metrics.register_metric('val.latency_50', PercentileMeter(50)) global_metrics.register_metric('val.latency_90', PercentileMeter(90)) global_metrics.register_metric('val.latency_95', PercentileMeter(95)) global_metrics.register_metric('val.latency_99', PercentileMeter(99)) global_metrics.register_metric('val.latency_100', PercentileMeter(100)) # run if args.mode in ['train_val', 'train']: model_fit( args, model, train, begin_epoch=args.begin_epoch, num_epoch=args.num_epochs, run_epoch=args.run_epochs, eval_data=val, eval_metric=eval_metrics, global_metrics=global_metrics, kvstore=args.kv_store, kv=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, model_prefix=os.path.join(args.workspace, args.model_prefix), ) elif args.mode == 'val': for epoch in range(args.num_epochs): # loop for benchmarking score, duration_stats, durations = model_score( args, model, val, eval_metrics, args.kv_store) dllogger_data = dict( starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score))) dllogger_data.update( starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())) global_metrics.update_dict(dllogger_data) for percentile in [50, 90, 95, 99, 100]: metric_name = 'val.latency_{}'.format(percentile) dllogger_data[metric_name] = np.percentile( durations, percentile) global_metrics.update_metric(metric_name, durations) dllogger.log(step=(epoch, ), data=dllogger_data) else: raise ValueError('Wrong mode') mx.nd.waitall() dllogger.log(tuple(), data=global_metrics.get())
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True if args.optimizer == 'lamb': optim_params['bias_correction'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1} else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) elif backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0]) running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') logging.info('Generating the first batch of data, which may take a few minutes ...') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) if backend == 'byteps': bps.byteps_declare_tensor("local_num_masks") bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) logging.debug('Broadcast local_num_masks tensor') next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq))) data_list = list(split_and_load(next_batch, ctxs)) parallel.put(data_list[0]) parallel.get() trainer._init_params() while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) if early_stop and step_num == 10: mx.nd.waitall() exit() # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] with mx.autograd.record(): num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_num_masks += num_masks local_mlm_loss += ls1 running_num_tks += valid_length.sum() # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: running_mlm_loss += local_mlm_loss / local_num_masks if backend == 'horovod': hvd.allreduce_(local_num_masks, average=False, name='local_num_masks') elif backend == 'byteps': bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) # because byteps implicitly set scale /= num_workers fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks, num_ctxs=len(ctxs) * num_workers) local_num_masks, local_mlm_loss = 0, 0 # update metrics if args.no_compute_acc: for mask_pred_i in mask_pred_list: mask_pred_i.wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss, 0, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model.bert, args.ckpt_dir) if (step_num + 1) % args.eval_interval == 0 and data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers) batch_num += 1 # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train_gluon(): def evaluate(epoch): if not args.use_rec: return val_data.reset() acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) for _, batch in enumerate(val_data): data, label = get_data_label(batch, context) output = net(data.astype(args.dtype, copy=False)) acc_top1.update([label], [output]) acc_top5.update([label], [output]) top1_name, top1_acc = acc_top1.get() top5_name, top5_acc = acc_top5.get() logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f', epoch, rank, top1_name, top1_acc, top5_name, top5_acc) # Hybridize and initialize model net.hybridize() net.initialize(initializer, ctx=context) # Horovod: fetch and broadcast parameters params = net.collect_params() if params is not None: hvd.broadcast_parameters(params, root_rank=0) # Create optimizer optimizer_params = { 'wd': args.wd, 'momentum': args.momentum, 'lr_scheduler': lr_sched } if args.dtype == 'float16': optimizer_params['multi_precision'] = True opt = mx.optimizer.create('sgd', **optimizer_params) # Horovod: create DistributedTrainer, a subclass of gluon.Trainer trainer = hvd.DistributedTrainer( params, opt, gradient_predivide_factor=args.gradient_predivide_factor) # Create loss function and train metric loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() metric = mx.metric.Accuracy() # Train model for epoch in range(args.num_epochs): tic = time.time() if args.use_rec: train_data.reset() metric.reset() btic = time.time() for nbatch, batch in enumerate(train_data, start=1): data, label = get_data_label(batch, context) with autograd.record(): output = net(data.astype(args.dtype, copy=False)) loss = loss_fn(output, label) loss.backward() trainer.step(batch_size) metric.update([label], [output]) if args.log_interval and nbatch % args.log_interval == 0: name, acc = metric.get() logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f', epoch, rank, nbatch, name, acc, trainer.learning_rate) if rank == 0: batch_speed = num_workers * batch_size * args.log_interval / ( time.time() - btic) logging.info( 'Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec', epoch, nbatch, batch_speed) btic = time.time() # Report metrics elapsed = time.time() - tic _, acc = metric.get() logging.info( 'Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f', epoch, rank, nbatch, elapsed, acc) if rank == 0: epoch_speed = num_workers * batch_size * nbatch / elapsed logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed) # Evaluate performance if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0: evaluate(epoch) # Save model if args.save_frequency and (epoch + 1) % args.save_frequency == 0: net.export('%s-%d' % (args.model, rank), epoch=epoch) # Evaluate performance at the end of training evaluate(epoch)
'rescale_grad': 1.0 / args.batch_size} opt = mx.optimizer.create('sgd', **optimizer_params) # Horovod: wrap optimizer with DistributedOptimizer opt = hvd.DistributedOptimizer(opt) initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) model.init_params(initializer) # Horovod: fetch and broadcast parameters (arg_params, aux_params) = model.get_params() if arg_params is not None: hvd.broadcast_parameters(arg_params, root_rank=0) if aux_params is not None: hvd.broadcast_parameters(aux_params, root_rank=0) model.set_params(arg_params=arg_params, aux_params=aux_params) model.fit(train_iter, # train data kvstore=None, # no kvstore eval_data=val_iter, # validation data optimizer=opt, # use SGD to train eval_metric='acc', # report accuracy during training batch_end_callback=mx.callback.Speedometer(args.batch_size), num_epoch=args.epochs) # train for at most 10 dataset passes # Step 5: evaluate model accuracy acc = mx.metric.Accuracy() model.score(val_iter, acc)
def train(ctx): if isinstance(ctx, mx.Context): ctx = [ctx] if opt.resume_params == '': net.initialize(mx.init.MSRAPrelu(), ctx=ctx) if opt.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 hvd.broadcast_parameters(net.collect_params(), root_rank=0) # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) # trainer = hvd.DistributedTrainer( # net.collect_params(), # optimizer, # optimizer_params) if opt.trainer == 'sgd': trainer = SGDTrainer( net.collect_params(), optimizer, optimizer_params) elif opt.trainer == 'efsgd': trainer = EFSGDTrainerV1( net.collect_params(), 'EFSGDV1', optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1) elif opt.trainer == 'qsparselocalsgd': trainer = QSparseLocalSGDTrainerV1( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1, local_sgd_interval=opt.local_sgd_interval) elif opt.trainer == 'ersgd': trainer = ERSGDTrainerV2( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1) elif opt.trainer == 'partiallocalsgd': trainer = PartialLocalSGDTrainerV1( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1, local_sgd_interval=opt.local_sgd_interval) elif opt.trainer == 'ersgd2': trainer = ERSGD2TrainerV2( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio_1=1./opt.input_sparse_1, output_sparse_ratio_1=1./opt.output_sparse_1, layer_sparse_ratio_1=1./opt.layer_sparse_1, input_sparse_ratio_2=1./opt.input_sparse_2, output_sparse_ratio_2=1./opt.output_sparse_2, layer_sparse_ratio_2=1./opt.layer_sparse_2, local_sgd_interval=opt.local_sgd_interval) else: trainer = SGDTrainer( net.collect_params(), optimizer, optimizer_params) if opt.resume_states != '': trainer.load_states(opt.resume_states) if opt.label_smoothing or opt.mixup: sparse_label_loss = False else: sparse_label_loss = True if distillation: L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature, hard_weight=opt.hard_weight, sparse_label=sparse_label_loss) else: L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss) best_val_score = 1 for epoch in range(opt.resume_epoch, opt.num_epochs): tic = time.time() if opt.use_rec: train_data.reset() # train_metric.reset() train_loss = 0 btic = time.time() # test speed if opt.test_speed > 0: n_repeats = opt.test_speed elif opt.test_speed == 0: n_repeats = 1 else: n_repeats = 0 for i, batch in enumerate(train_data): # test speed if n_repeats == 0 and not (i+1)%opt.log_interval: print('[Epoch %d] # batch: %d'%(epoch, i)) continue data, label = batch_fn(batch, ctx) for j in range(n_repeats): if opt.mixup: lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) if epoch >= opt.num_epochs - opt.mixup_off_epoch: lam = 1 data = [lam*X + (1-lam)*X[::-1] for X in data] if opt.label_smoothing: eta = 0.1 else: eta = 0.0 label = mixup_transform(label, classes, lam, eta) elif opt.label_smoothing: hard_label = label label = smooth(label, classes) if distillation: teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \ for X in data] with ag.record(): outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] if distillation: loss = [L(yhat.astype('float32', copy=False), y.astype('float32', copy=False), p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)] else: loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] for l in loss: l.backward() trainer.step(batch_size) # if opt.mixup: # output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \ # for out in outputs] # train_metric.update(label, output_softmax) # else: # if opt.label_smoothing: # train_metric.update(hard_label, outputs) # else: # train_metric.update(label, outputs) step_loss = sum([l.sum().asscalar() for l in loss]) train_loss += step_loss if opt.log_interval and not (i+j+1)%opt.log_interval: # train_metric_name, train_metric_score = train_metric.get() if hvd.rank() == 0: # logger.info('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( # epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), # train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6)) # print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( # epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), # train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6)) print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), 'loss', step_loss/batch_size, trainer.learning_rate, trainer._comm_counter/1e6)) btic = time.time() mx.nd.waitall() toc = time.time() if n_repeats == 0: allreduce_array_nd = mx.nd.array([i]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) mx.nd.waitall() print('[Epoch %d] # total batch: %d'%(epoch, i)) continue train_metric_name, train_metric_score = train_metric.get() throughput = int(batch_size * i /(toc - tic) * hvd.size()) train_loss /= (batch_size * i) if opt.trainer == 'ersgd' or opt.trainer == 'qsparselocalsgd' or opt.trainer == 'ersgd2' or opt.trainer == 'partiallocalsgd': allreduce_for_val = True else: allreduce_for_val = False if allreduce_for_val: trainer.pre_test() # err_train_tic = time.time() # err_top1_train, err_top5_train = test(ctx, train_data, val=False) err_val_tic = time.time() err_top1_val, err_top5_val = test(ctx, val_data, val=True) err_val_toc = time.time() if allreduce_for_val: trainer.post_test() mx.nd.waitall() # allreduce the results allreduce_array_nd = mx.nd.array([train_loss, err_top1_val, err_top5_val]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) allreduce_array_np = allreduce_array_nd.asnumpy() train_loss = np.asscalar(allreduce_array_np[0]) err_top1_val = np.asscalar(allreduce_array_np[1]) err_top5_val = np.asscalar(allreduce_array_np[2]) if hvd.rank() == 0: # logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score)) logger.info('[Epoch %d] training: loss=%f'%(epoch, train_loss)) logger.info('[Epoch %d] speed: %d samples/sec training-time: %f comm: %f'%(epoch, throughput, toc-tic, trainer._comm_counter/1e6)) logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-time=%f'%(epoch, err_top1_val, err_top5_val, err_val_toc - err_val_tic)) trainer._comm_counter = 0 if err_top1_val < best_val_score: best_val_score = err_top1_val # if hvd.local_rank() == 0: # net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) # trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: if hvd.local_rank() == 0: net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
def train(net, train_data, val_data, eval_metric, ctx, args): """Training pipeline""" net.collect_params().reset_ctx(ctx) if args.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if args.label_smooth: net._target_generator._label_smooth = True if args.lr_decay_period > 0: lr_decay_epoch = list( range(args.lr_decay_period, args.epochs, args.lr_decay_period)) else: lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')] lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch] num_batches = args.num_samples // args.batch_size lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=0, target_lr=args.lr, nepochs=args.warmup_epochs, iters_per_epoch=num_batches), LRScheduler(args.lr_mode, base_lr=args.lr, nepochs=args.epochs - args.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=args.lr_decay, power=2), ]) if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', { 'wd': args.wd, 'momentum': args.momentum, 'lr_scheduler': lr_scheduler }) else: trainer = gluon.Trainer( net.collect_params(), 'sgd', { 'wd': args.wd, 'momentum': args.momentum, 'lr_scheduler': lr_scheduler }, kvstore='local', update_on_kvstore=(False if args.amp else None)) if args.amp: amp.init_trainer(trainer) # targets sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) l1_loss = gluon.loss.L1Loss() # metrics obj_metrics = mx.metric.Loss('ObjLoss') center_metrics = mx.metric.Loss('BoxCenterLoss') scale_metrics = mx.metric.Loss('BoxScaleLoss') cls_metrics = mx.metric.Loss('ClassLoss') # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.epochs): if args.mixup: # TODO(zhreshold): more elegant way to control mixup during runtime try: train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5) except AttributeError: train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5) if epoch >= args.epochs - args.no_mixup_epochs: try: train_data._dataset.set_mixup(None) except AttributeError: train_data._dataset._data.set_mixup(None) tic = time.time() btic = time.time() mx.nd.waitall() net.hybridize() for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) # objectness, center_targets, scale_targets, weights, class_targets fixed_targets = [ gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6) ] gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0) sum_losses = [] obj_losses = [] center_losses = [] scale_losses = [] cls_losses = [] with autograd.record(): for ix, x in enumerate(data): obj_loss, center_loss, scale_loss, cls_loss = net( x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets]) sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss) obj_losses.append(obj_loss) center_losses.append(center_loss) scale_losses.append(scale_loss) cls_losses.append(cls_loss) if args.amp: with amp.scale_loss(sum_losses, trainer) as scaled_loss: autograd.backward(scaled_loss) else: autograd.backward(sum_losses) trainer.step(batch_size) if (not args.horovod or hvd.rank() == 0): obj_metrics.update(0, obj_losses) center_metrics.update(0, center_losses) scale_metrics.update(0, scale_losses) cls_metrics.update(0, cls_losses) if args.log_interval and not (i + 1) % args.log_interval: name1, loss1 = obj_metrics.get() name2, loss2 = center_metrics.get() name3, loss3 = scale_metrics.get() name4, loss4 = cls_metrics.get() logger.info( '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}' .format(epoch, i, trainer.learning_rate, args.batch_size / (time.time() - btic), name1, loss1, name2, loss2, name3, loss3, name4, loss4)) btic = time.time() if (not args.horovod or hvd.rank() == 0): name1, loss1 = obj_metrics.get() name2, loss2 = center_metrics.get() name3, loss3 = scale_metrics.get() name4, loss4 = cls_metrics.get() logger.info( '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}' .format(epoch, (time.time() - tic), name1, loss1, name2, loss2, name3, loss3, name4, loss4)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric) val_msg = '\n'.join( ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) logger.info('[Epoch {}] Validation: \n{}'.format( epoch, val_msg)) current_map = float(mean_ap[-1]) else: current_map = 0. save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \ rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): """Training pipeline""" kv = mx.kvstore.create(args.kv_store) net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') <<<<<<< HEAD ======= optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} >>>>>>> origin/master if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params) else: trainer = gluon.Trainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if args.amp else None), kvstore=kv) if args.amp: amp.init_trainer(trainer)
def train(args): store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) data_val = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ]) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) if args.fp16: raise NotImplementedError # cfg.MODEL.dtype = 'float16' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() rescale_loss = 100.0 if args.comm_backend == 'horovod': hvd.broadcast_parameters(model.collect_params(), root_rank=0) # Construct the trainer # TODO(sxjscience) Support AMP if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) trainer_settings = (model.collect_params(), 'adam', { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.98, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(*trainer_settings) else: trainer = gluon.Trainer(*trainer_settings) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, seed=args.seed, num_parts=num_parts, part_index=rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError if local_rank == 0: logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' model.zero_grad() model_averager = AverageSGDTracker(model.collect_params()) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class accum_count = 0 loss_denom = 0 n_train_iters = 0 log_wc = 0 log_avg_loss = 0.0 log_loss_denom = 0 epoch_id = 0 while (args.epochs < 0 or epoch_id < args.epochs ): # when args.epochs < 0, the model will keep training n_epoch_train_iters = 0 processed_batch_num = 0 train_multi_data_loader = grouper(train_data_loader, len(ctx_l)) is_last_batch = False sample_data_l = next(train_multi_data_loader) while not is_last_batch: processed_batch_num += len(sample_data_l) loss_l = [] for sample_data, ctx in zip(sample_data_l, ctx_l): if sample_data is None: continue src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data src_wc, tgt_wc, bs = src_valid_length.sum( ), tgt_valid_length.sum(), src_token_ids.shape[0] loss_denom += tgt_wc - bs log_loss_denom += tgt_wc - bs log_wc += src_wc + tgt_wc src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) with mx.autograd.record(): tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss_l.append(loss.sum() / rescale_loss) for l in loss_l: l.backward() accum_count += 1 try: sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters( model.collect_params()) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'. format(num_params, num_fixed_params)) sum_loss = sum([l.as_in_ctx(mx.cpu()) for l in loss_l]) * rescale_loss log_avg_loss += sum_loss mx.npx.waitall() if accum_count == args.num_accumulated or is_last_batch: # Update the parameters n_train_iters += 1 n_epoch_train_iters += 1 trainer.step(loss_denom.asnumpy() / rescale_loss) accum_count = 0 loss_denom = 0 model.zero_grad() if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() if local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format( epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join( args.save_dir, 'update{:d}.params'.format( n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format( epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break epoch_id += 1 if args.num_averages > 0: model_averager.copy_back( model.collect_params()) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)
def train(net, train_data, val_data, eval_metric, ctx, args): import gluoncv as gcv gcv.utils.check_version("0.6.0") from gluoncv import data as gdata from gluoncv import utils as gutils from gluoncv.data.batchify import Pad, Stack, Tuple from gluoncv.data.dataloader import RandomTransformDataLoader from gluoncv.data.transforms.presets.yolo import ( YOLO3DefaultTrainTransform, YOLO3DefaultValTransform, ) from gluoncv.model_zoo import get_model from gluoncv.utils import LRScheduler, LRSequential from gluoncv.utils.metrics.coco_detection import COCODetectionMetric from gluoncv.utils.metrics.voc_detection import VOC07MApMetric """Training pipeline""" net.collect_params().reset_ctx(ctx) if args.no_wd: for k, v in net.collect_params(".*beta|.*gamma|.*bias").items(): v.wd_mult = 0.0 if args.label_smooth: net._target_generator._label_smooth = True if args.lr_decay_period > 0: lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period)) else: lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(",")] lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch] num_batches = args.num_samples // args.batch_size lr_scheduler = LRSequential( [ LRScheduler( "linear", base_lr=0, target_lr=args.lr, nepochs=args.warmup_epochs, iters_per_epoch=num_batches, ), LRScheduler( args.lr_mode, base_lr=args.lr, nepochs=args.epochs - args.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=args.lr_decay, power=2, ), ] ) if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_params(), "sgd", {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler}, ) else: trainer = gluon.Trainer( net.collect_params(), "sgd", {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler}, kvstore="local", update_on_kvstore=(False if args.amp else None), ) if args.amp: amp.init_trainer(trainer) # targets sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) l1_loss = gluon.loss.L1Loss() # metrics obj_metrics = mx.metric.Loss("ObjLoss") center_metrics = mx.metric.Loss("BoxCenterLoss") scale_metrics = mx.metric.Loss("BoxScaleLoss") cls_metrics = mx.metric.Loss("ClassLoss") # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + "_train.log" log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) logger.info("Start training from [Epoch {}]".format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.num_epochs): if args.mixup: # TODO(zhreshold): more elegant way to control mixup during runtime try: train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5) except AttributeError: train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5) if epoch >= args.num_epochs - args.no_mixup_epochs: try: train_data._dataset.set_mixup(None) except AttributeError: train_data._dataset._data.set_mixup(None) tic = time.time() btic = time.time() mx.nd.waitall() net.hybridize() for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) # objectness, center_targets, scale_targets, weights, class_targets fixed_targets = [ gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6) ] gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0) sum_losses = [] obj_losses = [] center_losses = [] scale_losses = [] cls_losses = [] with autograd.record(): for ix, x in enumerate(data): obj_loss, center_loss, scale_loss, cls_loss = net( x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets] ) sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss) obj_losses.append(obj_loss) center_losses.append(center_loss) scale_losses.append(scale_loss) cls_losses.append(cls_loss) if args.amp: with amp.scale_loss(sum_losses, trainer) as scaled_loss: autograd.backward(scaled_loss) else: autograd.backward(sum_losses) trainer.step(batch_size) if not args.horovod or hvd.rank() == 0: obj_metrics.update(0, obj_losses) center_metrics.update(0, center_losses) scale_metrics.update(0, scale_losses) cls_metrics.update(0, cls_losses) if args.log_interval and not (i + 1) % args.log_interval: name1, loss1 = obj_metrics.get() name2, loss2 = center_metrics.get() name3, loss3 = scale_metrics.get() name4, loss4 = cls_metrics.get() logger.info( "[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format( epoch, i, trainer.learning_rate, args.batch_size / (time.time() - btic), name1, loss1, name2, loss2, name3, loss3, name4, loss4, ) ) btic = time.time() if not args.horovod or hvd.rank() == 0: name1, loss1 = obj_metrics.get() name2, loss2 = center_metrics.get() name3, loss3 = scale_metrics.get() name4, loss4 = cls_metrics.get() logger.info( "[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format( epoch, (time.time() - tic), name1, loss1, name2, loss2, name3, loss3, name4, loss4, ) ) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric) val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)]) logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg)) current_map = float(mean_ap[-1]) else: current_map = 0.0 save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix) # save model net.set_nms(nms_thresh=0.45, nms_topk=400, post_nms=100) net(mx.nd.ones((1, 3, args.data_shape, args.data_shape), ctx=ctx[0])) net.export("%s/model" % os.environ["SM_MODEL_DIR"])
def train(): # Get model from GluonCV model zoo # https://gluon-cv.mxnet.io/model_zoo/index.html net = get_model(args.model, **kwargs) net.cast(args.dtype) # Create input symbol data = mx.sym.var('data') if args.dtype == 'float16': data = mx.sym.Cast(data=data, dtype=np.float16) net.cast(np.float16) # Create output symbol out = net(data) if args.dtype == 'float16': out = mx.sym.Cast(data=out, dtype=np.float32) softmax = mx.sym.SoftmaxOutput(out, name='softmax') if args.use_pretrained: arg_params = {} for x in net.collect_params().values(): x.reset_ctx(mx.cpu()) arg_params[x.name] = x.data() else: arg_params = None aux_params = None # Create model mod = mx.mod.Module(softmax, context=context) # Create optimizer optimizer_params = { 'wd': args.wd, 'momentum': args.momentum, 'rescale_grad': 1.0 / batch_size, 'lr_scheduler': lr_sched } if args.dtype == 'float16': optimizer_params['multi_precision'] = True opt = mx.optimizer.create('sgd', sym=out, **optimizer_params) # Horovod: wrap optimizer with DistributedOptimizer opt = hvd.DistributedOptimizer(opt) # Create initializer and initializer parameters initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params) # Horovod: fetch and broadcast parameters (arg_params, aux_params) = mod.get_params() if arg_params is not None: hvd.broadcast_parameters(arg_params, root_rank=0) if aux_params is not None: hvd.broadcast_parameters(aux_params, root_rank=0) mod.set_params(arg_params=arg_params, aux_params=aux_params) # Setup validation data and callback during training eval_data = None if args.eval_epoch: eval_data = val_data batch_callback = None if args.log_interval > 0: batch_callback = mx.callback.Speedometer(batch_size, max(1, args.log_interval)) epoch_callback = None if args.save_frequency > 0: epoch_callback = mx.callback.do_checkpoint('%s-%d' % (args.model, rank), period=args.save_frequency) # Train model mod.fit(train_data, eval_data=eval_data, num_epoch=args.num_epochs, kvstore=None, batch_end_callback=batch_callback, epoch_end_callback=epoch_callback, optimizer=opt, optimizer_params=optimizer_params) # Evaluate performance if not using synthetic data if args.use_rec: acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) res = mod.score(val_data, [acc_top1, acc_top5]) for name, val in res: logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', args.num_epochs - 1, rank, name, val)
def train(epochs, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] net.initialize(mx.init.Xavier(), ctx=ctx) # if opt.print_tensor_shape and rank == 0: # print(net) train_dataset = gluon.data.vision.CIFAR100(train=True).transform_first(transform_train) train_data = gluon.data.DataLoader( train_dataset, sampler=SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank), batch_size=batch_size, last_batch='discard', num_workers=opt.num_workers) # val_dataset = gluon.data.vision.CIFAR100(train=False).transform_first(transform_test) # val_data = gluon.data.DataLoader( # val_dataset, # sampler=SplitSampler(len(val_dataset), num_parts=num_workers, part_index=rank), # batch_size=batch_size, num_workers=opt.num_workers) val_data = gluon.data.DataLoader( gluon.data.vision.CIFAR100(train=False).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=opt.num_workers) hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = QSparseLocalSGDTrainerV1( net.collect_params(), 'nag', optimizer_params, input_sparse_ratio=1./opt.input_sparse, output_sparse_ratio=1./opt.output_sparse, layer_sparse_ratio=1./opt.layer_sparse, local_sgd_interval=opt.local_sgd_interval) # trainer = gluon.Trainer(net.collect_params(), optimizer, # {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() train_history = TrainingHistory(['training-error', 'validation-error']) iteration = 0 lr_decay_count = 0 best_val_score = 0 lr = opt.lr for epoch in range(epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) alpha = 1 if epoch == lr_decay_epoch[lr_decay_count]: lr *= lr_decay trainer.set_learning_rate(lr) lr_decay_count += 1 for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, acc = train_metric.get() iteration += 1 mx.nd.waitall() toc = time.time() train_loss /= batch_size * num_batch name, acc = train_metric.get() # name, val_acc = test(ctx, val_data) trainer.pre_test() name, val_acc = test(ctx, val_data) trainer.post_test() train_history.update([1-acc, 1-val_acc]) # train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name)) # allreduce the results allreduce_array_nd = mx.nd.array([train_loss, acc, val_acc]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) allreduce_array_np = allreduce_array_nd.asnumpy() train_loss = np.asscalar(allreduce_array_np[0]) acc = np.asscalar(allreduce_array_np[1]) val_acc = np.asscalar(allreduce_array_np[2]) if val_acc > best_val_score: best_val_score = val_acc # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) if rank == 0: logging.info('[Epoch %d] train=%f val=%f loss=%f comm=%.2f time: %f' % (epoch, acc, val_acc, train_loss, trainer._comm_counter/1e6, toc-tic)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch)) trainer._comm_counter = 0. if rank == 0: if save_period and save_dir: net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args): """Training pipeline""" args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store kv = mx.kvstore.create(args.kv_store) net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') for k, v in net.collect_params('.*bias').items(): v.wd_mult = 0.0 optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum, } if args.clip_gradient > 0.0: optimizer_params['clip_gradient'] = args.clip_gradient if args.amp: optimizer_params['multi_precision'] = True if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params ) else: trainer = gluon.Trainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if args.amp else None), kvstore=kv) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_warmup = float(args.lr_warmup) # avoid int division rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) rpn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rpn_smoothl1_rho) # == smoothl1 rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rcnn_smoothl1_rho) # == smoothl1 rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) metrics = [mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1'), mx.metric.Loss('RCNN_Mask')] rpn_acc_metric = RPNAccMetric() rpn_bbox_metric = RPNL1LossMetric() rcnn_acc_metric = RCNNAccMetric() rcnn_bbox_metric = RCNNL1LossMetric() rcnn_mask_metric = MaskAccMetric() rcnn_fgmask_metric = MaskFGAccMetric() metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric, rcnn_mask_metric, rcnn_fgmask_metric] async_eval_processes = [] logger.info(args) if args.verbose: logger.info('Trainable parameters:') logger.info(net.collect_train_params().keys()) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] base_lr = trainer.learning_rate for epoch in range(args.start_epoch, args.epochs): rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, rcnn_mask_loss, args.amp) executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None if not args.disable_hybridization: net.hybridize(static_alloc=args.static_alloc) while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) for metric in metrics: metric.reset() tic = time.time() btic = time.time() train_data_iter = iter(train_data) next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) for i in range(len(train_data)): batch = next_data_batch if i + epoch * len(train_data) <= lr_warmup: # adjust based on real percentage new_lr = base_lr * get_lr_at_iter((i + epoch * len(train_data)) / lr_warmup, args.lr_warmup_factor) if new_lr != trainer.learning_rate: if i % args.log_interval == 0: logger.info('[Epoch {} Iteration {}] Set learning rate to {}' .format(epoch, i, new_lr)) trainer.set_learning_rate(new_lr) metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] if executor is not None: for data in zip(*batch): executor.put(data) for j in range(len(ctx)): if executor is not None: result = executor.get() else: result = rcnn_task.forward_backward(list(zip(*batch))[0]) if (not args.horovod) or hvd.rank() == 0: for k in range(len(metric_losses)): metric_losses[k].append(result[k]) for k in range(len(add_losses)): add_losses[k].append(result[len(metric_losses) + k]) try: # prefetch next batch next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) except StopIteration: pass for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) trainer.step(batch_size) if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) batch_speed = args.log_interval * args.batch_size / (time.time() - btic) speed.append(batch_speed) logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, batch_speed, msg)) btic = time.time() if speed: avg_batch_speed = sum(speed) / len(speed) # validate and save params if (not args.horovod) or hvd.rank() == 0: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) logger.info('[Epoch {}] Training cost: {:.3f}, Speed: {:.3f} samples/sec, {}'.format( epoch, (time.time() - tic), avg_batch_speed, msg)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time validate(net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map, args) elif (not args.horovod) or hvd.rank() == 0: current_map = 0. save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix) for thread in async_eval_processes: thread.join()
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.info('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = { 'scale_window': 2000 / num_workers, 'init_scale': 2**10 } else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join( args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps logging.info('Training started') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = (num_train_steps - step_num) / (num_train_steps - num_warmup_steps) new_lr = lr * max(offset, 0) trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) running_mlm_loss += ls1.as_in_context(mx.cpu()) / len(ctxs) running_nsp_loss += ls2.as_in_context(mx.cpu()) / len(ctxs) running_num_tks += valid_length.sum().as_in_context(mx.cpu()) # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: fp16_trainer.step(1, max_norm=1.0 * num_workers) if accumulate > 1: param_dict.zero_grad() # update metrics if args.no_compute_acc: mask_pred_list[0].wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if step_num % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if step_num % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model.bert, args.ckpt_dir) if step_num % args.eval_interval == 0 and data_eval \ and (batch_num + 1) % accumulate == 0: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype) batch_num += 1 if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model.bert, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train_module(): # Create input symbol data = mx.sym.var('data') if args.dtype == 'float16': data = mx.sym.Cast(data=data, dtype=np.float16) net.cast(np.float16) # Create output symbol out = net(data) if args.dtype == 'float16': out = mx.sym.Cast(data=out, dtype=np.float32) softmax = mx.sym.SoftmaxOutput(out, name='softmax') # Create model mod = mx.mod.Module(softmax, context=context) # Initialize parameters if args.use_pretrained: arg_params = {} for x in net.collect_params().values(): x.reset_ctx(mx.cpu()) arg_params[x.name] = x.data() else: arg_params = None aux_params = None mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params) # Horovod: fetch and broadcast parameters (arg_params, aux_params) = mod.get_params() if arg_params is not None: hvd.broadcast_parameters(arg_params, root_rank=0) if aux_params is not None: hvd.broadcast_parameters(aux_params, root_rank=0) mod.set_params(arg_params=arg_params, aux_params=aux_params) # Create optimizer # Note that when using Module API, we need to specify rescale_grad since # we create optimizer first and wrap it with DistributedOptimizer. For # Gluon API, it is handled in Trainer.step() function so there is no need # to specify rescale_grad (see above train_gluon() function). optimizer_params = { 'wd': args.wd, 'momentum': args.momentum, 'rescale_grad': 1.0 / batch_size, 'lr_scheduler': lr_sched } if args.dtype == 'float16': optimizer_params['multi_precision'] = True opt = mx.optimizer.create('sgd', **optimizer_params) # Horovod: wrap optimizer with DistributedOptimizer dist_opt = hvd.DistributedOptimizer( opt, gradient_predivide_factor=args.gradient_predivide_factor) # Setup validation data and callback during training eval_data = None if args.eval_epoch: eval_data = val_data batch_callback = None if args.log_interval > 0 and rank == 0: batch_callback = mx.callback.Speedometer(batch_size * num_workers, args.log_interval) epoch_callback = None if args.save_frequency > 0: epoch_callback = mx.callback.do_checkpoint('%s-%d' % (args.model, rank), period=args.save_frequency) # Train model mod.fit(train_data, eval_data=eval_data, num_epoch=args.num_epochs, kvstore=None, batch_end_callback=batch_callback, epoch_end_callback=epoch_callback, optimizer=dist_opt) # Evaluate performance if not using synthetic data if args.use_rec: acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) res = mod.score(val_data, [acc_top1, acc_top5]) for name, val in res: logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', args.num_epochs - 1, rank, name, val)
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) setup_logging(args, local_rank) cfg, tokenizer, qa_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.classifier_dropout, args.param_checkpoint, args.backbone_path) logging.info('Prepare training data') train_features = get_squad_features(args, tokenizer, segment='train') dataset_processor = SquadDatasetProcessor( tokenizer=tokenizer, doc_stride=args.doc_stride, max_seq_length=args.max_seq_length, max_query_length=args.max_query_length) logging.info('Processing the Training data:') train_dataset, num_answer_mismatch, num_unreliable \ = dataset_processor.get_train(train_features, skip_unreliable=True) logging.info( 'Done! #Unreliable Span={} / #Mismatched Answer={} / #Total={}'.format( num_unreliable, num_answer_mismatch, len(train_features))) # Get dataset statistics num_impossible = 0 for sample in train_dataset: num_impossible += sample.is_impossible logging.info('Before Chunking, #Train/Is Impossible = {}/{}'.format( len(train_features), sum([ele.is_impossible for ele in train_features]))) logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'.format( len(train_dataset), num_impossible)) # Shuffle the dataset using a fixed seed across all workers rs = np.random.RandomState(args.pre_shuffle_seed) rs.shuffle(train_dataset) sampler = SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank, even_size=True) train_dataloader = mx.gluon.data.DataLoader( train_dataset, batchify_fn=dataset_processor.BatchifyFunction, batch_size=args.batch_size, num_workers=0, sampler=sampler) if 'electra' in args.model_name: # Froze parameters, does not work for albert model since parameters in all layers are shared if args.untunable_depth > 0: qa_net.backbone.frozen_params(args.untunable_depth) if args.layerwise_decay > 0: qa_net.backbone.apply_layerwise_decay(args.layerwise_decay) logging.info('Creating distributed trainer...') # Collect differentiable parameters param_dict = qa_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l) if args.num_train_steps is not None: num_train_steps = args.num_train_steps else: num_train_steps = int(args.epochs * epoch_size / args.num_accumulated) if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else epoch_size // args.num_accumulated logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) # set up optimization lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } adam_betas = eval(args.adam_betas) if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, 'correct_bias': False, }) elif args.optimizer == 'adam': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 global_tic = time.time() tic = time.time() for step_num, batch_data in enumerate( grouper(repeat(train_dataloader), len(ctx_l) * num_accumulated)): for sample_l in grouper(batch_data, len(ctx_l)): loss_l = [] span_loss_l = [] answerable_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # Copy the data to device tokens = sample.data.as_in_ctx(ctx) log_sample_num += len(tokens) segment_ids = sample.segment_ids.as_in_ctx( ctx) if use_segmentation else None valid_length = sample.valid_length.as_in_ctx(ctx) p_mask = sample.masks.as_in_ctx(ctx) gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32) gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32) is_impossible = sample.is_impossible.as_in_ctx(ctx).astype( np.int32) batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx) p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask with mx.autograd.record(): start_logits, end_logits, answerable_logits \ = qa_net(tokens, segment_ids, valid_length, p_mask, gt_start) sel_start_logits = start_logits[batch_idx, gt_start] sel_end_logits = end_logits[batch_idx, gt_end] sel_answerable_logits = answerable_logits[batch_idx, is_impossible] span_loss = -0.5 * (sel_start_logits + sel_end_logits).mean() answerable_loss = -0.5 * sel_answerable_logits.mean() loss = span_loss + answerable_loss loss_l.append(loss) span_loss_l.append(span_loss) answerable_loss_l.append(answerable_loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_span_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in span_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() log_answerable_loss += sum([ ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l ]).asnumpy() # update trainer.allreduce_grads() if args.max_grad_norm > 0: total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) else: total_norm = grad_global_norm(params) if args.comm_backend == 'horovod': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1 trainer.update(num_workers, ignore_stale_grad=True) total_norm = total_norm / num_workers if args.num_accumulated > 1: # set grad to zero for gradient accumulation qa_net.zero_grad() # saving if local_rank == 0 and (step_num + 1) % save_interval == 0 or ( step_num + 1) >= num_train_steps: version_prefix = 'squad' + args.version ckpt_name = '{}_{}_{}.params'.format(args.model_name, version_prefix, (step_num + 1)) params_saved = os.path.join(args.output_dir, ckpt_name) qa_net.save_parameters(params_saved) ckpt_candidates = [ f for f in os.listdir(args.output_dir) if f.endswith('.params') ] # keep last `max_saved_ckpt` checkpoints if len(ckpt_candidates) > args.max_saved_ckpt: ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) os.remove(os.path.join(args.output_dir, ckpt_candidates[0])) logging.info('Params saved in: {}'.format(params_saved)) # logging if (step_num + 1) % log_interval == 0: log_span_loss /= log_sample_num log_answerable_loss /= log_sample_num log_total_loss /= log_sample_num toc = time.time() logging.info( 'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s' ' ETA={:.2f}h'.format( (step_num + 1), num_train_steps, log_span_loss, log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600)) tic = time.time() log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 num_samples_per_update = 0 if (step_num + 1) >= num_train_steps: toc = time.time() logging.info('Finish training step: {} within {} hours'.format( step_num + 1, (toc - global_tic) / 3600)) break return params_saved
def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx): """Training function.""" hvd.broadcast_parameters(model.collect_params(), root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers} else: loss_scale_param = None trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in model.collect_params().values() if p.grad_req != 'null'] param_dict = model.collect_params() # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break # create dummy data loader if needed if args.dummy_data_len: target_shape = (args.batch_size, args.dummy_data_len) dataloader = get_dummy_dataloader(dataloader, target_shape) for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data if args.use_avg_len: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip([ctx], data_batch)] else: data_list = list(split_and_load(data_batch, [ctx])) data = data_list[0] # forward with mx.autograd.record(): (ls, ns_label, classified, masked_id, decoded, \ masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss, nsp_loss, vocab_size, args.dtype) ls = ls / accumulate # backward if args.dtype == 'float16': fp16_trainer.backward(ls) else: ls.backward() running_mlm_loss += ls1.as_in_context(mx.cpu()) running_nsp_loss += ls2.as_in_context(mx.cpu()) running_num_tks += valid_len.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: # step() performs 3 things: # 1. allreduce gradients from all workers # 2. checking the global_norm of gradients and clip them if necessary # 3. averaging the gradients and apply updates fp16_trainer.step(1, max_norm=1*num_workers) nsp_metric.update([ns_label], [classified]) mlm_metric.update([masked_id], [decoded], [masked_weight]) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) if data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1, False, False, 1) evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx], args.log_interval, args.dtype) batch_num += 1 if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train_module(): # Create input symbol data = mx.sym.var('data') if args.dtype == 'float16': data = mx.sym.Cast(data=data, dtype=np.float16) net.cast(np.float16) # Create output symbol out = net(data) if args.dtype == 'float16': out = mx.sym.Cast(data=out, dtype=np.float32) softmax = mx.sym.SoftmaxOutput(out, name='softmax') # Create model mod = mx.mod.Module(softmax, context=context) # Initialize parameters if args.use_pretrained: arg_params = {} for x in net.collect_params().values(): x.reset_ctx(mx.cpu()) arg_params[x.name] = x.data() else: arg_params = None aux_params = None mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params) # Horovod: fetch and broadcast parameters (arg_params, aux_params) = mod.get_params() if arg_params is not None: hvd.broadcast_parameters(arg_params, root_rank=0) if aux_params is not None: hvd.broadcast_parameters(aux_params, root_rank=0) mod.set_params(arg_params=arg_params, aux_params=aux_params) # Setup validation data and callback during training eval_data = None if args.eval_epoch: eval_data = val_data batch_callback = None if args.log_interval > 0 and rank == 0: batch_callback = mx.callback.Speedometer(batch_size * num_workers, args.log_interval) epoch_callback = None if args.save_frequency > 0: epoch_callback = mx.callback.do_checkpoint('%s-%d' % (args.model, rank), period=args.save_frequency) # Train model mod.fit(train_data, eval_data=eval_data, num_epoch=args.num_epochs, kvstore=None, batch_end_callback=batch_callback, epoch_end_callback=epoch_callback, optimizer=opt) # Evaluate performance if not using synthetic data if args.use_rec: acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) res = mod.score(val_data, [acc_top1, acc_top5]) for name, val in res: logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', args.num_epochs - 1, rank, name, val)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): """Training pipeline""" kv = mx.kvstore.create(args.kv_store) net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum } if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params( ), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params) else: trainer = gluon.Trainer( net.collect_train_params( ), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if args.amp else None), kvstore=kv) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted( [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_warmup = float(args.lr_warmup) # avoid int division # TODO(zhreshold) losses? rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss( from_sigmoid=False) rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1 rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1 metrics = [ mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1'), ] rpn_acc_metric = RPNAccMetric() rpn_bbox_metric = RPNL1LossMetric() rcnn_acc_metric = RCNNAccMetric() rcnn_bbox_metric = RCNNL1LossMetric() metrics2 = [ rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric ] # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) if args.verbose: logger.info('Trainable parameters:') logger.info(net.collect_train_params().keys()) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.epochs): mix_ratio = 1.0 if not args.disable_hybridization: net.hybridize(static_alloc=args.static_alloc) rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0) executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None if args.mixup: # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5) mix_ratio = 0.5 if epoch >= args.epochs - args.no_mixup_epochs: train_data._dataset._data.set_mixup(None) mix_ratio = 1.0 while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format( epoch, new_lr)) for metric in metrics: metric.reset() tic = time.time() btic = time.time() base_lr = trainer.learning_rate rcnn_task.mix_ratio = mix_ratio print(len(train_data)) for i, batch in enumerate(train_data): if epoch == 0 and i <= lr_warmup: # adjust based on real percentage new_lr = base_lr * get_lr_at_iter(i / lr_warmup, args.lr_warmup_factor) if new_lr != trainer.learning_rate: if i % args.log_interval == 0: logger.info( '[Epoch 0 Iteration {}] Set learning rate to {}'. format(i, new_lr)) trainer.set_learning_rate(new_lr) batch = split_and_load(batch, ctx_list=ctx) metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] if executor is not None: for data in zip(*batch): executor.put(data) for j in range(len(ctx)): if executor is not None: result = executor.get() else: result = rcnn_task.forward_backward(list(zip(*batch))[0]) if (not args.horovod) or hvd.rank() == 0: for k in range(len(metric_losses)): metric_losses[k].append(result[k]) for k in range(len(add_losses)): add_losses[k].append(result[len(metric_losses) + k]) for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) trainer.step(batch_size) # update metrics if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join([ '{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2 ]) logger.info( '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'. format( epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg)) btic = time.time() if (not args.horovod) or hvd.rank() == 0: msg = ','.join( ['{}={:.3f}'.format(*metric.get()) for metric in metrics]) logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format( epoch, (time.time() - tic), msg)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args) map_name_train, mean_ap_train = validate( net, train_data, ctx, eval_metric, args) if isinstance(map_name, list): val_msg = '\n'.join([ '{}={}'.format(k, v) for k, v in zip(map_name, mean_ap) ]) train_msg = '\n'.join([ '{}={}'.format(k, v) for k, v in zip(map_name_train, mean_ap_train) ]) current_map = float(mean_ap[-1]) else: val_msg = '{}={}'.format(map_name, mean_ap) train_msg = '{}={}'.format(map_name_train, mean_ap_train) current_map = mean_ap logger.info('[Epoch {}] Validation: {}'.format(epoch, val_msg)) logger.info('[Epoch {}] Train: {}'.format(epoch, train_msg)) else: current_map = 0. save_params(net, logger, best_map, current_map, epoch, args.save_interval, os.path.join(args.model_dir, 'fastrcnn')) executor.__del__()
def train_gluon(): if args.save_dir: save_dir = args.save_dir save_dir = os.path.expanduser(save_dir) makedirs(save_dir) else: save_dir = './' save_frequency = 0 def evaluate(epoch): acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) for _, batch in enumerate(val_data): data, label = val_batch_fn(batch, context) output = net(data.astype(args.dtype, copy=False)) acc_top1.update([label], [output]) acc_top5.update([label], [output]) top1_name, top1_acc = acc_top1.get() top5_name, top5_acc = acc_top5.get() if MPI is not None: comm = MPI.COMM_WORLD res1 = comm.gather(top1_acc, root=0) res2 = comm.gather(top5_acc, root=0) if rank == 0: if MPI is not None: #logging.info('MPI gather res1: {}'.format(res1)) top1_acc = sum(res1) / len(res1) top5_acc = sum(res2) / len(res2) logging.info( 'Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f', epoch, rank, top1_name, top1_acc, top5_name, top5_acc) # Hybridize and initialize model net.hybridize() #net.initialize(initializer, ctx=context) if args.resume_params is not '': net.load_parameters(args.resume_params, ctx=context) else: net.initialize(initializer, ctx=context) if args.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Horovod: fetch and broadcast parameters params = net.collect_params() if params is not None: hvd.broadcast_parameters(params, root_rank=0) # Create optimizer optimizer = 'nag' optimizer_params = { 'wd': args.wd, 'momentum': args.momentum, 'lr_scheduler': lr_sched } if args.dtype == 'float16': optimizer_params['multi_precision'] = True opt = mx.optimizer.create(optimizer, **optimizer_params) # Horovod: create DistributedTrainer, a subclass of gluon.Trainer trainer = hvd.DistributedTrainer(params, opt) if args.resume_states is not '': trainer.load_states(args.resume_states) # Create loss function and train metric if args.label_smoothing or args.mixup: sparse_label_loss = False else: sparse_label_loss = True distillation = args.teacher is not None and args.hard_weight < 1.0 if distillation: teacher = get_model(args.teacher, pretrained=True, classes=num_classes, ctx=context) teacher.hybridize() teacher.cast(args.dtype) loss_fn = gcv.loss.DistillationSoftmaxCrossEntropyLoss( temperature=args.temperature, hard_weight=args.hard_weight, sparse_label=sparse_label_loss) if rank == 0: logging.info('Using Distillation') else: loss_fn = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=sparse_label_loss) if args.mixup: train_metric = mx.metric.RMSE() else: train_metric = mx.metric.Accuracy() def mixup_transform(label, classes, lam=1, eta=0.0): if isinstance(label, mx.nd.NDArray): label = [label] res = [] for l in label: y1 = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) y2 = l[::-1].one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) res.append(lam * y1 + (1 - lam) * y2) return res def smooth(label, classes, eta=0.1): if isinstance(label, mx.NDArray): label = [label] smoothed = [] for l in label: res = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) smoothed.append(res) return smoothed # Train model for epoch in range(args.resume_epoch, args.num_epochs): drop_scheduler(epoch) tic = time.time() train_metric.reset() btic = time.time() for nbatch, batch in enumerate(train_data, start=1): data, label = train_batch_fn(batch, context) data, label = [data], [label] if args.mixup: lam = np.random.beta(args.mixup_alpha, args.mixup_alpha) if epoch >= args.num_epochs - args.mixup_off_epoch: lam = 1 data = [lam * X + (1 - lam) * X[::-1] for X in data] if args.label_smoothing: eta = 0.1 else: eta = 0.0 label = mixup_transform(label, num_classes, lam, eta) elif args.label_smoothing: hard_label = label label = smooth(label, num_classes) if distillation: teacher_prob = [mx.nd.softmax(teacher(X.astype(args.dtype, copy=False)) / args.temperature) \ for X in data] with autograd.record(): outputs = [net(X.astype(args.dtype, copy=False)) for X in data] if distillation: loss = [ loss_fn(yhat.astype('float32', copy=False), y.astype('float32', copy=False), p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob) ] else: loss = [ loss_fn(yhat, y.astype(args.dtype, copy=False)) for yhat, y in zip(outputs, label) ] for l in loss: l.backward() trainer.step(batch_size) if args.mixup: output_softmax = [mx.nd.SoftmaxActivation(out.astype('float32', copy=False)) \ for out in outputs] train_metric.update(label, output_softmax) else: if args.label_smoothing: train_metric.update(hard_label, outputs) else: train_metric.update(label, outputs) if args.log_interval and nbatch % args.log_interval == 0: if rank == 0: logging.info('Epoch[%d] Batch[%d] Loss[%.3f]', epoch, nbatch, loss[0].mean().asnumpy()[0]) train_metric_name, train_metric_score = train_metric.get() logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f', epoch, rank, nbatch, train_metric_name, train_metric_score, trainer.learning_rate) #batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic) #logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec', # epoch, nbatch, batch_speed) btic = time.time() # Report metrics elapsed = time.time() - tic _, acc = train_metric.get() if rank == 0: logging.info( 'Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-metric=%f', epoch, rank, nbatch, elapsed, acc) epoch_speed = num_workers * batch_size * nbatch / elapsed logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed) # Evaluate performance if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0: evaluate(epoch) # Save model if args.save_frequency and (epoch + 1) % args.save_frequency == 0: net.save_parameters('%s/imagenet-%s-%d.params' % (save_dir, args.model, epoch)) trainer.save_states('%s/imagenet-%s-%d.states' % (save_dir, args.model, epoch)) # Evaluate performance at the end of training evaluate(epoch) net.save_parameters('%s/imagenet-%s-%d.params' % (save_dir, args.model, args.num_epochs - 1)) trainer.save_states('%s/imagenet-%s-%d.states' % (save_dir, args.model, args.num_epochs - 1))