def check_model_on_dataloader():
    model = UNet(input_channels=11, num_classes=16)
    model.eval()
    model.cuda(device=0)
    loaders = get_dataloaders_generated_data(
        generated_data_path='generated_dataset',
        save_data_path='pickled_generated_datalist.pkl',
        block_size=256,
        model_input_size=64,
        batch_size=128,
        num_workers=8)
    with torch.no_grad():
        train_dataloader, val_dataloader, test_dataloader = loaders
        for idx, data in enumerate(train_dataloader):
            examples, labels = data['input'], data['label']
            examples = examples.cuda(device=0)
            print('-> on batch {}/{}, {}'.format(idx + 1,
                                                 len(train_dataloader),
                                                 examples.size()))
            out_tensor, prediction = model(examples)
            print(examples.shape, labels.shape, out_tensor.shape,
                  prediction.shape,
                  torch.argmax(prediction, dim=1)[0, :, :].shape)

    pass
Beispiel #2
0
def check_model_on_dataloader():
    model = UNet(input_channels=11, num_classes=16)
    model.eval()
    model.cuda(device=0)

    # loaders = get_dataloaders(images_path='/home/annus/PycharmProjects/ForestCoverChange_inputs_and_numerical_results/'
    #                                       'ESA_landcover_dataset/raw/full_test_site_2015.tif',
    #                           bands=range(1,14),
    #                           labels_path='/home/annus/PycharmProjects/ForestCoverChange_inputs_and_numerical_results/'
    #                                       'ESA_landcover_dataset/raw/label_full_test_site.npy',
    #                           save_data_path='/home/annus/PycharmProjects/ForestCoverChange_inputs_and_numerical_results/'
    #                                          'ESA_landcover_dataset/raw/pickled_data.pkl',
    #                           block_size=256, model_input_size=64, batch_size=16, num_workers=4)

    # loaders = get_dataloaders_generated_data(generated_data_path='/home/annus/PycharmProjects/'
    #                                                              'ForestCoverChange_inputs_and_numerical_results/'
    #                                                              'ESA_landcover_dataset/divided',
    #                                          save_data_path='/home/annus/PycharmProjects/'
    #                                                         'ForestCoverChange_inputs_and_numerical_results/'
    #                                                         'ESA_landcover_dataset/generated_data.pkl',
    #                                          block_size=256, model_input_size=64, batch_size=16, num_workers=8)

    loaders = get_dataloaders_generated_data(
        generated_data_path='generated_dataset',
        save_data_path='pickled_generated_datalist.pkl',
        block_size=256,
        model_input_size=64,
        batch_size=128,
        num_workers=8)

    with torch.no_grad():
        train_dataloader, val_dataloader, test_dataloader = loaders
        for idx, data in enumerate(train_dataloader):
            examples, labels = data['input'], data['label']
            examples = examples.cuda(device=0)
            print('-> on batch {}/{}, {}'.format(idx + 1,
                                                 len(train_dataloader),
                                                 examples.size()))
            out_tensor, prediction = model(examples)
            print(examples.shape, labels.shape, out_tensor.shape,
                  prediction.shape,
                  torch.argmax(prediction, dim=1)[0, :, :].shape)

    pass
def train_net(model, generated_data_path, input_dim, workers, pre_model, save_data, save_dir, sum_dir,
              batch_size, lr, epochs, log_after, cuda, device):
    if cuda:
        print('log: Using GPU')
        model.cuda(device=device)

    if pre_model == -1:
        model_number = 0
        print('log: No trained model passed. Starting from scratch...')
        # model_path = os.path.join(save_dir, 'model-{}.pt'.format(model_number))
    else:
        model_number = pre_model
        model_path = os.path.join(save_dir, 'model-{}.pt'.format(pre_model))
        model.load_state_dict(torch.load(model_path), strict=False)
        print('log: Resuming from model {} ...'.format(model_path))
    ###############################################################################

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(sum_dir):
        os.mkdir(sum_dir)
    # writer = SummaryWriter()

    # define loss and optimizer
    optimizer = Adam(model.parameters(), lr=lr)
    # focal_criterion = FocalLoss2d(weight=weights)
    crossentropy_criterion = nn.BCELoss()
    # dice_criterion = DiceLoss(weights=weights)

    lr_final = 5e-5
    LR_decay = (lr_final / lr) ** (1. / epochs)
    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=LR_decay)

    loaders = get_dataloaders_generated_data(generated_data_path=generated_data_path, save_data_path=save_data,
                                             model_input_size=input_dim, batch_size=batch_size, num_classes=2,
                                             one_hot=True, num_workers=workers)
    writer = SummaryWriter()

    train_loader, val_dataloader, test_loader = loaders
    # training loop
    for k in range(epochs):
        net_loss = []
        total_correct, total_examples = 0, 0
        model_path = os.path.join(save_dir, 'model-{}.pt'.format(model_number+k))
        if not os.path.exists(model_path):
            torch.save(model.state_dict(), model_path)
            print('log: saved {}'.format(model_path))
            # remember to save only five previous models, so
            del_this = os.path.join(save_dir, 'model-{}.pt'.format(model_number+k-6))
            if os.path.exists(del_this):
                os.remove(del_this)
                print('log: removed {}'.format(del_this))

        for idx, data in enumerate(train_loader):
            model.train()
            model.zero_grad()
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, logits = model.forward(test_x)
            pred = torch.argmax(logits, dim=1)
            # print(np.unique(pred.detach().cpu().numpy()))
            not_one_hot_target = torch.argmax(label, dim=1)
            # dice_criterion(logits, label) #+ focal_criterion(logits, not_one_hot_target) #
            loss = crossentropy_criterion(logits, label.float())
            loss.backward()
            clip_grad_norm_(model.parameters(), 0.05)
            optimizer.step()
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0)) #test_x.size(0)*dimension**2)
            total_correct += numerator
            total_examples += denominator

            if idx % log_after == 0 and idx > 0:
                accuracy = float(numerator) * 100 / denominator
                print('{}. ({}/{}) output size = {}, loss = {}, '
                      'accuracy = {}/{} = {:.2f}%, (lr = {})'.format(k, idx, len(train_loader), out_x.size(),
                                                                     loss.item(), numerator, denominator, accuracy,
                                                                     optimizer.param_groups[0]['lr']))
            net_loss.append(loss.item())

        # this should be done at the end of epoch only
        scheduler.step()  # to dynamically change the learning rate
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()
        writer.add_scalar(tag='train loss', scalar_value=mean_loss, global_step=k)
        writer.add_scalar(tag='train over_all accuracy', scalar_value=mean_accuracy, global_step=k)
        print('####################################')
        print('LOG: epoch {} -> total loss = {:.5f}, total accuracy = {:.5f}%'.format(k, mean_loss, mean_accuracy))
        print('####################################')

        # validate model
        print('log: Evaluating now...')
        eval_net(model=model, criterion=crossentropy_criterion, val_loader=val_dataloader, cuda=cuda, device=device,
                 writer=None, batch_size=batch_size, global_step=k)
    pass
def eval_net(**kwargs):
    cuda = kwargs['cuda']
    device = kwargs['device']
    model = kwargs['model']
    model.eval()
    if cuda:
        model.cuda(device=device)
    if 'writer' in kwargs.keys():
        # it means this is evaluation at training time
        val_loader = kwargs['val_loader']
        model = kwargs['model']
        writer = kwargs['writer']
        global_step = kwargs['global_step']
        crossentropy_criterion = kwargs['criterion']
        total_examples, total_correct, net_loss = 0, 0, []
        for idx, data in enumerate(val_loader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x)
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            # dice_criterion(softmaxed, label) # + focal_criterion(softmaxed, not_one_hot_target) #
            loss = crossentropy_criterion(softmaxed, label.float())
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0)) #test_x.size(0) * dimension ** 2)
            # accuracy = float(numerator) * 100 / denominator
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            #################################
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()
        writer.add_scalar(tag='val. loss', scalar_value=mean_loss, global_step=global_step)
        writer.add_scalar(tag='val. over_all accuracy', scalar_value=mean_accuracy, global_step=global_step)
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('LOG: validation:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(mean_loss, mean_accuracy))
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')

    else:
        # model, images, labels, pre_model, save_dir, sum_dir, batch_size, lr, log_after, cuda
        num_classes = 3
        pre_model = kwargs['pre_model']
        un_confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=False)
        confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)

        model_path = os.path.join(kwargs['save_dir'], 'model-{}.pt'.format(pre_model))
        model.load_state_dict(torch.load(model_path), strict=False)
        print('log: resumed model {} successfully!'.format(pre_model))

        # weights = torch.Tensor([1, 1, 1])  # forest has ten times more weight
        # weights = weights.cuda(device=device) if cuda else weights
        # dice_criterion, focal_criterion = nn.CrossEntropyLoss(), DiceLoss(), FocalLoss2d()
        crossentropy_criterion = nn.BCELoss()
        loaders = get_dataloaders_generated_data(generated_data_path=kwargs['generated_data_path'],
                                                 save_data_path=kwargs['save_data'],
                                                 model_input_size=kwargs['input_dim'],
                                                 batch_size=kwargs['batch_size'],
                                                 one_hot=True,
                                                 num_workers=kwargs['workers'])
        train_loader, test_loader, empty_loader = loaders

        net_loss = []
        total_correct, total_examples = 0, 0
        net_class_accuracy_0, net_class_accuracy_1, net_class_accuracy_2, \
        net_class_accuracy_3, net_class_accuracy_4, net_class_accuracy_5,\
        net_class_accuracy_6  = [], [], [], [], [], [], []
        # net_class_accuracies = [[] for i in range(16)]
        classes_mean_accuracies = []
        for idx, data in enumerate(train_loader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x)
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            loss = crossentropy_criterion(softmaxed, label.float()) # dice_criterion(softmaxed, label) # +
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0)) #test_x.size(0) * dimension ** 2)
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            un_confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))
            confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))

            if idx % 10 == 0:
                print('log: on {}'.format(idx))

            # get per-class metrics
            # for k in range(num_classes):
            #     class_pred = (pred == k)
            #     class_label = (label == k)
            #     class_accuracy = (class_pred == class_label).sum()
            #     class_accuracy = class_accuracy * 100 / (pred.view(-1).size(0))
            #     net_class_accuracies[k].append(class_accuracy)

            # class_pred_0 = (pred == 0)
            # class_label_0 = (label == 0)
            # class_accuracy_0 = (class_pred_0 == class_label_0).sum()
            # class_accuracy_0 = class_accuracy_0 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_0.append(class_accuracy_0)
            #
            # class_pred_1 = (pred == 1)
            # class_label_1 = (label == 1)
            # class_accuracy_1 = (class_pred_1 == class_label_1).sum()
            # class_accuracy_1 = class_accuracy_1 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_1.append(class_accuracy_1)
            #
            # class_pred_2 = (pred == 2)
            # class_label_2 = (label == 2)
            # class_accuracy_2 = (class_pred_2 == class_label_2).sum()
            # class_accuracy_2 = class_accuracy_2 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_2.append(class_accuracy_2)
            #
            # class_pred_3 = (pred == 3)
            # class_label_3 = (label == 3)
            # class_accuracy_3 = (class_pred_3 == class_label_3).sum()
            # class_accuracy_3 = class_accuracy_3 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_3.append(class_accuracy_3)
            #
            # class_pred_4 = (pred == 4)
            # class_label_4 = (label == 4)
            # class_accuracy_4 = (class_pred_4 == class_label_4).sum()
            # class_accuracy_4 = class_accuracy_4 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_4.append(class_accuracy_4)
            #
            # class_pred_5 = (pred == 5)
            # class_label_5 = (label == 5)
            # class_accuracy_5 = (class_pred_5 == class_label_5).sum()
            # class_accuracy_5 = class_accuracy_5 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_5.append(class_accuracy_5)
            #
            # class_pred_6 = (pred == 6)
            # class_label_6 = (label == 6)
            # class_accuracy_6 = (class_pred_6 == class_label_6).sum()
            # class_accuracy_6 = class_accuracy_6 * 100 / (pred.view(-1).size(0))
            # net_class_accuracy_6.append(class_accuracy_6)

            # preds = torch.cat((preds, pred.long().view(-1)))
            # labs = torch.cat((labs, label.long().view(-1)))
            #################################
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()

        # for k in range(num_classes):
        #     classes_mean_accuracies.append(np.asarray(net_class_accuracies[k]).mean())
        #
        # class_0_mean_accuracy = np.asarray(net_class_accuracy_0).mean()
        # class_1_mean_accuracy = np.asarray(net_class_accuracy_1).mean()
        # class_2_mean_accuracy = np.asarray(net_class_accuracy_2).mean()
        # class_3_mean_accuracy = np.asarray(net_class_accuracy_3).mean()
        # class_4_mean_accuracy = np.asarray(net_class_accuracy_4).mean()
        # class_5_mean_accuracy = np.asarray(net_class_accuracy_5).mean()
        # class_6_mean_accuracy = np.asarray(net_class_accuracy_6).mean()

        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(mean_loss, mean_accuracy))
        # for k in range(num_classes):
        #     print('log: class {}:: total accuracy = {:.5f}%'.format(k, classes_mean_accuracies[k]))
        # print('log: class 0:: total accuracy = {:.5f}%'.format(class_0_mean_accuracy))
        # print('log: class 1:: total accuracy = {:.5f}%'.format(class_1_mean_accuracy))
        # print('log: class 2:: total accuracy = {:.5f}%'.format(class_2_mean_accuracy))
        # print('log: class 3:: total accuracy = {:.5f}%'.format(class_3_mean_accuracy))
        # print('log: class 4:: total accuracy = {:.5f}%'.format(class_4_mean_accuracy))
        # print('log: class 5:: total accuracy = {:.5f}%'.format(class_5_mean_accuracy))
        # print('log: class 6:: total accuracy = {:.5f}%'.format(class_6_mean_accuracy))
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')

        # class_names = ['background/clutter', 'buildings', 'trees', 'cars',
        #                'low_vegetation', 'impervious_surfaces', 'noise']
        with open('normalized.pkl', 'wb') as this:
            pkl.dump(confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
        with open('un_normalized.pkl', 'wb') as this:
            pkl.dump(un_confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
    pass
Beispiel #5
0
def generate_error_maps(**kwargs):
    model = kwargs['model']
    classes = kwargs['classes']
    num_classes = len(classes)
    cuda = kwargs['cuda']
    device = kwargs['device']
    model.eval()
    all_predictions = np.array([])  # empty all predictions
    all_ground_truth = np.array([])
    # special variables
    all_but_chitral_and_upper_dir_predictions = np.array(
        [])  # empty all predictions
    all_but_chitral_and_upper_dir_ground_truth = np.array([])
    if cuda:
        model.cuda(device=device)
    # model, images, labels, pre_model, save_dir, sum_dir, batch_size, lr, log_after, cuda
    pre_model = kwargs['pre_model']
    batch_size = kwargs['batch_size']
    un_confusion_meter = tnt.meter.ConfusionMeter(num_classes,
                                                  normalized=False)
    confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)
    model_path = os.path.join(kwargs['save_dir'], pre_model)
    model.load_state_dict(torch.load(model_path, map_location='cpu'),
                          strict=False)
    print('[LOG] Resumed model {} successfully!'.format(pre_model))
    destination_path = os.path.join(kwargs['error_maps_path'],
                                    pre_model.split('.')[0])
    if not os.path.exists(os.path.join(destination_path)):
        os.mkdir(destination_path)
    weights = torch.Tensor([10, 10])  # forest has ___ times more weight
    weights = weights.cuda(device=device) if cuda else weights
    focal_criterion = FocalLoss2d(weight=weights)
    loaders = get_dataloaders_generated_data(
        generated_data_path=kwargs['generated_data_path'],
        data_split_lists_path=kwargs['data_split_lists'],
        bands=kwargs['bands'],
        model_input_size=kwargs['input_dim'],
        num_classes=num_classes + 1,
        train_split=0.8,
        one_hot=True,
        batch_size=batch_size,
        num_workers=kwargs['workers'])
    net_loss = list()
    train_dataloader, val_dataloader, test_dataloader = loaders
    total_correct, total_examples = 0, 0
    print("[LOG] Evaluating performance on test data...")
    forest_cmap = ListedColormap(["yellow", "green"])
    true_false_cmap = ListedColormap(['red', 'blue'])
    bands_for_testing = [x - 1 for x in kwargs['bands']]
    accuracy_per_district = defaultdict(lambda: [0, 0])
    for idx, data in enumerate(test_dataloader):
        test_x, label, sample_identifiers = data['input'], data['label'], data[
            'sample_identifier']
        test_x = test_x.cuda(device=device) if cuda else test_x
        label = label.cuda(device=device) if cuda else label
        out_x, softmaxed = model.forward(test_x[:, bands_for_testing, :, :])
        pred = torch.argmax(softmaxed, dim=1)
        not_one_hot_target = torch.argmax(label, dim=1)
        for i in range(not_one_hot_target.shape[0]):
            image_name = sample_identifiers[0][i].split('/')[-1].split('.')[0]
            district_name = image_name.split('_')[0]
            if district_name == 'upper' or district_name == 'lower':
                district_name += ' dir'
            # print(district_name)
            rgb_image = (255 * (test_x.numpy()[i].transpose(
                1, 2, 0)[:, :, [3, 2, 1]])).astype(np.uint8)
            district_ground_truth = not_one_hot_target[i, :, :].clone()
            ground_truth = not_one_hot_target[i, :, :] - 1
            ground_truth[ground_truth < 0] = 0
            district_prediction = pred[i, :, :]
            error_map = np.array(ground_truth == district_prediction).astype(
                np.uint8)
            # calculate accuracy for this district image (below)
            district_label_valid_indices = (district_ground_truth.view(-1) !=
                                            0)
            district_valid_label = district_ground_truth.view(
                -1)[district_label_valid_indices] - 1
            district_valid_pred = district_prediction.view(
                -1)[district_label_valid_indices]
            district_accurate = (
                district_valid_pred == district_valid_label).sum().item()
            district_total_pixels = float(district_valid_pred.view(-1).size(0))
            accuracy_per_district[district_name][0] += district_accurate
            accuracy_per_district[district_name][1] += district_total_pixels
            # special variables
            if district_name != "upper dir" and district_name != "chitral":
                all_but_chitral_and_upper_dir_predictions = np.concatenate(
                    (all_but_chitral_and_upper_dir_predictions,
                     district_valid_pred.view(-1).cpu()),
                    axis=0)
                all_but_chitral_and_upper_dir_ground_truth = np.concatenate(
                    (all_but_chitral_and_upper_dir_ground_truth,
                     district_valid_label.view(-1).cpu()),
                    axis=0)
            # # calculate accuracy for this district image (above)
            # fig = plt.figure(figsize=(12,3))
            # fig.suptitle("[Non-Forest: Yellow; Forest: Green;] Error: [Correct: Blue, In-correct: Red]", fontsize="x-large")
            # ax1 = fig.add_subplot(1, 4, 1)
            # ax1.imshow(rgb_image)
            # ax1.set_title('Image')
            # ax2 = fig.add_subplot(1, 4, 2)
            # ax2.imshow(ground_truth, cmap=forest_cmap, vmin=0, vmax=2)
            # ax2.set_title('Ground Truth')
            # ax3 = fig.add_subplot(1, 4, 3)
            # ax3.imshow(district_prediction, cmap=forest_cmap, vmin=0, vmax=2)
            # ax3.set_title('Prediction')
            # ax4 = fig.add_subplot(1, 4, 4)
            # ax4.imshow(error_map, cmap=true_false_cmap, vmin=0, vmax=1)
            # ax4.set_title('Error')
            # fig.savefig(os.path.join(destination_path, '{}.png'.format(image_name)))
            # plt.close()
        #######################################################
        not_one_hot_target_for_loss = not_one_hot_target.clone()
        not_one_hot_target_for_loss[not_one_hot_target_for_loss == 0] = 1
        not_one_hot_target_for_loss -= 1
        loss = focal_criterion(softmaxed, not_one_hot_target_for_loss)
        label_valid_indices = (not_one_hot_target.view(-1) != 0)
        # mind the '-1' fix please. This is to convert Forest and Non-Forest labels from 1, 2 to 0, 1
        valid_label = not_one_hot_target.view(-1)[label_valid_indices] - 1
        valid_pred = pred.view(-1)[label_valid_indices]
        # NULL elimination
        accurate = (valid_pred == valid_label).sum().item()
        numerator = float(accurate)
        denominator = float(valid_pred.view(-1).size(0))
        total_correct += numerator
        total_examples += denominator
        net_loss.append(loss.item())
        ########################################
        # with NULL elimination
        un_confusion_meter.add(predicted=valid_pred.view(-1),
                               target=valid_label.view(-1))
        confusion_meter.add(predicted=valid_pred.view(-1),
                            target=valid_label.view(-1))
        all_predictions = np.concatenate(
            (all_predictions, valid_pred.view(-1).cpu()), axis=0)
        all_ground_truth = np.concatenate(
            (all_ground_truth, valid_label.view(-1).cpu()), axis=0)
        if idx % 10 == 0:
            print('log: on test sample: {}/{}'.format(idx,
                                                      len(test_dataloader)))
        #################################
    mean_accuracy = total_correct * 100 / total_examples
    mean_loss = np.asarray(net_loss).mean()
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(
        mean_loss, mean_accuracy))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('---> Confusion Matrix:')
    print(confusion_meter.value())
    confusion = confusion_matrix(all_ground_truth, all_predictions)
    print('Confusion Matrix from Scikit-Learn\n')
    print(confusion)
    print('\nClassification Report\n')
    print(
        classification_report(all_ground_truth,
                              all_predictions,
                              target_names=classes))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    # Get per district test scores without Upper Dir and Chitral districts
    print('[LOG] Per District Test Accuracies')
    print(accuracy_per_district)
    numerator_sum, denominator_sum = 0, 0
    for idx, (this_district,
              [true, total]) in enumerate(accuracy_per_district.items(), 1):
        print("{}: {} -> {}/{} = {:.2f}%".format(idx, this_district, true,
                                                 total, 100 * true / total))
        if this_district != 'upper dir' and this_district != 'chitral':
            numerator_sum += true
            denominator_sum += total
        else:
            print("[LOG] Skipping {} district for performance testing".format(
                this_district))
    print("[LOG] Net Test Accuracy Without Chitral and Upper Dir: {:.2f}%".
          format(100 * numerator_sum / denominator_sum))
    print('---> Confusion Matrix:')
    print(confusion_meter.value())
    confusion = confusion_matrix(all_but_chitral_and_upper_dir_ground_truth,
                                 all_but_chitral_and_upper_dir_predictions)
    print('Confusion Matrix from Scikit-Learn\n')
    print(confusion)
    print('\nClassification Report\n')
    print(
        classification_report(all_but_chitral_and_upper_dir_ground_truth,
                              all_but_chitral_and_upper_dir_predictions,
                              target_names=classes))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    pass
Beispiel #6
0
def train_net(model, model_topology, generated_data_path, input_dim, bands,
              classes, workers, pre_model, data_split_lists, save_dir, sum_dir,
              error_maps_path, batch_size, lr, epochs, log_after, cuda,
              device):
    # print(model)
    if cuda:
        print('log: Using GPU')
        model.cuda(device=device)
    ###############################################################################
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(sum_dir):
        os.mkdir(sum_dir)
    # writer = SummaryWriter()
    # define loss and optimizer
    optimizer = RMSprop(model.parameters(), lr=lr)
    # save our initial learning rate
    lr_initial = lr
    weights = torch.Tensor([10, 10])  # forest has ____ times more weight
    weights = weights.cuda(device=device) if cuda else weights
    focal_criterion = FocalLoss2d(weight=weights)
    lr_final = lr / 10  # get to one tenth of the starting rate
    LR_decay = (lr_final / lr)**(1. / epochs)
    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=LR_decay)
    loaders = get_dataloaders_generated_data(
        generated_data_path=generated_data_path,
        data_split_lists_path=data_split_lists,
        model_input_size=input_dim,
        bands=bands,
        batch_size=batch_size,
        num_classes=len(classes) + 1,
        train_split=0.8,
        one_hot=True,
        num_workers=workers)
    train_loader, val_dataloader, test_loader = loaders
    best_evaluation = 0.0
    ################################################################
    if pre_model == 'None':
        model_number = 0
        print('log: No trained model passed. Starting from scratch...')
    else:
        model_path = os.path.join(save_dir, pre_model)
        model_number = int(pre_model.split('/')[-1].split('_')[1])
        model.load_state_dict(torch.load(model_path), strict=False)
        print('log: Resuming from model {} ...'.format(model_path))
        print('log: Evaluating now...')
        best_evaluation = eval_net(model=model,
                                   criterion=focal_criterion,
                                   val_loader=val_dataloader,
                                   cuda=cuda,
                                   device=device,
                                   writer=None,
                                   batch_size=batch_size,
                                   step=0)
        print('LOG: Starting with best evaluation accuracy: {:.3f}%'.format(
            best_evaluation))
    ##########################################################################
    # training loop
    bands_for_training = [x - 1 for x in bands]
    for k in range(epochs):
        net_loss = []
        total_correct, total_examples = 0, 0
        print('log: Evaluating now...')
        eval_net(model=model,
                 classes=classes,
                 criterion=focal_criterion,
                 val_loader=val_dataloader,
                 cuda=cuda,
                 device=device,
                 writer=None,
                 batch_size=batch_size,
                 step=k)
        model_number += 1
        model_path = os.path.join(
            save_dir, 'model_{}_topology{}_lr{}_bands{}.pt'.format(
                model_number, model_topology, lr_initial, len(bands)))
        torch.save(model.state_dict(), model_path)
        print('log: Saved best performing {}'.format(model_path))
        # we will save all models for now
        # del_this = os.path.join(save_dir, 'model-{}.pt'.format(model_number-10))
        # if os.path.exists(del_this):
        #     os.remove(del_this)
        #     print('log: Removed {}'.format(del_this))
        for idx, data in enumerate(train_loader):
            model.train()
            model.zero_grad()
            # get the required bands for training
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, logits = model.forward(test_x[:, bands_for_training, :, :])
            pred = torch.argmax(logits, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            loss = focal_criterion(logits, not_one_hot_target)
            loss.backward()
            clip_grad_norm_(model.parameters(), 0.05)
            optimizer.step()
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0))
            total_correct += numerator
            total_examples += denominator
            if idx % log_after == 0 and idx > 0:
                accuracy = float(numerator) * 100 / denominator
                print(
                    '{}. ({}/{}) input size= {}, output size = {}, loss = {}, accuracy = {}/{} = {:.2f}%'
                    .format(k, idx, len(train_loader), test_x.size(),
                            out_x.size(), loss.item(), numerator, denominator,
                            accuracy))
            net_loss.append(loss.item())
        # this should be done at the end of epoch only
        scheduler.step()  # to dynamically change the learning rate
        mean_accuracy = total_correct * 100 / total_examples
        mean_loss = np.asarray(net_loss).mean()
        print('####################################')
        print('LOG: epoch {} -> total loss = {:.5f}, total accuracy = {:.5f}%'.
              format(k, mean_loss, mean_accuracy))
        print('####################################')
    pass
Beispiel #7
0
def eval_net(**kwargs):
    model = kwargs['model']
    classes = kwargs['classes']
    num_classes = len(classes)
    cuda = kwargs['cuda']
    device = kwargs['device']
    model.eval()
    all_predictions = np.array([])  # empty all predictions
    all_ground_truth = np.array([])
    if cuda:
        model.cuda(device=device)
    bands_for_testing = [x - 1 for x in kwargs['bands']]
    if 'writer' in kwargs.keys():
        # it means this is evaluation at training time
        val_loader = kwargs['val_loader']
        model = kwargs['model']
        focal_criterion = kwargs['criterion']
        total_examples, total_correct, net_loss = 0, 0, []
        un_confusion_meter = tnt.meter.ConfusionMeter(num_classes,
                                                      normalized=False)
        confusion_meter = tnt.meter.ConfusionMeter(num_classes,
                                                   normalized=True)
        for idx, data in enumerate(val_loader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x[:,
                                                    bands_for_testing, :, :])
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            not_one_hot_target_for_loss = not_one_hot_target.clone()
            not_one_hot_target_for_loss[not_one_hot_target_for_loss == 0] = 1
            not_one_hot_target_for_loss -= 1
            loss = focal_criterion(softmaxed, not_one_hot_target_for_loss
                                   )  # dice_criterion(softmaxed, label) #
            label_valid_indices = (not_one_hot_target.view(-1) != 0)
            # mind the '-1' fix please. This is to convert Forest and Non-Forest labels from 1, 2 to 0, 1
            valid_label = not_one_hot_target.view(-1)[label_valid_indices] - 1
            valid_pred = pred.view(-1)[label_valid_indices]
            # Eliminate NULL pixels from testing
            accurate = (valid_pred == valid_label).sum().item()
            numerator = float(accurate)
            denominator = float(valid_pred.view(-1).size(0))
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            # NULL elimination
            un_confusion_meter.add(predicted=valid_pred.view(-1),
                                   target=valid_label.view(-1))
            confusion_meter.add(predicted=valid_pred.view(-1),
                                target=valid_label.view(-1))
            all_predictions = np.concatenate(
                (all_predictions, valid_pred.view(-1).cpu()), axis=0)
            all_ground_truth = np.concatenate(
                (all_ground_truth, valid_label.view(-1).cpu()), axis=0)
            #################################
        mean_accuracy = total_correct * 100 / total_examples
        mean_loss = np.asarray(net_loss).mean()
        # writer.add_scalar(tag='eval accuracy', scalar_value=mean_accuracy, global_step=step)
        # writer.add_scalar(tag='eval loss', scalar_value=mean_loss, global_step=step)
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print(
            'LOG: validation: total loss = {:.5f}, total accuracy = ({}/{}) = {:.5f}%'
            .format(mean_loss, total_correct, total_examples, mean_accuracy))
        print('Log: Confusion matrix')
        print(confusion_meter.value())
        confusion = confusion_matrix(all_ground_truth, all_predictions)
        print('Confusion Matrix from Scikit-Learn\n')
        print(confusion)
        print('\nClassification Report\n')
        print(
            classification_report(all_ground_truth,
                                  all_predictions,
                                  target_names=classes))
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    else:
        # model, images, labels, pre_model, save_dir, sum_dir, batch_size, lr, log_after, cuda
        pre_model = kwargs['pre_model']
        batch_size = kwargs['batch_size']
        un_confusion_meter = tnt.meter.ConfusionMeter(num_classes,
                                                      normalized=False)
        confusion_meter = tnt.meter.ConfusionMeter(num_classes,
                                                   normalized=True)
        model_path = os.path.join(kwargs['save_dir'], pre_model)
        model.load_state_dict(torch.load(model_path, map_location='cpu'),
                              strict=False)
        print('log: resumed model {} successfully!'.format(pre_model))
        weights = torch.Tensor([10, 10])  # forest has ___ times more weight
        weights = weights.cuda(device=device) if cuda else weights
        focal_criterion = FocalLoss2d(weight=weights)
        loaders = get_dataloaders_generated_data(
            generated_data_path=kwargs['generated_data_path'],
            data_split_lists_path=kwargs['data_split_lists'],
            bands=kwargs['bands'],
            model_input_size=kwargs['input_dim'],
            num_classes=num_classes,
            train_split=0.8,
            one_hot=True,
            batch_size=batch_size,
            num_workers=kwargs['workers'])
        net_loss = list()
        train_dataloader, val_dataloader, test_dataloader = loaders
        total_correct, total_examples = 0, 0
        print("(LOG): Evaluating performance on test data...")
        for idx, data in enumerate(test_dataloader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x[:,
                                                    bands_for_testing, :, :])
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            #######################################################
            not_one_hot_target_for_loss = not_one_hot_target.clone()
            not_one_hot_target_for_loss[not_one_hot_target_for_loss == 0] = 1
            not_one_hot_target_for_loss -= 1
            loss = focal_criterion(softmaxed, not_one_hot_target_for_loss)
            label_valid_indices = (not_one_hot_target.view(-1) != 0)
            # mind the '-1' fix please. This is to convert Forest and Non-Forest labels from 1, 2 to 0, 1
            valid_label = not_one_hot_target.view(-1)[label_valid_indices] - 1
            valid_pred = pred.view(-1)[label_valid_indices]
            # NULL elimination
            accurate = (valid_pred == valid_label).sum().item()
            numerator = float(accurate)
            denominator = float(valid_pred.view(-1).size(0))
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            ########################################
            # with NULL elimination
            un_confusion_meter.add(predicted=valid_pred.view(-1),
                                   target=valid_label.view(-1))
            confusion_meter.add(predicted=valid_pred.view(-1),
                                target=valid_label.view(-1))
            all_predictions = np.concatenate(
                (all_predictions, valid_pred.view(-1).cpu()), axis=0)
            all_ground_truth = np.concatenate(
                (all_ground_truth, valid_label.view(-1).cpu()), axis=0)
            if idx % 10 == 0:
                print('log: on test sample: {}/{}'.format(
                    idx, len(test_dataloader)))
            #################################
        mean_accuracy = total_correct * 100 / total_examples
        mean_loss = np.asarray(net_loss).mean()
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print(
            'log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(
                mean_loss, mean_accuracy))
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('---> Confusion Matrix:')
        print(confusion_meter.value())
        confusion = confusion_matrix(all_ground_truth, all_predictions)
        print('Confusion Matrix from Scikit-Learn\n')
        print(confusion)
        print('\nClassification Report\n')
        print(
            classification_report(all_ground_truth,
                                  all_predictions,
                                  target_names=classes))
        with open('normalized.pkl', 'wb') as this:
            pkl.dump(confusion_meter.value(),
                     this,
                     protocol=pkl.HIGHEST_PROTOCOL)
        with open('un_normalized.pkl', 'wb') as this:
            pkl.dump(un_confusion_meter.value(),
                     this,
                     protocol=pkl.HIGHEST_PROTOCOL)
            pass
        pass
    pass
def train_net(model, generated_data_path, input_dim, workers, pre_model, save_data, save_dir, sum_dir, batch_size,
              lr, epochs, log_after, cuda, device):
    # print(model)
    if cuda:
        print('log: Using GPU')
        model.cuda(device=device)
    ###############################################################################

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(sum_dir):
        os.mkdir(sum_dir)
    # writer = SummaryWriter()

    # define loss and optimizer
    optimizer = RMSprop(model.parameters(), lr=lr)
    weights = torch.Tensor([1, 1]) # forest has ____ times more weight
    weights = weights.cuda(device=device) if cuda else weights
    focal_criterion = FocalLoss2d(weight=weights)
    # crossentropy_criterion = nn.BCELoss(weight=weights)
    # dice_criterion = DiceLoss(weights=weights)

    lr_final = 5e-5
    LR_decay = (lr_final / lr) ** (1. / epochs)
    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=LR_decay)

    loaders = get_dataloaders_generated_data(generated_data_path=generated_data_path, save_data_path=save_data,
                                             model_input_size=input_dim, batch_size=batch_size, num_classes=2,
                                             train_split=0.8, one_hot=True, num_workers=workers, max_label=1)
    train_loader, val_dataloader, test_loader = loaders
    best_evaluation = 0.0
    ################################################################
    if pre_model == -1:
        model_number = 0
        print('log: No trained model passed. Starting from scratch...')
        # model_path = os.path.join(save_dir, 'model-{}.pt'.format(model_number))
    else:
        model_number = pre_model
        model_path = os.path.join(save_dir, 'model-{}.pt'.format(pre_model))
        model.load_state_dict(torch.load(model_path), strict=False)
        print('log: Resuming from model {} ...'.format(model_path))
        print('log: Evaluating now...')
        best_evaluation = eval_net(model=model, criterion=focal_criterion, val_loader=val_dataloader,
                                   cuda=cuda, device=device, writer=None, batch_size=batch_size, step=0)
        print('LOG: Starting with best evaluation accuracy: {:.3f}%'.format(best_evaluation))
    ##########################################################################

    # training loop
    for k in range(epochs):
        net_loss = []
        total_correct, total_examples = 0, 0
        for idx, data in enumerate(train_loader):
            model.train()
            model.zero_grad()
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, logits = model.forward(test_x)
            pred = torch.argmax(logits, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            # dice_criterion(logits, label) #+ focal_criterion(logits, not_one_hot_target) #
            # print(logits.view(batch_size, -1).shape, logits.view(batch_size, -1).shape)
            # loss = focal_criterion(logits.view(-1, 2), label.view(-1, 2))
            loss = focal_criterion(logits, not_one_hot_target) # dice_criterion(logits, label) #
            loss.backward()
            clip_grad_norm_(model.parameters(), 0.05)
            optimizer.step()
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0)) #test_x.size(0)*dimension**2)
            total_correct += numerator
            total_examples += denominator

            if idx % log_after == 0 and idx > 0:
                accuracy = float(numerator) * 100 / denominator
                print('{}. ({}/{}) output size = {}, loss = {}, '
                      'accuracy = {}/{} = {:.2f}%'.format(k, idx, len(train_loader), out_x.size(), loss.item(),
                                                          numerator, denominator, accuracy))
            net_loss.append(loss.item())

        # this should be done at the end of epoch only
        scheduler.step()  # to dynamically change the learning rate
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()
        print('####################################')
        print('LOG: epoch {} -> total loss = {:.5f}, total accuracy = {:.5f}%'.format(k, mean_loss, mean_accuracy))
        print('####################################')

        # validate model
        print('log: Evaluating now...')
        eval_accuracy = eval_net(model=model, criterion=focal_criterion, val_loader=val_dataloader,
                                 cuda=cuda, device=device, writer=None, batch_size=batch_size, step=k)

        # save best performing models only
        if eval_accuracy > best_evaluation:
            best_evaluation = eval_accuracy
            model_number += 1
            model_path = os.path.join(save_dir, 'model-{}.pt'.format(model_number))
            torch.save(model.state_dict(), model_path)
            print('log: Saved best performing {}'.format(model_path))

            del_this = os.path.join(save_dir, 'model-{}.pt'.format(model_number-6))
            if os.path.exists(del_this):
                os.remove(del_this)
                print('log: Removed {}'.format(del_this))
    pass
def eval_net(**kwargs):
    cuda = kwargs['cuda']
    device = kwargs['device']
    model = kwargs['model']
    model.eval()
    if cuda:
        model.cuda(device=device)
    if 'writer' in kwargs.keys():
        # it means this is evaluation at training time
        val_loader = kwargs['val_loader']
        model = kwargs['model']
        focal_criterion = kwargs['criterion']
        total_examples, total_correct, net_loss = 0, 0, []
        num_classes = 2
        un_confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=False)
        confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)
        for idx, data in enumerate(val_loader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x)
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            # dice_criterion(softmaxed, label) # + focal_criterion(softmaxed, not_one_hot_target) #
            # loss = crossentropy_criterion(softmaxed.view(-1, 2), label.view(-1, 2))
            loss = focal_criterion(softmaxed, not_one_hot_target) #dice_criterion(softmaxed, label) #
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0)) #test_x.size(0)*dimension**2)
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            un_confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))
            confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))
            #################################
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()
        # writer.add_scalar(tag='eval accuracy', scalar_value=mean_accuracy, global_step=step)
        # writer.add_scalar(tag='eval loss', scalar_value=mean_loss, global_step=step)
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('LOG: validation: total loss = {:.5f}, total accuracy = ({}/{}) = {:.5f}%'.format(mean_loss,
                                                                                                total_correct,
                                                                                                total_examples,
                                                                                                mean_accuracy))
        print('Log: Confusion matrix')
        print(confusion_meter.value())
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        return mean_accuracy

    else:
        # model, images, labels, pre_model, save_dir, sum_dir, batch_size, lr, log_after, cuda
        pre_model = kwargs['pre_model']
        batch_size = kwargs['batch_size']
        num_classes = 2  # we convert to a binary classification problem at test time only
        un_confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=False)
        confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)
        model_path = os.path.join(kwargs['save_dir'], 'model-{}.pt'.format(pre_model))
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
        print('log: resumed model {} successfully!'.format(pre_model))
        weights = torch.Tensor([1, 1])  # forest has ten times more weight
        weights = weights.cuda(device=device) if cuda else weights
        # dice_criterion, focal_criterion = nn.CrossEntropyLoss(), DiceLoss(), FocalLoss2d()
        # crossentropy_criterion = nn.BCELoss(weight=weights)
        focal_criterion = FocalLoss2d(weight=weights)
        # dice_criterion = DiceLoss(weights=weights)
        loaders = get_dataloaders_generated_data(generated_data_path=kwargs['generated_data_path'],
                                                 save_data_path=kwargs['save_data'],
                                                 model_input_size=kwargs['input_dim'],
                                                 batch_size=batch_size,
                                                 # train_split=0.8,
                                                 one_hot=True,
                                                 num_workers=kwargs['workers'],
                                                 max_label=num_classes)
        train_loader, test_loader, empty_loader = loaders
        net_loss = []
        total_correct, total_examples = 0, 0
        for idx, data in enumerate(test_loader):
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            out_x, softmaxed = model.forward(test_x)
            pred = torch.argmax(softmaxed, dim=1)
            not_one_hot_target = torch.argmax(label, dim=1)
            '''
                Not needed anymore, forest is already 0 and non-forest has label 1
            # convert to binary classes
            # 0-> noise, 1-> forest, 2-> non-forest, 3-> water
            # pred[pred == 0] = 2
            # pred[pred == 3] = 2
            # not_one_hot_target[not_one_hot_target == 0] = 2
            # not_one_hot_target[not_one_hot_target == 3] = 2
            # # now convert 1, 2 to 0, 1
            # pred -= 1
            # not_one_hot_target -= 1
            '''
            # dice_criterion(softmaxed, label) # +
            # print(softmaxed.shape, label.shape)
            # loss = crossentropy_criterion(softmaxed.view(-1, 2), label.view(-1, 2))
            loss = focal_criterion(softmaxed, not_one_hot_target) # dice_criterion(softmaxed, label) #
            accurate = (pred == not_one_hot_target).sum().item()
            numerator = float(accurate)
            denominator = float(pred.view(-1).size(0))  # test_x.size(0)*dimension**2)
            total_correct += numerator
            total_examples += denominator
            net_loss.append(loss.item())
            un_confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))
            confusion_meter.add(predicted=pred.view(-1), target=not_one_hot_target.view(-1))
            if idx % 10 == 0:
                print('log: on {}'.format(idx))
            #################################
        mean_accuracy = total_correct*100/total_examples
        mean_loss = np.asarray(net_loss).mean()
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(mean_loss, mean_accuracy))
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print('---> Confusion Matrix:')
        print(confusion_meter.value())
        with open('normalized.pkl', 'wb') as this:
            pkl.dump(confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
        with open('un_normalized.pkl', 'wb') as this:
            pkl.dump(un_confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
    pass
def train_net(model, generated_data_path, images, labels, block_size,
              input_dim, workers, pre_model, save_data, save_dir, sum_dir,
              batch_size, lr, epochs, log_after, cuda, device):
    # print(model)
    if cuda:
        print('log: Using GPU')
        model.cuda(device=device)
    # define loss and optimizer
    optimizer = RMSprop(model.parameters(), lr=lr)
    weights = torch.Tensor([
        7, 2, 241, 500, 106, 5, 319, 0.06, 0.58, 0.125, 0.045, 0.18, 0.026,
        0.506, 0.99, 0.321
    ])
    # weights = weights.cuda(device=device) if cuda else weights
    # criterion = nn.CrossEntropyLoss(weight=weights)
    # choose a better loss for this problem
    # if cuda:
    #     criterion = DiceLoss(weights=weights, device='cpu')
    # else:
    #     criterion = DiceLoss(weights=weights, device='cuda:{}'.format(device))
    # criterion = tversky_loss(num_c=16)
    criterion = FocalLoss2d()
    #### scheduler addition
    lr_final = 0.0000003
    LR_decay = (lr_final / lr)**(1. / epochs)
    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=LR_decay)

    # loaders = get_dataloaders(images_path=images,
    #                           bands=range(1,14),
    #                           labels_path=labels,
    #                           save_data_path=save_data,
    #                           block_size=block_size,
    #                           model_input_size=input_dim,
    #                           batch_size=batch_size,
    #                           num_workers=workers)

    loaders = get_dataloaders_generated_data(
        generated_data_path=generated_data_path,
        save_data_path=save_data,
        model_input_size=input_dim,
        batch_size=batch_size,
        train_split=0.8,
        num_workers=workers,
        max_label=16)

    train_loader, val_dataloader, test_loader = loaders

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(sum_dir):
        os.mkdir(sum_dir)
    # writer = SummaryWriter()
    if pre_model == -1:
        model_number = 0
        print('log: No trained model passed. Starting from scratch...')
        # model_path = os.path.join(save_dir, 'model-{}.pt'.format(model_number))
    else:
        model_number = pre_model
        model_path = os.path.join(save_dir, 'model-{}.pt'.format(pre_model))
        model.load_state_dict(torch.load(model_path))
        print('log: Resuming from model {} ...'.format(model_path))
    ###############################################################################
    # training loop
    for k in range(epochs):
        net_loss = []
        net_accuracy = []
        model_path = os.path.join(save_dir,
                                  'model-{}.pt'.format(model_number + k))
        if not os.path.exists(model_path):
            torch.save(model.state_dict(), model_path)
            print('log: saved {}'.format(model_path))
            # remember to save only five previous models, so
            del_this = os.path.join(save_dir,
                                    'model-{}.pt'.format(model_number + k - 6))
            if os.path.exists(del_this):
                os.remove(del_this)
                print('log: removed {}'.format(del_this))

        for idx, data in enumerate(train_loader):
            model.train()
            test_x, label = data['input'], data['label']
            test_x = test_x.cuda(device=device) if cuda else test_x
            label = label.cuda(device=device) if cuda else label
            dimension = test_x.size(-1)
            out_x, logits = model.forward(test_x)
            pred = torch.argmax(logits, dim=1)

            # label = label.unsqueeze(1)
            # print(logits.shape, label.shape)
            # print(label[label > 15])

            # out_x, crit_label = out_x.cpu(), label.cpu().unsqueeze(1).float()
            # print(out_x.shape, crit_label.shape)
            loss = criterion(out_x, label)
            accurate = (pred == label).sum()

            numerator = accurate
            denominator = float(test_x.size(0) * dimension**2)
            accuracy = float(numerator) * 100 / denominator
            if idx % log_after == 0 and idx > 0:
                print('{}. ({}/{}) output size = {}, loss = {}, '
                      'accuracy = {}/{} = {:.2f}%'.format(
                          k, idx, len(train_loader), out_x.size(), loss.item(),
                          numerator, denominator, accuracy))
            #################################
            # three steps for backprop
            model.zero_grad()
            loss.backward()
            # perform gradient clipping between loss backward and optimizer step
            clip_grad_norm_(model.parameters(), 0.05)
            optimizer.step()
            net_accuracy.append(accuracy)
            net_loss.append(loss.item())
            #################################

        # this should be done at the end of epoch only
        scheduler.step()  # to dynamically change the learning rate
        mean_accuracy = np.asarray(net_accuracy).mean()
        mean_loss = np.asarray(net_loss).mean()
        print('####################################')
        print('LOG: epoch {} -> total loss = {:.5f}, total accuracy = {:.5f}%'.
              format(k, mean_loss, mean_accuracy))
        print('####################################')

        # validate model
        print('log: Evaluating now...')
        eval_net(model=model,
                 criterion=criterion,
                 val_loader=val_dataloader,
                 cuda=cuda,
                 device=device,
                 writer=None,
                 batch_size=batch_size,
                 step=k)
    pass