Ejemplo n.º 1
0
def get_loss_func(config, device, logger, ewc_loss=False):
    """Get a function handle that can be used as task loss function.

    Note, this function makes use of function
    :func:`sequential.train_utils_sequential.sequential_nll`.

    Since PoS tagging is a classification task, this function implements a
    sequential cross-entropy loss.

    Args:
        config (argparse.Namespace): The command line arguments.
        device: Torch device (cpu or gpu).
        logger: Console (and file) logger.
        ewc_loss (bool): Whether the loss is determined for task training or
            to compute Fisher elements via EWC. Note, based on the user
            configuration, the loss computation might be different.

    Returns:
        (func): A function handler as described by argument ``custom_nll``
        of function :func:`utils.ewc_regularizer.compute_fisher`, if option
        ``pass_ids=True``.

        Note:
            This loss **sums** the NLL across the batch dimension. A proper
            scaling wrt other loss terms during training would require a
            multiplication of the loss with a factor :math:`N/B`, where
            :math:`N` is the training set size and :math:`B` is the mini-batch
            size.
    """
    if hasattr(config, 'ts_weighting') or \
            hasattr(config, 'ts_weighting_fisher'):
        raise NotImplementedError(
            'The copy task dataset has a fixed loss ' +
            'weighting scheme, which is not configurable.')

    ce_loss = tuseq.sequential_nll(loss_type='ce', reduction='sum')

    sample_loss_func = lambda Y, T, tsf, beta: ce_loss(
        Y, T, None, None, None, ts_factors=tsf, beta=beta)

    # Unfortunately, we can't just use the above loss function, since we need
    # to respect the different sequence lengths.
    # We therefore create a custom time step weighting mask per sample in a
    # given batch.
    def task_loss_func(Y, T, data, allowed_outputs, empirical_fisher,
                       batch_ids):
        # Build batch specific timestep mask.
        tsf = torch.zeros(T.shape[0], T.shape[1]).to(T.device)

        seq_lengths = data.get_out_seq_lengths(batch_ids)

        for i in range(batch_ids.size):
            sl = int(seq_lengths[i])

            tsf[:sl, i] = 1

        return sample_loss_func(Y, T, tsf, None)

    return task_loss_func
Ejemplo n.º 2
0
def get_copy_loss_func(config, device, logger, ewc_loss=False):
    """Get a function handle that can be used as task loss function.

    Note, this function makes use of function
    :func:`sequential.train_utils_sequential.sequential_nll`.

    We use the Binary Cross Entropy loss, since our desired outputs should
    always be 0s or 1s. This function can be used to do multi-label binary
    classification, which is what we are interested in with the copy task,
    since several output units should be active at any given time.

    Args:
        config (argparse.Namespace): The command line arguments.
        device: Torch device (cpu or gpu).
        logger: Console (and file) logger.
        ewc_loss (bool): Whether the loss is determined for task training or
            to compute Fisher elements via EWC. Note, based on the user
            configuration, the loss computation might be different.

    Returns:
        (func): A function handler as described by argument ``custom_nll``
        of function :func:`utils.ewc_regularizer.compute_fisher`, if option
        ``pass_ids=True``.

        Note:
            This loss **sums** the NLL across the batch dimension. A proper
            scaling wrt other loss terms during training would require a
            multiplication of the loss with a factor :math:`N/B`, where
            :math:`N` is the training set size and :math:`B` is the mini-batch
            size.
    """
    if hasattr(config, 'ts_weighting') or \
            hasattr(config, 'ts_weighting_fisher'):
        raise NotImplementedError(
            'The copy task dataset has a fixed loss ' +
            'weighting scheme, which is not configurable.')

    bce_loss = tuseq.sequential_nll(loss_type='bce', reduction='sum')

    sample_loss_func = lambda Y, T, tsf, beta: bce_loss(
        Y, T, None, None, None, ts_factors=tsf, beta=beta)

    # Unfortunately, we can't just use the above loss function, since we need
    # to respect the different sequence lengths.
    # We therefore create a custom time step weighting mask per sample in a
    # given batch.
    def task_loss_func(Y, T, data, allowed_outputs, empirical_fisher,
                       batch_ids):
        # Build batch specific timestep mask.
        tsf = torch.zeros(T.shape[0], T.shape[1]).to(T.device)

        pat_starts, pat_lengths = data.get_out_pattern_bounds(batch_ids)

        for i in range(batch_ids.size):
            ps = pat_starts[i]
            pe = ps + pat_lengths[i]

            tsf[ps:pe, i] = 1

            # Note, the `[i]` is necessary to avoid loosing the batch dimension.
            #loss += sample_loss_func(out_logits[s_start:s_end, [i], :],
            #    targets[s_start:s_end, [i], :], None, None)

        return sample_loss_func(Y, T, tsf, None)

    return task_loss_func
def get_loss_func(config, device, logger, ewc_loss=False):
    """Get a function handle that can be used as task loss function.

    Note, this function makes use of function
    :func:`sequential.train_utils_sequential.sequential_nll`.

    Args:
        config (argparse.Namespace): The command line arguments.
        device: Torch device (cpu or gpu).
        logger: Console (and file) logger.
        ewc_loss (bool): Whether the loss is determined for task training or
            to compute Fisher elements via EWC. Note, based on the user
            configuration, the loss computation might be different.

    Returns:
        (func): A function handler as described by argument ``custom_nll``
        of function :func:`utils.ewc_regularizer.compute_fisher`, if option
        ``pass_ids=True``.

        Note:
            This loss **sums** the NLL across the batch dimension. A proper
            scaling wrt other loss terms during training would require a
            multiplication of the loss with a factor :math:`N/B`, where
            :math:`N` is the training set size and :math:`B` is the mini-batch
            size.
    """
    # Log-likelihoods of timesteps are usually just summed. Here, the user
    # can change this to a weighted sum.
    if not ewc_loss:
        ts_weighting = config.ts_weighting
    else:
        ts_weighting = config.ts_weighting_fisher

    purpose = 'Fisher' if ewc_loss else 'loss'
    if ts_weighting == 'none':
        logger.debug(
            'Considering the NLL of all timesteps (including padded ' +
            'ones) for %s computation.' % purpose)
    elif ts_weighting == 'unpadded':
        logger.debug('Considering the NLL of all unpadded timesteps for ' +
                     '%s computation.' % purpose)
    elif ts_weighting == 'last':
        logger.debug('Considering the NLL of last unpadded timestep for ' +
                     '%s computation.' % purpose)
    elif ts_weighting == 'last_ten_percent':
        logger.debug('Considering the NLL of last 10% of unpadded timestep ' +
                     'for %s computation.' % purpose)
    else:
        assert ts_weighting == 'discount'
        logger.debug('Weighting the NLL of the later timesteps more than ' +
                     'the NLL of earlier timesteps for %s computation.' \
                     % purpose)

    ce_loss = tuseq.sequential_nll(loss_type='ce', reduction='sum')

    # Unfortunately, we can't just use the above loss function, since we need
    # to respect the different sequence lengths.
    # We therefore create a custom time step weighting mask per sample in a
    # given batch.
    def task_loss_func(Y, T, data, allowed_outputs, empirical_fisher,
                       batch_ids):
        # Build batch specific timestep mask.
        ts_factors = torch.zeros(T.shape[0], T.shape[1]).to(T.device)

        seq_lengths = data.get_out_seq_lengths(batch_ids)

        if ts_weighting == 'none':
            ts_factors = None
        if ts_weighting == 'unpadded':
            for i, sl in enumerate(seq_lengths):
                ts_factors[:sl, i] = 1
        elif ts_weighting == 'last':
            ts_factors[seq_lengths - 1, np.arange(seq_lengths.size)] = 1
        elif ts_weighting == 'last_ten_percent':
            for i, sl in enumerate(seq_lengths):
                sl_10 = sl // 10
                ts_factors[(sl - sl_10):sl, i] = 1
        else:
            assert ts_weighting == 'discount'
            gamma = 1.
            discount = 0.9
            max_num_ts = Y.shape[0]
            dc_factors = torch.zeros(max_num_ts)
            for tt in range(max_num_ts, -1, -1):
                dc_factors[tt] = gamma
                gamma *= discount
            for i, sl in enumerate(seq_lengths):
                ts_factors[:sl, i] = dc_factors[-sl:]

        # FIXME What is a good way of normalizing weights?
        # The timestep factors should be normalized such that the final
        # NLL strength corresponds to the original one. But what is the
        # original one? Either the one, that only takes the last timestep
        # into account (hence, `ts_factors` should sum to 1) or the one that
        # takes all unpadded timesteps into account (hence, `ts_factors` should
        # sum to `seq_lengths`).
        # Since there is only one label per sample, I decided that only 1
        # timestep counts, the last one.
        if ts_factors is not None:
            ts_factors /= ts_factors.sum(dim=0)[None, :]

        return ce_loss(Y,
                       T,
                       None,
                       None,
                       None,
                       ts_factors=ts_factors,
                       beta=None)

    return task_loss_func
Ejemplo n.º 4
0
def get_loss_func(config, device, logger, ewc_loss=False):
    """Get a function handle that can be used as task loss function.

    Note, this function makes use of function
    :func:`sequential.train_utils_sequential.sequential_nll`.

    Args:
        config (argparse.Namespace): The command line arguments.
        device: Torch device (cpu or gpu).
        logger: Console (and file) logger.
        ewc_loss (bool): Whether the loss is determined for task training or
            to compute Fisher elements via EWC. Note, based on the user
            configuration, the loss computation might be different.

    Returns:
        (func): A function handler as described by argument ``custom_nll``
        of function :func:`utils.ewc_regularizer.compute_fisher`, if option
        ``pass_ids=True``.

        Note:
            This loss **sums** the NLL across the batch dimension. A proper
            scaling wrt other loss terms during training would require a
            multiplication of the loss with a factor :math:`N/B`, where
            :math:`N` is the training set size and :math:`B` is the mini-batch
            size.
    """
    # Log-likelihoods of timesteps are usually just summed. Here, the user
    # can change this to a weighted sum.
    if not ewc_loss:
        ts_weighting = config.ts_weighting
    else:
        ts_weighting = config.ts_weighting_fisher

    # Note, there is no padding applied in this dataset.
    purpose = 'Fisher' if ewc_loss else 'loss'
    if ts_weighting == 'none' or ts_weighting == 'unpadded':
        logger.debug(
            'Considering the NLL of all timesteps for %s computation.' %
            purpose)
    elif ts_weighting == 'last':
        logger.debug('Considering the NLL of last timestep for ' +
                     '%s computation.' % purpose)
    elif ts_weighting == 'last_ten_percent':
        logger.debug('Considering the NLL of last 10% of timestep ' +
                     'for %s computation.' % purpose)
    else:
        assert ts_weighting == 'discount'
        logger.debug('Weighting the NLL of the later timesteps more than ' +
                     'the NLL of earlier timesteps for %s computation.' \
                     % purpose)

    ce_loss = tuseq.sequential_nll(loss_type='ce', reduction='sum')

    # Build batch specific timestep mask.
    # Note, all samples have the same sequence length.
    seq_length = 10
    ts_factors = torch.zeros(seq_length, 1).to(device)

    # FIXME We can compute the weigthings outside of this function, since
    # they are static for all batches (no padding).
    if ts_weighting == 'none' or ts_weighting == 'unpadded':
        ts_factors = None
    if ts_weighting == 'last':
        ts_factors[-1, :] = 1
    elif ts_weighting == 'last_ten_percent':
        sl_10 = seq_length // 10
        ts_factors[-sl_10:, :] = 1
    else:
        assert ts_weighting == 'discount'
        gamma = 1.
        discount = 0.9
        for tt in range(seq_length - 1, -1, -1):
            ts_factors[tt, 0] = gamma
            gamma *= discount

    # FIXME What is a good way of normalizing weights?
    # The timestep factors should be normalized such that the final
    # NLL strength corresponds to the original one. But what is the
    # original one? Either the one, that only takes the last timestep
    # into account (hence, `ts_factors` should sum to 1) or the one that
    # takes all timesteps into account (hence, `ts_factors` should
    # sum to `seq_length`).
    # Since there is only one label per sample, I decided that only 1
    # timestep counts, the last one.
    if ts_factors is not None:
        ts_factors /= ts_factors.sum()

    # We need to ensure additionally that `batch_ids` can be passed to the loss,
    # even though we don't use them here as all sequences have the same length.
    # Note, `dh`, `ao`, `ef` are also unused by `ce_loss` and are just provided
    # to certify a common interface.
    loss_func = lambda Y, T, dh, ao, ef, _: ce_loss(
        Y, T, None, None, None, ts_factors=ts_factors, beta=None)

    return loss_func