Exemplo n.º 1
0
def train_non_dl_model(model, model_name, dataloaders, args, X, y):
    results = {'train_acc': [], 'train_f1': [], 'val_acc': [], 'val_f1': [], 'test_acc': [], 'test_f1': []}
    for rep in range(args.num_repeat):
        for split in ['train', 'val', 'test'] if not args.eval_on_test else ['test']:
            dl = dataloaders[split]
            X, y = datasets.get_Xy(dl, args.country)            

            if split == 'train':
                model.fit(X, y)
                preds = model.predict(X)
                _, cm, accuracy, _ = evaluate(model_name, preds, y, args.country, reduction='avg')
                f1 = metrics.get_f1score(cm, avg=True) 

                # save model
                with open(os.path.join(args.save_dir, args.name + "_pkl"), "wb") as output_file:
                    pickle.dump(model, output_file)

            elif split in ['val', 'test']:
                preds = model.predict(X)
                _, cm, accuracy, _ = evaluate(model_name, preds, y, args.country, reduction='avg')
                f1 = metrics.get_f1score(cm, avg=True) 

            print('{} accuracy: {}, {} f1-score: {}'.format(split, accuracy, split, f1))
            results[f'{split}_acc'].append(accuracy)
            results[f'{split}_f1'].append(f1)
            print('{} cm: {}'.format(split, cm))
            print('{} per class f1 scores: {}'.format(split, metrics.get_f1score(cm, avg=False)))

    for split in ['train', 'val', 'test'] if not args.eval_on_test else ['test']: 
        print('\n------------------------\nOverall Results:\n')
        print('{} accuracy: {} +/- {}'.format(split, np.mean(results[f'{split}_acc']), np.std(results[f'{split}_acc'])))
        print('{} f1-score: {} +/- {}'.format(split, np.mean(results[f'{split}_f1']), np.std(results[f'{split}_f1'])))
Exemplo n.º 2
0
def evaluate_split(model, model_name, split_loader, device, loss_weight, weight_scale, gamma, num_classes, country, var_length):
    total_loss = 0
    total_pixels = 0
    total_cm = np.zeros((num_classes, num_classes)).astype(int) 
    loss_fn = loss_fns.get_loss_fn(model_name)
    for inputs, targets, cloudmasks, hres_inputs in split_loader:
        with torch.set_grad_enabled(False):
            if not var_length:
                inputs.to(device)
            else:
                for sat in inputs:
                    if "length" not in sat:
                        inputs[sat].to(device)
            targets.to(device)
            hres_inputs.to(device)
            if hres_inputs is not None: hres_inputs.to(device)

            preds = model(inputs, hres_inputs) if model_name in MULTI_RES_MODELS else model(inputs)   
            batch_loss, batch_cm, _, num_pixels, confidence = evaluate(model_name, preds, targets, country, loss_fn=loss_fn, reduction="sum", loss_weight=loss_weight, weight_scale=weight_scale, gamma=gamma)
            total_loss += batch_loss.item()
            total_pixels += num_pixels
            total_cm += batch_cm

    f1_avg = metrics.get_f1score(total_cm, avg=True)
    acc_avg = sum([total_cm[i][i] for i in range(num_classes)]) / np.sum(total_cm) 
    return total_loss / total_pixels, f1_avg, acc_avg 
Exemplo n.º 3
0
def classification_report(all_metrics, split, epoch_num, country, save_dir):
    if country in ['ghana', 'southsudan', 'tanzania', 'germany']:
        class_names = CROPS[country]
    else:
        raise ValueError(
            f"Country {country} not supported in visualize.py, record_epoch")

    if all_metrics[f'{split}_loss'] is not None:
        loss_epoch = all_metrics[f'{split}_loss'] / all_metrics[f'{split}_pix']
    if all_metrics[f'{split}_correct'] is not None:
        acc_epoch = all_metrics[f'{split}_correct'] / all_metrics[
            f'{split}_pix']

    observed_accuracy = np.sum(all_metrics[f'{split}_cm'].diagonal()) / np.sum(
        all_metrics[f'{split}_cm'])
    expected_accuracy = np.sum(
        np.sum(all_metrics[f'{split}_cm'], axis=0) *
        np.sum(all_metrics[f'{split}_cm'], axis=1) / np.sum(
            all_metrics[f'{split}_cm'])) / np.sum(all_metrics[f'{split}_cm'])
    kappa = (observed_accuracy - expected_accuracy) / (1 - expected_accuracy)

    fname = os.path.join(save_dir, split + '_classification_report.txt')
    with open(fname, 'a') as f:
        f.write('Country:\n ' + country + '\n\n')
        f.write('Epoch number:\n ' + str(epoch_num) + '\n\n')
        f.write('Split:\n ' + split + '\n\n')
        f.write('Epoch Loss:\n ' + str(loss_epoch) + '\n\n')
        f.write('Epoch Accuracy:\n ' + str(acc_epoch) + '\n\n')
        f.write('Observed Accuracy:\n ' + str(observed_accuracy) + '\n\n')
        f.write(
            'Epoch f1:\n ' +
            str(metrics.get_f1score(all_metrics[f'{split}_cm'], avg=True)) +
            '\n\n')
        f.write('Kappa coefficient:\n ' + str(kappa) + '\n\n')
        f.write('Per class accuracies:\n ' +
                str(all_metrics[f'{split}_cm'].diagonal() /
                    all_metrics[f'{split}_cm'].sum(axis=1)) + '\n\n')
        f.write(
            'Per class f1 scores:\n ' +
            str(metrics.get_f1score(all_metrics[f'{split}_cm'], avg=False)) +
            '\n\n')
        f.write('Crop Class Names:\n ' + str(class_names) + '\n\n')
        f.write('Confusion Matrix:\n ' + str(all_metrics[f'{split}_cm']) +
                '\n\n')
Exemplo n.º 4
0
def train_dl_model(model, model_name, dataloaders, args):
    splits = ['train', 'val'] if not args.eval_on_test else ['test']
    
    if args.clip_val:
        clip_val = sum(p.numel() for p in model.parameters() if p.requires_grad) // 20000
        print('clip value: ', clip_val)

    # set up information lists for visdom    
    vis_logger = visualize.VisdomLogger(args.env_name, model_name, args.country, splits)
    loss_fn = loss_fns.get_loss_fn(model_name)
    optimizer = loss_fns.get_optimizer(model.parameters(), args.optimizer, args.lr, args.momentum, args.weight_decay)
    best_val_f1 = 0
    
    for i in range(args.epochs if not args.eval_on_test else 1):
        print('Epoch: {}'.format(i))
        
        vis_logger.reset_epoch_data()
        
        for split in ['train', 'val'] if not args.eval_on_test else ['test']:
            dl = dataloaders[split]
            model.train() if split == ['train'] else model.eval()
            # TODO: figure out how to pack inputs from dataloader together in the case of variable length sequences
            for inputs, targets, cloudmasks, hres_inputs in tqdm(dl):
                with torch.set_grad_enabled(True):
                    if not args.var_length:
                        inputs.to(args.device)
                        if hres_inputs is not None: hres_inputs.to(args.device)
                    else:
                        for sat in inputs:
                            if "length" not in sat:
                                inputs[sat].to(args.device)
                    targets.to(args.device)
                    preds = model(inputs, hres_inputs) if model_name in MULTI_RES_MODELS else model(inputs)
                    loss, cm_cur, total_correct, num_pixels, confidence = evaluate(model_name, preds, targets, args.country, loss_fn=loss_fn, 
                                              reduction="sum", loss_weight=args.loss_weight, weight_scale=args.weight_scale, gamma=args.gamma)
 
                    if split == 'train' and loss is not None:         # TODO: not sure if we need this check?
                        # If there are valid pixels, update weights
                        optimizer.zero_grad()
                        #with autograd.detect_anomaly():
                        loss.backward()
                        if args.clip_val:
                            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val)
                        optimizer.step()
                       
                        total_norm = 0 
                        for p in model.parameters():
                            if p.grad is not None:
                                param_norm = p.grad.data.norm(2)
                                total_norm += param_norm.item() ** 2
                        gradnorm = total_norm ** (1. / 2)
                        #gradnorm = torch.norm(list(model.parameters())[0].grad).detach().cpu() / torch.prod(torch.tensor(list(model.parameters())[0].shape), dtype=torch.float32)

                        vis_logger.update_progress('train', 'gradnorm', gradnorm)
                    
                    if cm_cur is not None: # TODO: not sure if we need this check?
                        # If there are valid pixels, update metrics
                        vis_logger.update_epoch_all(split, cm_cur, loss, total_correct, num_pixels)
                
                vis_logger.record_batch(inputs, cloudmasks, targets, preds, confidence, 
                                        NUM_CLASSES[args.country], split, 
                                        args.include_doy, args.use_s1, args.use_s2, 
                                        model_name, args.time_slice, var_length=args.var_length)

            if split in ['test']:
                vis_logger.record_epoch(split, i, args.country, save=False, save_dir=os.path.join(args.save_dir, args.name + "_best_dir"))
            else:
                vis_logger.record_epoch(split, i, args.country)

            if split == 'val':
                val_f1 = metrics.get_f1score(vis_logger.epoch_data['val_cm'], avg=True)                 

                if val_f1 > best_val_f1:
                    torch.save(model.state_dict(), os.path.join(args.save_dir, args.name + "_best"))
                    best_val_f1 = val_f1
                    if args.save_best: 
                        # TODO: Ideally, this would save any batch except the last one so that the saved images
                        #  are not only the remainder from the last batch 
                        vis_logger.record_batch(inputs, cloudmasks, targets, preds, confidence, 
                                                NUM_CLASSES[args.country], split, 
                                                args.include_doy, args.use_s1, args.use_s2, 
                                                model_name, args.time_slice, save=True, var_length=args.var_length, 
                                                save_dir=os.path.join(args.save_dir, args.name + "_best_dir"))

                        vis_logger.record_epoch(split, i, args.country, save=True, 
                                              save_dir=os.path.join(args.save_dir, args.name + "_best_dir"))               

                        vis_logger.record_epoch('train', i, args.country, save=True, 
                                              save_dir=os.path.join(args.save_dir, args.name + "_best_dir"))               
Exemplo n.º 5
0
    def record_epoch(self,
                     split,
                     epoch_num,
                     country,
                     save=False,
                     save_dir=None):
        """ Record values for epoch in visdom
        """
        if country in ['ghana', 'southsudan', 'tanzania', 'germany']:
            class_names = CROPS[country]
        else:
            raise ValueError(
                f"Country {country} not supported in visualize.py, record_epoch"
            )

        if self.epoch_data[f'{split}_loss'] is not None:
            loss_epoch = self.epoch_data[f'{split}_loss'] / self.epoch_data[
                f'{split}_pix']
        if self.epoch_data[f'{split}_correct'] is not None:
            acc_epoch = self.epoch_data[f'{split}_correct'] / self.epoch_data[
                f'{split}_pix']

        # Don't append if you are saving. Information has already been appended!
        if save == False:
            self.progress_data[f'{split}_loss'].append(loss_epoch)
            self.progress_data[f'{split}_acc'].append(acc_epoch)
            self.progress_data[f'{split}_f1'].append(
                metrics.get_f1score(self.epoch_data[f'{split}_cm'], avg=True))

            if self.progress_data[f'{split}_classf1'] is None:
                self.progress_data[f'{split}_classf1'] = metrics.get_f1score(
                    self.epoch_data[f'{split}_cm'], avg=False)
                self.progress_data[f'{split}_classf1'] = np.vstack(
                    self.progress_data[f'{split}_classf1']).T
            else:
                self.progress_data[f'{split}_classf1'] = np.vstack(
                    (self.progress_data[f'{split}_classf1'],
                     metrics.get_f1score(self.epoch_data[f'{split}_cm'],
                                         avg=False)))

        for cur_metric in ['loss', 'acc', 'f1']:
            visdom_plot_metric(cur_metric, split, f'{split} {cur_metric}',
                               'Epoch', cur_metric, self.progress_data,
                               self.vis)
            if save or split in ['test']:
                save_dir = save_dir.replace(" ", "")
                save_dir = save_dir.replace(":", "")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                visdom_save_metric(cur_metric, split, f'{split}{cur_metric}',
                                   'Epoch', cur_metric, self.progress_data,
                                   save_dir)

        visdom_plot_many_metrics('classf1', split,
                                 f'{split}_per_class_f1-score', 'Epoch',
                                 'per class f1-score', class_names,
                                 self.progress_data, self.vis)

        fig = util.plot_confusion_matrix(
            self.epoch_data[f'{split}_cm'],
            class_names,
            normalize=True,
            title='{} confusion matrix, epoch {}'.format(split, epoch_num),
            cmap=plt.cm.Blues)

        self.vis.matplot(fig, win=f'{split} CM')
        if save or split in ['test']:
            visdom_save_many_metrics('classf1', split, f'{split}_per_class_f1',
                                     'Epoch', 'per class f1-score',
                                     class_names, self.progress_data, save_dir)
            fig.savefig(os.path.join(save_dir, f'{split}_cm.png'))
            classification_report(self.epoch_data, split, epoch_num, country,
                                  save_dir)