'[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, ' 'SacreBlEU={}, Detok SacreBLUE={}'.format( epoch_id, train_iter, total_train_iters, avg_val_loss, np.exp(avg_val_loss), raw_sacrebleu_out.score, detok_sacrebleu_out.score)) writer.add_scalar('valid_loss', avg_val_loss, train_iter) writer.add_scalar('valid_bleu', raw_sacrebleu_out.score, train_iter) if args.num_averages > 0: model_averager.copy_back( param_dict) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True) if __name__ == '__main__': os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' parser = get_parser() args = parser.parse_args() if args.max_update > 0: args.epochs = -1 np.random.seed(args.seed) mx.random.seed(args.seed) random.seed(args.seed) if args.fp16: # Initialize amp if it's fp16 training from mxnet import amp amp.init() train(args)
def verify_backbone_fp16(model_cls, cfg, ctx, inputs, atol=1E-2, rtol=1E-2, check_amp=True): """Test whether the backbone model has the comparable parameter gradient + Parameters ---------- model_cls The modeling class cfg The configuration ctx The context inputs The input tensors of the model. We will atol The absolute tolerance rtol The relative tolerance check_amp Whether to check the AMP process. You will need to ensure that there is no randomness in the model when it is turned on. """ model_fp32 = model_cls.from_cfg(cfg, dtype='float32') model_fp32.initialize(ctx=ctx) model_fp32.hybridize() # Check forward fp32_inputs = move_to_ctx(inputs, ctx=ctx) outputs_fp32 = model_fp32(*fp32_inputs) mx.npx.waitall() # Check forward of fp16 model_fp16 = model_cls.from_cfg(cfg, dtype='float16') model_fp16.share_parameters(model_fp32.collect_params()) model_fp16.cast('float16') model_fp16.hybridize() for param in model_fp16.collect_params().values(): assert param.dtype == 'float16' fp16_inputs = move_to_ctx(_cast_nested_to_fp16(inputs), ctx=ctx) outputs_fp16 = model_fp16(*fp16_inputs) mx.npx.waitall() _match_struct_output(outputs_fp16, outputs_fp32, atol=atol, rtol=rtol) if check_amp: from mxnet import amp amp.init() # Reconstruct the fp32 model model_fp32 = model_cls.from_cfg(cfg, dtype='float32') model_fp32.initialize(ctx=ctx) model_fp32.hybridize() trainer = mx.gluon.Trainer(model_fp32.collect_params(), 'adam', { 'learning_rate': 1E-3, 'wd': 1E-4, 'multi_precision': True }, update_on_kvstore=False) amp.init_trainer(trainer) with mx.autograd.record(): outputs_amp = model_fp32(*fp32_inputs) if not isinstance(outputs_amp, (tuple, list)): loss = outputs_amp.mean() else: loss = sum([ele.mean() for ele in outputs_amp]) with amp.scale_loss(loss, trainer) as scaled_loss: mx.autograd.backward(scaled_loss) trainer.step(1) mx.npx.waitall()
def _train_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\ -> Tuple[float, Memory]: if self._use_fp16: from mxnet import amp amp.init() if self._use_gpu: ctx = mxnet.gpu() else: ctx = mxnet.cpu() model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name) cfg.defrost() cfg.MODEL.layout = self._layout if model_cls.__name__ not in ['BartModel']: cfg.MODEL.compute_layout = self._compute_layout cfg.freeze() if model_cls.__name__ in ['BartModel']: model = model_cls.from_cfg(cfg, extract_feature=True) else: model = model_cls.from_cfg(cfg) model.load_parameters(backbone_param_path, ctx=ctx) model.hybridize(static_alloc=True) vocab_size = cfg.MODEL.vocab_size if hasattr(cfg.MODEL, 'units'): out_units = cfg.MODEL.units else: out_units = cfg.MODEL.DECODER.units if self._layout == 'NT': input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length), dtype=np.int32, ctx=ctx) token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx) valid_length = mxnet.np.full((batch_size,), sequence_length, dtype=np.int32, ctx=ctx) contextual_embedding_ograd = mxnet.np.random.normal( 0, 1, (batch_size, sequence_length, out_units), dtype=np.float32, ctx=ctx) pooled_out_ograd = mxnet.np.random.normal( 0, 1, (batch_size, out_units), dtype=np.float32, ctx=ctx) elif self._layout == 'TN': input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size), dtype=np.int32, ctx=ctx) token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx) valid_length = mxnet.np.full((batch_size,), sequence_length, dtype=np.int32, ctx=ctx) contextual_embedding_ograd = mxnet.np.random.normal( 0, 1, (sequence_length, batch_size, out_units), dtype=np.float32, ctx=ctx) pooled_out_ograd = mxnet.np.random.normal(0, 1, (batch_size, out_units), dtype=np.float32, ctx=ctx) else: raise NotImplementedError if model_cls.__name__ in ['BertModel', 'AlbertModel', 'ElectraModel', 'MobileBertModel']: def train_step(): with mxnet.autograd.record(): contextual_embedding, pooled_out = model(input_ids, token_types, valid_length) # We'd like to set the head gradient of # contextual_embedding to contextual_embedding_ograd # and the head gradient of pooled_out to pooled_out_ograd # Thus, we simply doing two hadamard product and sum up the results. fake_loss = mxnet.np.sum(contextual_embedding * contextual_embedding_ograd)\ + mxnet.np.sum(pooled_out * pooled_out_ograd) fake_loss.backward() mxnet.npx.waitall() elif model_cls.__name__ in ['BartModel']: def train_step(): with mxnet.autograd.record(): contextual_embedding, pooled_out = model(input_ids, valid_length, input_ids, valid_length) fake_loss = (contextual_embedding * contextual_embedding_ograd).sum() \ + (pooled_out * pooled_out_ograd).sum() fake_loss.backward() mxnet.npx.waitall() else: raise NotImplementedError timeit.repeat(train_step, repeat=1, number=5) mxnet.npx.waitall() runtimes = timeit.repeat(train_step, repeat=self._repeat, number=3) mxnet.npx.waitall() ctx.empty_cache() mxnet.npx.waitall() # Profile memory if self._use_gpu: nvml.nvmlInit() train_step() mxnet.npx.waitall() handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx) meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) max_bytes_in_use = meminfo.used memory = Memory(max_bytes_in_use) # shutdown nvml nvml.nvmlShutdown() else: # cpu memory_bytes = measure_peak_memory_cpu(train_step) memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes return float(np.min(runtimes) / 3.0), memory
def amp_init(): amp.init()