def allreduce_params(self): for i, param in enumerate(self._params): if param.grad_req != 'null': hvd.allreduce_(param.list_data()[0], average=True, name=str(i), priority=-i)
def test_horovod_allreduce_inplace(self): """Test that the allreduce correctly sums 1D, 2D, 3D tensors.""" hvd.init() size = hvd.size() dtypes = self.filter_supported_types( ['int32', 'int64', 'float32', 'float64']) dims = [1, 2, 3] ctx = self._current_context() count = 0 shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): mx.random.seed(1234, ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) multiplied = tensor * size hvd.allreduce_(tensor, average=False, name=str(count)) count += 1 # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in ['int32', 'int64']: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break assert almost_equal(tensor.asnumpy(), multiplied.asnumpy(), atol=threshold), \ f'hvd.allreduce produces incorrect results for self: {hvd.rank()} {count} {dtype} {dim}'
def pre_test(self): for i, param in enumerate(self._params): if param.grad_req != 'null': self._params_cache[i][:] = param.list_data()[0] hvd.allreduce_(param.list_data()[0], average=True, name=str(i), priority=-i)
def _allreduce_params(self): for i, param in enumerate(self._params): if param.grad_req != 'null': hvd.allreduce_(param.list_data()[0], average=True, name=str(i), priority=-i) # communication counter self._comm_counter += param.list_data()[0].size * 2
def _do_allreduce(self, index, grad): if hvd.size() == 1: return if isinstance(index, (tuple, list)): for i in range(len(index)): hvd.allreduce_(grad[i], average=False, name=self._prefix + str(index[i]), priority=-i) else: hvd.allreduce_(grad, average=False, name=self._prefix + str(index))
def allreduce_states(self): for i, param in reversed(list(enumerate(self._params))): if param.grad_req != 'null': state_array = self._updaters[0].states[i][1] idx = i + len(self._params) if param._stype == 'default': hvd.allreduce_(state_array, average=True, name=str(idx), priority=i - len(self._params) * 2) self._updaters[0].states[i][0][:] = state_array else: raise ValueError( "Cannot pull row_sparse parameters for local SGD")
def allgather(self, tensor, name, shape, dtype, context): """ Implement in-place AllGather using AllReduce """ assert isinstance(tensor, nd.NDArray), type(tensor) assert isinstance(name, str), type(name) assert isinstance(shape, tuple), type(shape) assert isinstance(dtype, str), type(dtype) assert isinstance(context, mx.context.Context), type(context) total_tensor = self.get_ndarray( context=context, name=name, shape=shape, dtype=dtype) total_tensor[:] = 0 # reset array before all-reduce is very important total_tensor[self.rank * self.batch_size: self.rank * self.batch_size + self.batch_size] = tensor hvd.allreduce_(total_tensor, average=False) # all-reduce in-place return total_tensor
def allreduce_params(self): """For each parameter, reduce the parameters from different contexts. Should be called after `autograd.backward()`, outside of `record()` scope, and before `trainer.update()`. For normal parameter updates, `step()` should be used, which internally calls `allreduce_grads()` and then `update()`. However, if you need to get the reduced gradients to perform certain transformation, such as in gradient clipping, then you may want to manually call `allreduce_grads()` and `update()` separately. """ for i, param in enumerate(self._params): if param.grad_req != 'null': hvd.allreduce_(param.list_data()[0], average=True, name=str(i), priority=-i) for j in range(1, len(param.list_data())): param.list_data()[0].copyto(param.list_data()[j])
def test_horovod_allreduce_inplace(self): """Test that the allreduce correctly sums 1D, 2D, 3D tensors.""" hvd.init() size = hvd.size() dtypes = self.filter_supported_types( ['int32', 'int64', 'float32', 'float64']) dims = [1, 2, 3] ctx = self._current_context() count = 0 shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): mx.random.seed(1234, ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) multiplied = tensor * size hvd.allreduce_(tensor, average=False, name=str(count)) max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied)) count += 1 # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in ['int32', 'int64']: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break if max_difference > threshold: print("self", count, dtype, dim, max_difference, threshold) print("tensor", hvd.rank(), tensor) print("multiplied", hvd.rank(), multiplied) assert max_difference <= threshold, 'hvd.allreduce produces \
def train(ctx): if isinstance(ctx, mx.Context): ctx = [ctx] if opt.resume_params == '': net.initialize(mx.init.MSRAPrelu(), ctx=ctx) if opt.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 hvd.broadcast_parameters(net.collect_params(), root_rank=0) # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) # trainer = hvd.DistributedTrainer( # net.collect_params(), # optimizer, # optimizer_params) if opt.trainer == 'sgd': trainer = SGDTrainer( net.collect_params(), optimizer, optimizer_params) elif opt.trainer == 'efsgd': trainer = EFSGDTrainerV1( net.collect_params(), 'EFSGDV1', optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1) elif opt.trainer == 'qsparselocalsgd': trainer = QSparseLocalSGDTrainerV1( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1, local_sgd_interval=opt.local_sgd_interval) elif opt.trainer == 'ersgd': trainer = ERSGDTrainerV2( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1) elif opt.trainer == 'partiallocalsgd': trainer = PartialLocalSGDTrainerV1( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio=1./opt.input_sparse_1, output_sparse_ratio=1./opt.output_sparse_1, layer_sparse_ratio=1./opt.layer_sparse_1, local_sgd_interval=opt.local_sgd_interval) elif opt.trainer == 'ersgd2': trainer = ERSGD2TrainerV2( net.collect_params(), optimizer, optimizer_params, input_sparse_ratio_1=1./opt.input_sparse_1, output_sparse_ratio_1=1./opt.output_sparse_1, layer_sparse_ratio_1=1./opt.layer_sparse_1, input_sparse_ratio_2=1./opt.input_sparse_2, output_sparse_ratio_2=1./opt.output_sparse_2, layer_sparse_ratio_2=1./opt.layer_sparse_2, local_sgd_interval=opt.local_sgd_interval) else: trainer = SGDTrainer( net.collect_params(), optimizer, optimizer_params) if opt.resume_states != '': trainer.load_states(opt.resume_states) if opt.label_smoothing or opt.mixup: sparse_label_loss = False else: sparse_label_loss = True if distillation: L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature, hard_weight=opt.hard_weight, sparse_label=sparse_label_loss) else: L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss) best_val_score = 1 for epoch in range(opt.resume_epoch, opt.num_epochs): tic = time.time() if opt.use_rec: train_data.reset() # train_metric.reset() train_loss = 0 btic = time.time() # test speed if opt.test_speed > 0: n_repeats = opt.test_speed elif opt.test_speed == 0: n_repeats = 1 else: n_repeats = 0 for i, batch in enumerate(train_data): # test speed if n_repeats == 0 and not (i+1)%opt.log_interval: print('[Epoch %d] # batch: %d'%(epoch, i)) continue data, label = batch_fn(batch, ctx) for j in range(n_repeats): if opt.mixup: lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) if epoch >= opt.num_epochs - opt.mixup_off_epoch: lam = 1 data = [lam*X + (1-lam)*X[::-1] for X in data] if opt.label_smoothing: eta = 0.1 else: eta = 0.0 label = mixup_transform(label, classes, lam, eta) elif opt.label_smoothing: hard_label = label label = smooth(label, classes) if distillation: teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \ for X in data] with ag.record(): outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] if distillation: loss = [L(yhat.astype('float32', copy=False), y.astype('float32', copy=False), p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)] else: loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] for l in loss: l.backward() trainer.step(batch_size) # if opt.mixup: # output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \ # for out in outputs] # train_metric.update(label, output_softmax) # else: # if opt.label_smoothing: # train_metric.update(hard_label, outputs) # else: # train_metric.update(label, outputs) step_loss = sum([l.sum().asscalar() for l in loss]) train_loss += step_loss if opt.log_interval and not (i+j+1)%opt.log_interval: # train_metric_name, train_metric_score = train_metric.get() if hvd.rank() == 0: # logger.info('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( # epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), # train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6)) # print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( # epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), # train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6)) print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%( epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic), 'loss', step_loss/batch_size, trainer.learning_rate, trainer._comm_counter/1e6)) btic = time.time() mx.nd.waitall() toc = time.time() if n_repeats == 0: allreduce_array_nd = mx.nd.array([i]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) mx.nd.waitall() print('[Epoch %d] # total batch: %d'%(epoch, i)) continue train_metric_name, train_metric_score = train_metric.get() throughput = int(batch_size * i /(toc - tic) * hvd.size()) train_loss /= (batch_size * i) if opt.trainer == 'ersgd' or opt.trainer == 'qsparselocalsgd' or opt.trainer == 'ersgd2' or opt.trainer == 'partiallocalsgd': allreduce_for_val = True else: allreduce_for_val = False if allreduce_for_val: trainer.pre_test() # err_train_tic = time.time() # err_top1_train, err_top5_train = test(ctx, train_data, val=False) err_val_tic = time.time() err_top1_val, err_top5_val = test(ctx, val_data, val=True) err_val_toc = time.time() if allreduce_for_val: trainer.post_test() mx.nd.waitall() # allreduce the results allreduce_array_nd = mx.nd.array([train_loss, err_top1_val, err_top5_val]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) allreduce_array_np = allreduce_array_nd.asnumpy() train_loss = np.asscalar(allreduce_array_np[0]) err_top1_val = np.asscalar(allreduce_array_np[1]) err_top5_val = np.asscalar(allreduce_array_np[2]) if hvd.rank() == 0: # logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score)) logger.info('[Epoch %d] training: loss=%f'%(epoch, train_loss)) logger.info('[Epoch %d] speed: %d samples/sec training-time: %f comm: %f'%(epoch, throughput, toc-tic, trainer._comm_counter/1e6)) logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-time=%f'%(epoch, err_top1_val, err_top5_val, err_val_toc - err_val_tic)) trainer._comm_counter = 0 if err_top1_val < best_val_score: best_val_score = err_top1_val # if hvd.local_rank() == 0: # net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) # trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: if hvd.local_rank() == 0: net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
def test_gluon_trainer(self): """Test using horovod allreduce in MXNet Gluon trainer.""" from mxnet import gluon from mxnet.gluon import Block, nn, HybridBlock hvd.init() rank = hvd.rank() np.random.seed(1000 + 10 * rank) mx.random.seed(1000 + 10 * rank) ctx = mx.gpu(rank) def gen_random_dataset(batch_size=64, dim=32, min_len=20, max_len=100, size=1000): for _ in range(size): length = np.random.randint(min_len, max_len + 1) rand_src = mx.nd.random.normal(0, 1, (length, dim)) rand_dst = mx.nd.random.normal(0, 1, (length, dim)) yield rand_src, rand_dst class SimpleNet(HybridBlock): def __init__(self, layer_num=6, **kwargs): super(SimpleNet, self).__init__(**kwargs) self._layer_num = layer_num with self.name_scope(): self.ln_l = nn.HybridSequential() self.dense_l = nn.HybridSequential() for i in range(layer_num): self.dense_l.add( nn.Dense(units=32 + layer_num - 1 - i, flatten=False)) self.ln_l.add(nn.LayerNorm()) def hybrid_forward(self, F, data): """ Parameters ---------- data : Shape (batch_size, seq_len, fea_dim) Returns ------- out : Shape (batch_size, seq_len, fea_dim) """ for i in range(self._layer_num): data = self.ln_l[i](data) data = self.dense_l[i](data) return data net = SimpleNet() net.initialize(ctx=ctx) net.hybridize(static_alloc=True) params = net.collect_params() cnt = 0 lr = 1E-4 trainer = gluon.Trainer(params, 'adam', {'learning_rate': lr}, update_on_kvstore=False) data_gen = gen_random_dataset() for (src_data, dst_data) in data_gen: src_data = src_data.as_in_context(ctx).astype(np.float32) dst_data = dst_data.as_in_context(ctx).astype(np.float32) with mx.autograd.record(): pred = net(src_data) loss = mx.nd.abs(pred - dst_data).mean() loss.backward() # Begin to update the parameter trainer.step(1.0) cnt += 1 l = loss.asscalar() if cnt >= 10: for key, param in params.items(): hvd.allreduce_(param.list_data()[0]) cnt = 0
def pushpull(self, key, value, out=None, priority=0): """ Performs allreduce on a single tensor or a list of tensor objects This function performs in-place summation of the input tensor over all the processes. The name `pushpull` is a generic term. In Horovod, its action is implemented via ring allreduce. Each operation is identified by the 'key'; if `key` is not provided, an incremented auto-generated name is used. The tensor type and shape must be the same on all processes for a given name. The reduction will not start until all processes are ready to send and receive the tensor. Parameters ---------- key : str, int, or sequence of str or int Keys used to uniquely tag an operation. value : NDArray Tensor value on one process to be summed. If `out` is not specified, the `value` will be modified in-place out: NDArray Output tensor after allreduce. If not specified, the input tensor `value` will be modified in-place. priority : int, optional The priority of the operation. Higher priority operations are likely to be executed before other actions. Examples -------- >>> # perform in-place allreduce on tensor a >>> shape = (2, 3) >>> nworker = kv.num_workers # assume there are 8 processes >>> a = mx.nd.ones(shape) >>> kv.pushpull('1', a) >>> print(a.asnumpy()) [[ 8. 8. 8.] [ 8. 8. 8.]] >>> # perform allreduce on tensor a and output to b >>> a = mx.nd.ones(shape) >>> kv.pushpull('2', a, out=b) >>> print(b.asnumpy()) [[ 8. 8. 8.] [ 8. 8. 8.]] """ import horovod.mxnet as hvd if out is None: value = value if isinstance(value, list) else [value] for v in value: hvd.allreduce_(v, average=False, name=str(key), priority=priority) else: out = out if isinstance(out, list) else [out] value = value if isinstance(value, list) else [value] for o, v in zip(out, value): o[:] = hvd.allreduce(v, average=False, name=str(key), priority=priority)
def backward_sample(self, total_feature, label): this_rank_classes = int(self.memory_bank.num_sample) local_index, unique_sorted_global_label = self.memory_bank.sample( label) # Get local index _mapping_dict = {} local_sampled_class = local_index + self.rank * self.memory_bank.num_local global_label_set = set(unique_sorted_global_label) for idx, absolute_label in enumerate(local_sampled_class): if absolute_label in global_label_set: _mapping_dict[ absolute_label] = idx + self.rank * self.memory_bank.num_sample label_list = list(label.asnumpy()) mapping_label = [] for i in range(len(label_list)): absolute_label = label_list[i] if absolute_label in _mapping_dict.keys(): mapping_label.append(_mapping_dict[absolute_label]) else: mapping_label.append(-1) mapping_label = nd.array(mapping_label, dtype=np.int32) # Get weight local_index = nd.array(local_index) local_index = self.get_ndarray2(self.gpu, "local_index", local_index) sample_weight, sample_weight_mom = self.memory_bank.get(local_index) # Sync to gpu if self.memory_bank.gpu: _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank, total_feature) _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank, sample_weight) _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d' % self.rank, sample_weight_mom) else: _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank, total_feature) _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank, sample_weight) _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d' % self.rank, sample_weight_mom) # Attach grad _data.attach_grad() _weight.attach_grad() # Convert label _label = self.get_ndarray2(self.gpu, 'mapping_label_%d' % self.rank, mapping_label) _label = _label - int(self.rank * self.memory_bank.num_sample) _fc7, _one_hot = self.fc7_model.forward(_data, _weight, mapping_label=_label, depth=this_rank_classes) # Sync max max_fc7 = nd.max(_fc7, axis=1, keepdims=True) max_fc7 = nd.reshape(max_fc7, -1) total_max_fc7 = self.get_ndarray(context=self.gpu, name='total_max_fc7', shape=(max_fc7.shape[0], self.size), dtype='float32') total_max_fc7[:] = 0 total_max_fc7[:, self.rank] = max_fc7 hvd.allreduce_(total_max_fc7, average=False) global_max_fc7 = self.get_ndarray(context=self.gpu, name='global_max_fc7', shape=(max_fc7.shape[0], 1), dtype='float32') nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7) # Calculate exp(logits) _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7) _fc7_grad = nd.exp(_fc7_grad) # Calculate sum sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True) global_sum_fc7 = hvd.allreduce(sum_fc7, average=False) # Calculate grad _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7) # Calculate loss tmp = _fc7_grad * _one_hot tmp = nd.sum(tmp, axis=1, keepdims=True) tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp) tmp = hvd.allreduce(tmp, average=False) global_loss = -nd.mean(nd.log(tmp + 1e-30)) _fc7_grad = _fc7_grad - _one_hot # Backward _fc7.backward(out_grad=_fc7_grad) # Update center _weight_grad = _weight.grad self.memory_optimizer.update(weight=_weight, grad=_weight_grad, state=_weight_mom, learning_rate=self.memory_lr) if self.memory_bank.gpu: self.memory_bank.set(index=local_index, updated_weight=_weight, updated_weight_mom=_weight_mom) else: self.memory_bank.set(index=local_index, updated_weight=self.get_ndarray2( mx.cpu(), "cpu_weight_%d" % self.rank, _weight), updated_weight_mom=self.get_ndarray2( mx.cpu(), "cpu_weight_mom_%d" % self.rank, _weight_mom)) return _data.grad, global_loss
def backward(self, total_feature, label): memory_bank = self.memory_bank assert memory_bank.num_local == memory_bank.num_sample, "pass" _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank, total_feature) # Attach grad _data.attach_grad() memory_bank.weight.attach_grad() # Convert label _label = self.get_ndarray2(self.gpu, 'label_%d' % self.rank, label) _label = _label - int(self.rank * memory_bank.num_local) _fc7, _one_hot = self.fc7_model.forward(_data, memory_bank.weight, mapping_label=_label, depth=memory_bank.num_local) # Sync max max_fc7 = nd.max(_fc7, axis=1, keepdims=True) max_fc7 = nd.reshape(max_fc7, -1) total_max_fc7 = self.get_ndarray(context=self.gpu, name='total_max_fc7', shape=(max_fc7.shape[0], self.size), dtype='float32') total_max_fc7[:] = 0 total_max_fc7[:, self.rank] = max_fc7 hvd.allreduce_(total_max_fc7, average=False) global_max_fc7 = self.get_ndarray(context=self.gpu, name='global_max_fc7', shape=(max_fc7.shape[0], 1), dtype='float32') nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7) # Calculate exp(logits) _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7) _fc7_grad = nd.exp(_fc7_grad) # Calculate sum sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True) global_sum_fc7 = hvd.allreduce(sum_fc7, average=False) # Calculate prob _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7) # Calculate loss tmp = _fc7_grad * _one_hot tmp = nd.sum(tmp, axis=1, keepdims=True) tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp) tmp = hvd.allreduce(tmp, average=False) global_loss = -nd.mean(nd.log(tmp + 1e-30)) # Calculate fc7 grad _fc7_grad = _fc7_grad - _one_hot # Backward _fc7.backward(out_grad=_fc7_grad) # Update center _weight_grad = memory_bank.weight.grad self.memory_optimizer.update(weight=memory_bank.weight, grad=_weight_grad, state=memory_bank.weight_mom, learning_rate=self.memory_lr) return _data.grad, global_loss
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True if args.optimizer == 'lamb': optim_params['bias_correction'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1} else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) elif backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0]) running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') logging.info('Generating the first batch of data, which may take a few minutes ...') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) if backend == 'byteps': bps.byteps_declare_tensor("local_num_masks") bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) logging.debug('Broadcast local_num_masks tensor') next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq))) data_list = list(split_and_load(next_batch, ctxs)) parallel.put(data_list[0]) parallel.get() trainer._init_params() while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) if early_stop and step_num == 10: mx.nd.waitall() exit() # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] with mx.autograd.record(): num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_num_masks += num_masks local_mlm_loss += ls1 running_num_tks += valid_length.sum() # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: running_mlm_loss += local_mlm_loss / local_num_masks if backend == 'horovod': hvd.allreduce_(local_num_masks, average=False, name='local_num_masks') elif backend == 'byteps': bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) # because byteps implicitly set scale /= num_workers fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks, num_ctxs=len(ctxs) * num_workers) local_num_masks, local_mlm_loss = 0, 0 # update metrics if args.no_compute_acc: for mask_pred_i in mask_pred_list: mask_pred_i.wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss, 0, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model.bert, args.ckpt_dir) if (step_num + 1) % args.eval_interval == 0 and data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers) batch_num += 1 # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train(epochs, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] net.initialize(mx.init.Xavier(), ctx=ctx) # if opt.print_tensor_shape and rank == 0: # print(net) train_dataset = gluon.data.vision.CIFAR100(train=True).transform_first(transform_train) train_data = gluon.data.DataLoader( train_dataset, sampler=SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank), batch_size=batch_size, last_batch='discard', num_workers=opt.num_workers) # val_dataset = gluon.data.vision.CIFAR100(train=False).transform_first(transform_test) # val_data = gluon.data.DataLoader( # val_dataset, # sampler=SplitSampler(len(val_dataset), num_parts=num_workers, part_index=rank), # batch_size=batch_size, num_workers=opt.num_workers) val_data = gluon.data.DataLoader( gluon.data.vision.CIFAR100(train=False).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=opt.num_workers) hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = QSparseLocalSGDTrainerV1( net.collect_params(), 'nag', optimizer_params, input_sparse_ratio=1./opt.input_sparse, output_sparse_ratio=1./opt.output_sparse, layer_sparse_ratio=1./opt.layer_sparse, local_sgd_interval=opt.local_sgd_interval) # trainer = gluon.Trainer(net.collect_params(), optimizer, # {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() train_history = TrainingHistory(['training-error', 'validation-error']) iteration = 0 lr_decay_count = 0 best_val_score = 0 lr = opt.lr for epoch in range(epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) alpha = 1 if epoch == lr_decay_epoch[lr_decay_count]: lr *= lr_decay trainer.set_learning_rate(lr) lr_decay_count += 1 for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, acc = train_metric.get() iteration += 1 mx.nd.waitall() toc = time.time() train_loss /= batch_size * num_batch name, acc = train_metric.get() # name, val_acc = test(ctx, val_data) trainer.pre_test() name, val_acc = test(ctx, val_data) trainer.post_test() train_history.update([1-acc, 1-val_acc]) # train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name)) # allreduce the results allreduce_array_nd = mx.nd.array([train_loss, acc, val_acc]) hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True) allreduce_array_np = allreduce_array_nd.asnumpy() train_loss = np.asscalar(allreduce_array_np[0]) acc = np.asscalar(allreduce_array_np[1]) val_acc = np.asscalar(allreduce_array_np[2]) if val_acc > best_val_score: best_val_score = val_acc # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) if rank == 0: logging.info('[Epoch %d] train=%f val=%f loss=%f comm=%.2f time: %f' % (epoch, acc, val_acc, train_loss, trainer._comm_counter/1e6, toc-tic)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch)) trainer._comm_counter = 0. if rank == 0: if save_period and save_dir: net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
def train(): """Training loop for language model. """ # logging.info(model) from_epoch = 0 model.initialize(mx.init.Xavier(factor_type='out'), ctx=ctx) trainer_params = {'learning_rate': args.lr, 'wd': 0, 'eps': args.eps} # trainer = gluon.Trainer(model.collect_params(), args.optimizer, trainer_params) # fully sync at the beginning trainer = DistributedHierLocalHVDTrainer(model.collect_params(), args.optimizer, trainer_params, local_sgd_interval=0) trainer._optimizer._full_sync = True if args.from_epoch: from_epoch = args.from_epoch checkpoint_name = '%s.%s' % (args.save, format(from_epoch - 1, '02d')) model.load_parameters(checkpoint_name) trainer.load_states('%s.state' % args.save) logging.info('Loaded parameters from checkpoint %s' % (checkpoint_name)) hvd.broadcast_parameters(model.collect_params(), root_rank=0) model.hybridize(static_alloc=True, static_shape=True) encoder_params = model.encoder.collect_params().values() embedding_params = list(model.embedding.collect_params().values()) step_num = 0 lr = args.lr current_lr = lr epoch = from_epoch start_epoch_time = time.time() start_log_interval_time = time.time() nbatch = 0 while epoch < args.epochs: sys.stdout.flush() total_L = 0.0 hidden = model.begin_state(batch_size=args.batch_size, func=mx.nd.zeros, ctx=ctx) has_next = True train_data_iter = iter(train_data) data, target, mask, sample = next(train_data_iter) while has_next: nbatch += 1 step_num += 1 if step_num <= args.warmup_steps: new_lr = lr * step_num / args.warmup_steps trainer.set_learning_rate(new_lr) current_lr = new_lr if step_num == args.warmup_steps + 1: trainer._local_sgd_interval = args.local_sgd_interval trainer._optimizer._full_sync = False trainer.init_states() hidden = detach(hidden) with autograd.record(): output, hidden, new_target = model(data, target, hidden, sample) output = output.reshape((-3, -1)) new_target = new_target.reshape((-1, )) ls = loss(output, new_target) * mask.reshape((-1, )) ls = ls / args.batch_size ls.backward() # prefetch the next batch of data try: data, target, mask, sample = next(train_data_iter) except StopIteration: has_next = False # rescale embedding grad x = embedding_params[0].grad(ctx) x[:] *= args.batch_size encoder_grad = [p.grad(ctx) for p in encoder_params] # perform gradient clipping per ctx gluon.utils.clip_global_norm(encoder_grad, args.clip) trainer.step(1) ls_sum = mx.nd.sum(ls) total_L += ls_sum / args.bptt # total_L += mx.nd.sum(ls).asscalar() / args.bptt if nbatch % args.log_interval == 0: hvd.allreduce_(total_L, average=True, name='ls', priority=-9999) cur_L = total_L.asscalar() / args.log_interval ppl = math.exp(cur_L) if cur_L < 100 else float('inf') if rank == 0: logging.info( '[Epoch %d Batch %d] loss %.2f, ppl %.2f, ' 'throughput %.2f samples/s, lr %.4f' % (epoch, nbatch, cur_L, ppl, train_batch_size * num_workers * args.log_interval / (time.time() - start_log_interval_time), current_lr)) total_L = 0.0 start_log_interval_time = time.time() sys.stdout.flush() if nbatch == num_batches_per_epoch: end_epoch_time = time.time() logging.info('Epoch %d took %.2f seconds.' % (epoch, end_epoch_time - start_epoch_time)) mx.nd.waitall() checkpoint_name = '%s.%s' % (args.save, format(epoch, '02d')) if local_rank == 0: model.save_parameters(checkpoint_name) if local_rank == 1: trainer.save_states('%s.state' % args.save) nbatch = 0 start_epoch_time = time.time() epoch += 1 if epoch == args.epochs: break