Пример #1
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output.shape[1]
            output_softmax = softmax_helper(output)
            output_seg = output_softmax.argmax(1)
            target = target[:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()

            self.online_eval_foreground_dc.append(
                list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Пример #2
0
    def forward(self, x, y):
        dc = 0
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            net_output = self.apply_nonlin(x)  # (b,c,x,y,z)

            gt_onehot = gt2onehot(net_output, y, axes)  # (b,c,x,y,z)
            intersection = sum_tensor(net_output * gt_onehot,
                                      axes,
                                      keepdim=False)
            pred_o = sum_tensor(net_output**2, axes, keepdim=False)
            ground_o = sum_tensor(gt_onehot**2, axes, keepdim=False)
            dc = 2.0 * (intersection + self.smooth) / (ground_o + pred_o +
                                                       self.smooth)

            if not self.do_bg:
                if self.batch_dice:
                    dc = dc[1:]
                else:
                    dc = dc[:, 1:]
            dc = dc.mean()

        return -dc
Пример #3
0
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tp, dim=1)),
                         dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fp, dim=1)),
                         dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fn, dim=1)),
                         dim=1)

    if square:
        tp = tp**2
        fp = fp**2
        fn = fn**2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn
Пример #4
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if len(shp_x) != len(shp_y):
            y = y.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(x.shape, y.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y
        else:
            gt = y.long()
            y_onehot = torch.zeros(shp_x)
            if x.device.type == "cuda":
                y_onehot = y_onehot.cuda(x.device.index)
            y_onehot.scatter_(1, gt, 1)

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        if not self.do_bg:
            x = x[:, 1:]
            y_onehot = y_onehot[:, 1:]

        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask,
                                        self.square)

        # GDL weight computation, we use 1/V
        volumes = sum_tensor(
            y_onehot, axes) + 1e-6  # add some eps to prevent div by zero

        if self.square_volumes:
            volumes = volumes**2

        # apply weights
        tp = tp / volumes
        fp = fp / volumes
        fn = fn / volumes

        # sum over classes
        if self.batch_dice:
            axis = 0
        else:
            axis = 1

        tp = tp.sum(axis, keepdim=False)
        fp = fp.sum(axis, keepdim=False)
        fn = fn.sum(axis, keepdim=False)

        # compute dice
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        dc = dc.mean()

        return -dc
Пример #5
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                y = y.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(x.shape, y.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = y
            else:
                y = y.long()
                y_onehot = torch.zeros(shp_x)
                if x.device.type == "cuda":
                    y_onehot = y_onehot.cuda(x.device.index)
                y_onehot.scatter_(1, y, 1).float()

        intersect = x * y_onehot
        # values in the denominator get smoothed
        denominator = x**2 + y_onehot**2

        # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
        # squaring)
        intersect = sum_tensor(intersect, axes, False) + self.smooth
        denominator = sum_tensor(denominator, axes, False) + self.smooth

        dc = 2 * intersect / denominator

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
Пример #6
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output[0].shape[1]
            output_seg = output[0].argmax(1)
            target = target[0][:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
            #                                         axes, None)
            # print_if_rank0("before allgather", tp_hard.shape)
            tp_hard = tp_hard.sum(0, keepdim=False)[None]
            fp_hard = fp_hard.sum(0, keepdim=False)[None]
            fn_hard = fn_hard.sum(0, keepdim=False)[None]

            tp_hard = awesome_allgather_function.apply(tp_hard)
            fp_hard = awesome_allgather_function.apply(fp_hard)
            fn_hard = awesome_allgather_function.apply(fn_hard)

        tp_hard = tp_hard.detach().cpu().numpy().sum(0)
        fp_hard = fp_hard.detach().cpu().numpy().sum(0)
        fn_hard = fn_hard.detach().cpu().numpy().sum(0)
        self.online_eval_foreground_dc.append(
            list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))
Пример #7
0
def trainmodel(trainset,
               testset,
               all_labels,
               num_freqs,
               block_len,
               write_dir,
               config,
               split_i=None):

    num_categories = len(all_labels)
    shuffle_buffer_size = config.getint('Dataset', 'shuffle_buffer_size')
    n_epochs = config.getint('Dataset', 'n_epochs')
    batch_size = config.getint('Dataset', 'batch_size')

    tf.compat.v1.reset_default_graph()

    linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
        num_mel_bins=config.getint('Spectrogram', 'mel_bands'),
        num_spectrogram_bins=num_freqs,
        sample_rate=config.getint('Spectrogram', 'sample_rate'),
        lower_edge_hertz=config.getfloat('Spectrogram', 'mel_min'),
        upper_edge_hertz=config.getfloat('Spectrogram', 'mel_max'))

    # create datasets and iterators
    dataset_train = tf.data.TFRecordDataset([f[0] for f in trainset]). \
        apply(tf.data.experimental.shuffle_and_repeat(shuffle_buffer_size, n_epochs)). \
        map(lambda x: dataset_fft_to_mel_single(x, block_len, num_freqs, linear_to_mel_weight_matrix, all_labels, True), num_parallel_calls=4). \
        batch(batch_size). \
        prefetch(5)

    dataset_test = tf.data.TFRecordDataset([f[0] for f in testset]).\
        map(lambda x: dataset_fft_to_mel_multi(x, block_len, num_freqs, linear_to_mel_weight_matrix, all_labels), num_parallel_calls=4). \
        flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x)). \
        batch(batch_size). \
        prefetch(5)

    handle = tf.compat.v1.placeholder(tf.string,
                                      shape=[],
                                      name='iterator_handle')
    iterator = tf.compat.v1.data.Iterator.from_string_handle(
        handle, dataset_train.output_types, dataset_train.output_shapes)
    x, y = iterator.get_next(name='xinput')

    training_iterator = tf.compat.v1.data.make_initializable_iterator(
        dataset_train)
    test_iterator = tf.compat.v1.data.make_initializable_iterator(dataset_test)

    # create model
    with slim.arg_scope(
            residual_parameters(weight_decay=0.00004,
                                batch_norm_decay=0.9997,
                                batch_norm_epsilon=0.001,
                                reg=3,
                                elu=True)):
        logits, predictions = residual_model(x,
                                             num_categories,
                                             is_training=True,
                                             reuse=False,
                                             is_music_layer1=True,
                                             n_filt=4,
                                             n_res_lay=4,
                                             layer1_size=10,
                                             use_max_pool=[1, 2, 1, 2],
                                             use_atrous=False,
                                             reg=3,
                                             weight_decay=0.00004)
        test_logits, test_predictions = residual_model(
            x,
            num_categories,
            is_training=False,
            reuse=True,
            is_music_layer1=True,
            n_filt=4,
            n_res_lay=4,
            layer1_size=10,
            use_max_pool=[1, 2, 1, 2],
            use_atrous=False,
            reg=3,
            weight_decay=0.00004)

    size = lambda v: reduce(lambda x, y: x * y, v.get_shape().as_list())
    n = sum(size(v) for v in tf.compat.v1.trainable_variables())
    print("Total trainable parameters: ", n)

    one_hot_labels = slim.one_hot_encoding(y, num_categories)
    loss = tf.compat.v1.losses.softmax_cross_entropy(one_hot_labels, logits)

    total_loss = tf.compat.v1.losses.get_total_loss(
        add_regularization_losses=True)

    global_step = tf.compat.v1.train.get_or_create_global_step()

    learning_rate = tf.compat.v1.train.exponential_decay(0.1,
                                                         global_step,
                                                         500,
                                                         0.96,
                                                         staircase=True)
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=learning_rate)

    train_op = slim.learning.create_train_op(total_loss,
                                             optimizer,
                                             global_step=global_step)

    predidx = tf.argmax(input=predictions, axis=1, name='predidx')
    train_measures, train_updates = slim.metrics.aggregate_metric_map({
        'train/Count':
        tf.contrib.metrics.count(y, name='train_c'),
        'train/Accuracy':
        tf.compat.v1.metrics.accuracy(y, predidx, name='train_a'),
        'train/PerClass':
        tf.compat.v1.metrics.mean_per_class_accuracy(y,
                                                     predidx,
                                                     num_categories,
                                                     name='train_ca'),
        'train/ConfusionMT':
        sum_tensor(tf.math.confusion_matrix(labels=y,
                                            predictions=predidx,
                                            num_classes=num_categories,
                                            name='train_cm'),
                   name='train_scm')
    })
    train_measures['train/LearningRate'] = learning_rate

    test_predidx = tf.argmax(input=test_predictions,
                             axis=1,
                             name='test_predidx')
    test_measures, test_updates = slim.metrics.aggregate_metric_map({
        'test/Count':
        tf.contrib.metrics.count(y, name='test_c'),
        'test/Accuracy':
        tf.compat.v1.metrics.accuracy(y, test_predidx, name='test_a'),
        'test/PerClass':
        tf.compat.v1.metrics.mean_per_class_accuracy(y,
                                                     test_predidx,
                                                     num_categories,
                                                     name='test_ca'),
        'test/ConfusionMT':
        sum_tensor(tf.math.confusion_matrix(labels=y,
                                            predictions=test_predidx,
                                            num_classes=num_categories,
                                            name='test_cm'),
                   name='test_scm')
    })

    all_m = dict(train_measures)
    all_m.update(test_measures)
    summary = add_summary_ops(all_m,
                              add_variables=True,
                              confmat_size=num_categories)

    config_proto = tf.compat.v1.ConfigProto()
    config_proto.gpu_options.allow_growth = True

    sum_dir = write_dir + ('' if split_i is None else '/split' + str(split_i))
    tf.io.gfile.makedirs(sum_dir)

    with tf.compat.v1.train.MonitoredTrainingSession(
            checkpoint_dir=sum_dir, config=config_proto) as sess:
        training_handle = sess.run(training_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())
        sess.run(training_iterator.initializer)
        for var in sess.graph.get_collection(
                tf.compat.v1.GraphKeys.METRIC_VARIABLES):
            sess.run(var.initializer)

        while not sess.should_stop():

            # do training
            [loss, pp,
             yy] = sess.run([train_op, predidx, y],
                            feed_dict={'iterator_handle:0': training_handle})
            # update training measures
            upd = sess.run(train_updates,
                           feed_dict={
                               'predidx:0': pp,
                               'xinput:1': yy,
                               handle: training_handle
                           })
            gs = sess.run(global_step)
            if gs % 500 == 0:
                print(gs, loss)
                print(upd['train/Accuracy'])
                if gs % 1000 == 0 or sess.should_stop():
                    # every 1000 iterations, do testing and update test measures
                    acc = test_model(sess, test_iterator, test_handle,
                                     test_predidx, y, test_updates)
                    # also reset training stats
                    for var in [
                            x for x in sess.graph.get_collection(
                                tf.compat.v1.GraphKeys.METRIC_VARIABLES)
                            if x.name.startswith('train')
                    ]:
                        sess.run(var.initializer)

    print(upd['test/Accuracy'])
    print(upd['test/PerClass'])
    print(upd['test/ConfusionMT'])

    return acc