Beispiel #1
0
    def filter_indices_by_size(
        self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
    ):
        """
        Filter examples that are too large

        Args:
            indices (np.array): original array of sample indices
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
        Returns:
            np.array: array of filtered sample indices
        """
        indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
        if len(ignored) > 0:
            if not ignore_invalid_inputs:
                raise Exception(
                    (
                        "Size of sample #{} is invalid (={}) since max_positions={}, "
                        "skip this example with --skip-invalid-size-inputs-valid-test"
                    ).format(ignored[0], dataset.size(ignored[0]), max_positions)
                )
            LOGGER.warning(
                (
                    "{:,} samples have invalid sizes and will be skipped, "
                    "max_positions={}, first few sample ids={}"
                ).format(len(ignored), max_positions, ignored[:10])
            )
        return indices
Beispiel #2
0
def gnn_tensorize(datapoint, source_dictionary, edge_types):
    tensorized_data = TensorizedGraphData(
        adjacency_lists=list(__iterate_edge_types(datapoint, edge_types)),
        node_tensorized_data=[
            # enforce_not_None(self.__node_embedding_model.tensorize(ni))
            source_dictionary.index(ni) for ni in datapoint.node_information
        ],
        reference_nodes={
            n: np.array(np.array(refs, dtype=np.int32))
            for n, refs in datapoint.reference_nodes.items()
        },
        num_nodes=len(datapoint.node_information),
    )

    if tensorized_data.num_nodes > 80000:
        LOGGER.warning("Dropping graph with %s nodes." %
                       tensorized_data.num_nodes)
        return None

    num_edges = sum(len(adj) for adj in tensorized_data.adjacency_lists)
    if num_edges > 100000:
        LOGGER.warning("Dropping graph with %s edges." % num_edges)
        return None

    return tensorized_data
Beispiel #3
0
 def _log_oom(self, exc):
     msg = "OOM: Ran out of memory with exception: {}".format(exc)
     LOGGER.warning(msg)
     if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
         for device_idx in range(torch.cuda.device_count()):
             LOGGER.warning(torch.cuda.memory_summary(device=device_idx))
     sys.stderr.flush()
 def check_alignment(alignment, src_len, tgt_len):
     if alignment is None or len(alignment) == 0:
         return False
     if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
         LOGGER.warning("alignment size mismatch found, skipping alignment!")
         return False
     return True
Beispiel #5
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = sample

        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        LOGGER.warning(
                            "ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            logging_outputs = [logging_output]
            if is_dummy_batch:
                sample_size *= 0  # multiply by 0 to preserve device

        # gather logging outputs from all replicas
        if self.args['distributed_training']['distributed_world_size'] > 1:
            logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ignore=is_dummy_batch,
            )
            if 'bleu' in logging_outputs[0]:
                logging_outputs[0]['bleu'] /= self.args[
                    'distributed_training']['distributed_world_size']
            if 'rouge_l' in logging_outputs[0]:
                logging_outputs[0]['rouge_l'] /= self.args[
                    'distributed_training']['distributed_world_size']
            if 'meteor' in logging_outputs[0]:
                logging_outputs[0]['meteor'] /= self.args[
                    'distributed_training']['distributed_world_size']
        # log validation stats
        logging_output = self._reduce_and_log_stats(logging_outputs,
                                                    sample_size)
        return logging_output
Beispiel #6
0
def verify_checkpoint_directory(save_dir: str) -> None:
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    temp_file_path = os.path.join(save_dir, "dummy")
    try:
        with open(temp_file_path, "w"):
            pass
    except OSError as e:
        LOGGER.warning(
            "Unable to access checkpoint save directory: {}".format(save_dir))
        raise e
    else:
        os.remove(temp_file_path)
Beispiel #7
0
def download(name):
    if name in BPE_MODEL_ARCHIVE_MAP:
        url = BPE_MODEL_ARCHIVE_MAP[name]
        LOGGER.info(f"Download {name} BPE model from {url}")
        out_file = os.path.join(__BPE_DIR__, f"{name}.tar.gz")
        gdown.download(url=url, output=out_file)
        try:
            with tarfile.open(out_file) as reader:
                reader.extractall(__BPE_DIR__)
            os.remove(out_file)
        except tarfile.ExtractError as err:
            LOGGER.error(__BPE_DIR__)
            LOGGER.warning(f"{name}.tar.gz is corrupted, please contact us.")
    else:
        raise FileExistsError(f"No {name}.tar.gz in the server. Please build your own BPE models. " \
                              f"Once they are built, you can upload them into the server.")
Beispiel #8
0
    def _apply(self, module, inp, x, backward):
        if torch.is_tensor(x):
            if isinstance(inp, tuple) and len(inp) > 0:
                inp = inp[0]
            err = self._detect(x, module.__module_name, backward)
            if err is not None:
                if torch.is_tensor(inp) and not backward:
                    err += (
                        f" input max: {inp.max().item()}, input min: {inp.min().item()}"
                    )

                has_printed_attr = 'has_printed_b' if backward else 'has_printed_f'
                LOGGER.warning(err)
                setattr(self, has_printed_attr, True)
        elif isinstance(x, dict):
            for v in x.values():
                self._apply(module, inp, v, backward)
        elif isinstance(x, list) or isinstance(x, tuple):
            for v in x:
                self._apply(module, inp, v, backward)
Beispiel #9
0
    def __init__(self, SO_FILE, LANGUAGE, to_lower=False, operators_file=None):
        self.parser = Parser()
        try:
            assert PathManager.exists(SO_FILE), FileExistsError(
                f"{SO_FILE} does not exist, automatically download TreeSitter parse file {LANGUAGE}.so."
            )
        except FileExistsError as err:
            LOGGER.warning(err)
            from ncc.hub.tree_sitter.download import download
            download(LANGUAGE)

        if LANGUAGE == 'csharp':
            LANGUAGE = 'c_sharp'
        self.parser.set_language(Language(SO_FILE, LANGUAGE))
        self.LANGUAGE = LANGUAGE
        self.to_lower = to_lower

        if operators_file is None:
            operators_file = os.path.join(os.path.dirname(__file__),
                                          'operators.json')
        with open(operators_file, 'r') as reader:
            self.operators = json_io.json_load(reader)
Beispiel #10
0
    def __init__(self, args, optimizer):
        super().__init__(args, optimizer)
        if len(args['optimization']['lr']) > 1:
            raise ValueError(
                'Cannot use a fixed learning rate schedule with cosine.'
                ' Consider --lr-scheduler=fixed instead.')

        self.min_lr = max(args['optimization'].get('min_lr', 0), 0)
        self.max_lr = args['optimization'].get('max_lr',
                                               args['optimization']['lr'][0])

        self.warmup_init_lr = args['optimization'].get('warmup_init_lr', 0)
        warmup_end_lr = args['optimization'].get('warmup_end_lr', self.max_lr)

        assert self.max_lr > self.min_lr, 'max_lr must be more than lr'

        self.t_mult = args['optimization'].get('t_mult', 1.)

        if 'lr_period_updates' not in args['optimization']:
            LOGGER.warning(
                'lr_period_updates has not been set and, therefore, set epoch_num * batch_num as default.'
            )
            self.period = -1
        else:
            self.period = args['optimization']['lr_period_updates']

        if args['optimization']['warmup_updates'] > 0:
            # linearly warmup for the first args.warmup_updates
            self.lr_step = \
                (warmup_end_lr - args['optimization']['warmup_init_lr']) / args['optimization']['warmup_updates']
        else:
            self.lr_step = 1

        self.warmup_updates = args['optimization']['warmup_updates']
        self.lr_shrink = args['optimization'].get('lr_shrink', 0.1)

        # initial learning rate
        self.lr = args['optimization']['warmup_init_lr']
        self.optimizer.set_lr(self.lr)
Beispiel #11
0
    def _setup_optimizer(self):
        no_decay = ['bias', 'LayerNorm.weight']
        params = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
            LOGGER.info(
                "NOTE: your device may support faster training with --fp16")
        self._optimizer = optimizers.setup_optimizer(self.args, params)

        if self.args['optimization']['use_bmuf']:
            self._optimizer = optimizers.NccBMUF(self.args, self._optimizer)

        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_schedulers.build_lr_scheduler(
            self.args, self.optimizer)
        if getattr(self._lr_scheduler, 'period', None) == -1:
            import math
            self._lr_scheduler.period = \
                self.args['optimization']['max_epoch'] * \
                math.ceil(len(self.task.dataset('train')) / self.args['dataset']['max_sentences'])
            LOGGER.warning('Update steps of {} has not been set and, therefore, set {} as default.'. \
                           format(self.lr_scheduler.__class__.__name__, self._lr_scheduler.period))
        self._lr_scheduler.step_update(0)
Beispiel #12
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample, is_dummy_batch = self._prepare_sample(sample)

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args['distributed_training']['distributed_world_size']
                        > 1 and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    LOGGER.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args['distributed_training'][
                            'distributed_world_size'] == 1:
                        return None
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0  # multiply by 0 to preserve device

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                # reduce gradients across workers
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP normalizes by the number of data parallel workers for
                # improved fp16 precision.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                # In case of fp16, this step also undoes loss scaling.
                # (Debugging note: Some optimizers perform this scaling on the
                # fly, so inspecting model.parameters() or optimizer.params may
                # still show the original, unscaled gradients.)
                num = (
                    self.args['distributed_training']['distributed_world_size']
                    if not self.args['optimization']['use_bmuf']
                    or self._sync_stats() else 1)
                self.optimizer.multiply_grads(num / (sample_size or 1.0))

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.args['optimization']['clip_norm'])

            # check that grad norms are consistent across workers
            if not self.args['optimization']['use_bmuf']:
                self._check_grad_norms(grad_norm)
            if not torch.isfinite(grad_norm).all():
                if self.args['common'].get('amp', False):
                    overflow = True
                else:
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
                if self.args['common'].get('amp', False) and overflow:
                    if self._amp_retries == self.args['common'][
                            'amp_batch_retries']:
                        LOGGER.info("AMP: skipping this batch.")
                        self._amp_retries = 0
                    else:
                        self._amp_retries += 1
                        return self.train_step(
                            samples,
                            raise_oom)  # recursion to feed in same batch

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            self.zero_grad()
            with NanDetector(self._model):
                for _, sample in enumerate(samples):
                    sample, _ = self._prepare_sample(sample)
                    self.task.train_step(sample,
                                         self.model,
                                         self.criterion,
                                         self.optimizer,
                                         self.get_num_updates(),
                                         ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            LOGGER.info(
                f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
            )
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                LOGGER.error("OOM during optimization, irrecoverable")
            raise e

        logging_output = None
        if not overflow:
            self.set_num_updates(self.get_num_updates() + 1)

            if self.cuda and self.cuda_env is not None:
                # log minimum free memory over the iteration
                gb_used = torch.cuda.max_memory_allocated(
                ) / 1024 / 1024 / 1024
                torch.cuda.reset_peak_memory_stats()
                gb_free = self.cuda_env.total_memory_in_GB - gb_used
                metrics.log_scalar("gb_free",
                                   gb_free,
                                   priority=1500,
                                   round=1,
                                   weight=0)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.cuda and self.args['common']['empty_cache_freq'] > 0
                    and ((self.get_num_updates() +
                          self.args['common']['empty_cache_freq'] - 1) %
                         self.args['common']['empty_cache_freq']) == 0):
                torch.cuda.empty_cache()

        if self.args['common']['fp16'] or self.args['common'].get(
                'amp', False):
            metrics.log_scalar(
                "loss_scale",
                (self.optimizer.scaler.loss_scale
                 if self.args['common']['fp16'] else
                 self.optimizer.scaler.get_scale()),
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output
import json
from collections import OrderedDict

from ncc import LOGGER

try:
    from third_party.pycocoevalcap.bleu import corpus_bleu
    from third_party.pycocoevalcap.rouge import Rouge
    from third_party.pycocoevalcap.meteor import Meteor
except ImportError as err:
    LOGGER.warning(err)
    from third_party.download import download

    download('pycocoevalcap')

from .smoothed_bleu import compute_smoothed_bleu


def eval_accuracies(hypotheses,
                    references,
                    sources=None,
                    filename=None,
                    mode='dev',
                    smoothed_blue=False):
    """An unofficial evalutation helper.
     Arguments:
        hypotheses: A mapping from instance id to predicted sequences.
        references: A mapping from instance id to ground truth sequences.
        copy_info: Map of id --> copy information.
        sources: Map of id --> input text sequence.
        filename:
Beispiel #14
0
# -*- coding: utf-8 -*-

try:
    import dgl
    import networkx as nx
except ImportError:
    from ncc import LOGGER

    LOGGER.warning(
        "Please install dgl with referring to https://www.dgl.ai/pages/start.html"
    )

import numpy as np
import torch

from ncc.data.constants import MAX_SUBTOKEN_LEN


def build_graph(tree_dict,
                dictionary,
                tree_leaf_subtoken=1,
                DGLGraph_PAD_WORD=-1) -> dgl.DGLGraph:
    #  叶子节点存的是拆开后的subtoken ,当然,如果token拆不开,那就还是一个token
    # 用来训练的.pt数据里叶子节点token保存格式是["a_hu",["a","hu"]],
    # (1)tree_leaf_subtoken为1时 本函数只将其subtoken转换成wordid ,#即保存为[和a对应的id,和hu对应的id],比如[23,179]
    # 如果是拆不开的token,pt数据里格式是 ["imq",["imq",PAD_WORD]]
    # 那么这里将其转换为[和imq对应的id,和codesum.PAD_WORD],比如[258,0]
    # pad到的长度由train val test整个数据集里token拆开后最大长度决定
    # (2)tree_leaf_subtoken为0时,本函数用的拆之前的token得到wordid,即比如用a_hu得到wordid
    nx_graph = nx.DiGraph()
Beispiel #15
0
def single_main(args, init_distributed=False):
    assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # 0. Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])
    set_seed.set_seed(args['common']['seed'])
    if init_distributed:
        args['distributed_training'][
            'distributed_rank'] = distributed_utils.distributed_init(args)

    # Verify checkpoint directory
    if distributed_utils.is_master(args):
        save_dir = args['checkpoint']['save_dir']
        checkpoint_utils.verify_checkpoint_directory(save_dir)
        if not PathManager.is_empty(os.path.join(save_dir, '*.pt')):
            LOGGER.warning(f"{save_dir} contains checkpoint files.")
        PathManager.rm(os.path.join(
            save_dir, '*.pt'))  # this code will remove pre-trained models

    # 1. Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 2. Load valid dataset (we load training data below, based on the latest checkpoint)
    task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1)

    # 3. Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    LOGGER.info(model)
    LOGGER.info('model {}, criterion {}'.format(args['model']['arch'],
                                                criterion.__class__.__name__))
    LOGGER.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # 4. Build trainer
    trainer = Trainer(args, task, model, criterion)
    LOGGER.info('training on {} GPUs'.format(
        args['distributed_training']['distributed_world_size']))
    LOGGER.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args['dataset']['max_tokens'],
            args['dataset']['max_sentences'],
        ))

    # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args,
                                                              trainer,
                                                              combine=False)

    # 6. Train until the learning rate gets too small
    max_epoch = args['optimization']['max_epoch'] or math.inf
    max_update = args['optimization']['max_update'] or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args['dataset']['valid_subset'].split(',')
    while (lr > args['optimization']['min_lr']
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[
                'dataset']['validate_interval'] == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        # early stop
        if should_stop_early(args, valid_losses[0]):
            LOGGER.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args['checkpoint']['patience']))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            combine=False,  # TODO to be checked
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in args['task']['data']),
        )

    train_meter.stop()
    LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #16
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        # self.zero_grad()
        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            # def maybe_no_sync():
            #     """
            #     Whenever *samples* contains more than one mini-batch, we
            #     want to accumulate gradients locally and only call
            #     all-reduce in the last backwards pass.
            #     """
            #     if (
            #         self.args['distributed_training']['distributed_world_size'] > 1
            #         and hasattr(self.model, "no_sync")
            #         and i < len(samples) - 1
            #     ):
            #         return self.model.no_sync()
            #     else:
            #         return contextlib.ExitStack()  # dummy contextmanager

            try:
                # with maybe_no_sync():
                #     # forward and backward
                #     loss, sample_size_i, logging_output = self.task.train_step(
                #         sample=sample,
                #         model=self.model,
                #         criterion=self.criterion,
                #         optimizer=self.optimizer,
                #         update_num=self.get_num_updates(),
                #         ignore_grad=is_dummy_batch,
                #     )
                #     del loss
                self.model.train()
                self.optimizer.zero_grad()
                self.set_num_updates(self.get_num_updates())
                loss, sample_size, logging_output = self.criterion(
                    self.model, sample)
                # if ignore_grad:
                #     loss *= 0
                print('loss: ', loss.item())
                # optimizer.backward(loss)
                loss.backward()
                self.optimizer.step()
                self.lr_scheduler.step()
                # logging_outputs.append(logging_output)
                # sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                # if self.cuda and self.get_num_updates() == 0:
                #     torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    LOGGER.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args['distributed_training'][
                            'distributed_world_size'] == 1:
                        return None
                else:
                    raise e