Esempio n. 1
0
    def test_byteps_trainer_param_order(self):
        size = bps.size()
        dtypes = self.filter_supported_types(['float32'])
        dims = [1]
        ctx = self._current_context()
        net = mx.gluon.nn.Sequential()
        # layers may be added in a random order for all workers
        layers = {'ones_': 1, 'zeros_': 0}
        for name, init in layers.items():
            net.add(
                mx.gluon.nn.Dense(10,
                                  in_units=10,
                                  weight_initializer=mx.init.Constant(init),
                                  use_bias=False,
                                  prefix=name))
        params = net.collect_params()
        net.initialize()
        trainer = bps.DistributedTrainer(params, 'sgd')
        trainer._init_params()
        # check the result of bps_broadcast
        for name, init in layers.items():
            weight = params[name + 'weight'].data()[0].asnumpy()
            expected = np.full(shape=weight.shape,
                               fill_value=init,
                               dtype=weight.dtype)
            assert np.array_equal(weight, expected), (weight, expected)

        print('test_byteps_trainer_param_order passed')
Esempio n. 2
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).shard(nworker, rank).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).shard(nworker, rank).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        params = net.collect_params()

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k,
            "fp16": opt.fp16_pushpull
        }

        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

        trainer = bps.DistributedTrainer(params,
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0
        bps.byteps_declare_tensor("acc")
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, train_acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, train_acc = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info(
                '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
                (epoch, throughput, time.time() - tic, trainer.learning_rate))

            name, val_acc = test(ctx, val_data)
            acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0])
            bps.byteps_push_pull(acc, name="acc", is_average=False)
            acc /= bps.size()
            train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar()
            if bps.rank() == 0:
                logger.info('[Epoch %d] training: %s=%f' %
                            (epoch, name, train_acc))
                logger.info('[Epoch %d] validation: %s=%f' %
                            (epoch, name, val_acc))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar100-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar100-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))
Esempio n. 3
0
    def test_topk(self, k):
        ctx = mx.gpu(0)
        net = get_model("resnet18_v2")
        net.initialize(mx.init.Xavier(), ctx=ctx)
        net.summary(nd.ones((1, 3, 224, 224), ctx=ctx))

        # hyper-params
        batch_size = 32
        optimizer_params = {'momentum': 0, 'wd': 0,
                            'learning_rate': 0.01}

        compression_params = {
            "compressor": "topk",
            "k": k,
        }

        trainer = bps.DistributedTrainer(net.collect_params(
        ), "sgd", optimizer_params, compression_params=compression_params)

        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        train_data = fake_data(batch_size=batch_size)

        params = {}

        for i, param in enumerate(trainer._params):
            if param.grad_req != 'null':
                params[i] = param._data[0].asnumpy()

        for it, batch in tqdm(enumerate(train_data)):
            data = batch[0].as_in_context(ctx)
            label = batch[1].as_in_context(ctx)

            with autograd.record():
                output = net(data)
                loss = loss_fn(output, label)

            loss.backward()

            gs = {}
            xs = {}

            for i, param in enumerate(trainer._params):
                if param.grad_req != 'null':
                    gs[i] = param._grad[0].asnumpy()
                    xs[i] = param._data[0].asnumpy()

            trainer.step(batch_size)

            for i, param in enumerate(trainer._params):
                if param.grad_req != "null":
                    g = gs[i] / (batch_size * bps.size())
                    c = topk(g, k)

                    cs = topk(c, k)
                    c = cs

                    params[i] -= optimizer_params["learning_rate"] * c

        cnt = 0
        tot = 0
        for i, param in enumerate(trainer._params):
            if param.grad_req != "null":
                x = param._data[0].asnumpy()
                tot += len(x.flatten())
                if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps):
                    diff = np.abs(x.flatten() - params[i].flatten())
                    idx = np.where(diff > np.finfo(np.float32).eps)
                    cnt += len(idx[0])

        assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot)
Esempio n. 4
0
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1}
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info('Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num, 10, 14, profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks, average=False, name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks, is_average=False,
                                         name="local_num_masks", priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss,
                              0, step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
#                if is_master_node:
#                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
#                    if local_rank == 0:
#                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                                         1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers)

            batch_num += 1

#    if is_master_node:
#        save_states(step_num, trainer, args.ckpt_dir, local_rank)
#        if local_rank == 0:
#            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Esempio n. 5
0
    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k
        }

        trainer = bps.DistributedTrainer(net.collect_params(),
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        # bps.byteps_declare_tensor("acc")
        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [
                        nd.softmax(
                            teacher(X.astype(opt.dtype, copy=False)) /
                            opt.temperature) for X in data
                    ]

                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False)) for X in data
                    ]
                    if distillation:
                        loss = [
                            L(yhat.astype('float32', copy=False),
                              y.astype('float32', copy=False),
                              p.astype('float32', copy=False))
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                if opt.mixup:
                    output_softmax = [
                        nd.SoftmaxActivation(out.astype('float32', copy=False))
                        for out in outputs
                    ]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f\ttime=%f'
                        % (epoch, i, batch_size * nworker * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate,
                           time.time() - btic))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            # acc = mx.nd.array([train_metric_score, err_top1_val, err_top5_val],
            #                   ctx=ctx[0])
            # bps.byteps_push_pull(acc, name="acc", is_average=False)
            # acc /= bps.size()
            # train_metric_score, err_top1_val, err_top5_val = acc[0].asscalar(
            # ), acc[1].asscalar(), acc[2].asscalar()

            # if bps.rank() == 0:
            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))
Esempio n. 6
0
def train(args):
    _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    level = logging.DEBUG if args.verbose else logging.INFO
    logging_config(
        args.ckpt_dir,
        name='pretrain_bert_' + str(rank),  # avoid race
        level=level,
        console=(local_rank == 0))
    logging.info(args)
    logging.debug('Random seed set to {}'.format(args.seed))
    set_seed(args.seed)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(args.num_buckets,
                                                    num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args.ckpt_dir, 'Loading',
                          ctx_l)
    else:
        model.initialize(ctx=ctx_l)
    model.hybridize()

    if args.raw:
        get_dataset_fn = functools.partial(
            get_pretrain_data_text,
            max_seq_length=args.max_seq_length,
            short_seq_prob=args.short_seq_prob,
            masked_lm_prob=args.masked_lm_prob,
            max_predictions_per_seq=args.max_predictions_per_seq,
            whole_word_mask=args.whole_word_mask,
            random_next_sentence=args.random_next_sentence,
            tokenizer=tokenizer,
            circle_length=args.circle_length,
            repeat=args.repeat,
            dataset_cached=args.dataset_cached,
            num_max_dataset_cached=args.num_max_dataset_cached)
    else:
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data,
                                args.batch_size,
                                shuffle=True,
                                num_buckets=args.num_buckets,
                                vocab=tokenizer.vocab,
                                num_parts=num_workers,
                                part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'

    num_steps = args.num_steps
    warmup_steps = int(num_steps * args.warmup_ratio)
    log_interval = args.log_interval
    save_interval = args.ckpt_interval
    logging.info(
        '#Total Training Steps={}, Warmup Steps={}, Save Interval={}'.format(
            num_steps, warmup_steps, save_interval))
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd}
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': 1e-6,
            'correct_bias': False,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    elif args.comm_backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        states_option(args.start_step, trainer, args.ckpt_dir, local_rank,
                      'Loading')

    # backend specific implementation
    if args.comm_backend == 'byteps':
        trainer._init_params()
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # prepare the loss function
    nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    nsp_loss_fn.hybridize()
    mlm_loss_fn.hybridize()

    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps

    running_mlm_loss, running_nsp_loss = 0., 0.
    running_num_tks = 0

    train_start_time = time.time()
    tic = time.time()
    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_steps:
        step_num += 1
        for _ in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            mlm_loss_l = []
            nsp_loss_l = []
            loss_l = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for sample, ctx in zip(sample_l, ctx_l):
                # prepare data
                (input_id, masked_id, masked_position, masked_weight, \
                    next_sentence_label, segment_id, valid_length) = sample
                input_id = input_id.as_in_ctx(ctx)
                masked_id = masked_id.as_in_ctx(ctx)
                masked_position = masked_position.as_in_ctx(ctx)
                masked_weight = masked_weight.as_in_ctx(ctx)
                next_sentence_label = next_sentence_label.as_in_ctx(ctx)
                segment_id = segment_id.as_in_ctx(ctx)
                valid_length = valid_length.as_in_ctx(ctx)

                with mx.autograd.record():
                    _, _, nsp_score, mlm_scores = model(
                        input_id, segment_id, valid_length, masked_position)
                    denominator = (masked_weight.sum() +
                                   1e-8) * num_accumulated * len(ctx_l)
                    mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1))
                    masked_id_r = masked_id.reshape((-1, ))
                    mlm_loss = mlm_loss_fn(mlm_scores_r, masked_id_r,
                                           masked_weight.reshape(
                                               (-1, 1))).sum() / denominator
                    denominator = num_accumulated * len(ctx_l)
                    nsp_loss = nsp_loss_fn(
                        nsp_score, next_sentence_label).mean() / denominator
                    mlm_loss_l.append(mlm_loss)
                    nsp_loss_l.append(nsp_loss)
                    loss_l.append(mlm_loss + nsp_loss)
                    mask_label_list.append(masked_id_r)
                    mask_pred_list.append(mlm_scores_r)
                    mask_weight_list.append(masked_weight.reshape((-1, )))
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(nsp_score)

                running_num_tks += valid_length.sum().as_in_ctx(mx.cpu())

            for loss in loss_l:
                loss.backward()

            running_mlm_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l
            ]).asnumpy().item()
            running_nsp_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l
            ]).asnumpy().item()
            mlm_metric.update(mask_label_list, mask_pred_list,
                              mask_weight_list)
            nsp_metric.update(ns_label_list, ns_pred_list)
        # update
        trainer.allreduce_grads()

        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_workers)
        total_norm = total_norm / num_workers

        # update learning rate
        scheduled_lr = args.lr
        if step_num <= warmup_steps:
            scheduled_lr *= step_num / warmup_steps
        else:
            offset = (num_steps - step_num) / (num_steps - warmup_steps)
            scheduled_lr *= max(offset, 0)
        trainer.set_learning_rate(scheduled_lr)

        if args.comm_backend == 'horovod' or args.comm_backend == 'byteps':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers.
            # *num_workers* of Horovod is the number of GPUs.
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1.
            # *num_workers* of Trainer is the number of machines.
            trainer.update(num_workers, ignore_stale_grad=True)

        if num_accumulated > 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        # saving
        if step_num % save_interval == 0 or step_num >= num_steps:
            states_option(step_num, trainer, args.ckpt_dir, local_rank,
                          'Saving')
            if local_rank == 0:
                parameters_option(step_num, model, args.ckpt_dir, 'Saving')
        # logging
        if step_num % log_interval == 0:
            running_mlm_loss /= log_interval
            running_nsp_loss /= log_interval
            toc = time.time()
            logging.info(
                '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, '
                ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,'
                ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format(
                    step_num, running_mlm_loss, running_nsp_loss,
                    mlm_metric.get()[1],
                    nsp_metric.get()[1], trainer.learning_rate, total_norm,
                    toc - tic,
                    running_num_tks.asnumpy().item() / (toc - tic) / 1000,
                    (num_steps - step_num) /
                    (step_num / (toc - train_start_time)) / 3600))
            mlm_metric.reset()
            nsp_metric.reset()
            tic = time.time()

            running_mlm_loss = 0
            running_nsp_loss = 0
            running_num_tks = 0

    logging.info('Finish training step: %d', step_num)

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time -
                                              train_start_time))

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.ckpt_dir, model_name)
        final_save(model, save_dir, tokenizer, cfg)
Esempio n. 7
0
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()

# BytePS: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
    bps.broadcast_parameters(params, root_rank=0)

# BytePS: create DistributedTrainer, a subclass of gluon.Trainer
optimizer_params = {
    'momentum': args.momentum,
    'learning_rate': args.lr * num_workers
}
trainer = bps.DistributedTrainer(params, "sgd", optimizer_params)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
    tic = time.time()
    metric.reset()
    for i, batch in enumerate(train_data):
        data = batch[0].as_in_context(context)
        label = batch[1].as_in_context(context)

        with autograd.record():
            output = model(data)
    'momentum': args.momentum,
    'wd': args.wd,
    'learning_rate': args.lr * num_workers
}

compression_params = {
    "compressor": args.compressor,
    "ef": args.ef,
    "momentum": args.compress_momentum,
    "scaling": args.scaling,
    "k": args.k,
    "fp16": args.fp16_pushpull
}

trainer = bps.DistributedTrainer(params,
                                 "sgd",
                                 optimizer_params,
                                 compression_params=compression_params)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

total_time = 0
# Train model
bps.byteps_declare_tensor("acc")
for epoch in range(args.epochs):
    tic = time.time()
    metric.reset()
    for i, batch in enumerate(train_data):
        data = batch[0].as_in_context(context)
        label = batch[1].as_in_context(context)
Esempio n. 9
0
    def test_dithering(self, k, ptype, ntype, seed):
        ctx = mx.gpu(0)
        net = get_model("resnet18_v2")
        net.initialize(mx.init.Xavier(), ctx=ctx)
        net.summary(nd.ones((1, 3, 224, 224), ctx=ctx))

        # hyper-params
        batch_size = 32
        optimizer_params = {'momentum': 0, 'wd': 0, 'learning_rate': 0.01}

        compression_params = {
            "compressor": "dithering",
            "k": k,
            "partition": ptype,
            "normalize": ntype,
            "seed": seed
        }
        print(compression_params)

        trainer = bps.DistributedTrainer(net.collect_params(),
                                         "sgd",
                                         optimizer_params,
                                         compression_params=compression_params)

        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        train_data = fake_data(batch_size=batch_size)

        params = {}
        rngs = {}
        rngs_s = {}

        for i, param in enumerate(trainer._params):
            if param.grad_req != 'null':
                params[i] = param._data[0].asnumpy()
                rngs[i] = np.array([seed, seed], dtype=np.uint64)
                rngs_s[i] = np.array([seed, seed], dtype=np.uint64)

        for it, batch in tqdm(enumerate(train_data)):
            data = batch[0].as_in_context(ctx)
            label = batch[1].as_in_context(ctx)

            with autograd.record():
                output = net(data)
                loss = loss_fn(output, label)

            loss.backward()

            gs = {}
            xs = {}

            for i, param in enumerate(trainer._params):
                if param.grad_req != 'null':
                    gs[i] = param._grad[0].asnumpy()
                    xs[i] = param._data[0].asnumpy()

            trainer.step(batch_size)

            for i, param in enumerate(trainer._params):
                if param.grad_req != "null":
                    g = gs[i] / (batch_size * bps.size())
                    c = dithering(g, k, rngs[i], ptype, ntype)

                    cs = dithering(c, k, rngs_s[i], ptype, ntype)
                    c = cs

                    params[i] -= optimizer_params["learning_rate"] * c

                    np_g = c.flatten()
                    mx_g = param._grad[0].asnumpy().flatten()
                    if not np.allclose(
                            np_g, mx_g, atol=np.finfo(np.float32).eps):
                        diff = np.abs(np_g - mx_g)
                        print("np", np_g)
                        print("mx", mx_g)
                        print("diff", diff)
                        print("max diff", np.max(diff))
                        idx = np.nonzero(diff > 1e-5)
                        print("idx", idx, np_g[idx], mx_g[idx])
                        input()

        cnt = 0
        tot = 0
        for i, param in enumerate(trainer._params):
            if param.grad_req != "null":
                x = param._data[0].asnumpy()
                tot += len(x.flatten())
                if not np.allclose(params[i], x, atol=np.finfo(
                        np.float32).eps):
                    diff = np.abs(x.flatten() - params[i].flatten())
                    idx = np.where(diff > np.finfo(np.float32).eps)
                    cnt += len(idx[0])

        assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt / tot)
Esempio n. 10
0
loss.initialize(ctx=ctx)
loss.hybridize(static_alloc=True)

# trainer
num_batches = train_size // batch_size
train_params = net.collect_params()
train_params.update(loss.params)
lr_scheduler = LRScheduler(mode="cosine",
                           base_lr=lr,
                           target_lr=1e-8,
                           iters_per_epoch=num_batches,
                           nepochs=epochs)
trainer = bps.DistributedTrainer(
    train_params, 'sgd', {
        'momentum': momentum,
        'wd': wd,
        "lr_scheduler": lr_scheduler,
        "multi_precision": True,
        "learning_rate": lr
    })

# metrics
loss_mtc, acc_mtc = mx.metric.Loss(), mx.metric.Accuracy()
tic = time.time()
btic = time.time()

# train loop
for epoch in range(epochs):
    for i, batch in enumerate(train_iter):
        it = epoch * num_batches + i
        data = batch[0].data[0]
        label = batch[0].label[0]
Esempio n. 11
0
                      target_lr=args.lr,
                      iters_per_epoch=num_batches)
cosine = LRScheduler(mode="cosine",
                     base_lr=args.lr,
                     target_lr=1e-8,
                     iters_per_epoch=num_batches,
                     nepochs=epochs,
                     offset=args.warmup_epochs * num_batches)
lr_scheduler = LRSequential([warm_up, cosine])

params = net.collect_params()

trainer = bps.DistributedTrainer(
    params, 'nag', {
        'learning_rate': args.lr,
        'momentum': args.momentum,
        'wd': args.wd,
        "multi_precision": True,
        "lr_scheduler": lr_scheduler
    })

cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=not use_mix_up)

# set log output
train_mode = 'MixUP' if use_mix_up else 'Vanilla'
logger = logging.getLogger('TRAIN')
if rank == 0:
    logger.setLevel("INFO")
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(
        logging.FileHandler(
            os.path.join(