def test_byteps_push_pull(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" size = bps.size() dtypes = self.filter_supported_types(['float32']) dims = [1] ctx = self._current_context() count = 100 shapes = [(), (17)] for dtype, dim in itertools.product(dtypes, dims): # MXNet uses gpu_id as part of the seed, so to get identical seeds # we must set a context. mx.random.seed(10 + 10 * bps.rank(), ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) print("tensor before push_pull:", tensor) bps.byteps_declare_tensor("tensor_" + str(count)) bps.byteps_push_pull(tensor, name="tensor_" + str(count)) tensor.wait_to_read() print("tensor after push_pull:", tensor) print('test_byteps_push_pull passed')
def test_byteps_push_pull(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" dtypes = ['float16', 'float32', 'float64'] dims = [1, 2, 3] count = 0 ctx = self._current_context() shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): # MXNet uses gpu_id as part of the seed, so to get identical seeds # we must set a context. mx.random.seed(10 + 10 * bps.rank(), ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) input = tensor.asnumpy() bps.byteps_declare_tensor("tensor_" + str(count)) bps.byteps_push_pull(tensor, name="tensor_" + str(count)) tensor.wait_to_read() output = tensor.asnumpy() assert np.allclose(input, output) count += 1 print('test_byteps_push_pull passed')
def test_byteps_push_pull_inplace(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" size = bps.size() dtypes = self.filter_supported_types( ['int32', 'int64', 'float32', 'float64']) dims = [1, 2, 3] ctx = self._current_context() count = 200 shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): mx.random.seed(1234, ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) multiplied = tensor.copy() bps.byteps_declare_tensor("tensor_" + str(count)) bps.byteps_push_pull(tensor, name="tensor_" + str(count)) max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied)) count += 1 # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in ['int32', 'int64']: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break if max_difference > threshold: print("self", count, dtype, dim, max_difference, threshold) print("tensor", bps.rank(), tensor) print("multiplied", bps.rank(), multiplied) assert max_difference <= threshold, 'bps.byteps_push_pull produces \ incorrect results for self' print('test_byteps_push_pull_inplace passed')
def train(epochs, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] net.initialize(mx.init.Xavier(), ctx=ctx) train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100( train=True).shard(nworker, rank).transform_first(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100( train=False).shard(nworker, rank).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=num_workers) params = net.collect_params() compression_params = { "compressor": opt.compressor, "ef": opt.ef, "momentum": opt.compress_momentum, "scaling": opt.onebit_scaling, "k": opt.k, "fp16": opt.fp16_pushpull } optimizer_params = { 'lr_scheduler': lr_scheduler, 'wd': opt.wd, 'momentum': opt.momentum } trainer = bps.DistributedTrainer(params, optimizer, optimizer_params, compression_params=compression_params) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() iteration = 0 best_val_score = 0 bps.byteps_declare_tensor("acc") for epoch in range(epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, train_acc = train_metric.get() iteration += 1 train_loss /= batch_size * num_batch name, train_acc = train_metric.get() throughput = int(batch_size * nworker * i / (time.time() - tic)) logger.info( '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' % (epoch, throughput, time.time() - tic, trainer.learning_rate)) name, val_acc = test(ctx, val_data) acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0]) bps.byteps_push_pull(acc, name="acc", is_average=False) acc /= bps.size() train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar() if bps.rank() == 0: logger.info('[Epoch %d] training: %s=%f' % (epoch, name, train_acc)) logger.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if val_acc > best_val_score: best_val_score = val_acc net.save_parameters( '%s/%.4f-cifar-%s-%d-best.params' % (save_dir, best_val_score, model_name, epoch)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('%s/cifar100-%s-%d.params' % (save_dir, model_name, epoch)) if save_period and save_dir: net.save_parameters('%s/cifar100-%s-%d.params' % (save_dir, model_name, epochs - 1))
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True if args.optimizer == 'lamb': optim_params['bias_correction'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1} else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) elif backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0]) running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') logging.info('Generating the first batch of data, which may take a few minutes ...') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) if backend == 'byteps': bps.byteps_declare_tensor("local_num_masks") bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) logging.debug('Broadcast local_num_masks tensor') next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq))) data_list = list(split_and_load(next_batch, ctxs)) parallel.put(data_list[0]) parallel.get() trainer._init_params() while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) if early_stop and step_num == 10: mx.nd.waitall() exit() # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] with mx.autograd.record(): num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_num_masks += num_masks local_mlm_loss += ls1 running_num_tks += valid_length.sum() # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: running_mlm_loss += local_mlm_loss / local_num_masks if backend == 'horovod': hvd.allreduce_(local_num_masks, average=False, name='local_num_masks') elif backend == 'byteps': bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) # because byteps implicitly set scale /= num_workers fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks, num_ctxs=len(ctxs) * num_workers) local_num_masks, local_mlm_loss = 0, 0 # update metrics if args.no_compute_acc: for mask_pred_i in mask_pred_list: mask_pred_i.wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss, 0, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model.bert, args.ckpt_dir) if (step_num + 1) % args.eval_interval == 0 and data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers) batch_num += 1 # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
trainer.step(args.batch_size) metric.update([label], [output]) if i % 100 == 0: name, acc = metric.get() logger.info('[Epoch %d Batch %d] Training: %s=%f' % (epoch, i, name, acc)) elapsed = time.time() - tic total_time += elapsed speed = train_size * num_workers / elapsed logger.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', epoch, speed, elapsed) # Evaluate model accuracy _, train_acc = metric.get() name, val_acc = evaluate(model, val_data, context) acc = mx.nd.array([train_acc, val_acc], ctx=context) bps.byteps_push_pull(acc, name="acc", is_average=False) acc /= bps.size() train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar() if bps.rank() == 0: logger.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, train_acc, name, val_acc) if bps.rank() == 0 and epoch == args.epochs - 1: assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ (0.96)" % val_acc logger.info("total time=%.2f", total_time)