예제 #1
0
    def test_byteps_push_pull(self):
        """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
        size = bps.size()
        dtypes = self.filter_supported_types(['float32'])
        dims = [1]
        ctx = self._current_context()
        count = 100
        shapes = [(), (17)]
        for dtype, dim in itertools.product(dtypes, dims):
            # MXNet uses gpu_id as part of the seed, so to get identical seeds
            # we must set a context.
            mx.random.seed(10 + 10 * bps.rank(), ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)

            print("tensor before push_pull:", tensor)
            bps.byteps_declare_tensor("tensor_" + str(count))
            bps.byteps_push_pull(tensor, name="tensor_" + str(count))
            tensor.wait_to_read()
            print("tensor after push_pull:", tensor)

        print('test_byteps_push_pull passed')
예제 #2
0
    def test_byteps_push_pull(self):
        """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
        dtypes = ['float16', 'float32', 'float64']
        dims = [1, 2, 3]
        count = 0
        ctx = self._current_context()
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            # MXNet uses gpu_id as part of the seed, so to get identical seeds
            # we must set a context.
            mx.random.seed(10 + 10 * bps.rank(), ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            input = tensor.asnumpy()

            bps.byteps_declare_tensor("tensor_" + str(count))
            bps.byteps_push_pull(tensor, name="tensor_" + str(count))
            tensor.wait_to_read()
            output = tensor.asnumpy()
            assert np.allclose(input, output)
            count += 1

        print('test_byteps_push_pull passed')
예제 #3
0
    def test_byteps_push_pull_inplace(self):
        """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
        size = bps.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 200
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            multiplied = tensor.copy()
            bps.byteps_declare_tensor("tensor_" + str(count))
            bps.byteps_push_pull(tensor, name="tensor_" + str(count))
            max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            if max_difference > threshold:
                print("self", count, dtype, dim, max_difference, threshold)
                print("tensor", bps.rank(), tensor)
                print("multiplied", bps.rank(), multiplied)
            assert max_difference <= threshold, 'bps.byteps_push_pull produces \
                                                 incorrect results for self'

        print('test_byteps_push_pull_inplace passed')
예제 #4
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).shard(nworker, rank).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).shard(nworker, rank).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        params = net.collect_params()

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k,
            "fp16": opt.fp16_pushpull
        }

        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

        trainer = bps.DistributedTrainer(params,
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0
        bps.byteps_declare_tensor("acc")
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, train_acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, train_acc = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info(
                '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
                (epoch, throughput, time.time() - tic, trainer.learning_rate))

            name, val_acc = test(ctx, val_data)
            acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0])
            bps.byteps_push_pull(acc, name="acc", is_average=False)
            acc /= bps.size()
            train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar()
            if bps.rank() == 0:
                logger.info('[Epoch %d] training: %s=%f' %
                            (epoch, name, train_acc))
                logger.info('[Epoch %d] validation: %s=%f' %
                            (epoch, name, val_acc))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar100-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar100-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))
예제 #5
0
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1}
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info('Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num, 10, 14, profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks, average=False, name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks, is_average=False,
                                         name="local_num_masks", priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss,
                              0, step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
#                if is_master_node:
#                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
#                    if local_rank == 0:
#                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                                         1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers)

            batch_num += 1

#    if is_master_node:
#        save_states(step_num, trainer, args.ckpt_dir, local_rank)
#        if local_rank == 0:
#            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
예제 #6
0
    "k": args.k,
    "fp16": args.fp16_pushpull
}

trainer = bps.DistributedTrainer(params,
                                 "sgd",
                                 optimizer_params,
                                 compression_params=compression_params)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

total_time = 0
# Train model
bps.byteps_declare_tensor("acc")
for epoch in range(args.epochs):
    tic = time.time()
    metric.reset()
    for i, batch in enumerate(train_data):
        data = batch[0].as_in_context(context)
        label = batch[1].as_in_context(context)

        with autograd.record():
            output = model(data)
            loss = loss_fn(output, label)

        loss.backward()
        trainer.step(args.batch_size)
        metric.update([label], [output])