Beispiel #1
0
    def __init__(self, args, task, model, criterion):
        self.args = args
        self.task = task

        self.cuda = torch.cuda.is_available() and not args.cpu
        if self.cuda:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        # copy model and criterion to current device
        self._criterion = criterion
        self._model = model
        if args.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        self._criterion = self._criterion.to(device=self.device)
        self._model = self._model.to(device=self.device)

        self._dummy_batch = "DUMMY"  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        if self.cuda and args.distributed_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(
                args.distributed_world_size)
        else:
            self._grad_norm_buf = None

        metrics.log_start_time("wall", priority=790, round=0)
Beispiel #2
0
    def __init__(self, args, task, model, criterion, quantizer=None):
        self.args = args
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        device = ort_supplement.setup_onnxruntime_with_mpi(args)

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._ptmodel = model
        self._model = model
        #Shape inference
        #self._infer_model =  SymbolicShapeInference.infer_shapes(self._model, auto_merge=True, int_max=100000)
        self._model = bart_model_with_loss(self._model, criterion)
        self._model = ort_supplement.create_ort_trainer(args, device, self._model)

        self._dummy_batch = "DUMMY"  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        metrics.log_start_time("wall", priority=790, round=0)
Beispiel #3
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                if self._dummy_batch == "DUMMY":
                    self._dummy_batch = sample
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.cfg.distributed_training.distributed_world_size == 1:
                        return None
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm

                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (
                sample_size,
                ooms,
                total_train_time,
            ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ooms,
                train_time,
                ignore=is_dummy_batch,
            )
            self._cumulative_training_time = (total_train_time /
                                              self.data_parallel_world_size)

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP already normalizes by the number of data parallel workers.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                if not self.cfg.optimization.use_bmuf:
                    self.optimizer.multiply_grads(
                        self.data_parallel_world_size / sample_size)
                elif sample_size > 0:  # BMUF needs to check sample size
                    num = self.data_parallel_world_size if self._sync_stats(
                    ) else 1
                    self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.cfg.optimization.clip_norm)

            # check that grad norms are consistent across workers
            # on tpu check tensor is slow
            if not self.tpu:
                if (not self.cfg.optimization.use_bmuf
                        and self.cfg.distributed_training.distributed_wrapper
                        != "SlowMo"):
                    self._check_grad_norms(grad_norm)
                if not torch.isfinite(grad_norm).all():
                    # check local gradnorm single GPU case, trigger NanDetector
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.get_model()):
                self.task.train_step(
                    sample,
                    self.model,
                    self.criterion,
                    self.optimizer,
                    self.get_num_updates(),
                    ignore_grad=False,
                )
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, "perform_additional_optimizer_actions"):
            if hasattr(self.optimizer, "fp32_params"):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        logging_output = None
        if (not overflow or self.cfg.distributed_training.distributed_wrapper
                == "SlowMo"):
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm

                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.cfg.common.log_interval == 0:
                    # log memory usage
                    mem_info = xm.get_memory_info(self.device)
                    gb_free = mem_info["kb_free"] / 1024 / 1024
                    gb_total = mem_info["kb_total"] / 1024 / 1024
                    metrics.log_scalar(
                        "gb_free",
                        gb_free,
                        priority=1500,
                        round=1,
                        weight=0,
                    )
                    metrics.log_scalar(
                        "gb_total",
                        gb_total,
                        priority=1600,
                        round=1,
                        weight=0,
                    )

                    logging_output = self._reduce_and_log_stats(
                        logging_outputs,
                        sample_size,
                        grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs,
                    sample_size,
                    grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (self.cuda and self.cfg.common.empty_cache_freq > 0
                        and ((self.get_num_updates() +
                              self.cfg.common.empty_cache_freq - 1) %
                             self.cfg.common.empty_cache_freq) == 0):
                    torch.cuda.empty_cache()

        if self.cfg.common.fp16:
            metrics.log_scalar(
                "loss_scale",
                self.optimizer.scaler.loss_scale,
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output
Beispiel #4
0
    def __init__(self,
                 cfg: FairseqConfig,
                 task,
                 model,
                 criterion,
                 quantizer=None):

        if isinstance(cfg, Namespace):
            logger.warning(
                "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
            )
            cfg = convert_namespace_to_omegaconf(cfg)

        self.cfg = cfg
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)
        self.tpu = cfg.common.tpu
        self.cuda = torch.cuda.is_available(
        ) and not cfg.common.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device("cuda")
        elif self.tpu:
            self.device = utils.get_tpu_device()
        else:
            self.device = torch.device("cpu")

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if cfg.common.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif cfg.common.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        if not cfg.distributed_training.pipeline_model_parallel:
            self._criterion = self._criterion.to(device=self.device)
            self._model = self._model.to(device=self.device)
        self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
        self.last_device = None
        if self.cuda and self.pipeline_model_parallel:
            self.last_device = torch.device(
                cfg.distributed_training.pipeline_devices[-1])

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info("detected shared parameter: {} <- {}".format(
                    shared_param[0], path))
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = None  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(
                self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(
                    self.cuda_env, group=distributed_utils.get_global_group())
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(
                    self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
Beispiel #5
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                if not is_dummy_batch:
                    sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3)
        if ooms == self.args.distributed_world_size * len(samples):
            logger.warning("OOM in all workers, skipping update")
            self.zero_grad()
            return None

        try:
            # normalize grads by sample size
            if sample_size > 0:
                if self._sync_stats():
                    # multiply gradients by (# GPUs / sample_size) since DDP
                    # already normalizes by the number of GPUs. Thus we get
                    # (sum_of_gradients / sample_size).
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / sample_size)
                else:
                    self.optimizer.multiply_grads(1 / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size)
            metrics.log_speed("ups",
                              1.,
                              ignore_first=10,
                              priority=100,
                              round=2)
            metrics.log_scalar("gnorm",
                               utils.item(grad_norm),
                               priority=400,
                               round=3)
            metrics.log_scalar(
                "clip",
                100 if grad_norm > self.args.clip_norm > 0 else 0,
                priority=500,
                round=1,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Beispiel #6
0
    def __init__(self, args, task, model, criterion, quantizer=None):
        self.args = args
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if self.tpu:
            import torch_xla.core.xla_model as xm
            self._model = xm.send_cpu_data_to_device(self._model, self.device)
        if args.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif args.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        self._criterion = self._criterion.to(device=self.device)
        self._model = self._model.to(device=self.device)

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info(
                    'detected shared parameter: {} <- {}'.format(shared_param[0], path)
                )
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = "DUMMY"  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env)
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
Beispiel #7
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        overflow = False
        try:
            if self.tpu and self.data_parallel_world_size > 1:
                import torch_xla.core.xla_model as xm
                gradients = xm._fetch_gradients(self.optimizer.optimizer)
                xm.all_reduce('sum',
                              gradients,
                              scale=1.0 / self.data_parallel_world_size)

            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args.use_bmuf:
                self.optimizer.multiply_grads(self.data_parallel_world_size /
                                              sample_size)
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.data_parallel_world_size if self._sync_stats(
                ) else 1
                self.optimizer.multiply_grads(num / sample_size)

            # clip grads
            grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if (not self.args.use_bmuf
                    and self.args.distributed_wrapper != 'SlowMo'
                    and not self.tpu):
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     self.get_num_updates(),
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm
                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.args.log_interval == 0:
                    logging_output = self._reduce_and_log_stats(
                        logging_outputs,
                        sample_size,
                        grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs,
                    sample_size,
                    grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (self.cuda and self.args.empty_cache_freq > 0 and (
                    (self.get_num_updates() + self.args.empty_cache_freq - 1) %
                        self.args.empty_cache_freq) == 0):
                    torch.cuda.empty_cache()

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Beispiel #8
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0

        # added by Junxian
        lambda_stats_sum = 0
        nsentences = 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    if is_dummy_batch:
                        print("dummy batch!")

                    # try:
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                    logging_outputs.append(logging_output)
                    sample_size += sample_size_i
                    nsentences += logging_output['nsentences']
                    # except:
                    #     pass

                    # Added by Junxian: manually update lambda (support distributed training)
                    with torch.no_grad():
                        lambda_stats_sum_i = self.task.collect_lambda_stats(
                            self.model, sample)
                        lambda_stats_sum = lambda_stats_sum + lambda_stats_sum_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        try:
            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args.use_bmuf:
                self.optimizer.multiply_grads(
                    self.args.distributed_world_size / sample_size)
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.args.distributed_world_size if self._sync_stats(
                ) else 1
                self.optimizer.multiply_grads(num / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # Added by Junxian: manually update lambda (support distributed training)
            with torch.no_grad():
                if self._sync_stats():
                    torch.distributed.all_reduce(
                        lambda_stats_sum, op=torch.distributed.ReduceOp.SUM)
                    # print('nsentences {} per gpu'.format(nsentences))
                    nsentences_t = torch.tensor(nsentences, device=self.device)
                    torch.distributed.all_reduce(
                        nsentences_t, op=torch.distributed.ReduceOp.SUM)
                    nsentences = nsentences_t.item()

                    # print('nsentences total {}'.format(nsentences))

                # TODO(junxian): is_dummy_batch might be different across GPUs and would
                # potentially cause lambda mismatch among GPUs
                self.task.distributed_update_lambda(
                    model=self.model,
                    lambda_stats_sum=lambda_stats_sum,
                    nsentences=nsentences,
                    update_num=self.get_num_updates(),
                    ignore_grad=is_dummy_batch)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

                # added by Junxian to check some manually updated params
                # self._check_grad_norms(torch.tensor([self.model._lambda_t], device=self.device))
                self._check_grad_norms(self.model.get_lambda().max())

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     self.get_num_updates(),
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Beispiel #9
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.data_parallel_world_size > 1
                    and hasattr(self.model, "no_sync")
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        if is_dummy_batch:
            sample_size *= 0.  # multiply by 0 to preserve device

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
                logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
            )

        overflow = False
        try:
            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args.use_bmuf:
                multiplier = self.data_parallel_world_size
                self.optimizer.multiply_grads(
                    multiplier / sample_size
                )
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.data_parallel_world_size if self._sync_stats() else 1
                self.optimizer.multiply_grads(num / sample_size)

            # clip grads
            grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo':
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
                    ignore_grad=False
                )
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)
            try:
                opt_ = self.optimizer.fp32_optimizer._optimizer
            except:
                opt_ = self.optimizer._optimizer
            states = opt_.state[opt_.param_groups[0]['params'][0]]

            gvar = None
            adam_mom2 = None
            gvar_diff = None
            xstd = None
            ams_mom = None
            acc_ratio = None
            real_var = None
            real_var_diff = None

            if self.args.optimizer == "varscale_sgd":
                adam_mom2 = torch.mean(states['g_sq_est']).item()

            if getattr(opt_, "adaptive_lrs", None) is not None:
                lr_min, lr_max, lr_median = opt_.adaptive_lrs
            else:
                lr_min, lr_max, lr_median = None, None, None

            if getattr(opt_, "update_size", None) is not None:
                update_min, update_max, update_median = opt_.update_size
            else:
                update_min, update_max, update_median = None, None, None

            valid_ratio = getattr(opt_, "valid_ratio", None)

            ad_beta = getattr(opt_, "adaptive_beta", None)

            var_adapt = getattr(opt_, "var_adapt", None)
                # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size, grad_norm, gvar=gvar, adam_mom2=adam_mom2,
                gvar_diff=gvar_diff, xstd=xstd, ams_mom=ams_mom, acc_ratio=acc_ratio, real_var=real_var,
                real_var_diff=real_var_diff, ad_beta=ad_beta, lr_min=lr_min, lr_max=lr_max, lr_median=lr_median,
                update_min=update_min, update_max=update_max, update_median=update_median, valid_ratio=valid_ratio,
                var_adapt=var_adapt
            )

        # clear CUDA cache to reduce memory fragmentation
        if (
            self.args.empty_cache_freq > 0
            and (
                (self.get_num_updates() + self.args.empty_cache_freq - 1)
                % self.args.empty_cache_freq
            ) == 0
            and torch.cuda.is_available()
            and not self.args.cpu
        ):
            torch.cuda.empty_cache()

        if self.args.fp16:
            metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Beispiel #10
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        if is_dummy_batch:
            sample_size *= 0.  # multiply by 0 to preserve device

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        try:
            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args.use_bmuf:
                multiplier = self.data_parallel_world_size
                self.optimizer.multiply_grads(multiplier / sample_size)
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.data_parallel_world_size if self._sync_stats(
                ) else 1
                self.optimizer.multiply_grads(num / sample_size)
            """  
            # simple-ln 
            for name, param in self.model.named_parameters():
                if 'layer_norm' in name:
                    param.grad = None
            """

            # clip grads
            grad_norm = self.clip_grad_norm(self.args.clip_norm)
            """ 
            for name, param in self.model.named_parameters():
                if param.grad == None or 'bias' in name:
                    continue
                else:
                    
                    param.grad = param.grad.data.float() / param.grad.data.float().norm() 
                    #print(name)
                    #print(torch.mean(torch.abs(param.grad)))
                
            """

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            beta3 = 0.99
            if self.get_num_updates() > 0:
                for name, param in self.model.named_parameters():
                    if 'decoder.layers.0.fc2.weight' in name:
                        #layer0 = self.optimizer._optimizer.state[param]['exp_avg'].data.float().norm()
                        layer0 = param.grad.data.float().norm()
                    elif 'decoder.layers.5.fc2.weight' in name:
                        #layer5 = self.optimizer._optimizer.state[param]['exp_avg'].data.float().norm()
                        layer5 = param.grad.data.float().norm()
                    else:
                        pass

                current_ratio = layer0.item() / layer5.item()
                decay_ratio = current_ratio * (
                    1 - beta3**self.get_num_updates()) / self.ratio
                #print(decay_ratio)
                if (decay_ratio > 1.5 or
                        decay_ratio < 0.75) and self.get_num_updates() < 4000:
                    self.optimizer._optimizer.decay = True
                    #pass
                else:
                    pass
                self.ratio = beta3 * self.ratio + (1 - beta3) * current_ratio

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)
            self.optimizer._optimizer.decay = False
            """ 
            #print(self.get_num_updates())
            for name, param in self.model.named_parameters():
                if 'decoder.layers.0' in name and 'fc2' in name and 'weight' in name:
                    #exp_avg = self.optimizer._optimizer.state[param]['exp_avg']
                    #denom = torch.sqrt(self.optimizer._optimizer.state[param]['exp_avg_sq']) + 1e-8
                    #target = (math.sqrt(1 - 0.98 ** self.get_num_updates()) / (1 - 0.9 ** self.get_num_updates())) * exp_avg / denom
                    #target = exp_avg / (1 - 0.9 ** self.get_num_updates())
                    #print(name)
                    #print(torch.mean(torch.abs(target)))
                    #print(torch.mean(torch.abs(self.optimizer._optimizer.state[param]['update_term'])))
                    print(self.optimizer._optimizer.state[param]['ratio'])
            """

            if self.get_num_updates() == 1:
                for name, param in self.model.named_parameters():
                    if 'decoder.layers.0.fc2.weight' in name:
                        layer0 = self.optimizer._optimizer.state[param][
                            'exp_avg'].data.float().norm()
                    elif 'decoder.layers.5.fc2.weight' in name:
                        layer5 = self.optimizer._optimizer.state[param][
                            'exp_avg'].data.float().norm()
                    else:
                        pass

                current_ratio = layer0.item() / layer5.item()
                self.ratio = beta3 * self.ratio + (1 - beta3) * current_ratio

            # visualize lr
            """
            for name, param in self.model.named_parameters():
                if 'decoder.layers.0.fc2.weight' in name:
                    print(self.optimizer._optimizer.state[param]['ratio']) 
                    break   
            """
            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     self.get_num_updates(),
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Beispiel #11
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        #self.model.train()
        #self.criterion.train()
        #self.zero_grad()
        #print(len(samples))
        #print('In ORT Train Step')
        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)

            if sample is not None:
                #print ('Token Ids: ', sample['id'])
                '''
                net_input = sample['net_input']
                src_tokens = net_input['src_tokens']
                print('ORT_TRAIN_STEP: src_tokens size: {}'.format(src_tokens.size()))
                src_lengths = net_input['src_lengths']
                prev_output_tokens = net_input['prev_output_tokens']
                target = sample['target']
                target = target.view(-1)
                print('ORT_TRAIN_STEP: src_lengths size: {}'.format(src_lengths.size()))
                print('ORT_TRAIN_STEP: prev_output_tokens size: {}'.format(prev_output_tokens.size()))
                print('ORT_TRAIN_STEP: target size: {}'.format(target.size()))
                if (src_lengths.size(0) != 3):
                    print('src_lengths incorrect size', src_lengths.size(0))
                    sample = None
                '''
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                sample = self._prepare_sample(sample)
                is_dummy_batch = False

            #for key, value in sample.items():
                #print('Sample key: {}'.format(key))
            '''
            # Visualize model
            model_desc = ort_supplement.bart_model_description(self.args)
        
            # example: {input0:{0:'batch'}, input1:{0:'batch'}}
            dynamic_axes = {}
            for input in model_desc.inputs_:
                symbolic_axis = {}
                for i, axis in enumerate(input.shape_):
                    if isinstance(axis, str):
                        symbolic_axis[i] = axis
                if len(symbolic_axis):
                    dynamic_axes[input.name_] = symbolic_axis

            for output in model_desc.outputs_:
                symbolic_axis = {}
                for i, axis in enumerate(output.shape_):
                    if isinstance(axis, str):
                        symbolic_axis[i] = axis
                if len(symbolic_axis):
                    dynamic_axes[output.name_] = symbolic_axis

            net_input = sample['net_input']
            src_tokens = net_input['src_tokens']
            src_lengths = net_input['src_lengths']
            prev_output_tokens = net_input['prev_output_tokens']
            target = sample['target']
            target = target.view(-1)
            src_tokens.cpu()
            src_lengths.cpu()
            prev_output_tokens.cpu()
            target.cpu()
            #self._model.cuda()
            input_names = [input.name_ for input in model_desc.inputs_]
            output_names = [output.name_ for output in model_desc.outputs_]

            self._model.eval()
            with torch.no_grad():
                sample_outputs = self._model(src_tokens, src_lengths, prev_output_tokens, target)
            if isinstance(sample_outputs, torch.Tensor):
                sample_outputs = [sample_outputs]
            for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
                output_desc.dtype_ = sample_output.dtype
            self._model.train()
            import io
            f = io.BytesIO()

            # Other export options to use(this is for backward compatibility).
            other_export_options = {}
            other_export_options['training'] = True

            torch.onnx._export(self._model, tuple([src_tokens, src_lengths, prev_output_tokens, target]), f,
                    input_names=input_names,
                    output_names=output_names,
                    opset_version=12,
                    dynamic_axes=dynamic_axes,
                    _retain_param_name=True,
                    example_outputs=tuple(sample_outputs),
                    do_constant_folding=False,
                    **other_export_options)
            '''
            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.data_parallel_world_size > 1
                    and hasattr(self.model, "no_sync")
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = ort_supplement.ort_train_step(
                        self.args,
                        update_num=self.get_num_updates(),
                        model=self.model,
                        sample=sample,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
                logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
            )

        overflow = False
        '''
        try:
            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args.use_bmuf:
                self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size)
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.data_parallel_world_size if self._sync_stats() else 1
                self.optimizer.multiply_grads(num / sample_size)

            # clip grads
            grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if (
                not self.args.use_bmuf
                and self.args.distributed_wrapper != 'SlowMo'
                and not self.tpu
            ):
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.model):
                self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
                    ignore_grad=False
                )
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)
        '''

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
            '''
            # clear CUDA cache to reduce memory fragmentation
            if (
                self.cuda
                and self.args.empty_cache_freq > 0
                and (
                    (self.get_num_updates() + self.args.empty_cache_freq - 1)
                    % self.args.empty_cache_freq
                ) == 0
            ):
                torch.cuda.empty_cache()
            '''
        #if self.args.fp16:
            #metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)

        metrics.log_stop_time("train_wall")

        return logging_output