Beispiel #1
0
    def _init_trainer(self):
        if self.last_train is None:
            raise RuntimeError('Cannot init trainer without knowing the size of training data')
        if isinstance(self.last_train, pd.DataFrame):
            train_size = len(self.last_train)
        elif isinstance(self.last_train, int):
            train_size = self.last_train
        else:
            raise ValueError("Unknown type of self.last_train: {}".format(type(self.last_train)))

        if self._cfg.train.lr_decay_period > 0:
            lr_decay_epoch = list(range(self._cfg.train.lr_decay_period,
                                        self._cfg.train.epochs,
                                        self._cfg.train.lr_decay_period))
        else:
            lr_decay_epoch = [int(i) for i in self._cfg.train.lr_decay_epoch]
        lr_decay_epoch = [e - self._cfg.train.warmup_epochs for e in lr_decay_epoch]
        num_batches = train_size // self._cfg.train.batch_size
        lr_scheduler = LRSequential([
            LRScheduler('linear', base_lr=0, target_lr=self._cfg.train.lr,
                        nepochs=self._cfg.train.warmup_epochs, iters_per_epoch=num_batches),
            LRScheduler(self._cfg.train.lr_mode, base_lr=self._cfg.train.lr,
                        nepochs=self._cfg.train.epochs - self._cfg.train.warmup_epochs,
                        iters_per_epoch=num_batches,
                        step_epoch=lr_decay_epoch,
                        step_factor=self._cfg.train.lr_decay, power=2),
        ])

        if self._cfg.horovod:
            hvd.broadcast_parameters(self.net.collect_params(), root_rank=0)
            self.trainer = hvd.DistributedTrainer(
                self.net.collect_params(), 'sgd',
                {'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum, 'lr_scheduler': lr_scheduler})
        else:
            self.trainer = gluon.Trainer(
                self.net.collect_params(), 'sgd',
                {'wd': self._cfg.train.wd, 'momentum': self._cfg.train.momentum, 'lr_scheduler': lr_scheduler},
                kvstore='local', update_on_kvstore=(False if self._cfg.yolo3.amp else None))

        if self._cfg.yolo3.amp:
            amp.init_trainer(self.trainer)
Beispiel #2
0
    def _init_trainer(self):
        if self._cfg.horovod:
            hvd.broadcast_parameters(self.net.collect_params(), root_rank=0)
            self.trainer = hvd.DistributedTrainer(
                self.net.collect_params(), 'sgd', {
                    'learning_rate': self._cfg.train.lr,
                    'wd': self._cfg.train.wd,
                    'momentum': self._cfg.train.momentum
                })
        else:
            self.trainer = gluon.Trainer(
                self.net.collect_params(),
                'sgd', {
                    'learning_rate': self._cfg.train.lr,
                    'wd': self._cfg.train.wd,
                    'momentum': self._cfg.train.momentum
                },
                update_on_kvstore=(False if self._cfg.ssd.amp else None))

        if self._cfg.ssd.amp:
            amp.init_trainer(self.trainer)
Beispiel #3
0
    def _init_trainer(self):
        kv_store_type = 'device' if (self._cfg.faster_rcnn.amp and 'nccl' in self._cfg.kv_store) \
            else self._cfg.kv_store
        kv = mx.kvstore.create(kv_store_type)
        optimizer_params = {'learning_rate': self._cfg.train.lr, 'wd': self._cfg.train.wd,
                            'momentum': self._cfg.train.momentum}
        if self._cfg.faster_rcnn.amp:
            optimizer_params['multi_precision'] = True
        if self._cfg.horovod:
            hvd.broadcast_parameters(self.net.collect_params(), root_rank=0)
            self.trainer = hvd.DistributedTrainer(
                self.net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
                'sgd',
                optimizer_params)
        else:
            self.trainer = gluon.Trainer(
                self.net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
                'sgd',
                optimizer_params,
                update_on_kvstore=(False if self._cfg.faster_rcnn.amp else None), kvstore=kv)

        if self._cfg.faster_rcnn.amp:
            self._cfg.init_trainer(self.trainer)
def train(args):
    _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    level = logging.DEBUG if args.verbose else logging.INFO
    logging_config(
        args.ckpt_dir,
        name='pretrain_bert_' + str(rank),  # avoid race
        level=level,
        console=(local_rank == 0))
    logging.info(args)
    logging.debug('Random seed set to {}'.format(args.seed))
    set_seed(args.seed)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(args.num_buckets,
                                                    num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args.ckpt_dir, 'Loading',
                          ctx_l)
    else:
        model.initialize(ctx=ctx_l)
    model.hybridize()

    if args.raw:
        get_dataset_fn = functools.partial(
            get_pretrain_data_text,
            max_seq_length=args.max_seq_length,
            short_seq_prob=args.short_seq_prob,
            masked_lm_prob=args.masked_lm_prob,
            max_predictions_per_seq=args.max_predictions_per_seq,
            whole_word_mask=args.whole_word_mask,
            random_next_sentence=args.random_next_sentence,
            tokenizer=tokenizer,
            circle_length=args.circle_length,
            repeat=args.repeat,
            dataset_cached=args.dataset_cached,
            num_max_dataset_cached=args.num_max_dataset_cached)
    else:
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data,
                                args.batch_size,
                                shuffle=True,
                                num_buckets=args.num_buckets,
                                vocab=tokenizer.vocab,
                                num_parts=num_workers,
                                part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'

    num_steps = args.num_steps
    warmup_steps = int(num_steps * args.warmup_ratio)
    log_interval = args.log_interval
    save_interval = args.ckpt_interval
    logging.info(
        '#Total Training Steps={}, Warmup Steps={}, Save Interval={}'.format(
            num_steps, warmup_steps, save_interval))
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd}
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': 1e-6,
            'correct_bias': False,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    elif args.comm_backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        states_option(args.start_step, trainer, args.ckpt_dir, local_rank,
                      'Loading')

    # backend specific implementation
    if args.comm_backend == 'byteps':
        trainer._init_params()
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # prepare the loss function
    nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    nsp_loss_fn.hybridize()
    mlm_loss_fn.hybridize()

    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps

    running_mlm_loss, running_nsp_loss = 0., 0.
    running_num_tks = 0

    train_start_time = time.time()
    tic = time.time()
    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_steps:
        step_num += 1
        for _ in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            mlm_loss_l = []
            nsp_loss_l = []
            loss_l = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for sample, ctx in zip(sample_l, ctx_l):
                # prepare data
                (input_id, masked_id, masked_position, masked_weight, \
                    next_sentence_label, segment_id, valid_length) = sample
                input_id = input_id.as_in_ctx(ctx)
                masked_id = masked_id.as_in_ctx(ctx)
                masked_position = masked_position.as_in_ctx(ctx)
                masked_weight = masked_weight.as_in_ctx(ctx)
                next_sentence_label = next_sentence_label.as_in_ctx(ctx)
                segment_id = segment_id.as_in_ctx(ctx)
                valid_length = valid_length.as_in_ctx(ctx)

                with mx.autograd.record():
                    _, _, nsp_score, mlm_scores = model(
                        input_id, segment_id, valid_length, masked_position)
                    denominator = (masked_weight.sum() +
                                   1e-8) * num_accumulated * len(ctx_l)
                    mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1))
                    masked_id_r = masked_id.reshape((-1, ))
                    mlm_loss = mlm_loss_fn(mlm_scores_r, masked_id_r,
                                           masked_weight.reshape(
                                               (-1, 1))).sum() / denominator
                    denominator = num_accumulated * len(ctx_l)
                    nsp_loss = nsp_loss_fn(
                        nsp_score, next_sentence_label).mean() / denominator
                    mlm_loss_l.append(mlm_loss)
                    nsp_loss_l.append(nsp_loss)
                    loss_l.append(mlm_loss + nsp_loss)
                    mask_label_list.append(masked_id_r)
                    mask_pred_list.append(mlm_scores_r)
                    mask_weight_list.append(masked_weight.reshape((-1, )))
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(nsp_score)

                running_num_tks += valid_length.sum().as_in_ctx(mx.cpu())

            for loss in loss_l:
                loss.backward()

            running_mlm_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l
            ]).asnumpy().item()
            running_nsp_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l
            ]).asnumpy().item()
            mlm_metric.update(mask_label_list, mask_pred_list,
                              mask_weight_list)
            nsp_metric.update(ns_label_list, ns_pred_list)
        # update
        trainer.allreduce_grads()

        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_workers)
        total_norm = total_norm / num_workers

        # update learning rate
        scheduled_lr = args.lr
        if step_num <= warmup_steps:
            scheduled_lr *= step_num / warmup_steps
        else:
            offset = (num_steps - step_num) / (num_steps - warmup_steps)
            scheduled_lr *= max(offset, 0)
        trainer.set_learning_rate(scheduled_lr)

        if args.comm_backend == 'horovod' or args.comm_backend == 'byteps':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers.
            # *num_workers* of Horovod is the number of GPUs.
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1.
            # *num_workers* of Trainer is the number of machines.
            trainer.update(num_workers, ignore_stale_grad=True)

        if num_accumulated > 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        # saving
        if step_num % save_interval == 0 or step_num >= num_steps:
            states_option(step_num, trainer, args.ckpt_dir, local_rank,
                          'Saving')
            if local_rank == 0:
                parameters_option(step_num, model, args.ckpt_dir, 'Saving')
        # logging
        if step_num % log_interval == 0:
            running_mlm_loss /= log_interval
            running_nsp_loss /= log_interval
            toc = time.time()
            logging.info(
                '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, '
                ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,'
                ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format(
                    step_num, running_mlm_loss, running_nsp_loss,
                    mlm_metric.get()[1],
                    nsp_metric.get()[1], trainer.learning_rate, total_norm,
                    toc - tic,
                    running_num_tks.asnumpy().item() / (toc - tic) / 1000,
                    (num_steps - step_num) /
                    (step_num / (toc - train_start_time)) / 3600))
            mlm_metric.reset()
            nsp_metric.reset()
            tic = time.time()

            running_mlm_loss = 0
            running_nsp_loss = 0
            running_num_tks = 0

    logging.info('Finish training step: %d', step_num)

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time -
                                              train_start_time))

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.ckpt_dir, model_name)
        final_save(model, save_dir, tokenizer, cfg)
Beispiel #5
0
def train():
    """Training function."""
    segment = 'train'  #if not args.debug else 'dev'
    log.info('Loading %s data...', segment)
    if version_2:
        train_data = SQuAD(segment, version='2.0')
    else:
        train_data = SQuAD(segment, version='1.1')
    if args.debug:
        sampled_data = [train_data[i] for i in range(0, 10000)]
        train_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in Train data:{}'.format(len(train_data)))
    train_data_transform = preprocess_dataset(
        tokenizer,
        train_data,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        input_features=True)

    log.info('The number of examples after preprocessing:{}'.format(
        len(train_data_transform)))

    sampler = nlp.data.SplitSampler(len(train_data_transform),
                                    num_parts=size,
                                    part_index=rank,
                                    even_size=True)
    num_train_examples = len(sampler)
    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_size=batch_size,
                                                num_workers=4,
                                                sampler=sampler)

    log.info('Start Training')

    optimizer_params = {'learning_rate': lr}
    param_dict = net.collect_params()
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.dtype == 'float16':
        amp.init_trainer(trainer)

    step_size = batch_size * accumulate if accumulate else batch_size
    num_train_steps = int(num_train_examples / step_size * args.epochs)
    if args.training_steps:
        num_train_steps = args.training_steps

    num_warmup_steps = int(num_train_steps * warmup_ratio)

    def set_new_lr(step_num, batch_id):
        """set new learning rate"""
        # set grad to zero for gradient accumulation
        if accumulate:
            if batch_id % accumulate == 0:
                step_num += 1
        else:
            step_num += 1
        # learning rate schedule
        # Notice that this learning rate scheduler is adapted from traditional linear learning
        # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps
        if step_num < num_warmup_steps:
            new_lr = lr * step_num / num_warmup_steps
        else:
            offset = (step_num - num_warmup_steps) * lr / \
                (num_train_steps - num_warmup_steps)
            new_lr = lr - offset
        trainer.set_learning_rate(new_lr)
        return step_num

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if accumulate:
        for p in params:
            p.grad_req = 'add'
    net.collect_params().zero_grad()

    epoch_tic = time.time()

    total_num = 0
    log_num = 0
    batch_id = 0
    step_loss = 0.0
    tic = time.time()
    step_num = 0

    tic = time.time()
    while step_num < num_train_steps:
        for _, data in enumerate(train_dataloader):
            # set new lr
            step_num = set_new_lr(step_num, batch_id)
            # forward and backward
            _, inputs, token_types, valid_length, start_label, end_label = data
            num_labels = len(inputs)
            log_num += num_labels
            total_num += num_labels

            with mx.autograd.record():
                out = net(inputs.as_in_context(ctx),
                          token_types.as_in_context(ctx),
                          valid_length.as_in_context(ctx).astype('float32'))

                loss = loss_function(out, [
                    start_label.as_in_context(ctx).astype('float32'),
                    end_label.as_in_context(ctx).astype('float32')
                ]).sum() / num_labels

                if accumulate:
                    loss = loss / accumulate
                if args.dtype == 'float16':
                    with amp.scale_loss(loss, trainer) as l:
                        mx.autograd.backward(l)
                        norm_clip = 1.0 * size * trainer._amp_loss_scaler.loss_scale
                else:
                    mx.autograd.backward(loss)
                    norm_clip = 1.0 * size

            # update
            if not accumulate or (batch_id + 1) % accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, norm_clip)
                trainer.update(1)
                if accumulate:
                    param_dict.zero_grad()

            if args.comm_backend == 'horovod':
                step_loss += hvd.allreduce(loss, average=True).asscalar()
            else:
                step_loss += loss.asscalar()

            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                log.info('Batch: {}/{}, Loss={:.4f}, lr={:.7f} '
                         'Thoughput={:.2f} samples/s'.format(
                             batch_id % len(train_dataloader),
                             len(train_dataloader), step_loss / log_interval,
                             trainer.learning_rate, log_num / (toc - tic)))
                tic = time.time()
                step_loss = 0.0
                log_num = 0

            if step_num >= num_train_steps:
                break
            batch_id += 1

        log.info('Finish training step: %d', step_num)
        epoch_toc = time.time()
        log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
            epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    if rank == 0:
        net.save_parameters(os.path.join(output_dir, 'net.params'))
Beispiel #6
0
    def train(ctx):
        if opt.resume_params is '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        # Horovod: fetch and broadcast parameters
        params = net.collect_params()
        if params is not None:
            hvd.broadcast_parameters(params, root_rank=0)

        trainer = hvd.DistributedTrainer(params, optimizer, optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False)) for X in data
                    ]
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                if opt.label_smoothing:
                    train_metric.update(hard_label, outputs)
                else:
                    train_metric.update(label, outputs)

                if opt.log_interval and (i + 1) % opt.log_interval == 0:
                    train_metric_name, train_metric_score = train_metric.get()
                    if rank == 0:
                        logger.info(
                            'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                            % (epoch, i,
                               num_gpus * batch_size * opt.log_interval /
                               (time.time() - btic), train_metric_name,
                               train_metric_score, trainer.learning_rate))
                    btic = time.time()

            elapsed = time.time() - tic
            train_metric_name, train_metric_score = train_metric.get()
            logger.info(
                'Epoch[%d] Rank [%d] Batch[%d]\tTraining: %s=%f' %
                (epoch, rank, i, train_metric_name, train_metric_score))

            if opt.eval_frequency and (epoch + 1) % opt.eval_frequency == 0:
                evaluate(epoch)

            if rank == 0:
                throughput = int(num_gpus * batch_size * i / elapsed)
                logger.info('Epoch [%d] Speed: %d samples/sec\ttime cost: %f' %
                            (epoch, throughput, elapsed))
            #if err_top1_val < best_val_score:
            #    best_val_score = err_top1_val
            #    net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
            #    trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            #if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
            #    net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
            #    trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        #if save_frequency and save_dir:
        #    net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
        #    trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

        # Evaluate performance at the end of training
        evaluate(epoch)
def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Training function."""
    hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers}
    else:
        loss_scale_param = None
    trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in model.collect_params().values() if p.grad_req != 'null']
    param_dict = model.collect_params()

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break

            # create dummy data loader if needed
            if args.dummy_data_len:
                target_shape = (args.batch_size, args.dummy_data_len)
                dataloader = get_dummy_dataloader(dataloader, target_shape)

            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                    if accumulate > 1:
                        param_dict.zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 14, profile_name=args.profile + str(rank))

                # load data
                if args.use_avg_len:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip([ctx], data_batch)]
                else:
                    data_list = list(split_and_load(data_batch, [ctx]))
                data = data_list[0]

                # forward
                with mx.autograd.record():
                    (ls, ns_label, classified, masked_id, decoded, \
                     masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss,
                                                                   nsp_loss, vocab_size, args.dtype)
                    ls = ls / accumulate
                    # backward
                    if args.dtype == 'float16':
                        fp16_trainer.backward(ls)
                    else:
                        ls.backward()

                running_mlm_loss += ls1.as_in_context(mx.cpu())
                running_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_len.sum().as_in_context(mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    # step() performs 3 things:
                    # 1. allreduce gradients from all workers
                    # 2. checking the global_norm of gradients and clip them if necessary
                    # 3. averaging the gradients and apply updates
                    fp16_trainer.step(1, max_norm=1*num_workers)

                nsp_metric.update([ns_label], [classified])
                mlm_metric.update([masked_id], [decoded], [masked_weight])

                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    begin_time = time.time()
                    running_mlm_loss = running_nsp_loss = running_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
                    if is_master_node:
                        save_states(step_num, trainer, args.ckpt_dir, local_rank)
                        if local_rank == 0:
                            save_parameters(step_num, model, args.ckpt_dir)
                    if data_eval:
                        # eval data is always based on a fixed npz file.
                        dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1,
                                                             False, False, 1)
                        evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx],
                                 args.log_interval, args.dtype)

                batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train_gluon():
    def evaluate(epoch):
        if not args.use_rec:
            return

        val_data.reset()
        acc_top1 = mx.metric.Accuracy()
        acc_top5 = mx.metric.TopKAccuracy(5)
        for _, batch in enumerate(val_data):
            data, label = get_data_label(batch, context)
            output = net(data.astype(args.dtype, copy=False))
            acc_top1.update([label], [output])
            acc_top5.update([label], [output])

        top1_name, top1_acc = acc_top1.get()
        top5_name, top5_acc = acc_top5.get()
        logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f',
                     epoch, rank, top1_name, top1_acc, top5_name, top5_acc)

    # Hybridize and initialize model
    net.hybridize()
    net.initialize(initializer, ctx=context)

    # Horovod: fetch and broadcast parameters
    params = net.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Create optimizer
    optimizer_params = {
        'wd': args.wd,
        'momentum': args.momentum,
        'lr_scheduler': lr_sched
    }
    if args.dtype == 'float16':
        optimizer_params['multi_precision'] = True
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(
        params, opt, gradient_predivide_factor=args.gradient_predivide_factor)

    # Create loss function and train metric
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    metric = mx.metric.Accuracy()

    # Train model
    for epoch in range(args.num_epochs):
        tic = time.time()
        if args.use_rec:
            train_data.reset()
        metric.reset()

        btic = time.time()
        for nbatch, batch in enumerate(train_data, start=1):
            data, label = get_data_label(batch, context)
            with autograd.record():
                output = net(data.astype(args.dtype, copy=False))
                loss = loss_fn(output, label)
            loss.backward()
            trainer.step(batch_size)

            metric.update([label], [output])
            if args.log_interval and nbatch % args.log_interval == 0:
                name, acc = metric.get()
                logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f',
                             epoch, rank, nbatch, name, acc,
                             trainer.learning_rate)
                if rank == 0:
                    batch_speed = num_workers * batch_size * args.log_interval / (
                        time.time() - btic)
                    logging.info(
                        'Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec', epoch,
                        nbatch, batch_speed)
                btic = time.time()

        # Report metrics
        elapsed = time.time() - tic
        _, acc = metric.get()
        logging.info(
            'Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f',
            epoch, rank, nbatch, elapsed, acc)
        if rank == 0:
            epoch_speed = num_workers * batch_size * nbatch / elapsed
            logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch,
                         epoch_speed)

        # Evaluate performance
        if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0:
            evaluate(epoch)

        # Save model
        if args.save_frequency and (epoch + 1) % args.save_frequency == 0:
            net.export('%s-%d' % (args.model, rank), epoch=epoch)

    # Evaluate performance at the end of training
    evaluate(epoch)
def train(args):
    _, num_parts, rank, local_rank, _, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    if args.comm_backend == 'horovod':
        logging_config(
            args.save_dir,
            name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}',
            console=(rank == 0))
        logging.info(args)
    else:
        logging_config(args.save_dir, name='train_transformer', console=True)
        logging.info(args)
    use_amp = args.fp16
    if use_amp:
        from mxnet import amp
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    base_tgt_tokenizer = MosesTokenizer(args.tgt_lang)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus,
        args.train_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        max_src_length=args.max_src_length,
        max_tgt_length=args.max_tgt_length,
        pretokenized=not args.tokenize)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus,
        args.dev_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        pretokenized=not args.tokenize)
    tgt_detok_sentences = []
    tgt_raw_sentences = []
    with open(args.dev_tgt_corpus, 'r') as in_f:
        for line in in_f:
            tgt_detok_sentences.append(
                base_tgt_tokenizer.decode(
                    tgt_tokenizer.decode(line.split()).split()))
    with open(args.dev_tgt_raw_corpus, 'r') as in_f:
        for line in in_f:
            tgt_raw_sentences.append(line.strip())
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    val_samples = [
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ]
    if args.comm_backend == 'horovod':
        slice_begin = rank * (len(val_samples) // num_parts)
        slice_end = min((rank + 1) * (len(val_samples) // num_parts),
                        len(val_samples))
        data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end])
    else:
        data_val = gluon.data.SimpleDataset(val_samples)
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    param_dict = deduplicate_param_dict(model.collect_params())

    inference_model = TransformerInference(model=model)
    inference_model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()

    # Construct the beam search sampler
    scorer = BeamSearchScorer(alpha=args.lp_alpha,
                              K=args.lp_k,
                              from_logits=False)
    beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size,
                                            decoder=inference_model,
                                            vocab_size=len(tgt_vocab),
                                            eos_id=tgt_vocab.eos_id,
                                            scorer=scorer,
                                            stochastic=False,
                                            max_length_a=args.max_length_a,
                                            max_length_b=args.max_length_b)

    logging.info(beam_search_sampler)
    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # Construct the trainer
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    optimizer_params = {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.997,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler,
        'wd': args.wd
    }
    user_provided_ptimizer_params = json.loads(args.optimizer_params)
    optimizer_params.update(user_provided_ptimizer_params)

    if args.fp16:
        optimizer_params.update({'multi_precision': True})
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = gluon.Trainer(param_dict,
                                args.optimizer,
                                optimizer_params,
                                update_on_kvstore=False)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            shuffle=True,
            seed=args.seed)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    num_updates_per_epoch = int(
        math.ceil(
            len(train_batch_sampler) /
            (num_parts * len(ctx_l) * args.num_accumulated)))
    # Convert the batch sampler to multiple shards
    if num_parts > 1:
        train_batch_sampler = ShardedIterator(train_batch_sampler,
                                              num_parts=num_parts,
                                              part_index=rank,
                                              even_size=True,
                                              seed=args.seed + 1000 * rank)

    logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)
    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    model_averager = AverageSGDTracker(param_dict)
    log_start_time = time.time()
    num_params, num_fixed_params = None, None

    # TODO(sxjscience) Add a log metric class
    log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    # Maintain the denominator of the loss.
    log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_avg_grad_norm = 0
    log_iter_num = 0

    if local_rank == 0:
        writer = SummaryWriter(
            logdir=os.path.join(args.save_dir, 'tensorboard'))
    if use_amp:
        amp.init_trainer(trainer)
    train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l))
    # when args.epochs < 0, the model will keep training
    if args.epochs < 0:
        if args.max_update > 0:
            total_train_iters = args.max_update
            if args.num_averages > 0:
                assert args.num_averages <= total_train_iters // args.save_iterval_update
                avg_start_iter = (
                    total_train_iters // args.save_iterval_update -
                    args.num_averages) * args.save_iterval_update
            else:
                avg_start_iter = -1
        else:
            total_train_iters = np.inf
            avg_start_iter = -1
    else:
        total_train_iters = args.epochs * num_updates_per_epoch
        if args.num_averages > 0:
            assert args.num_averages <= args.epochs
            avg_start_iter = (args.epochs -
                              args.num_average) * num_updates_per_epoch
        else:
            avg_start_iter = -1

    # Here, we are manually setting up the scale to 1.0 because
    # in horovod, the scale can be the number of workers:
    # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141
    # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale().
    # A scale that is larger than 1.0 can be problematic in this case.
    trainer._scale = 1.0
    if args.max_num_tokens > 0:
        const_scale = args.max_num_tokens
    else:
        const_scale = 100

    train_start_time = time.time()

    for train_iter in range(total_train_iters):
        model.zero_grad()
        loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
        for i in range(args.num_accumulated):
            loss_l = []
            sample_data_l = next(train_multi_data_loader)
            for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)):
                src_token_ids, tgt_token_ids, src_valid_length,\
                tgt_valid_length, sample_ids = sample_data
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                src_wc, tgt_wc, bs = src_valid_length.sum(), \
                                     tgt_valid_length.sum(), src_token_ids.shape[0]
                log_wc_l[j] += src_wc + tgt_wc
                log_tgt_wc_l[j] += tgt_wc
                token_count = (tgt_valid_length - 1).sum()
                loss_denom_l[j] += token_count / const_scale
                log_avg_loss_denom_l[j] += token_count / const_scale
                with mx.autograd.record():
                    if model.layout == 'NT':
                        tgt_pred = model(src_token_ids, src_valid_length,
                                         tgt_token_ids[:, :-1],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids[:, 1:]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=1)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                    elif model.layout == 'TN':
                        tgt_pred = model(src_token_ids.T, src_valid_length,
                                         tgt_token_ids.T[:-1, :],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids.T[1:, :]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=0)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                log_avg_loss_l[j] += loss
            if use_amp:
                with mx.autograd.record():
                    with amp.scale_loss(loss_l, trainer) as amp_loss_l:
                        for loss in amp_loss_l:
                            loss.backward()
            else:
                with mx.autograd.record():
                    for loss in loss_l:
                        loss.backward()

        # Print the total number of parameters
        if local_rank == 0 and num_params is None:
            num_params, num_fixed_params = count_parameters(param_dict)
            logging.info(
                'Total Number of Parameters (not-fixed/fixed): {}/{}'.format(
                    num_params, num_fixed_params))
        # All-Reduce the gradient
        trainer.allreduce_grads()
        if args.comm_backend == 'horovod':
            # All-Reduce the loss denominator
            assert len(loss_denom_l) == 1
            loss_denom = hvd.allreduce(loss_denom_l[0],
                                       average=False).asnumpy()
        else:
            loss_denom = sum([ele.asnumpy() for ele in loss_denom_l])
        if use_amp:
            # We need to first unscale the gradient and then perform allreduce.
            grad_scale = trainer.amp_loss_scale * loss_denom
        else:
            grad_scale = loss_denom
        if args.max_grad_norm is not None:
            total_norm, ratio, is_finite\
                = clip_grad_global_norm(params, args.max_grad_norm * grad_scale)
            total_norm = total_norm / grad_scale
        else:
            total_norm = grad_global_norm(params)
            total_norm = total_norm / grad_scale
        log_avg_grad_norm += total_norm
        log_iter_num += 1

        trainer.update(loss_denom, ignore_stale_grad=True)

        if avg_start_iter > 0 and train_iter >= avg_start_iter:
            model_averager.step()

        if ((train_iter + 1) % args.log_interval == 0
                or train_iter + 1 == total_train_iters):
            if args.comm_backend == 'horovod':
                # Use allreduce to get the total number of tokens and loss
                log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy()
                log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0],
                                           average=False).asnumpy()
                log_avg_loss = hvd.allreduce(log_avg_loss_l[0] /
                                             log_avg_loss_denom_l[0],
                                             average=True)
                log_avg_loss = log_avg_loss.asnumpy()
            else:
                log_wc = sum([ele.asnumpy() for ele in log_wc_l])
                log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l])
                log_avg_loss =\
                    sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy()
                         for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l)
            log_avg_grad_norm = log_avg_grad_norm / log_iter_num
            log_end_time = time.time()
            wps = log_wc / (log_end_time - log_start_time)
            epoch_id = train_iter // num_updates_per_epoch
            logging.info(
                '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, '
                'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,'
                ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format(
                    epoch_id, train_iter % num_updates_per_epoch + 1,
                    num_updates_per_epoch,
                    train_iter + 1, total_train_iters, log_avg_loss,
                    np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                    log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate,
                    log_avg_grad_norm,
                    (log_end_time - train_start_time) / (train_iter + 1) *
                    (total_train_iters - train_iter - 1) / 3600))
            if local_rank == 0:
                writer.add_scalar('throughput_wps', wps, train_iter)
                writer.add_scalar('train_loss', log_avg_loss, train_iter)
                writer.add_scalar('lr', trainer.learning_rate, train_iter)
                writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter)
            # Reinitialize the log variables
            log_start_time = time.time()
            log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_grad_norm = 0
            log_iter_num = 0
            log_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]
            log_tgt_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]

        if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \
            or ((train_iter + 1) % num_updates_per_epoch == 0) \
            or train_iter + 1 == total_train_iters:
            epoch_id = (train_iter + 1) // num_updates_per_epoch
            if local_rank == 0:
                if args.max_update <= 0:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'epoch{}.params'.format(epoch_id)),
                                          deduplicate=True)
                else:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'iter{}.params'.format(train_iter + 1)),
                                          deduplicate=True)

            avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\
                = validation(model, val_data_loader, inference_model, beam_search_sampler,
                             tgt_tokenizer, ctx_l)
            if args.comm_backend == 'horovod':
                flatten_pred_sentences = np.concatenate(pred_sentences, axis=0)
                all_val_loss = hvd.allgather(
                    mx.np.array([avg_val_loss * ntokens],
                                dtype=np.float32,
                                ctx=ctx_l[0]))
                all_ntokens = hvd.allgather(
                    mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0]))
                flatten_pred_sentences = hvd.allgather(
                    mx.np.array(flatten_pred_sentences,
                                dtype=np.int32,
                                ctx=ctx_l[0]))
                pred_lengths = hvd.allgather(
                    mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0]))
                sentence_ids = hvd.allgather(
                    mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0]))
                avg_val_loss = all_val_loss.asnumpy().sum(
                ) / all_ntokens.asnumpy().sum()
                flatten_pred_sentences = flatten_pred_sentences.asnumpy()
                pred_lengths = pred_lengths.asnumpy()
                sentence_ids = sentence_ids.asnumpy()
                pred_sentences = [None for _ in range(len(sentence_ids))]
                ptr = 0
                assert sentence_ids.min() == 0 and sentence_ids.max(
                ) == len(sentence_ids) - 1
                for sentence_id, length in zip(sentence_ids, pred_lengths):
                    pred_sentences[sentence_id] = flatten_pred_sentences[ptr:(
                        ptr + length)]
                    ptr += length
            if local_rank == 0:
                # Perform detokenization
                pred_sentences_bpe_decode = []
                pred_sentences_raw = []
                for sentence in pred_sentences:
                    bpe_decode_sentence = tgt_tokenizer.decode(
                        sentence.tolist())
                    raw_sentence = base_tgt_tokenizer.decode(
                        bpe_decode_sentence.split())
                    pred_sentences_bpe_decode.append(bpe_decode_sentence)
                    pred_sentences_raw.append(raw_sentence)
                detok_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_bpe_decode,
                    ref_streams=[tgt_detok_sentences])
                raw_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_raw,
                    ref_streams=[tgt_raw_sentences])
                with open(
                        os.path.join(args.save_dir,
                                     f'epoch{epoch_id}_dev_prediction.txt'),
                        'w') as of:
                    for line in pred_sentences_raw:
                        of.write(line + '\n')
                logging.info(
                    '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, '
                    'SacreBlEU={}, Detok SacreBLUE={}'.format(
                        epoch_id, train_iter, total_train_iters, avg_val_loss,
                        np.exp(avg_val_loss), raw_sacrebleu_out.score,
                        detok_sacrebleu_out.score))
                writer.add_scalar('valid_loss', avg_val_loss, train_iter)
                writer.add_scalar('valid_bleu', raw_sacrebleu_out.score,
                                  train_iter)

    if args.num_averages > 0:
        model_averager.copy_back(
            param_dict)  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)
Beispiel #10
0
def train_gluon():
    if args.save_dir:
        save_dir = args.save_dir
        save_dir = os.path.expanduser(save_dir)
        makedirs(save_dir)
    else:
        save_dir = './'
        save_frequency = 0

    def evaluate(epoch):
        acc_top1 = mx.metric.Accuracy()
        acc_top5 = mx.metric.TopKAccuracy(5)
        for _, batch in enumerate(val_data):
            data, label = val_batch_fn(batch, context)
            output = net(data.astype(args.dtype, copy=False))
            acc_top1.update([label], [output])
            acc_top5.update([label], [output])

        top1_name, top1_acc = acc_top1.get()
        top5_name, top5_acc = acc_top5.get()
        if MPI is not None:
            comm = MPI.COMM_WORLD
            res1 = comm.gather(top1_acc, root=0)
            res2 = comm.gather(top5_acc, root=0)
        if rank == 0:
            if MPI is not None:
                #logging.info('MPI gather res1: {}'.format(res1))
                top1_acc = sum(res1) / len(res1)
                top5_acc = sum(res2) / len(res2)
            logging.info(
                'Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f',
                epoch, rank, top1_name, top1_acc, top5_name, top5_acc)

    # Hybridize and initialize model
    net.hybridize()
    #net.initialize(initializer, ctx=context)
    if args.resume_params is not '':
        net.load_parameters(args.resume_params, ctx=context)

    else:
        net.initialize(initializer, ctx=context)

    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    # Horovod: fetch and broadcast parameters
    params = net.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Create optimizer
    optimizer = 'nag'
    optimizer_params = {
        'wd': args.wd,
        'momentum': args.momentum,
        'lr_scheduler': lr_sched
    }
    if args.dtype == 'float16':
        optimizer_params['multi_precision'] = True
    opt = mx.optimizer.create(optimizer, **optimizer_params)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params, opt)
    if args.resume_states is not '':
        trainer.load_states(args.resume_states)

    # Create loss function and train metric
    if args.label_smoothing or args.mixup:
        sparse_label_loss = False
    else:
        sparse_label_loss = True

    distillation = args.teacher is not None and args.hard_weight < 1.0
    if distillation:
        teacher = get_model(args.teacher,
                            pretrained=True,
                            classes=num_classes,
                            ctx=context)
        teacher.hybridize()
        teacher.cast(args.dtype)
        loss_fn = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
            temperature=args.temperature,
            hard_weight=args.hard_weight,
            sparse_label=sparse_label_loss)
        if rank == 0:
            logging.info('Using Distillation')
    else:
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=sparse_label_loss)
    if args.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, mx.nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, mx.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    # Train model
    for epoch in range(args.resume_epoch, args.num_epochs):
        drop_scheduler(epoch)
        tic = time.time()
        train_metric.reset()

        btic = time.time()
        for nbatch, batch in enumerate(train_data, start=1):
            data, label = train_batch_fn(batch, context)
            data, label = [data], [label]
            if args.mixup:
                lam = np.random.beta(args.mixup_alpha, args.mixup_alpha)
                if epoch >= args.num_epochs - args.mixup_off_epoch:
                    lam = 1
                data = [lam * X + (1 - lam) * X[::-1] for X in data]

                if args.label_smoothing:
                    eta = 0.1
                else:
                    eta = 0.0
                label = mixup_transform(label, num_classes, lam, eta)

            elif args.label_smoothing:
                hard_label = label
                label = smooth(label, num_classes)

            if distillation:
                teacher_prob = [mx.nd.softmax(teacher(X.astype(args.dtype, copy=False)) / args.temperature) \
                                for X in data]

            with autograd.record():
                outputs = [net(X.astype(args.dtype, copy=False)) for X in data]
                if distillation:
                    loss = [
                        loss_fn(yhat.astype('float32', copy=False),
                                y.astype('float32', copy=False),
                                p.astype('float32', copy=False))
                        for yhat, y, p in zip(outputs, label, teacher_prob)
                    ]
                else:
                    loss = [
                        loss_fn(yhat, y.astype(args.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]
            for l in loss:
                l.backward()
            trainer.step(batch_size)

            if args.mixup:
                output_softmax = [mx.nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                  for out in outputs]
                train_metric.update(label, output_softmax)
            else:
                if args.label_smoothing:
                    train_metric.update(hard_label, outputs)
                else:
                    train_metric.update(label, outputs)

            if args.log_interval and nbatch % args.log_interval == 0:
                if rank == 0:
                    logging.info('Epoch[%d] Batch[%d] Loss[%.3f]', epoch,
                                 nbatch, loss[0].mean().asnumpy()[0])

                    train_metric_name, train_metric_score = train_metric.get()
                    logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f',
                                 epoch, rank, nbatch, train_metric_name,
                                 train_metric_score, trainer.learning_rate)
                    #batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic)
                    #logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec',
                    #             epoch, nbatch, batch_speed)
                btic = time.time()

        # Report metrics
        elapsed = time.time() - tic
        _, acc = train_metric.get()
        if rank == 0:
            logging.info(
                'Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-metric=%f',
                epoch, rank, nbatch, elapsed, acc)
            epoch_speed = num_workers * batch_size * nbatch / elapsed
            logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch,
                         epoch_speed)

        # Evaluate performance
        if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0:
            evaluate(epoch)

        # Save model
        if args.save_frequency and (epoch + 1) % args.save_frequency == 0:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, args.model, epoch))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, args.model, epoch))

    # Evaluate performance at the end of training
    evaluate(epoch)

    net.save_parameters('%s/imagenet-%s-%d.params' %
                        (save_dir, args.model, args.num_epochs - 1))
    trainer.save_states('%s/imagenet-%s-%d.states' %
                        (save_dir, args.model, args.num_epochs - 1))
Beispiel #11
0
def model_fit(args, net, train_data, eval_metric, optimizer, optimizer_params,
              lr_scheduler, eval_data, global_metrics, kvstore, kv,
              begin_epoch, num_epoch, run_epoch, model_prefix):
    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)
    loss_metric = ScalarMetric()

    if 'horovod' in kvstore:
        trainer = hvd.DistributedTrainer(net.collect_params(), optimizer,
                                         optimizer_params)
    else:
        trainer = gluon.Trainer(net.collect_params(),
                                optimizer,
                                optimizer_params,
                                kvstore=kv,
                                update_on_kvstore=False)

    if args.amp:
        amp.init_trainer(trainer)

    sparse_label_loss = (args.label_smoothing == 0 and args.mixup == 0)
    loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
    loss.hybridize(static_shape=True, static_alloc=True)

    local_batch_size = train_data.batch_size
    total_batch_size = local_batch_size * train_data._num_gpus * (
        hvd.size() if 'horovod' in kvstore else 1)
    durations = []

    epoch_size = get_epoch_size(args, kv)
    run_epoch = num_epoch if (run_epoch == -1) else (begin_epoch + run_epoch)

    def transform_data(images, labels):
        if args.mixup != 0:
            coeffs = mx.nd.array(
                np.random.beta(args.mixup, args.mixup,
                               size=images.shape[0])).as_in_context(
                                   images.context)
            image_coeffs = coeffs.astype(images.dtype, copy=False).reshape(
                *coeffs.shape, 1, 1, 1)
            ret_images = image_coeffs * images + (1 -
                                                  image_coeffs) * images[::-1]

            ret_labels = label_smoothing(labels, args.num_classes,
                                         args.label_smoothing)
            label_coeffs = coeffs.reshape(*coeffs.shape, 1)
            ret_labels = label_coeffs * ret_labels + (
                1 - label_coeffs) * ret_labels[::-1]
        else:
            ret_images = images
            if not sparse_label_loss:
                ret_labels = label_smoothing(labels, args.num_classes,
                                             args.label_smoothing)
            else:
                ret_labels = labels

        return ret_images, ret_labels

    i = -1
    best_accuracy = -1
    for epoch in range(begin_epoch, min(run_epoch, num_epoch)):
        tic = time.time()
        btic = time.time()
        etic = time.time()

        train_data.reset()
        eval_metric.reset()
        loss_metric.reset()

        logging.info('Starting epoch {}'.format(epoch))
        outputs = []
        for i, batches in enumerate(train_data):
            # synchronize to previous iteration
            #for o in outputs:
            #    o.wait_to_read()

            trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))

            data = [b.data[0] for b in batches]
            label = [
                b.label[0].as_in_context(b.data[0].context) for b in batches
            ]
            orig_label = label

            data, label = zip(*starmap(transform_data, zip(data, label)))

            outputs = []
            Ls = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    # store the loss and do backward after we have done forward
                    # on all GPUs for better speed on multiple GPUs.
                    Ls.append(L)
                    outputs.append(z)

                if args.amp:
                    with amp.scale_loss(Ls, trainer) as scaled_loss:
                        ag.backward(scaled_loss)
                else:
                    ag.backward(Ls)

            if 'horovod' in kvstore:
                trainer.step(local_batch_size)
            else:
                trainer.step(total_batch_size)

            loss_metric.update(..., np.mean([l.asnumpy() for l in Ls]).item())

            if args.disp_batches and not (i + 1) % args.disp_batches:
                dllogger_it_data = {
                    'train.loss':
                    loss_metric.get()[1],
                    'train.ips':
                    args.disp_batches * total_batch_size /
                    (time.time() - btic),
                    'train.lr':
                    trainer.learning_rate
                }
                dllogger.log((epoch, i), data=dllogger_it_data)

                loss_metric.reset_local()
                btic = time.time()

            durations.append(time.time() - tic)
            tic = time.time()

        durations = durations[min(len(durations) // 10, 100):]
        dllogger_epoch_data = {
            'train.loss': loss_metric.get_global()[1],
            'train.ips': total_batch_size / np.mean(durations)
        }
        if args.mode == 'train_val':
            logging.info('Validating epoch {}'.format(epoch))
            score, duration_stats, _ = model_score(args, net, eval_data,
                                                   eval_metric, kvstore)

            dllogger_epoch_data.update(
                starmap(lambda key, val: ('val.{}'.format(key), val),
                        zip(*score)))
            dllogger_epoch_data.update(
                starmap(lambda key, val: ('val.{}'.format(key), val),
                        duration_stats.items()))

            score = dict(zip(*score))
            accuracy = score.get('accuracy', -1)
            save_checkpoint(net, epoch, accuracy, best_accuracy, model_prefix,
                            args.save_frequency, kvstore)
            best_accuracy = max(best_accuracy, accuracy)
        global_metrics.update_dict(dllogger_epoch_data)
        dllogger.log(step=(epoch, ), data=dllogger_epoch_data)
Beispiel #12
0
finetune_net = get_model(model_name, pretrained=True)
with finetune_net.name_scope():
    finetune_net.output = nn.Dense(classes)
finetune_net.output.initialize(init.Xavier(), ctx=ctx)
finetune_net.collect_params().reset_ctx(ctx)
finetune_net.hybridize()

optimizer_parameters = {'learning_rate': lr, 'wd': wd}
optimizer = mx.optimizer.NAG(momentum=momentum, **optimizer_parameters)

model_parameters = finetune_net.collect_params()
if model_parameters is not None:
    hvd.broadcast_parameters(model_parameters, root_rank=0)

trainer = hvd.DistributedTrainer(model_parameters, optimizer)

metric = mx.metric.Accuracy()
L = gluon.loss.SoftmaxCrossEntropyLoss()


# we define a evaluation function for validation and testing
def test(net, val_data, ctx):
    metric = mx.metric.Accuracy()
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch[0],
                                          ctx_list=ctx,
                                          batch_axis=0,
                                          even_split=False)
        label = gluon.utils.split_and_load(batch[1],
                                           ctx_list=ctx,
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    task = get_task(args.task_name)
    #setup_logging(args, local_rank)
    level = logging.INFO
    detail_dir = os.path.join(args.output_dir, args.task_name)
    if not os.path.exists(detail_dir):
        os.mkdir(detail_dir)
    logging_config(
        detail_dir,
        name='train_{}_{}_'.format(args.task_name, args.model_name) +
        str(rank),  # avoid race
        level=level,
        console=(local_rank == 0))
    logging.info(args)
    cfg, tokenizer, classify_net, use_segmentation = \
        get_network(args.model_name, ctx_l,
                    args.param_checkpoint,
                    args.backbone_path,
                    task)
    logging.info('Prepare training data')
    train_data, _ = get_task_data(args, tokenizer, segment='train', task=task)
    train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
                              bf.Stack())

    epoch_num_updates = len(train_data) // args.batch_size
    max_update = epoch_num_updates * args.epochs
    warmup_steps = int(np.ceil(max_update * args.warmup_ratio))

    dataloader = DataLoader(train_data,
                            batch_size=args.batch_size,
                            batchify_fn=train_batchify,
                            num_workers=4,
                            shuffle=True)
    dataloader = grouper(repeat(dataloader), len(ctx_l))

    param_dict = classify_net.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in classify_net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'

    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    lr_scheduler = PolyScheduler(max_update=max_update,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0.0,
                                 pwr=1,
                                 final_lr=0.0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler
    }
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(classify_net.collect_params(), 'adamw',
                                   optimizer_params)

    if args.task_name == 'sts':
        loss_function = gluon.loss.L2Loss()
    else:
        loss_function = gluon.loss.SoftmaxCELoss()

    #prepare loss function
    log_loss = 0
    log_gnorm = 0
    log_step = 0
    if args.log_interval > 0:
        log_interval = args.log_interval
    else:
        log_interval = int(epoch_num_updates * 0.5)

    for i in range(max_update):
        sample_l = next(dataloader)
        loss_l = []
        for sample, ctx in zip(sample_l, ctx_l):
            (token_ids, token_types, valid_length), label = sample
            # Move to the corresponding context
            token_ids = mx.np.array(token_ids, ctx=ctx)
            token_types = mx.np.array(token_types, ctx=ctx)
            valid_length = mx.np.array(valid_length, ctx=ctx)
            label = mx.np.array(label, ctx=ctx)
            with mx.autograd.record():
                scores = classify_net(token_ids, token_types, valid_length)
                loss = loss_function(scores, label).mean() / len(ctx_l)
                loss_l.append(loss)
        for loss in loss_l:
            loss.backward()
        trainer.allreduce_grads()
        # Begin Norm Clipping
        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm)
        trainer.update(1.0)
        step_loss = sum([loss.asnumpy() for loss in loss_l])
        log_loss += step_loss
        log_gnorm += total_norm
        log_step += 1
        if log_step >= log_interval or i == max_update - 1:
            logging.info(
                '[Iter {} / {}] avg {} = {:.2f}, avg gradient norm = {:.2f}'.
                format(i + 1, max_update, 'nll', log_loss / log_step,
                       log_gnorm / log_step))
            log_loss = 0
            log_gnorm = 0
            log_step = 0
        if local_rank == 0 and (i == max_update - 1 or i %
                                (max_update // args.epochs) == 0 and i > 0):
            ckpt_name = '{}_{}_{}.params'.format(args.model_name,
                                                 args.task_name, (i + 1))

            params_saved = os.path.join(detail_dir, ckpt_name)
            classify_net.save_parameters(params_saved)
            logging.info('Params saved in: {}'.format(params_saved))
Beispiel #14
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    for k, v in net.collect_params('.*beta|.*bias').items():
        v.wd_mult = 0.0

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
                        net.collect_train_params(), # fix batchnorm, fix first stage, etc...
                        'sgd',
                        {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(
                    net.collect_train_params(), # fix batchnorm, fix first stage, etc...
                    'sgd',
                    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                    update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1'),
               mx.metric.Loss('RCNN_Mask')]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    rcnn_mask_metric = MaskAccMetric()
    rcnn_fgmask_metric = MaskFGAccMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric,
                rcnn_acc_metric, rcnn_bbox_metric,
                rcnn_mask_metric, rcnn_fgmask_metric]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        base_lr = trainer.learning_rate
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            batch_size = len(batch[0])
            losses = []
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            with autograd.record():
                for data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(
                        *batch):
                    gt_label = label[:, :, 4:5]
                    gt_box = label[:, :, :4]
                    cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net(
                        data, gt_box)
                    # losses of rpn
                    rpn_score = rpn_score.squeeze(axis=-1)
                    num_rpn_pos = (rpn_cls_targets >= 0).sum()
                    rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets,
                                             rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
                    rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets,
                                             rpn_box_masks) * rpn_box.size / num_rpn_pos
                    # rpn overall loss, use sum rather than average
                    rpn_loss = rpn_loss1 + rpn_loss2
                    # generate targets for rcnn
                    cls_targets, box_targets, box_masks = net.target_generator(roi, samples,
                                                                               matches, gt_label,
                                                                               gt_box)
                    # losses of rcnn
                    num_rcnn_pos = (cls_targets >= 0).sum()
                    rcnn_loss1 = rcnn_cls_loss(cls_pred, cls_targets,
                                               cls_targets >= 0) * cls_targets.size / \
                                 cls_targets.shape[0] / num_rcnn_pos
                    rcnn_loss2 = rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
                                 box_pred.shape[0] / num_rcnn_pos
                    rcnn_loss = rcnn_loss1 + rcnn_loss2
                    # generate targets for mask
                    mask_targets, mask_masks = net.mask_target(roi, gt_mask, matches, cls_targets)
                    # loss of mask
                    mask_loss = rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \
                                mask_targets.size / mask_targets.shape[0] / mask_masks.sum()
                    # overall losses
                    losses.append(rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum())
                    if (not args.horovod or hvd.rank() == 0):
                        metric_losses[0].append(rpn_loss1.sum())
                        metric_losses[1].append(rpn_loss2.sum())
                        metric_losses[2].append(rcnn_loss1.sum())
                        metric_losses[3].append(rcnn_loss2.sum())
                        metric_losses[4].append(mask_loss.sum())
                        add_losses[0].append([[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]])
                        add_losses[1].append([[rpn_box_targets, rpn_box_masks], [rpn_box]])
                        add_losses[2].append([[cls_targets], [cls_pred]])
                        add_losses[3].append([[box_targets, box_masks], [box_pred]])
                        add_losses[4].append([[mask_targets, mask_masks], [mask_pred]])
                        add_losses[5].append([[mask_targets, mask_masks], [mask_pred]])
                if args.amp:
                    with amp.scale_loss(losses, trainer) as scaled_losses:
                        autograd.backward(scaled_losses)
                else:
                    autograd.backward(losses)
                if (not args.horovod or hvd.rank() == 0):
                    for metric, record in zip(metrics, metric_losses):
                        metric.update(0, record)
                    for metric, records in zip(metrics2, add_losses):
                        for pred in records:
                            metric.update(pred[0], pred[1])
            trainer.step(batch_size)
            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval and not (i + 1) % args.log_interval:
                msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg))
                btic = time.time()
        # validate and save params
        if (not args.horovod or hvd.rank() == 0):
            msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
Beispiel #15
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    net.collect_params.reset_ctx(ctx)
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', {
            'learning_rate': args.lr,
            'wd': args.wd,
            'momentum': args.momentum
        })
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            'sgd', {
                'learning_rate': args.lr,
                'wd': args.wd,
                'momentum': args.momentum
            },
            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])

    mbox_loss = gcv.loss.SSDMultiBoxLoss(
    )  # compute loss in entire batch across devices
    ce_metric = mx.metric.Loss('CrossEntropy')  # 记录cls 的loss
    smoothl1_metric = mx.metric.Loss('SmoothL1')  # 记录box 偏移量的loss

    # logger set_up
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch, args.epochs):

        # 重新设置 learning_rate
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info('[Epoch {}] Set learning rate to {}'.format(
                epoch, new_lr))

        ce_metric.reset()
        smoothl1_metric.reset()
        tic = time.time()  # 记录一次循环的时间
        btic = time.time()  # 记录每一个batch的时间
        net.hybridize(static_alloc=True, statis_shape=True)

        for i, batch in enumerate(train_data):
            # get data , target box and target box
            """if args.dali:
                data = [d.data[0] for d in batch]
                box_targets = [d.label[0] for d in batch]
                cls_targets = [nd.cast(d.label[1], dtype='float32') for d in batch]

            else:"""
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            cls_targets = gluon.utils.split_and_load(batch[1],
                                                     ctx_list=ctx,
                                                     batch_axis=0)
            box_targets = gluon.utils.split_and_load(batch[2],
                                                     ctx_list=ctx,
                                                     batch_axis=0)
            """
            x, y: y本就包含着该图片之前所有生成的锚框target的box位置信息即box_targets(batch_size, N, 4), N是锚框的个数
            以及每个锚框对应的类别即cls_targets(batch_size, N)

             """
            with autograd.record():
                cls_preds, box_preds, _ = net(data)
                # cls_preds: (batch_size, num_anchors, num_cls + 1)
                sum_loss, cls_loss, box_loss = mbox_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
                # 计算loss的时候,是Compute loss in entire batch across devices.
                # 也就是: divide by the sum of num positive targets in batch
                # sum_loss, cls_loss, box_loss 的形状???
                if args.amp:
                    with amp.scale_loss(sum_loss, trainer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    sum_loss.backward()
            trainer.step(1)
            # since we have already normalized the loss, we don't want to normalize
            # by batch-size anymore

            if (not args.horovod or hvd.rank() == 0):
                local_batch_size = int(args.batch_size //
                                       (hvd.size() if args.horovod else 1))
                ce_metric.update(0, [l * local_batch_size for l in cls_loss])
                smoothl1_metric.update(
                    0, [l * local_batch_size for l in box_loss])
                # ce_metric 和 smoothl1_metric 为什么要乘以local_batch_size...T_T
                # to get loss per image
                # ce_metric.get(),smoothl1_metric.get() 方法里面会除以batch_size
                # 所以在这之前,要先乘以batch_size, 否则就会变成loss/num_anchors/(batch_size * batch_size)

                if args.log_interval and not (i + 1) % args.log_interval:
                    # 每隔args.log_interval就记录一次
                    name1, loss1 = ce_metric.get()
                    name2, loss2 = smoothl1_metric.get()
                    logger.info(
                        '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'
                        .format(epoch, i,
                                args.batch_size / (time.time() - btic), name1,
                                loss1, name2, loss2))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info(
                '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.
                format(epoch, (time.time() - tic), name1, loss1, name2, loss2))
            if (epoch % args.val_interval
                    == 0) or (args.save_interval
                              and epoch % args.save_interval == 0):
                # 每循环args.val_interval或者args.save_interval次
                # 就需要使用验证集来测试一次,得到current_map
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = "\n".join('{}={}'.format(k, v)
                                    for k, v in zip(map_name, mean_ap))
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])  # mean_ap的最后一个数据就是mAP
            else:
                current_map = 0
            save_params(net, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
Beispiel #16
0
def main(args):
    # Function to get mnist iterator given a rank
    def get_mnist_iterator(rank):
        data_dir = "data-%d" % rank
        if not os.path.isdir(data_dir):
            os.makedirs(data_dir)
        zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
                                 dirname=data_dir)
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall(data_dir)

        input_shape = (1, 28, 28)
        batch_size = args.batch_size

        train_iter = mx.io.MNISTIter(
            image="%s/train-images-idx3-ubyte" % data_dir,
            label="%s/train-labels-idx1-ubyte" % data_dir,
            input_shape=input_shape,
            batch_size=batch_size,
            shuffle=True,
            flat=False,
            num_parts=hvd.size(),
            part_index=hvd.rank())

        val_iter = mx.io.MNISTIter(
            image="%s/t10k-images-idx3-ubyte" % data_dir,
            label="%s/t10k-labels-idx1-ubyte" % data_dir,
            input_shape=input_shape,
            batch_size=batch_size,
            flat=False,
        )

        return train_iter, val_iter

    # Function to define neural network
    def conv_nets():
        net = gluon.nn.HybridSequential()
        net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
        net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
        net.add(gluon.nn.Flatten())
        net.add(gluon.nn.Dense(512, activation="relu"))
        net.add(gluon.nn.Dense(10))
        return net

    # Function to evaluate accuracy for a model
    def evaluate(model, data_iter, context):
        data_iter.reset()
        metric = mx.gluon.metric.Accuracy()
        for _, batch in enumerate(data_iter):
            data = batch.data[0].as_in_context(context)
            label = batch.label[0].as_in_context(context)
            output = model(data.astype(args.dtype, copy=False))
            metric.update([label], [output])

        return metric.get()

    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(
        hvd.local_rank())
    num_workers = hvd.size()

    # Load training and validation data
    train_data, val_data = get_mnist_iterator(hvd.rank())

    # Build model
    model = conv_nets()
    model.cast(args.dtype)
    model.hybridize()

    # Create optimizer
    optimizer_params = {
        'momentum': args.momentum,
        'learning_rate': args.lr * hvd.size()
    }
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Initialize parameters
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    model.initialize(initializer, ctx=context)

    # Horovod: fetch and broadcast parameters
    params = model.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params, opt)

    # Create loss function and train metric
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    metric = mx.gluon.metric.Accuracy()

    # Train model
    for epoch in range(args.epochs):
        tic = time.time()
        train_data.reset()
        metric.reset()
        for nbatch, batch in enumerate(train_data, start=1):
            data = batch.data[0].as_in_context(context)
            label = batch.label[0].as_in_context(context)
            with autograd.record():
                output = model(data.astype(args.dtype, copy=False))
                loss = loss_fn(output, label)
            loss.backward()
            trainer.step(args.batch_size)
            metric.update([label], [output])

            if nbatch % 100 == 0:
                name, acc = metric.get()
                logging.info('[Epoch %d Batch %d] Training: %s=%f' %
                             (epoch, nbatch, name, acc))

        if hvd.rank() == 0:
            elapsed = time.time() - tic
            speed = nbatch * args.batch_size * hvd.size() / elapsed
            logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
                         epoch, speed, elapsed)

        # Evaluate model accuracy
        _, train_acc = metric.get()
        name, val_acc = evaluate(model, val_data, context)
        if hvd.rank() == 0:
            logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch,
                         name, train_acc, name, val_acc)

        if hvd.rank() == 0 and epoch == args.epochs - 1:
            assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
                                    (0.96)" % val_acc
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            'sgd', {
                'wd': args.wd,
                'momentum': args.momentum,
                'lr_scheduler': lr_scheduler
            },
            kvstore='local',
            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()
    """
    # metrics
    obj_metrics = mx.metric.Loss('ObjLoss')
    center_metrics = mx.metric.Loss('BoxCenterLoss')
    scale_metrics = mx.metric.Loss('BoxScaleLoss')
    cls_metrics = mx.metric.Loss('ClassLoss')   

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6)]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0):
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info('[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
                        epoch, i, trainer.learning_rate, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2, name3, loss3, name4, loss4))
                btic = time.time()

        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
                epoch, (time.time()-tic), name1, loss1, name2, loss2, name3, loss3, name4, loss4))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

if __name__ == '__main__':
    args = parse_args()

    if args.amp:
        amp.init()

    if args.horovod:
        if hvd is None:
            raise SystemExit("Horovod not found, please check if you installed it correctly.")
        hvd.init()

    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    # training contexts
    # check gpu or cpu
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        if ctx:
            print("Using GPU")
        else:
            ctx=mx.cpu()


    # network
    net_name = '_'.join(('yolo3', args.network, args.dataset))
    
    args.save_prefix += net_name
    # use sync bn if specified
    if args.syncbn and len(ctx) > 1:
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm,
                        norm_kwargs={'num_devices': len(ctx)})
        async_net = get_model(net_name, pretrained_base=True)  # used by cpu worker
    else:
        net = get_model(net_name, pretrained_base=True)
        async_net = net
    if args.resume.strip():
        net.load_parameters(args.resume.strip())
        async_net.load_parameters(args.resume.strip())
    else:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            net.initialize()
            async_net.initialize()


    # training data
    batch_size = (args.batch_size // hvd.size()) if args.horovod else args.batch_size
    train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
    train_data, val_data = get_dataloader(
        async_net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, args)

    start_time = time.time()

    # training
    train(net, train_data, val_data, eval_metric, ctx, args)

    elapsed_time = time.time() - start_time

    print("time to train: ", elapsed_time)

    file_name = "yolo3_mobilenet1.0_voc_last.params"
    net.save_parameters(file_name)

    
    """

    net = net.getmodel("yolo3_mobilenet1.0_voc")
    net.load_parameters("yolo3_mobilenet1.0_voc_best.params")

    im_fname = mx.utils.download(
        'https://raw.githubusercontent.com/zhreshold/' +
        'mxnet-ssd/master/data/demo/dog.jpg',
        path='dog.jpg')
    x, img = mx.data.transforms.presets.yolo.load_test(im_fname, short=512)
    print('Shape of pre-processed image:', x.shape)
    class_IDs, scores, bounding_boxs = net(x)

    ax = mx.utils.viz.plot_bbox(img,
                                bounding_boxs[0],
                                scores[0],
                                class_IDs[0],
                                class_names=net.classes)
    plt.show()
def main():

    # Function to get mnist iterator given a rank
    def get_voc_iterator(rank, num_workers, net, num_shards):
        data_dir = "data-%d" % rank
        try:
            s3_client = boto3.client('s3')
            for file in [
                    'VOCtrainval_06-Nov-2007.tar', 'VOCtest_06-Nov-2007.tar',
                    'VOCtrainval_11-May-2012.tar'
            ]:
                s3_client.download_file(args.s3bucket, f'voc_tars/{file}',
                                        f'/opt/ml/code/{file}')
                with tarfile.open(filename) as tar:
                    tar.extractall(path=path)
        except:
            print('downloading from source')
            download_voc(data_dir)

        input_shape = (1, 256, 256, 3)
        batch_size = args.batch_size

        # might want to replace with mx.io.ImageDetRecordIter, this means you need data in RecordIO format
        #         train_iter = mx.io.MNISTIter(
        #             image="%s/train-images-idx3-ubyte" % data_dir,
        #             label="%s/train-labels-idx1-ubyte" % data_dir,
        #             input_shape=input_shape,
        #             batch_size=batch_size,
        #             shuffle=True,
        #             flat=False,
        #             num_parts=hvd.size(),
        #             part_index=hvd.rank()
        #         )

        train_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'trainval'), (2012, 'trainval')])
        val_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'test')])
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        im_aspect_ratio = [1.] * len(train_dataset)
        train_bfn = FasterRCNNTrainBatchify(net)
        train_sampler = gluoncv.nn.sampler.SplitSortedBucketSampler(
            im_aspect_ratio,
            batch_size,
            num_parts=hvd.size() if args.horovod else 1,
            part_index=hvd.rank() if args.horovod else 0,
            shuffle=True)
        # had issue with multi_stage=True
        train_iter = mx.gluon.data.DataLoader(train_dataset.transform(
            FasterRCNNDefaultTrainTransform(net.short,
                                            net.max_size,
                                            net,
                                            ashape=net.ashape,
                                            multi_stage=False)),
                                              batch_sampler=train_sampler,
                                              batchify_fn=train_bfn,
                                              num_workers=num_workers)

        val_bfn = Tuple(*[Append() for _ in range(3)])
        short = net.short[-1] if isinstance(net.short,
                                            (tuple, list)) else net.short
        # validation use 1 sample per device
        val_iter = mx.gluon.data.DataLoader(val_dataset.transform(
            FasterRCNNDefaultValTransform(short, net.max_size)),
                                            num_shards,
                                            False,
                                            batchify_fn=val_bfn,
                                            last_batch='keep',
                                            num_workers=num_workers)

        return train_iter, val_iter

    # Function to define neural network
    def conv_nets(model_name):
        net = model_zoo.get_model(model_name, pretrained_base=False)
        return net

    def evaluate(net, val_data, ctx, eval_metric, args):
        """Test on validation dataset."""
        clipper = gcv.nn.bbox.BBoxClipToImage()
        eval_metric.reset()
        if not args.disable_hybridization:
            # input format is differnet than training, thus rehybridization is needed.
            net.hybridize(static_alloc=args.static_alloc)
        for batch in val_data:
            batch = split_and_load(batch, ctx_list=ctx)
            det_bboxes = []
            det_ids = []
            det_scores = []
            gt_bboxes = []
            gt_ids = []
            gt_difficults = []
            for x, y, im_scale in zip(*batch):
                # get prediction results
                ids, scores, bboxes = net(x)
                det_ids.append(ids)
                det_scores.append(scores)
                # clip to image size
                det_bboxes.append(clipper(bboxes, x))
                # rescale to original resolution
                im_scale = im_scale.reshape((-1)).asscalar()
                det_bboxes[-1] *= im_scale
                # split ground truths
                gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
                gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
                gt_bboxes[-1] *= im_scale
                gt_difficults.append(
                    y.slice_axis(axis=-1, begin=5, end=6
                                 ) if y.shape[-1] > 5 else None)

            # update metric
            for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
                    det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids,
                    gt_difficults):
                eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id,
                                   gt_diff)
        return eval_metric.get()

    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        ctx = ctx if ctx else [mx.cpu()]
    context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(
        hvd.local_rank())
    num_workers = hvd.size()

    # Build model
    model = conv_nets(args.model_name)
    model.cast(args.dtype)
    model.hybridize()

    # Initialize parameters
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    model.initialize(initializer, ctx=context)

    # Create optimizer
    optimizer_params = {
        'momentum': args.momentum,
        'learning_rate': args.lr * hvd.size()
    }
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Load training and validation data
    train_data, val_data = get_voc_iterator(hvd.rank(), num_workers, model,
                                            len(ctx))

    # Horovod: fetch and broadcast parameters
    params = model.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params, opt)

    # Create loss function and train metric
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    # adding in new loss functions
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rcnn_smoothl1_rho)  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    metric = mx.metric.Accuracy()

    # Global training timing
    if hvd.rank() == 0:
        global_tic = time.time()

    # Train model


#     for epoch in range(args.epochs):
#         tic = time.time()
#         train_data.reset()
#         metric.reset()
#         for nbatch, batch in enumerate(train_data, start=1):
#             data = batch.data[0].as_in_context(context)
#             label = batch.label[0].as_in_context(context)
#             with autograd.record():
#                 output = model(data.astype(args.dtype, copy=False))
#                 loss = loss_fn(output, label)
#             loss.backward()
#             trainer.step(args.batch_size)
#             metric.update([label], [output])

#             if nbatch % 100 == 0:
#                 name, acc = metric.get()
#                 logging.info('[Epoch %d Batch %d] Training: %s=%f' %
#                              (epoch, nbatch, name, acc))

#         if hvd.rank() == 0:
#             elapsed = time.time() - tic
#             speed = nbatch * args.batch_size * hvd.size() / elapsed
#             logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
#                          epoch, speed, elapsed)

#         # Evaluate model accuracy
#         _, train_acc = metric.get()
#         name, val_acc = evaluate(model, val_data, context)
#         if hvd.rank() == 0:
#             logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
#                          train_acc, name, val_acc)

#     if hvd.rank()==0:
#         global_training_time =time.time() - global_tic
#         print("Global elpased time on training:{}".format(global_training_time))
#         device = context.device_type + str(num_workers)

# train from train_faster_rcnn.py
    for epoch in range(args.epochs):
        lr_decay = float(args.lr_decay)
        lr_steps = sorted(
            [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
        lr_warmup = float(args.lr_warmup)  # avoid int division
        # this simplifies dealing with all of the loss functions
        rcnn_task = ForwardBackwardTask(model,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0,
                                        amp_enabled=args.amp)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        if not args.disable_hybridization:
            model.hybridize(static_alloc=args.static_alloc)
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:  # does a learning rate reset if warming up
                # adjust based on real percentage
                if (lr_warmup != 0):
                    new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                      args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(
                batch, ctx_list=ctx
            )  # does split and load function, creates a batch per device
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(model, val_data, ctx, eval_metric,
                                             args)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(model, logger, best_map, current_map, epoch,
                        args.save_interval, args.save_prefix)
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.info('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {
            'scale_window': 2000 / num_workers,
            'init_scale': 2**10
        }
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer,
                               dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(
            args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    if args.phase2:
        step_num -= args.phase1_num_steps

    logging.info('Training started')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0,
                                  parallel_model)

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = (num_train_steps - step_num) / (num_train_steps -
                                                             num_warmup_steps)
                    new_lr = lr * max(offset, 0)
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num,
                            10,
                            14,
                            profile_name=args.profile + str(rank))

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            num_data = len(data_list)
            for i in range(num_data):
                parallel.put(data_list[i])
            for _ in range(num_data):
                (next_sentence_label, classified, masked_id, decoded,
                 masked_weight, ls1, ls2, valid_length) = parallel.get()
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)
                running_mlm_loss += ls1.as_in_context(mx.cpu()) / len(ctxs)
                running_nsp_loss += ls2.as_in_context(mx.cpu()) / len(ctxs)
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                fp16_trainer.step(1, max_norm=1.0 * num_workers)
                if accumulate > 1:
                    param_dict.zero_grad()
            # update metrics
            if args.no_compute_acc:
                mask_pred_list[0].wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)

            # logging
            if step_num % (args.log_interval) == 0 and (batch_num +
                                                        1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks,
                              running_mlm_loss / accumulate,
                              running_nsp_loss / accumulate, step_num, trainer,
                              args.log_interval)
                else:
                    log(begin_time, running_num_tks,
                        running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if step_num % args.ckpt_interval == 0 and (batch_num +
                                                       1) % accumulate == 0:
                if is_master_node:
                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
                    if local_rank == 0:
                        save_parameters(step_num, model.bert, args.ckpt_dir)
            if step_num % args.eval_interval == 0 and data_eval \
                    and (batch_num + 1) % accumulate == 0:
                # eval data is always based on a fixed npz file.
                dataset_eval = get_pretrain_data_npz(data_eval,
                                                     batch_size_eval, 1, False,
                                                     1, vocab)
                evaluate(dataset_eval, model, ctxs, args.log_interval,
                         args.dtype)

            batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model.bert, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))
Beispiel #20
0
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    setup_logging(args, local_rank)
    cfg, tokenizer, qa_net, use_segmentation = \
        get_network(args.model_name, ctx_l,
                    args.classifier_dropout,
                    args.param_checkpoint,
                    args.backbone_path)

    logging.info('Prepare training data')
    train_features = get_squad_features(args, tokenizer, segment='train')
    dataset_processor = SquadDatasetProcessor(
        tokenizer=tokenizer,
        doc_stride=args.doc_stride,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length)
    logging.info('Processing the Training data:')
    train_dataset, num_answer_mismatch, num_unreliable \
        = dataset_processor.get_train(train_features, skip_unreliable=True)
    logging.info(
        'Done! #Unreliable Span={} / #Mismatched Answer={} / #Total={}'.format(
            num_unreliable, num_answer_mismatch, len(train_features)))

    # Get dataset statistics
    num_impossible = 0
    for sample in train_dataset:
        num_impossible += sample.is_impossible
    logging.info('Before Chunking, #Train/Is Impossible = {}/{}'.format(
        len(train_features),
        sum([ele.is_impossible for ele in train_features])))
    logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'.format(
        len(train_dataset), num_impossible))

    # Shuffle the dataset using a fixed seed across all workers
    rs = np.random.RandomState(args.pre_shuffle_seed)
    rs.shuffle(train_dataset)
    sampler = SplitSampler(len(train_dataset),
                           num_parts=num_workers,
                           part_index=rank,
                           even_size=True)
    train_dataloader = mx.gluon.data.DataLoader(
        train_dataset,
        batchify_fn=dataset_processor.BatchifyFunction,
        batch_size=args.batch_size,
        num_workers=0,
        sampler=sampler)
    if 'electra' in args.model_name:
        # Froze parameters, does not work for albert model since parameters in all layers are shared
        if args.untunable_depth > 0:
            qa_net.backbone.frozen_params(args.untunable_depth)
        if args.layerwise_decay > 0:
            qa_net.backbone.apply_layerwise_decay(args.layerwise_decay)

    logging.info('Creating distributed trainer...')
    # Collect differentiable parameters
    param_dict = qa_net.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'
    # backend specific implementation
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l)
    if args.num_train_steps is not None:
        num_train_steps = args.num_train_steps
    else:
        num_train_steps = int(args.epochs * epoch_size / args.num_accumulated)
    if args.warmup_steps is not None:
        warmup_steps = args.warmup_steps
    else:
        warmup_steps = int(num_train_steps * args.warmup_ratio)
    assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio'
    log_interval = args.log_interval
    save_interval = args.save_interval if args.save_interval is not None\
        else epoch_size // args.num_accumulated
    logging.info(
        '#Total Training Steps={}, Warmup={}, Save Interval={}'.format(
            num_train_steps, warmup_steps, save_interval))

    # set up optimization
    lr_scheduler = PolyScheduler(max_update=num_train_steps,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0,
                                 pwr=1,
                                 final_lr=0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
    }
    adam_betas = eval(args.adam_betas)
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': adam_betas[0],
            'beta2': adam_betas[1],
            'epsilon': args.adam_epsilon,
            'correct_bias': False,
        })
    elif args.optimizer == 'adam':
        optimizer_params.update({
            'beta1': adam_betas[0],
            'beta2': adam_betas[1],
            'epsilon': args.adam_epsilon,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)

    log_span_loss = 0
    log_answerable_loss = 0
    log_total_loss = 0
    log_sample_num = 0

    global_tic = time.time()
    tic = time.time()
    for step_num, batch_data in enumerate(
            grouper(repeat(train_dataloader),
                    len(ctx_l) * num_accumulated)):
        for sample_l in grouper(batch_data, len(ctx_l)):
            loss_l = []
            span_loss_l = []
            answerable_loss_l = []
            for sample, ctx in zip(sample_l, ctx_l):
                if sample is None:
                    continue
                # Copy the data to device
                tokens = sample.data.as_in_ctx(ctx)
                log_sample_num += len(tokens)
                segment_ids = sample.segment_ids.as_in_ctx(
                    ctx) if use_segmentation else None
                valid_length = sample.valid_length.as_in_ctx(ctx)
                p_mask = sample.masks.as_in_ctx(ctx)
                gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32)
                gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32)
                is_impossible = sample.is_impossible.as_in_ctx(ctx).astype(
                    np.int32)
                batch_idx = mx.np.arange(tokens.shape[0],
                                         dtype=np.int32,
                                         ctx=ctx)
                p_mask = 1 - p_mask  # In the network, we use 1 --> no_mask, 0 --> mask
                with mx.autograd.record():
                    start_logits, end_logits, answerable_logits \
                        = qa_net(tokens, segment_ids, valid_length, p_mask, gt_start)
                    sel_start_logits = start_logits[batch_idx, gt_start]
                    sel_end_logits = end_logits[batch_idx, gt_end]
                    sel_answerable_logits = answerable_logits[batch_idx,
                                                              is_impossible]
                    span_loss = -0.5 * (sel_start_logits +
                                        sel_end_logits).mean()
                    answerable_loss = -0.5 * sel_answerable_logits.mean()
                    loss = span_loss + answerable_loss
                    loss_l.append(loss)
                    span_loss_l.append(span_loss)
                    answerable_loss_l.append(answerable_loss)

            for loss in loss_l:
                loss.backward()
            # All Reduce the Step Loss
            log_span_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in span_loss_l]).asnumpy()
            log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
                                   for ele in loss_l]).asnumpy()
            log_answerable_loss += sum([
                ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l
            ]).asnumpy()
        # update
        trainer.allreduce_grads()

        if args.max_grad_norm > 0:
            total_norm, ratio, is_finite = clip_grad_global_norm(
                params, args.max_grad_norm * num_workers)
        else:
            total_norm = grad_global_norm(params)

        if args.comm_backend == 'horovod':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1
            trainer.update(num_workers, ignore_stale_grad=True)

        total_norm = total_norm / num_workers
        if args.num_accumulated > 1:
            # set grad to zero for gradient accumulation
            qa_net.zero_grad()

        # saving
        if local_rank == 0 and (step_num + 1) % save_interval == 0 or (
                step_num + 1) >= num_train_steps:
            version_prefix = 'squad' + args.version
            ckpt_name = '{}_{}_{}.params'.format(args.model_name,
                                                 version_prefix,
                                                 (step_num + 1))
            params_saved = os.path.join(args.output_dir, ckpt_name)
            qa_net.save_parameters(params_saved)
            ckpt_candidates = [
                f for f in os.listdir(args.output_dir) if f.endswith('.params')
            ]
            # keep last `max_saved_ckpt` checkpoints
            if len(ckpt_candidates) > args.max_saved_ckpt:
                ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
                os.remove(os.path.join(args.output_dir, ckpt_candidates[0]))
            logging.info('Params saved in: {}'.format(params_saved))

        # logging
        if (step_num + 1) % log_interval == 0:
            log_span_loss /= log_sample_num
            log_answerable_loss /= log_sample_num
            log_total_loss /= log_sample_num
            toc = time.time()
            logging.info(
                'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
                ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
                ' ETA={:.2f}h'.format(
                    (step_num + 1), num_train_steps, log_span_loss,
                    log_answerable_loss, log_total_loss, trainer.learning_rate,
                    total_norm, toc - tic, log_sample_num / (toc - tic),
                    (num_train_steps - (step_num + 1)) /
                    ((step_num + 1) / (toc - global_tic)) / 3600))
            tic = time.time()
            log_span_loss = 0
            log_answerable_loss = 0
            log_total_loss = 0
            log_sample_num = 0
            num_samples_per_update = 0

        if (step_num + 1) >= num_train_steps:
            toc = time.time()
            logging.info('Finish training step: {} within {} hours'.format(
                step_num + 1, (toc - global_tic) / 3600))
            break

    return params_saved
Beispiel #21
0
optimizer_params = {'momentum': args.momentum,
                    'learning_rate': args.lr * hvd.size()}
opt = mx.optimizer.create('sgd', **optimizer_params)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
                             magnitude=2)
model.initialize(initializer, ctx=context)

# Horovod: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
    hvd.broadcast_parameters(params, root_rank=0)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt,
                                 gradient_predivide_factor=args.gradient_predivide_factor)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
    tic = time.time()
    train_data.reset()
    metric.reset()
    for nbatch, batch in enumerate(train_data, start=1):
        data = batch.data[0].as_in_context(context)
        label = batch.label[0].as_in_context(context)
        with autograd.record():
            output = model(data.astype(args.dtype, copy=False))
Beispiel #22
0
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1}
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info('Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num, 10, 14, profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks, average=False, name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks, is_average=False,
                                         name="local_num_masks", priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss,
                              0, step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
#                if is_master_node:
#                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
#                    if local_rank == 0:
#                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                                         1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers)

            batch_num += 1

#    if is_master_node:
#        save_states(step_num, trainer, args.ckpt_dir, local_rank)
#        if local_rank == 0:
#            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Beispiel #23
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
                        net.collect_params(), 'sgd',
                        {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(
                    net.collect_params(), 'sgd',
                    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                    update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])

    mbox_loss = gcv.loss.SSDMultiBoxLoss()
    ce_metric = mx.metric.Loss('CrossEntropy')
    smoothl1_metric = mx.metric.Loss('SmoothL1')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch, args.epochs):
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        ce_metric.reset()
        smoothl1_metric.reset()
        tic = time.time()
        btic = time.time()
        net.hybridize(static_alloc=True, static_shape=True)

        for i, batch in enumerate(train_data):
            if args.dali:
                # dali iterator returns a mxnet.io.DataBatch
                data = [d.data[0] for d in batch]
                box_targets = [d.label[0] for d in batch]
                cls_targets = [nd.cast(d.label[1], dtype='float32') for d in batch]

            else:
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
                box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)

            with autograd.record():
                cls_preds = []
                box_preds = []
                for x in data:
                    cls_pred, box_pred, _ = net(x)
                    cls_preds.append(cls_pred)
                    box_preds.append(box_pred)
                sum_loss, cls_loss, box_loss = mbox_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
                if args.amp:
                    with amp.scale_loss(sum_loss, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_loss)
            # since we have already normalized the loss, we don't want to normalize
            # by batch-size anymore
            trainer.step(1)

            if (not args.horovod or hvd.rank() == 0):
                local_batch_size = int(args.batch_size // (hvd.size() if args.horovod else 1))
                ce_metric.update(0, [l * local_batch_size for l in cls_loss])
                smoothl1_metric.update(0, [l * local_batch_size for l in box_loss])
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = ce_metric.get()
                    name2, loss2 = smoothl1_metric.get()
                    logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                        epoch, i, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2))
                btic = time.time()

        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
                epoch, (time.time()-tic), name1, loss1, name2, loss2))
            if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0):
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
def train(net, train_data, val_data, eval_metric, ctx, args):

    import gluoncv as gcv

    gcv.utils.check_version("0.6.0")
    from gluoncv import data as gdata
    from gluoncv import utils as gutils
    from gluoncv.data.batchify import Pad, Stack, Tuple
    from gluoncv.data.dataloader import RandomTransformDataLoader
    from gluoncv.data.transforms.presets.yolo import (
        YOLO3DefaultTrainTransform,
        YOLO3DefaultValTransform,
    )
    from gluoncv.model_zoo import get_model
    from gluoncv.utils import LRScheduler, LRSequential
    from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
    from gluoncv.utils.metrics.voc_detection import VOC07MApMetric

    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params(".*beta|.*gamma|.*bias").items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(",")]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential(
        [
            LRScheduler(
                "linear",
                base_lr=0,
                target_lr=args.lr,
                nepochs=args.warmup_epochs,
                iters_per_epoch=num_batches,
            ),
            LRScheduler(
                args.lr_mode,
                base_lr=args.lr,
                nepochs=args.epochs - args.warmup_epochs,
                iters_per_epoch=num_batches,
                step_epoch=lr_decay_epoch,
                step_factor=args.lr_decay,
                power=2,
            ),
        ]
    )

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
        )
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
            kvstore="local",
            update_on_kvstore=(False if args.amp else None),
        )

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss("ObjLoss")
    center_metrics = mx.metric.Loss("BoxCenterLoss")
    scale_metrics = mx.metric.Loss("BoxScaleLoss")
    cls_metrics = mx.metric.Loss("ClassLoss")

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + "_train.log"
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.num_epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.num_epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0)
                for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(
                        x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets]
                    )
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if not args.horovod or hvd.rank() == 0:
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info(
                        "[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                            epoch,
                            i,
                            trainer.learning_rate,
                            args.batch_size / (time.time() - btic),
                            name1,
                            loss1,
                            name2,
                            loss2,
                            name3,
                            loss3,
                            name4,
                            loss4,
                        )
                    )
                btic = time.time()

        if not args.horovod or hvd.rank() == 0:
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                    epoch,
                    (time.time() - tic),
                    name1,
                    loss1,
                    name2,
                    loss2,
                    name3,
                    loss3,
                    name4,
                    loss4,
                )
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

    # save model
    net.set_nms(nms_thresh=0.45, nms_topk=400, post_nms=100)
    net(mx.nd.ones((1, 3, args.data_shape, args.data_shape), ctx=ctx[0]))
    net.export("%s/model" % os.environ["SM_MODEL_DIR"])
Beispiel #25
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    for k, v in net.collect_params('.*bias').items():
        v.wd_mult = 0.0
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum, }
    if args.clip_gradient > 0.0:
        optimizer_params['clip_gradient'] = args.clip_gradient
    if args.amp:
        optimizer_params['multi_precision'] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rcnn_smoothl1_rho)  # == smoothl1
    rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1'),
               mx.metric.Loss('RCNN_Mask')]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    rcnn_mask_metric = MaskAccMetric()
    rcnn_fgmask_metric = MaskFGAccMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric,
                rcnn_acc_metric, rcnn_bbox_metric,
                rcnn_mask_metric, rcnn_fgmask_metric]
    async_eval_processes = []
    logger.info(args)

    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    base_lr = trainer.learning_rate
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
                                        rcnn_box_loss, rcnn_mask_loss, args.amp)
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        train_data_iter = iter(train_data)
        next_data_batch = next(train_data_iter)
        next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
        for i in range(len(train_data)):
            batch = next_data_batch
            if i + epoch * len(train_data) <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter((i + epoch * len(train_data)) / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info('[Epoch {} Iteration {}] Set learning rate to {}'
                                    .format(epoch, i, new_lr))
                    trainer.set_learning_rate(new_lr)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            try:
                # prefetch next batch
                next_data_batch = next(train_data_iter)
                next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
            except StopIteration:
                pass

            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                batch_speed = args.log_interval * args.batch_size / (time.time() - btic)
                speed.append(batch_speed)
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, batch_speed, msg))
                btic = time.time()
        if speed:
            avg_batch_speed = sum(speed) / len(speed)
        # validate and save params
        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, Speed: {:.3f} samples/sec, {}'.format(
                epoch, (time.time() - tic), avg_batch_speed, msg))
        if not (epoch + 1) % args.val_interval:
            # consider reduce the frequency of validation to save time
            validate(net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map,
                     args)
        elif (not args.horovod) or hvd.rank() == 0:
            current_map = 0.
            save_params(net, logger, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
    for thread in async_eval_processes:
        thread.join()
Beispiel #26
0
def train(args):
    store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus, args.train_tgt_corpus, src_tokenizer,
        tgt_tokenizer, args.overwrite_cache)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer,
        args.overwrite_cache)
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    data_val = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ])
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    if args.fp16:
        raise NotImplementedError


#        cfg.MODEL.dtype = 'float16'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()
    rescale_loss = 100.0

    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    # Construct the trainer
    # TODO(sxjscience) Support AMP
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    trainer_settings = (model.collect_params(), 'adam', {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.98,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler
    })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(*trainer_settings)
    else:
        trainer = gluon.Trainer(*trainer_settings)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            seed=args.seed,
            num_parts=num_parts,
            part_index=rank)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    if local_rank == 0:
        logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)

    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    model.zero_grad()
    model_averager = AverageSGDTracker(model.collect_params())
    log_start_time = time.time()
    num_params, num_fixed_params = None, None
    # TODO(sxjscience) Add a log metric class
    accum_count = 0
    loss_denom = 0
    n_train_iters = 0
    log_wc = 0
    log_avg_loss = 0.0
    log_loss_denom = 0
    epoch_id = 0
    while (args.epochs < 0 or epoch_id < args.epochs
           ):  # when args.epochs < 0, the model will keep training
        n_epoch_train_iters = 0
        processed_batch_num = 0
        train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
        is_last_batch = False
        sample_data_l = next(train_multi_data_loader)
        while not is_last_batch:
            processed_batch_num += len(sample_data_l)
            loss_l = []
            for sample_data, ctx in zip(sample_data_l, ctx_l):
                if sample_data is None:
                    continue
                src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data
                src_wc, tgt_wc, bs = src_valid_length.sum(
                ), tgt_valid_length.sum(), src_token_ids.shape[0]
                loss_denom += tgt_wc - bs
                log_loss_denom += tgt_wc - bs
                log_wc += src_wc + tgt_wc
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                with mx.autograd.record():
                    tgt_pred = model(src_token_ids, src_valid_length,
                                     tgt_token_ids[:, :-1],
                                     tgt_valid_length - 1)
                    tgt_labels = tgt_token_ids[:, 1:]
                    loss = label_smooth_loss(tgt_pred, tgt_labels)
                    loss = mx.npx.sequence_mask(
                        loss,
                        sequence_length=tgt_valid_length - 1,
                        use_sequence_length=True,
                        axis=1)
                    loss_l.append(loss.sum() / rescale_loss)
            for l in loss_l:
                l.backward()
            accum_count += 1
            try:
                sample_data_l = next(train_multi_data_loader)
            except StopIteration:
                is_last_batch = True
            if local_rank == 0 and num_params is None:
                num_params, num_fixed_params = count_parameters(
                    model.collect_params())
                logging.info(
                    'Total Number of Parameters (not-fixed/fixed): {}/{}'.
                    format(num_params, num_fixed_params))
            sum_loss = sum([l.as_in_ctx(mx.cpu())
                            for l in loss_l]) * rescale_loss
            log_avg_loss += sum_loss
            mx.npx.waitall()
            if accum_count == args.num_accumulated or is_last_batch:
                # Update the parameters
                n_train_iters += 1
                n_epoch_train_iters += 1
                trainer.step(loss_denom.asnumpy() / rescale_loss)
                accum_count = 0
                loss_denom = 0
                model.zero_grad()
                if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
                   (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
                    model_averager.step()
                if local_rank == 0 and \
                   (n_epoch_train_iters % args.log_interval == 0 or is_last_batch):
                    log_end_time = time.time()
                    log_wc = log_wc.asnumpy()
                    wps = log_wc / (log_end_time - log_start_time)
                    log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
                    logging.info(
                        '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                        'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format(
                            epoch_id, processed_batch_num * num_parts,
                            len(train_data_loader), log_avg_loss,
                            np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                            trainer.learning_rate))
                    log_start_time = time.time()
                    log_avg_loss = 0
                    log_loss_denom = 0
                    log_wc = 0
                if local_rank == 0 and \
                   (args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
                    model.save_parameters(os.path.join(
                        args.save_dir, 'update{:d}.params'.format(
                            n_train_iters // args.save_interval_update)),
                                          deduplicate=True)
                if args.max_update > 0 and n_train_iters >= args.max_update:
                    break
        if local_rank == 0 and args.epochs > 0:
            model.save_parameters(os.path.join(
                args.save_dir, 'epoch{:d}.params'.format(epoch_id)),
                                  deduplicate=True)
        avg_valid_loss = validation(model, val_data_loader, ctx_l)
        logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format(
            epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

        if args.max_update > 0 and n_train_iters >= args.max_update:
            break
        epoch_id += 1

    if args.num_averages > 0:
        model_averager.copy_back(
            model.collect_params())  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)
Beispiel #27
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            'sgd', {
                'wd': args.wd,
                'momentum': args.momentum,
                'lr_scheduler': lr_scheduler
            },
            kvstore='local',
            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss('ObjLoss')
    center_metrics = mx.metric.Loss('BoxCenterLoss')
    scale_metrics = mx.metric.Loss('BoxScaleLoss')
    cls_metrics = mx.metric.Loss('ClassLoss')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it],
                                           ctx_list=ctx,
                                           batch_axis=0) for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(
                        x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                    sum_losses.append(obj_loss + center_loss + scale_loss +
                                      cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0):
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info(
                        '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                        .format(epoch, i, trainer.learning_rate,
                                args.batch_size / (time.time() - btic), name1,
                                loss1, name2, loss2, name3, loss3, name4,
                                loss4))
                btic = time.time()

        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info(
                '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                .format(epoch, (time.time() - tic), name1, loss1, name2, loss2,
                        name3, loss3, name4, loss4))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
Beispiel #28
0

def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
<<<<<<< HEAD

=======
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}
>>>>>>> origin/master
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None), kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)


    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
Beispiel #29
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'momentum': args.momentum
    }
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        print(len(train_data))
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric,
                                             args)
                map_name_train, mean_ap_train = validate(
                    net, train_data, ctx, eval_metric, args)
                if isinstance(map_name, list):
                    val_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name, mean_ap)
                    ])
                    train_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name_train, mean_ap_train)
                    ])
                    current_map = float(mean_ap[-1])
                else:
                    val_msg = '{}={}'.format(map_name, mean_ap)
                    train_msg = '{}={}'.format(map_name_train, mean_ap_train)
                    current_map = mean_ap
                logger.info('[Epoch {}] Validation: {}'.format(epoch, val_msg))
                logger.info('[Epoch {}] Train: {}'.format(epoch, train_msg))
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch,
                        args.save_interval,
                        os.path.join(args.model_dir, 'fastrcnn'))
        executor.__del__()
Beispiel #30
0
}
opt = mx.optimizer.create('sgd', **optimizer_params)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian',
                             factor_type="in",
                             magnitude=2)
model.initialize(initializer, ctx=context)

# Horovod: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
    hvd.broadcast_parameters(params, root_rank=0)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.gluon.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
    tic = time.time()
    train_data.reset()
    metric.reset()
    for nbatch, batch in enumerate(train_data, start=1):
        data = batch.data[0].as_in_context(context)
        label = batch.label[0].as_in_context(context)
        with autograd.record():
            output = model(data.astype(args.dtype, copy=False))