示例#1
0
def get_loss(root, split, net, recon_wei, choice):
    if choice == 'w_bce':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = weighted_binary_crossentropy_loss(pos_class_weight)
    elif choice == 'bce':
        loss = 'binary_crossentropy'
    elif choice == 'dice':
        loss = dice_loss
    elif choice == 'multi_dice':
        loss = multiclass_dice_loss
    elif choice == 'w_mar':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = margin_loss(margin=0.4,
                           downweight=0.5,
                           pos_weight=pos_class_weight)
    elif choice == 'mar':
        loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0)
    else:
        raise Exception("Unknow loss_type")

    if net.find('caps') != -1:
        return {
            'out_seg': loss,
            'out_recon': 'mse'
        }, {
            'out_seg': 1.,
            'out_recon': recon_wei
        }
    else:
        return loss, None
示例#2
0
def get_loss(root, split, net, choice):
    if choice == 'w_bce':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = weighted_binary_crossentropy_loss(pos_class_weight)
    elif choice == 'bce':
        loss = 'binary_crossentropy'
    elif choice == 'dice':
        loss = dice_loss
    elif choice == 'w_mar':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = margin_loss(margin=0.4,
                           downweight=0.5,
                           pos_weight=pos_class_weight)
    elif choice == 'mar':
        loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0)
    else:
        raise Exception("Unknow loss_type")

    return loss, None
def get_loss(root, split, net, recon_wei, choice):
    if choice == 'w_bce':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = weighted_binary_crossentropy_loss(pos_class_weight)
    elif choice == 'bce':
        loss = 'binary_crossentropy'
    elif choice == 'dice':
        loss = weighted_dice_loss(S_PRESENCE)
    elif choice == 'w_mar':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = margin_loss(margin=0.4,
                           downweight=0.5,
                           pos_weight=pos_class_weight)
    elif choice == 'mar':
        loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0)
    elif choice == 'cce':
        loss = 'categorical_crossentropy'
    elif choice == 'scce':
        loss = 'sparse_categorical_crossentropy'
    elif choice == 'spread':
        loss = spread_loss(epoch_step=EpochCounter.counter)
    elif choice == 'w_spread':
        if 'Spectralis' in root:
            weights = np.array(S_PRESENCES)
        else:
            weights = np.array(C_PRESENCE)
        loss = weighted_spread_loss(weights=weights,
                                    epoch_step=EpochCounter.counter)
    else:
        raise Exception("Unknow loss_type")

    if net.find('caps') != -1:
        # return {'out_seg': loss, 'recon0': weighted_mse_loss(S_PIXEL_MSE[0]),
        #         'recon1': weighted_mse_loss(S_PIXEL_MSE[1]),
        #         'recon2': weighted_mse_loss(S_PIXEL_MSE[2]),
        #         'recon3': weighted_mse_loss(S_PIXEL_MSE[3])}, {'out_seg': 1., 'recon0': recon_wei,
        #                                                        'recon1': recon_wei,
        #                                                        'recon2': recon_wei,
        #                                                        'recon3': recon_wei}
        return {'out_seg': loss}, None
    else:
        return loss, None
示例#4
0
def get_loss(root, split, net, recon_wei, choice):
    if choice == 'w_bce':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = weighted_binary_crossentropy_loss(pos_class_weight)
    elif choice == 'bce':
        loss = 'binary_crossentropy'
    elif choice == 'dice':
        loss = dice_loss
    elif choice == 'w_mar':
        pos_class_weight = load_class_weights(root=root, split=split)
        loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=pos_class_weight)
    elif choice == 'mar':
        loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0)
    else:
        raise Exception("Unknow loss_type")

    if net.find('caps') != -1:
        return {'out_seg': loss, 'out_recon': 'mse'}, {'out_seg': 1., 'out_recon': recon_wei}
    else:
        return loss, None