Exemple #1
0
    def validate(self, val_loader, epoch=0):
        self.model.eval()
        val_loss = 0.0
        total_acc = 0.0
        total_recall = 0.0
        total_precision = 0.0
        total_f1 = 0.0
        total_cm = 0
        total_d_acc = 0.0
        bleu = 0.0
        total_l1 = 0
        total_l2 = 0
        total_l3 = 0

        k_vals = [1, 2, 3, 4, 5]
        total_topk = {k: 0.0 for k in k_vals}
        per_disease_topk = defaultdict(lambda: {str(k): 0.0 for k in k_vals})
        losses = []
        with torch.no_grad():
            for i, (images, labels) in enumerate(val_loader):
                batch_size = images.size(0)
                images = images.to(self.device)
                labels = labels.to(self.device)
                teacher_scores = self.kd_model(images)
                val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1)

                index = val >= self.threshold
                if index.any().item():
                    images = images[index]
                    labels = labels[index]
                    teacher_scores = teacher_scores[index]
                    diseases = self.model.module(images)

                    loss1, _, _ = self.criterion(diseases, labels,
                                                 teacher_scores)

                    val_loss += loss1.item()

                    # Evaluation of P, R, F1, BLEU
                    d_pred = F.log_softmax(diseases, dim=-1).argmax(dim=-1)
                    total_d_acc += (d_pred.eq(labels).sum().item() /
                                    batch_size)
                    acc, recall, precision, f1 = accuracy_recall_precision_f1(
                        d_pred, labels)

                    total_recall += np.mean(recall)
                    total_precision += np.mean(precision)
                    total_f1 += np.mean(f1)

                    cm = calculate_confusion_matrix(d_pred, labels)
                    try:
                        total_cm += (cm / batch_size)
                    except:
                        print("error occured for this CM")
                        print(cm / batch_size)
        val_loss = val_loss / len(val_loader)
        total_d_acc = total_d_acc / len(val_loader)
        total_f1 = total_f1 / len(val_loader)
        total_precision = total_precision / len(val_loader)
        total_recall = total_recall / len(val_loader)
        total_cm = total_cm / len(val_loader)

        self.scheduler.step(val_loss)
        if val_loss <= self.min_val_loss:
            torch.save(self.model.state_dict(), self.save_path)
            self.min_val_loss = val_loss

        disease_f1 = {}
        disease_precision = {}
        disease_recall = {}

        # for i in range(len(total_f1)):
        #   disease_f1[i] = total_f1[i]
        #   disease_precision[i] = total_precision[i]
        #   disease_recall[i] = total_recall[i]

        return (val_loss, total_d_acc, total_f1, total_recall, total_precision,
                total_cm)
Exemple #2
0
    def validate(self, val_loader, epoch=0):
        self.model.eval()
        val_loss = 0.0
        total_acc = 0.0
        total_recall = 0.0
        total_precision = 0.0
        total_f1 = 0.0
        total_cm = 0
        total_d_acc = 0.0
        bleu = 0.0
        total_l1 = 0
        total_l2 = 0
        total_l3 = 0

        k_vals = [1, 2, 3, 4, 5]
        total_topk = {k: 0.0 for k in k_vals}
        per_disease_topk = defaultdict(lambda: {str(k): 0.0 for k in k_vals})
        per_disease_bleu = defaultdict(list)
        with torch.no_grad():
            for i, (_, images, labels, f_labels,
                    text) in enumerate(val_loader):
                batch_size = images.size(0)
                images = images.to(self.device)
                labels = labels.to(self.device)
                f_labels = f_labels.to(self.device)
                text = text.to(self.device)
                diseases, fine_diseases, text_pred = self.model(images, text)
                loss1 = self.criterion(diseases, labels)
                loss2 = self.criterion(fine_diseases, f_labels)
                text_loss = 0.0
                for k in range(text_pred.size(1)):
                    text_loss += self.criterion(text_pred[:, k].squeeze(),
                                                text[:, k + 1].squeeze())

                val_loss += torch.stack(
                    (loss1, loss2, text_loss))[self.tasks].sum()

                preds = F.log_softmax(fine_diseases, dim=-1)
                pred = preds.argmax(dim=-1)
                d_pred = F.log_softmax(diseases, dim=-1).argmax(dim=-1)

                # Evaluation of P, R, F1, CM, BLEU
                total_acc += (pred.eq(f_labels).sum().item() / batch_size)
                total_d_acc += (d_pred.eq(labels).sum().item() / batch_size)
                acc, recall, precision, f1 = accuracy_recall_precision_f1(
                    d_pred, labels)
                cm = calculate_confusion_matrix(d_pred, labels)
                try:
                    total_cm += (cm / batch_size)
                except:
                    print("Error occured for this CM")
                    print(cm / batch_size)

                # Top-k evaluation
                for k in k_vals:
                    total_topk[k] += compute_topk(preds, f_labels, k)
                    for d in [0, 1, 2, 3]:
                        mask = labels.eq(d)
                        if mask.sum() > 0:
                            per_disease_topk[d][str(k)] += compute_topk(
                                preds[mask], f_labels[mask], k)

                total_recall += np.mean(recall)
                total_precision += np.mean(precision)
                total_f1 += np.mean(f1)
                preds = torch.argmax(F.log_softmax(text_pred, dim=-1), dim=-1)
                text1 = text[:, 1:].squeeze().tolist()
                preds1 = preds.tolist()
                t_bleu, sent_gt, sent_pred = compute_bleu(
                    self.lang, text1, preds1, labels, per_disease_bleu)

                # Book-keeping
                bleu += t_bleu
                total_l1 += loss1.item()
                total_l2 += loss2.item()
                total_l3 += text_loss.item()
        bleu = bleu / (len(val_loader))
        val_loss = val_loss / len(val_loader)
        total_l1 /= len(val_loader)
        total_l2 /= len(val_loader)
        total_l3 /= len(val_loader)
        total_acc = total_acc / len(val_loader)
        total_d_acc = total_d_acc / len(val_loader)
        total_f1 = total_f1 / len(val_loader)
        total_precision = total_precision / len(val_loader)
        total_recall = total_recall / len(val_loader)
        total_cm = total_cm / len(val_loader)

        self.scheduler.step(val_loss)
        if val_loss <= self.min_val_loss:
            torch.save(self.model.state_dict(), self.save_path)
            self.min_val_loss = val_loss

        disease_f1 = {}
        disease_precision = {}
        disease_recall = {}

        #for i in range(len(total_f1)):
        #   disease_f1[i] = total_f1[i]
        #   disease_precision[i] = total_precision[i]
        #   disease_recall[i] = total_recall[i]
        for d in per_disease_bleu:
            per_disease_bleu[d] = np.mean(per_disease_bleu[d])

        total_topk = {str(k): total_topk[k] / len(val_loader) for k in k_vals}
        for d in [0, 1, 2, 3]:
            for k in k_vals:
                per_disease_topk[d][str(
                    k)] = per_disease_topk[d][str(k)] / len(val_loader)

        return (val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall,
                total_precision, sent_gt, sent_pred, total_topk,
                per_disease_topk, per_disease_bleu, total_cm)
Exemple #3
0
def main(args):
    # Print settings
    for k, v in vars(args).items():
        print(f'{k}: {v}')

    display_step = 5
    num_classes = 8
    size = (224, 224, 3)  # size of images

    # Runtime initialization will not allocate all memory on GPU
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
    except:
        # Invalid device or cannot modify virtual devices once initialized.
        pass
    # Create checkpoints dir
    os.makedirs('saved_models', exist_ok=True)

    optimizer = optimizers.SGD(learning_rate=args.learning_rate, momentum=0.9)
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    # model = models.vgg16(input_shape=size, num_classes=num_classes, classifier_activation='softmax')
    model = models.resnet50(input_shape=size,
                            num_classes=num_classes,
                            classifier_activation='softmax')
    model.build(input_shape=size)
    model.summary()

    if args.checkpoints:
        if os.path.exists(args.checkpoints):
            print(f'Loading checkpoints: {args.checkpoints}')
            model.load_weights(args.checkpoints)
        else:
            print(f'Checkpoints `{args.checkpoints}` not found',
                  file=sys.stderr)

    os.makedirs("logs/scalars/", exist_ok=True)
    logdir = "logs/scalars/" + datetime.now().strftime(
        "%Y%m%d-%H%M%S") + f"-{args.name}"
    summary_writer = tf.summary.create_file_writer(logdir)

    if args.train:
        # Same augs as C++
        train_aug = iaa.Sequential([
            iaa.Resize(size=size[:-1], interpolation='cubic'),
            iaa.Fliplr(p=0.5),
            iaa.Flipud(p=0.5),
            iaa.Rotate(rotate=(-180, 180)),
            iaa.AdditivePoissonNoise(lam=(0, 10)),
            iaa.GammaContrast(gamma=(.8, 1.5)),
            iaa.GaussianBlur(sigma=(.0, .8)),
            iaa.CoarseDropout(p=(.02, .1),
                              size_percent=(0.02, 0.05),
                              per_channel=0.5),
        ])

        val_aug = iaa.Sequential(
            [iaa.Resize(size=size[:-1], interpolation='cubic')])

        training_dataset = ISICClassification(args.dataset, 'training',
                                              args.batch_size, train_aug)
        training_tfdata = training_dataset.map_samples(args.epochs)
        training_iter = iter(training_tfdata)

        validation_dataset = ISICClassification(args.dataset,
                                                'validation',
                                                args.batch_size,
                                                val_aug,
                                                shuffle=False)
        validation_tfdata = validation_dataset.map_samples(args.epochs)
        validation_iter = iter(validation_tfdata)

        train_loss = tf.keras.metrics.Mean(name='train_loss')
        train_metric = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')
        val_metric = tf.keras.metrics.SparseCategoricalAccuracy(
            name='val_accuracy')

        best_accuracy = 0.
        for e in range(1, args.epochs + 1):
            train_loss.reset_states()
            train_metric.reset_states()
            val_metric.reset_states()

            total_preds = []
            total_labels = []
            for step in range(1, len(training_dataset)):
                images, labels = next(training_iter)

                # Run the optimization to update W and b values
                with tf.GradientTape() as tape:
                    pred = model(images)
                    loss = loss_fn(labels, pred)
                total_preds.append(pred)
                total_labels.append(labels)

                gradients = tape.gradient(loss, model.trainable_variables)

                # Update W and b following gradients
                optimizer.apply_gradients(
                    zip(gradients, model.trainable_variables))

                # Log loss and metric
                train_loss.update_state(loss)
                train_metric.update_state(labels, pred)
                if step % display_step == 0:
                    print(
                        "\rTraining {:d}/{:d} (batch {:d}/{:d}) - Loss: {:.4f} - Accuracy: {:.4f}"
                        .format(e, args.epochs, step, len(training_dataset),
                                train_loss.result(), train_metric.result()),
                        end="",
                        flush=True)

            cm = utils.calculate_confusion_matrix(
                tf.concat(total_labels, axis=0), tf.concat(total_preds,
                                                           axis=0))
            with summary_writer.as_default():
                tf.summary.scalar('loss/' + train_loss.name,
                                  train_loss.result(),
                                  step=e - 1)
                tf.summary.scalar('accuracy/' + train_metric.name,
                                  train_metric.result(),
                                  step=e - 1)
                tf.summary.image("cm/training_cm", cm, step=e)

            total_preds = []
            total_labels = []

            # Do validation
            print("\nValidation {:d}/{:d}".format(e, args.epochs),
                  end="",
                  flush=True)
            for step in range(1, len(validation_dataset)):
                images, labels = next(validation_iter)
                pred = model(images)
                val_metric.update_state(labels, pred)
                total_preds.append(pred)
                total_labels.append(labels)

            cm = utils.calculate_confusion_matrix(
                tf.concat(total_labels, axis=0), tf.concat(total_preds,
                                                           axis=0))
            with summary_writer.as_default():
                tf.summary.scalar('accuracy/' + val_metric.name,
                                  val_metric.result(),
                                  step=e - 1)
                tf.summary.image("cm/validation_cm", cm, step=e)

            # Compute accuracy and save checkpoints
            accuracy = val_metric.result()
            print(" - Accuracy: {:.4f}".format(accuracy), flush=True)

            if accuracy > best_accuracy:
                print(
                    f"Saving checkpoints (accuracy: {accuracy:.4f} > {best_accuracy:.4f})",
                    flush=True)
                best_accuracy = accuracy
                model.save_weights(f'saved_models/{args.name}.h5')

    if args.test:
        # Test model on test set
        test_aug = iaa.Sequential(
            [iaa.Resize(size=size[:-1], interpolation='cubic')])
        test_dataset = ISICClassification(args.dataset, 'test',
                                          args.batch_size, test_aug)
        test_tfdata = test_dataset.map_samples(1)
        tensorboard = keras.callbacks.TensorBoard(log_dir=logdir)
        results = model.evaluate(test_tfdata,
                                 verbose=1,
                                 callbacks=[tensorboard])
        print("Test set loss and accuracy:", results)
def val_epoch(epoch,
              data_loader,
              model,
              criterion,
              device,
              logger,
              tb_writer=None,
              distributed=False,
              conf_mtx_dict={}):
    print('validation at epoch {}'.format(epoch))

    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    # Added for 231n
    all_y_true = []
    all_y_pred = []

    end_time = time.time()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(data_loader):
            data_time.update(time.time() - end_time)

            targets = targets.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            acc = calculate_accuracy(outputs, targets)

            # Added for 231n
            y_true, y_pred = calc_ytrue_ypred(outputs, targets)
            all_y_true.extend(y_true)
            all_y_pred.extend(y_pred)
            #print(inputs)
            #print(inputs.shape)

            losses.update(loss.item(), inputs.size(0))
            accuracies.update(acc, inputs.size(0))

            batch_time.update(time.time() - end_time)
            end_time = time.time()

            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                      epoch,
                      i + 1,
                      len(data_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc=accuracies))

        # Added for 231n
        conf_mtx = calculate_confusion_matrix(all_y_true, all_y_pred)
        #print("conf_mtx = " + str(conf_mtx))
        conf_mtx_dict[epoch] = conf_mtx

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    if logger is not None:
        logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})

    if tb_writer is not None:
        tb_writer.add_scalar('val/loss', losses.avg, epoch)
        tb_writer.add_scalar('val/acc', accuracies.avg, epoch)

    return losses.avg
def evaluate_model_features(model, optimizer, loss_fn, dataloader, device, use_bert):
    """Evaluate model
    Args:
        model: Model either LSTM, LSTMAttention, CNN, MLP (torch.nn.Module)
        optimizer: Optimizer for parameters of the model (torch.optim)
        loss_fn: Loss function taht computs the loss for each batch based on the y_pred and y_target
        dataloader: Dataloader that generates batches of data and labels or in case of BERT input_ids, input_mask, segment_ids and label_ids
        device: Device run either on GPU or CPU
    """

    #Metrics
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_recall = [0, 0]
    epoch_precision = [0, 0]
    epoch_f1 = [0, 0]
    cm = np.zeros((2,2))

    #Set model in evaluate mode
    model.eval()

    with torch.no_grad():

        for step, batch in enumerate(tqdm(dataloader, desc="Iteration")):

            #Step 0: Get batch
            batch = tuple(t.to(device) for t in batch)

            if use_bert:
                input_ids, input_mask, segment_ids, label_ids = batch
            else:
                X, features, y_target = batch
                y_target = torch.autograd.Variable(y_target).long()

            #Step 1: Compute the forward pass of the model (model output)
            if use_bert:
                y_pred = model(input_ids, segment_ids, input_mask, labels=None)
                y_target = label_ids
            else:
                y_pred = model(X, features)

            #Step 2: Compute the loss
            loss = loss_fn(y_pred, y_target)
            loss_batch = loss.item()
            epoch_loss += loss_batch

            #Compute other metrics
            accuracy, recall, precision, f1 = accuracy_recall_precision_f1(y_pred, y_target)

            epoch_accuracy += accuracy
            epoch_recall += recall
            epoch_precision += precision
            epoch_f1 += f1

            #Compute confusion metrics
            cm += calculate_confusion_matrix(y_pred, y_target)

        #Evaluation results
        results = {
            'loss': np.round(epoch_loss / len(dataloader),2),
            'accuracy': np.round(float(epoch_accuracy / len(dataloader)),2),
            'recall': np.round(epoch_recall / len(dataloader), 2),
            'precision': np.round(epoch_precision / len(dataloader), 2),
            'f1': np.round(epoch_f1 / len(dataloader), 2),
            'cm': cm
        }

    return results