'[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, '
                    'SacreBlEU={}, Detok SacreBLUE={}'.format(
                        epoch_id, train_iter, total_train_iters, avg_val_loss,
                        np.exp(avg_val_loss), raw_sacrebleu_out.score,
                        detok_sacrebleu_out.score))
                writer.add_scalar('valid_loss', avg_val_loss, train_iter)
                writer.add_scalar('valid_bleu', raw_sacrebleu_out.score,
                                  train_iter)

    if args.num_averages > 0:
        model_averager.copy_back(
            param_dict)  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)


if __name__ == '__main__':
    os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
    parser = get_parser()
    args = parser.parse_args()
    if args.max_update > 0:
        args.epochs = -1
    np.random.seed(args.seed)
    mx.random.seed(args.seed)
    random.seed(args.seed)
    if args.fp16:
        # Initialize amp if it's fp16 training
        from mxnet import amp
        amp.init()
    train(args)
Exemple #2
0
def verify_backbone_fp16(model_cls,
                         cfg,
                         ctx,
                         inputs,
                         atol=1E-2,
                         rtol=1E-2,
                         check_amp=True):
    """Test whether the backbone model has the comparable parameter gradient +

    Parameters
    ----------
    model_cls
        The modeling class
    cfg
        The configuration
    ctx
        The context
    inputs
        The input tensors of the model. We will
    atol
        The absolute tolerance
    rtol
        The relative tolerance
    check_amp
        Whether to check the AMP process. You will need to ensure that there is no
        randomness in the model when it is turned on.

    """
    model_fp32 = model_cls.from_cfg(cfg, dtype='float32')
    model_fp32.initialize(ctx=ctx)
    model_fp32.hybridize()
    # Check forward
    fp32_inputs = move_to_ctx(inputs, ctx=ctx)
    outputs_fp32 = model_fp32(*fp32_inputs)
    mx.npx.waitall()
    # Check forward of fp16
    model_fp16 = model_cls.from_cfg(cfg, dtype='float16')
    model_fp16.share_parameters(model_fp32.collect_params())
    model_fp16.cast('float16')
    model_fp16.hybridize()
    for param in model_fp16.collect_params().values():
        assert param.dtype == 'float16'
    fp16_inputs = move_to_ctx(_cast_nested_to_fp16(inputs), ctx=ctx)
    outputs_fp16 = model_fp16(*fp16_inputs)
    mx.npx.waitall()
    _match_struct_output(outputs_fp16, outputs_fp32, atol=atol, rtol=rtol)
    if check_amp:
        from mxnet import amp
        amp.init()
        # Reconstruct the fp32 model
        model_fp32 = model_cls.from_cfg(cfg, dtype='float32')
        model_fp32.initialize(ctx=ctx)
        model_fp32.hybridize()
        trainer = mx.gluon.Trainer(model_fp32.collect_params(),
                                   'adam', {
                                       'learning_rate': 1E-3,
                                       'wd': 1E-4,
                                       'multi_precision': True
                                   },
                                   update_on_kvstore=False)
        amp.init_trainer(trainer)
        with mx.autograd.record():
            outputs_amp = model_fp32(*fp32_inputs)
            if not isinstance(outputs_amp, (tuple, list)):
                loss = outputs_amp.mean()
            else:
                loss = sum([ele.mean() for ele in outputs_amp])
            with amp.scale_loss(loss, trainer) as scaled_loss:
                mx.autograd.backward(scaled_loss)
        trainer.step(1)
        mx.npx.waitall()
Exemple #3
0
    def _train_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_fp16:
            from mxnet import amp
            amp.init()

        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True)
        else:
            model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path, ctx=ctx)
        model.hybridize(static_alloc=True)
        vocab_size = cfg.MODEL.vocab_size
        if hasattr(cfg.MODEL, 'units'):
            out_units = cfg.MODEL.units
        else:
            out_units = cfg.MODEL.DECODER.units
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, sequence_length, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, out_units), dtype=np.float32, ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (sequence_length, batch_size, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(0, 1, (batch_size, out_units),
                                                      dtype=np.float32,
                                                      ctx=ctx)
        else:
            raise NotImplementedError
        if model_cls.__name__ in ['BertModel', 'AlbertModel', 'ElectraModel', 'MobileBertModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, token_types, valid_length)
                    # We'd like to set the head gradient of
                    # contextual_embedding to contextual_embedding_ograd
                    # and the head gradient of pooled_out to pooled_out_ograd
                    # Thus, we simply doing two hadamard product and sum up the results.
                    fake_loss = mxnet.np.sum(contextual_embedding * contextual_embedding_ograd)\
                                + mxnet.np.sum(pooled_out * pooled_out_ograd)
                    fake_loss.backward()
                mxnet.npx.waitall()
        elif model_cls.__name__ in ['BartModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, valid_length,
                                                             input_ids, valid_length)
                    fake_loss = (contextual_embedding * contextual_embedding_ograd).sum() \
                                + (pooled_out * pooled_out_ograd).sum()
                    fake_loss.backward()
                mxnet.npx.waitall()
        else:
            raise NotImplementedError
        timeit.repeat(train_step, repeat=1, number=5)
        mxnet.npx.waitall()
        runtimes = timeit.repeat(train_step, repeat=self._repeat, number=3)
        mxnet.npx.waitall()
        ctx.empty_cache()
        mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            train_step()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(train_step)
            memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory
Exemple #4
0
def amp_init():
    amp.init()