Ejemplo n.º 1
0
def weighted_sigmoid_rule(nb_epoch: int, steps: int = 10, **kwargs):
    """ The probability to apply each augmentation increases or decreases following a sigmoid curve.
    Those probabilities are weighted using three lambda value.

    :param nb_epoch: The number of epoch the rule will be effective.
    :param steps: The number of time the probabilities will be updated.
    :param lamdda_sup_max: Weight of the Lsup loss.
    :param lambda_cot_max: Weight of the Lcot loss.
    :param lambda_dif_max: Weight of the Ldiff loss.
    """
    lcm = kwargs.get("lambda_cot_max", 1)
    ldm = kwargs.get("lambda_diff_max", 1)
    lsm = kwargs.get("lambda_sup_max", 1)

    hop_length = np.linspace(0, nb_epoch, steps)

    sup_steps = np.asarray(
        [lsm * sigmoid_rampdown(x, nb_epoch) for x in hop_length])
    cot_steps = np.asarray(
        [lcm * sigmoid_rampup(x, nb_epoch) for x in hop_length])
    diff_steps = np.asarray(
        [ldm * sigmoid_rampup(x, nb_epoch) for x in hop_length])

    # normalize
    for i in range(steps):
        summed = sup_steps[i] + cot_steps[i] + diff_steps[i]
        sup_steps[i] /= summed
        cot_steps[i] /= summed
        diff_steps[i] /= summed

    return sup_steps, cot_steps, diff_steps
Ejemplo n.º 2
0
def get_current_consistency_weight(final_consistency_weight, epoch,
                                   step_in_epoch, total_steps_in_epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    epoch = epoch - args.consistency_rampup_starts
    epoch = epoch + step_in_epoch / total_steps_in_epoch
    return final_consistency_weight * ramps.sigmoid_rampup(
        epoch, args.consistency_rampup_ends - args.consistency_rampup_starts)
Ejemplo n.º 3
0
 def get_current_consistency_weight(self, epoch):
     # Consistency ramp-up from https://arxiv.org/abs/1610.02242
     if epoch < self.config.t_start:
         self.consistency_weight = 0.
     else:
         self.consistency_weight = ramps.sigmoid_rampup(
             epoch, self.config.c_rampup)
Ejemplo n.º 4
0
def sigmoid_rule(nb_epoch: int, steps: int = 10, **kwargs):
    """ The probability to apply each augmentation increases or decreases following a sigmoid curve.

    :param nb_epoch: The number of epoch the rule will be effective.
    :param steps: The number of time the probabilities will be updated.
    """
    hop_length = np.linspace(0, nb_epoch, steps)

    sup_steps = np.asarray([sigmoid_rampdown(x, nb_epoch) for x in hop_length])
    cot_steps = np.asarray([sigmoid_rampup(x, nb_epoch) for x in hop_length])
    diff_steps = np.asarray([sigmoid_rampup(x, nb_epoch) for x in hop_length])

    # normalize
    for i in range(steps):
        summed = sup_steps[i] + cot_steps[i] + diff_steps[i]
        sup_steps[i] /= summed
        cot_steps[i] /= summed
        diff_steps[i] /= summed

    return sup_steps, cot_steps, diff_steps
Ejemplo n.º 5
0
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return FLAGS.consistency_weight * ramps.sigmoid_rampup(epoch, FLAGS.consistency_rampup)
Ejemplo n.º 6
0
def get_current_consistency_weight(epoch):
    return args.consistency * ramps.sigmoid_rampup(epoch,
                                                   args.consistency_rampup)
Ejemplo n.º 7
0
def get_current_consistency_weight(args, epoch, step_in_epoch, total_steps_in_epoch):
    epoch = epoch - args.consistency_rampup_starts
    epoch = epoch + step_in_epoch / total_steps_in_epoch
    return args.lmbda * ramps.sigmoid_rampup(epoch, args.consistency_rampup_ends - args.consistency_rampup_starts)