Exemple #1
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output.shape[1]
            output_softmax = softmax_helper(output)
            output_seg = output_softmax.argmax(1)
            target = target[:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()

            self.online_eval_foreground_dc.append(
                list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Exemple #2
0
    def forward(self, x, y):
        dc = 0
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            net_output = self.apply_nonlin(x)  # (b,c,x,y,z)

            gt_onehot = gt2onehot(net_output, y, axes)  # (b,c,x,y,z)
            intersection = sum_tensor(net_output * gt_onehot,
                                      axes,
                                      keepdim=False)
            pred_o = sum_tensor(net_output**2, axes, keepdim=False)
            ground_o = sum_tensor(gt_onehot**2, axes, keepdim=False)
            dc = 2.0 * (intersection + self.smooth) / (ground_o + pred_o +
                                                       self.smooth)

            if not self.do_bg:
                if self.batch_dice:
                    dc = dc[1:]
                else:
                    dc = dc[:, 1:]
            dc = dc.mean()

        return -dc
Exemple #3
0
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tp, dim=1)),
                         dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fp, dim=1)),
                         dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fn, dim=1)),
                         dim=1)

    if square:
        tp = tp**2
        fp = fp**2
        fn = fn**2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn
Exemple #4
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if len(shp_x) != len(shp_y):
            y = y.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(x.shape, y.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y
        else:
            gt = y.long()
            y_onehot = torch.zeros(shp_x)
            if x.device.type == "cuda":
                y_onehot = y_onehot.cuda(x.device.index)
            y_onehot.scatter_(1, gt, 1)

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        if not self.do_bg:
            x = x[:, 1:]
            y_onehot = y_onehot[:, 1:]

        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask,
                                        self.square)

        # GDL weight computation, we use 1/V
        volumes = sum_tensor(
            y_onehot, axes) + 1e-6  # add some eps to prevent div by zero

        if self.square_volumes:
            volumes = volumes**2

        # apply weights
        tp = tp / volumes
        fp = fp / volumes
        fn = fn / volumes

        # sum over classes
        if self.batch_dice:
            axis = 0
        else:
            axis = 1

        tp = tp.sum(axis, keepdim=False)
        fp = fp.sum(axis, keepdim=False)
        fn = fn.sum(axis, keepdim=False)

        # compute dice
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        dc = dc.mean()

        return -dc
Exemple #5
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                y = y.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(x.shape, y.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = y
            else:
                y = y.long()
                y_onehot = torch.zeros(shp_x)
                if x.device.type == "cuda":
                    y_onehot = y_onehot.cuda(x.device.index)
                y_onehot.scatter_(1, y, 1).float()

        intersect = x * y_onehot
        # values in the denominator get smoothed
        denominator = x**2 + y_onehot**2

        # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
        # squaring)
        intersect = sum_tensor(intersect, axes, False) + self.smooth
        denominator = sum_tensor(denominator, axes, False) + self.smooth

        dc = 2 * intersect / denominator

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
Exemple #6
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output[0].shape[1]
            output_seg = output[0].argmax(1)
            target = target[0][:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
            #                                         axes, None)
            # print_if_rank0("before allgather", tp_hard.shape)
            tp_hard = tp_hard.sum(0, keepdim=False)[None]
            fp_hard = fp_hard.sum(0, keepdim=False)[None]
            fn_hard = fn_hard.sum(0, keepdim=False)[None]

            tp_hard = awesome_allgather_function.apply(tp_hard)
            fp_hard = awesome_allgather_function.apply(fp_hard)
            fn_hard = awesome_allgather_function.apply(fn_hard)

        tp_hard = tp_hard.detach().cpu().numpy().sum(0)
        fp_hard = fp_hard.detach().cpu().numpy().sum(0)
        fn_hard = fn_hard.detach().cpu().numpy().sum(0)
        self.online_eval_foreground_dc.append(
            list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))
Exemple #7
0
def trainmodel(trainset,
               testset,
               all_labels,
               num_freqs,
               block_len,
               write_dir,
               config,
               split_i=None):

    num_categories = len(all_labels)
    shuffle_buffer_size = config.getint('Dataset', 'shuffle_buffer_size')
    n_epochs = config.getint('Dataset', 'n_epochs')
    batch_size = config.getint('Dataset', 'batch_size')

    tf.compat.v1.reset_default_graph()

    linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
        num_mel_bins=config.getint('Spectrogram', 'mel_bands'),
        num_spectrogram_bins=num_freqs,
        sample_rate=config.getint('Spectrogram', 'sample_rate'),
        lower_edge_hertz=config.getfloat('Spectrogram', 'mel_min'),
        upper_edge_hertz=config.getfloat('Spectrogram', 'mel_max'))

    # create datasets and iterators
    dataset_train = tf.data.TFRecordDataset([f[0] for f in trainset]). \
        apply(tf.data.experimental.shuffle_and_repeat(shuffle_buffer_size, n_epochs)). \
        map(lambda x: dataset_fft_to_mel_single(x, block_len, num_freqs, linear_to_mel_weight_matrix, all_labels, True), num_parallel_calls=4). \
        batch(batch_size). \
        prefetch(5)

    dataset_test = tf.data.TFRecordDataset([f[0] for f in testset]).\
        map(lambda x: dataset_fft_to_mel_multi(x, block_len, num_freqs, linear_to_mel_weight_matrix, all_labels), num_parallel_calls=4). \
        flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x)). \
        batch(batch_size). \
        prefetch(5)

    handle = tf.compat.v1.placeholder(tf.string,
                                      shape=[],
                                      name='iterator_handle')
    iterator = tf.compat.v1.data.Iterator.from_string_handle(
        handle, dataset_train.output_types, dataset_train.output_shapes)
    x, y = iterator.get_next(name='xinput')

    training_iterator = tf.compat.v1.data.make_initializable_iterator(
        dataset_train)
    test_iterator = tf.compat.v1.data.make_initializable_iterator(dataset_test)

    # create model
    with slim.arg_scope(
            residual_parameters(weight_decay=0.00004,
                                batch_norm_decay=0.9997,
                                batch_norm_epsilon=0.001,
                                reg=3,
                                elu=True)):
        logits, predictions = residual_model(x,
                                             num_categories,
                                             is_training=True,
                                             reuse=False,
                                             is_music_layer1=True,
                                             n_filt=4,
                                             n_res_lay=4,
                                             layer1_size=10,
                                             use_max_pool=[1, 2, 1, 2],
                                             use_atrous=False,
                                             reg=3,
                                             weight_decay=0.00004)
        test_logits, test_predictions = residual_model(
            x,
            num_categories,
            is_training=False,
            reuse=True,
            is_music_layer1=True,
            n_filt=4,
            n_res_lay=4,
            layer1_size=10,
            use_max_pool=[1, 2, 1, 2],
            use_atrous=False,
            reg=3,
            weight_decay=0.00004)

    size = lambda v: reduce(lambda x, y: x * y, v.get_shape().as_list())
    n = sum(size(v) for v in tf.compat.v1.trainable_variables())
    print("Total trainable parameters: ", n)

    one_hot_labels = slim.one_hot_encoding(y, num_categories)
    loss = tf.compat.v1.losses.softmax_cross_entropy(one_hot_labels, logits)

    total_loss = tf.compat.v1.losses.get_total_loss(
        add_regularization_losses=True)

    global_step = tf.compat.v1.train.get_or_create_global_step()

    learning_rate = tf.compat.v1.train.exponential_decay(0.1,
                                                         global_step,
                                                         500,
                                                         0.96,
                                                         staircase=True)
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=learning_rate)

    train_op = slim.learning.create_train_op(total_loss,
                                             optimizer,
                                             global_step=global_step)

    predidx = tf.argmax(input=predictions, axis=1, name='predidx')
    train_measures, train_updates = slim.metrics.aggregate_metric_map({
        'train/Count':
        tf.contrib.metrics.count(y, name='train_c'),
        'train/Accuracy':
        tf.compat.v1.metrics.accuracy(y, predidx, name='train_a'),
        'train/PerClass':
        tf.compat.v1.metrics.mean_per_class_accuracy(y,
                                                     predidx,
                                                     num_categories,
                                                     name='train_ca'),
        'train/ConfusionMT':
        sum_tensor(tf.math.confusion_matrix(labels=y,
                                            predictions=predidx,
                                            num_classes=num_categories,
                                            name='train_cm'),
                   name='train_scm')
    })
    train_measures['train/LearningRate'] = learning_rate

    test_predidx = tf.argmax(input=test_predictions,
                             axis=1,
                             name='test_predidx')
    test_measures, test_updates = slim.metrics.aggregate_metric_map({
        'test/Count':
        tf.contrib.metrics.count(y, name='test_c'),
        'test/Accuracy':
        tf.compat.v1.metrics.accuracy(y, test_predidx, name='test_a'),
        'test/PerClass':
        tf.compat.v1.metrics.mean_per_class_accuracy(y,
                                                     test_predidx,
                                                     num_categories,
                                                     name='test_ca'),
        'test/ConfusionMT':
        sum_tensor(tf.math.confusion_matrix(labels=y,
                                            predictions=test_predidx,
                                            num_classes=num_categories,
                                            name='test_cm'),
                   name='test_scm')
    })

    all_m = dict(train_measures)
    all_m.update(test_measures)
    summary = add_summary_ops(all_m,
                              add_variables=True,
                              confmat_size=num_categories)

    config_proto = tf.compat.v1.ConfigProto()
    config_proto.gpu_options.allow_growth = True

    sum_dir = write_dir + ('' if split_i is None else '/split' + str(split_i))
    tf.io.gfile.makedirs(sum_dir)

    with tf.compat.v1.train.MonitoredTrainingSession(
            checkpoint_dir=sum_dir, config=config_proto) as sess:
        training_handle = sess.run(training_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())
        sess.run(training_iterator.initializer)
        for var in sess.graph.get_collection(
                tf.compat.v1.GraphKeys.METRIC_VARIABLES):
            sess.run(var.initializer)

        while not sess.should_stop():

            # do training
            [loss, pp,
             yy] = sess.run([train_op, predidx, y],
                            feed_dict={'iterator_handle:0': training_handle})
            # update training measures
            upd = sess.run(train_updates,
                           feed_dict={
                               'predidx:0': pp,
                               'xinput:1': yy,
                               handle: training_handle
                           })
            gs = sess.run(global_step)
            if gs % 500 == 0:
                print(gs, loss)
                print(upd['train/Accuracy'])
                if gs % 1000 == 0 or sess.should_stop():
                    # every 1000 iterations, do testing and update test measures
                    acc = test_model(sess, test_iterator, test_handle,
                                     test_predidx, y, test_updates)
                    # also reset training stats
                    for var in [
                            x for x in sess.graph.get_collection(
                                tf.compat.v1.GraphKeys.METRIC_VARIABLES)
                            if x.name.startswith('train')
                    ]:
                        sess.run(var.initializer)

    print(upd['test/Accuracy'])
    print(upd['test/PerClass'])
    print(upd['test/ConfusionMT'])

    return acc