Exemplo n.º 1
0
def get_model(model_name):
    model = None
    if model_name == 'vgg16':
        from models.vgg16 import Vgg16GAP
        model = Vgg16GAP(name="vgg16")
        return model

    if model_name == 'unet':
        from models.unet import UNet
        model = UNet()
        return model

    if model_name == 'deeplab':
        from models.deeplab import DeepLab
        model = DeepLab(name="deeplab")
        return model

    if model_name == 'affinitynet':
        from models.aff_net import AffNet
        model = AffNet(name="affinitynet")
        return model

    if model_name == 'wasscam':
        from models.wass import WASS
        model = WASS()
        return model

    raise Error('Model name has no implementation')
Exemplo n.º 2
0
def testNetwork(images_folder, labels_folder, dictionary, target_classes, dataset_train,
                network_filename, output_folder):
    """
    Load a network and test it on the test dataset.
    :param network_filename: Full name of the network to load (PATH+name)
    """

    # TEST DATASET
    datasetTest = CoralsDataset(images_folder, labels_folder, dictionary, target_classes)
    datasetTest.disableAugumentation()

    datasetTest.num_classes = dataset_train.num_classes
    datasetTest.weights = dataset_train.weights
    datasetTest.dataset_average = dataset_train.dataset_average
    datasetTest.dict_target = dataset_train.dict_target

    output_classes = dataset_train.num_classes

    batchSize = 4
    dataloaderTest = DataLoader(datasetTest, batch_size=batchSize, shuffle=False, num_workers=0, drop_last=True,
                            pin_memory=True)

    # DEEPLAB V3+
    net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
    net.load_state_dict(torch.load(network_filename))
    print("Weights loaded.")

    metrics_test, loss = evaluateNetwork(datasetTest, dataloaderTest, "NONE", None, [0.0], 0.0, 0.0, 0.0, 0, 0, 0,
                                         output_classes, net, True, output_folder)
    metrics_filename = network_filename[:len(network_filename) - 4] + "-test-metrics.txt"
    saveMetrics(metrics_test, metrics_filename)
    print("***** TEST FINISHED *****")

    return metrics_test
Exemplo n.º 3
0
 def create_PredNet(self, ):
     ss = DeepLab(
         num_classes=19,
         backbone=self.cfg_model['basenet']['version'],
         output_stride=16,
         bn=self.cfg_model['bn'],
         freeze_bn=True,
     ).cuda()
     ss.eval()
     return ss
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='training mnist')
    parser.add_argument('--gpu',
                        '-g',
                        default=-1,
                        type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=8,
                        help='Number of images in each mini-batch')
    parser.add_argument('--load_model',
                        '-lm',
                        type=str,
                        default=None,
                        help='Path of the model object to load')

    args = parser.parse_args()

    backbone = 'mobilenet'
    model = ModifiedClassifier(
        DeepLab(n_class=13, task='semantic', backbone=backbone))

    if args.load_model is not None:
        serializers.load_npz(args.load_model, model)
    else:
        print('You need to specify path of the model object')
        sys.exit()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    dir_path = './dataset/2D-3D-S'
    test_data = Stanford2D3DS(dir_path,
                              'semantic',
                              area='5a',
                              train=False,
                              n_data=100)
    test_iter = iterators.MultiprocessIterator(test_data,
                                               args.batchsize,
                                               repeat=False,
                                               shuffle=False)

    label_list = list(test_data.label_dict.keys())[1:]
    evaluator = ModifiedEvaluator(test_iter,
                                  model,
                                  label_names=label_list,
                                  device=args.gpu)
    observation = evaluator()

    for k, v in observation.items():
        print(k, v)
Exemplo n.º 5
0
    def _load_classifier(self, modelName):

        models_dir = "models/"

        network_name = os.path.join(models_dir, modelName)

        classifier_pocillopora = DeepLab(backbone='resnet', output_stride=16, num_classes=self.nclasses)
        classifier_pocillopora.load_state_dict(torch.load(network_name))

        classifier_pocillopora.eval()

        return classifier_pocillopora
Exemplo n.º 6
0
def init_deeplab():
    """Creates Deeplab model """
    model_dir = tempfile.mkdtemp()
    tf.io.gfile.makedirs(model_dir)

    download_path = os.path.join(model_dir, TARBALL_NAME)
    print('Downloading model...')

    urllib.request.urlretrieve(DOWNLOAD_URL_PREFIX + MODEL_URLS, download_path)
    print('Loading DeepLab model...')

    deeplab = DeepLab(download_path)
    print('Done')

    return deeplab
Exemplo n.º 7
0
def get_model():
    if settings.MODEL_NAME == 'deeplab':
        model = DeepLab(output_stride=settings.STRIDE,
                        num_classes=settings.NCLASS)
    elif settings.MODEL_NAME == 'pspnet':
        model = PSPNet(output_stride=settings.STRIDE,
                       num_classes=settings.NCLASS)
    elif settings.MODEL_NAME == 'ann':
        model = ANNNet(output_stride=settings.STRIDE,
                       num_classes=settings.NCLASS)
    elif settings.MODEL_NAME == 'ocnet':
        model = OCNet(output_stride=settings.STRIDE,
                      num_classes=settings.NCLASS)
    elif settings.MODEL_NAME == 'danet':
        model = DANet(output_stride=settings.STRIDE,
                      num_classes=settings.NCLASS)
    elif settings.MODEL_NAME == 'ocrnet':
        model = OCRNet(output_stride=settings.STRIDE,
                       num_classes=settings.NCLASS)
    return model
Exemplo n.º 8
0
def trainingNetwork(images_folder_train, labels_folder_train,
                    images_folder_val, labels_folder_val, dictionary,
                    target_classes, num_classes, save_network_as,
                    classifier_name, epochs, batch_sz, batch_mult,
                    learning_rate, L2_penalty, validation_frequency,
                    flagShuffle, experiment_name, progress):

    ##### DATA #####

    # setup the training dataset
    datasetTrain = CoralsDataset(images_folder_train, labels_folder_train,
                                 dictionary, target_classes, num_classes)

    print("Dataset setup..", end='')
    datasetTrain.computeAverage()
    datasetTrain.computeWeights()
    target_classes = datasetTrain.dict_target
    print("done.")

    datasetTrain.enableAugumentation()

    datasetVal = CoralsDataset(images_folder_val, labels_folder_val,
                               dictionary, target_classes, num_classes)
    datasetVal.dataset_average = datasetTrain.dataset_average
    datasetVal.weights = datasetTrain.weights

    #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET
    datasetVal.disableAugumentation()

    # setup the data loader
    dataloaderTrain = DataLoader(datasetTrain,
                                 batch_size=batch_sz,
                                 shuffle=flagShuffle,
                                 num_workers=0,
                                 drop_last=True,
                                 pin_memory=True)

    validation_batch_size = 4
    dataloaderVal = DataLoader(datasetVal,
                               batch_size=validation_batch_size,
                               shuffle=False,
                               num_workers=0,
                               drop_last=True,
                               pin_memory=True)

    training_images_number = len(datasetTrain.images_names)
    validation_images_number = len(datasetVal.images_names)

    ###### SETUP THE NETWORK #####
    net = DeepLab(backbone='resnet',
                  output_stride=16,
                  num_classes=datasetTrain.num_classes)
    models_dir = "models/"
    network_name = os.path.join(models_dir, "deeplab-resnet.pth.tar")
    state = torch.load(network_name)
    # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER
    new_dictionary = state['state_dict']
    del new_dictionary['decoder.last_conv.8.weight']
    del new_dictionary['decoder.last_conv.8.bias']
    net.load_state_dict(state['state_dict'], strict=False)
    print("NETWORK USED: DEEPLAB V3+")

    # LOSS

    weights = datasetTrain.weights
    class_weights = torch.FloatTensor(weights).cuda()
    lossfn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    # OPTIMIZER
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.0002, momentum=0.9)
    optimizer = optim.Adam(net.parameters(),
                           lr=learning_rate,
                           weight_decay=L2_penalty)

    USE_CUDA = torch.cuda.is_available()

    if USE_CUDA:
        device = torch.device("cuda")
        net.to(device)

    ##### TRAINING LOOP #####

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=2,
                                                     verbose=True)

    best_accuracy = 0.0
    best_jaccard_score = 0.0

    print("Training Network")
    for epoch in range(epochs):  # loop over the dataset multiple times

        txt = "Epoch " + str(epoch + 1) + "/" + str(epochs)
        progress.setMessage(txt)
        progress.setProgress((100.0 * epoch) / epochs)
        QApplication.processEvents()

        net.train()
        optimizer.zero_grad()
        running_loss = 0.0
        for i, minibatch in enumerate(dataloaderTrain):
            # get the inputs
            images_batch = minibatch['image']
            labels_batch = minibatch['labels']

            if USE_CUDA:
                images_batch = images_batch.to(device)
                labels_batch = labels_batch.to(device)

            # forward+loss+backward
            outputs = net(images_batch)
            loss = lossfn(outputs, labels_batch)
            loss.backward()

            # TO AVOID MEMORY TRUBLE UPDATE WEIGHTS EVERY BATCH SIZE X BATCH MULT
            if (i + 1) % batch_mult == 0:
                optimizer.step()
                optimizer.zero_grad()

            print(epoch, i, loss.item())
            running_loss += loss.item()

        print("Epoch: %d , Running loss = %f" % (epoch, running_loss))

        ### VALIDATION ###
        if epoch > 0 and (epoch + 1) % validation_frequency == 0:

            print("RUNNING VALIDATION.. ", end='')

            # datasetVal.weights are the same of datasetTrain
            metrics_val, mean_loss_val = evaluateNetwork(
                dataloaderVal,
                datasetVal.weights,
                datasetVal.num_classes,
                net,
                flagTrainingDataset=False)
            accuracy = metrics_val['Accuracy']
            jaccard_score = metrics_val['JaccardScore']

            scheduler.step(mean_loss_val)

            metrics_train, mean_loss_train = evaluateNetwork(
                dataloaderTrain,
                datasetTrain.weights,
                datasetTrain.num_classes,
                net,
                flagTrainingDataset=True)
            accuracy_training = metrics_train['Accuracy']
            jaccard_training = metrics_train['JaccardScore']

            if jaccard_score > best_jaccard_score:

                best_accuracy = accuracy
                best_jaccard_score = jaccard_score
                torch.save(net.state_dict(), save_network_as)
                # performance of the best accuracy network on the validation dataset
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-val-metrics.txt"
                saveMetrics(metrics_val, metrics_filename)
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-train-metrics.txt"
                saveMetrics(metrics_train, metrics_filename)

            print("-> CURRENT BEST ACCURACY ", best_accuracy)

    print("***** TRAINING FINISHED *****")

    return datasetTrain
Exemplo n.º 9
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=1)
        model = nn.DataParallel(model)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model'].module
        model = nn.DataParallel(model)
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = DIMDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=num_workers)
    # valid_dataset = DIMDataset('valid')
    # valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False,
    #                                            num_workers=num_workers)

    # scheduler = MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # scheduler.step(epoch)

        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger)
        effective_lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(effective_lr))

        writer.add_scalar('model/train_loss', train_loss, epoch)

        # One epoch's validation
        # valid_loss = valid(valid_loader=valid_loader,
        #                    model=model,
        #                    logger=logger)
        #
        # writer.add_scalar('Valid_Loss', valid_loss, epoch)

        # One epoch's test
        sad_loss, mse_loss = test(model)
        writer.add_scalar('model/sad_loss', sad_loss, epoch)
        writer.add_scalar('model/mse_loss', mse_loss, epoch)

        # Print status
        status = 'Test: SAD {:.4f} MSE {:.4f}\n'.format(sad_loss, mse_loss)
        logger.info(status)

        # Check if there was an improvement
        is_best = mse_loss < best_loss
        best_loss = min(mse_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
Exemplo n.º 10
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        model = DeepLab(backbone='mobilenet',
                        output_stride=16,
                        num_classes=num_classes)
        model = nn.DataParallel(model)

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.99),
                                     weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = MICDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)
    valid_dataset = MICDataset('val')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=num_workers)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_acc = train(train_loader=train_loader,
                                      model=model,
                                      optimizer=optimizer,
                                      epoch=epoch,
                                      logger=logger)
        lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(lr))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_acc', train_acc, epoch)

        # One epoch's validation
        valid_loss, valid_acc = valid(valid_loader=valid_loader,
                                      model=model,
                                      logger=logger)

        writer.add_scalar('model/valid_loss', valid_loss, epoch)
        writer.add_scalar('model/valid_acc', valid_acc, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        best_loss, is_best)
Exemplo n.º 11
0
import torch
from torchscope import scope

from config import num_classes
from models.deeplab import DeepLab

if __name__ == "__main__":
    model = DeepLab(backbone='mobilenet',
                    output_stride=16,
                    num_classes=num_classes)
    model.eval()
    input = torch.rand(1, 1, 256, 256)
    output = model(input)
    print(output.size())
    scope(model, (1, 256, 256))

    # model = models.segmentation.deeplabv3_resnet101(pretrained=True, num_classes=num_classes)
    # model.eval()
    # input = torch.rand(1, 3, 256, 256)
    # output = model(input)['out']
    # print(output.size())
    # scope(model, (3, 256, 256))
Exemplo n.º 12
0
    windowing_params['image_window_shape'] = [1, 572, 572]
    windowing_params['label_window_shape'] = [1, 388, 388]
    model = UNet(n_classes=n_classes, padding=False,
                 up_mode='upconv').to(device)
    channels = 1
    backbone = ""

elif experiment == "DeepLab":
    # Fix the windowing parameters
    windowing_params = dict()
    windowing_params['image_window_shape'] = [1, 572, 572]
    windowing_params['label_window_shape'] = [1, 572, 572]
    windowing_params['window_spacing'] = [1, 572, 572]
    windowing_params['random_windowing'] = False
    model = DeepLab(num_classes=n_classes,
                    backbone=backbone,
                    freeze_bn=False,
                    sync_bn=False)
    channels = 3

# Load the data
eval_image = b3d.load(eval_dataset_dir, eval_data_file)
eval_image = (eval_image - eval_image.min()) * 255 / (eval_image.max() -
                                                      eval_image.min())
eval_image -= np.mean(eval_image)
eval_image /= np.std(eval_image)

# Specify data type when loading label data
eval_label = b3d.load(eval_dataset_dir, eval_label_file, data_type=np.int32)

print(f"Shape of Image: {eval_image.shape}")
print(f"Shape of Labels: {eval_label.shape}")
Exemplo n.º 13
0
    if (args.model_type_G == "unet4"):
        model_G = UNet4(n_in_channels=args.n_in_channels,
                        n_out_channels=args.n_classes,
                        n_fmaps=64).to(device)
    elif (args.model_type_G == "unet4_resnet"):
        model_G = UNet4ResNet34(n_in_channels=args.n_in_channels,
                                n_out_channels=args.n_classes,
                                n_fmaps=64,
                                pretrained=True).to(device)
    elif (args.model_type_G == "unet_fgvc6"):
        model_G = UNetFGVC6(n_channels=args.n_in_channels,
                            n_classes=args.n_classes).to(device)
    elif (args.model_type_G == "deeplab_v3"):
        model_G = DeepLab(backbone='resnet',
                          n_in_channels=args.n_in_channels,
                          output_stride=16,
                          num_classes=args.n_classes,
                          pretrained_backbone=True).to(device)
    else:
        NotImplementedError()

    if (args.debug):
        print("model_G :\n", model_G)

    # モデルを読み込む
    if not args.load_checkpoints_path_G == '' and os.path.exists(
            args.load_checkpoints_path_G):
        load_checkpoint(model_G, device, args.load_checkpoints_path_G)

    #================================
    # optimizer_G の設定
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description='training mnist')
    parser.add_argument('--gpu', '-g', default=-1, type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch', '-e', type=int, default=100,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--batchsize', '-b', type=int, default=8,
                        help='Number of images in each mini-batch')
    parser.add_argument('--seed', '-s', type=int, default=0,
                        help='Random seed')
    parser.add_argument('--report_trigger', '-rt', type=str, default='1e',
                        help='Interval for reporting(Ex.100i, default:1e)')
    parser.add_argument('--save_trigger', '-st', type=str, default='1e',
                        help='Interval for saving the model(Ex.100i, default:1e)')
    parser.add_argument('--load_model', '-lm', type=str, default=None,
                        help='Path of the model object to load')
    parser.add_argument('--load_optimizer', '-lo', type=str, default=None,
                        help='Path of the optimizer object to load')
    args = parser.parse_args()

    start_time = datetime.now()
    save_dir = Path('output/{}'.format(start_time.strftime('%Y%m%d_%H%M')))

    random.seed(args.seed)
    np.random.seed(args.seed)
    cupy.random.seed(args.seed)

    backbone = 'mobilenet'
    model = ModifiedClassifier(DeepLab(n_class=13, task='semantic', backbone=backbone), lossfun=F.softmax_cross_entropy)
    if args.load_model is not None:
        serializers.load_npz(args.load_model, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    optimizer = optimizers.Adam(alpha=1e-3)
    optimizer.setup(model)
    if args.load_optimizer is not None:
        serializers.load_npz(args.load_optimizer, optimizer)

    dir_path = './dataset/2D-3D-S/'
    augmentations = {'mirror': 0.5, 'flip': 0.5}
    train_data = Stanford2D3DS(dir_path, 'semantic', area='1 2 3 4', train=True)
    train_data.set_augmentations(crop=513, augmentations=augmentations)
    valid_data = Stanford2D3DS(dir_path, 'semantic', area='6', train=False, n_data=100)
    valid_data.set_augmentations(crop=513)

    train_iter = iterators.MultiprocessIterator(train_data, args.batchsize, n_processes=1)
    valid_iter = iterators.MultiprocessIterator(valid_data, args.batchsize, repeat=False, shuffle=False, n_processes=1)

    updater = StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = Trainer(updater, (args.epoch, 'epoch'), out=save_dir)

    label_list = list(valid_data.label_dict.keys())[1:]
    report_trigger = (int(args.report_trigger[:-1]), 'iteration' if args.report_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.LogReport(trigger=report_trigger))
    trainer.extend(ModifiedEvaluator(valid_iter, model, label_names=label_list,
                                     device=args.gpu), name='val', trigger=report_trigger)

    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'main/acc', 'val/main/loss',
                                           'val/main/acc', 'val/main/mean_class_acc', 'val/main/miou',
                                           'elapsed_time']), trigger=report_trigger)

    trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key=report_trigger[1],
                                         marker='.', file_name='loss.png', trigger=report_trigger))
    trainer.extend(extensions.PlotReport(['main/acc', 'val/main/acc'], x_key=report_trigger[1],
                                         marker='.', file_name='accuracy.png', trigger=report_trigger))
    class_accuracy_report = ['val/main/mean_class_acc']
    class_accuracy_report.extend(['val/main/class_acc/{}'.format(label) for label in label_list])
    class_iou_report = ['val/main/miou']
    class_iou_report.extend(['val/main/iou/{}'.format(label) for label in label_list])
    trainer.extend(extensions.PlotReport(class_accuracy_report, x_key=report_trigger[1],
                                         marker='.', file_name='class_accuracy.png', trigger=report_trigger))
    trainer.extend(extensions.PlotReport(class_iou_report, x_key=report_trigger[1],
                                         marker='.', file_name='class_iou.png', trigger=report_trigger))

    save_trigger = (int(args.save_trigger[:-1]), 'iteration' if args.save_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.snapshot_object(model, filename='model_{0}-{{.updater.{0}}}.npz'
                                              .format(save_trigger[1])), trigger=save_trigger)
    trainer.extend(extensions.snapshot_object(optimizer, filename='optimizer_{0}-{{.updater.{0}}}.npz'
                                              .format(save_trigger[1])), trigger=save_trigger)

    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir()
    (save_dir / 'training_details').mkdir()

    # Write parameters text
    with open(save_dir / 'training_details/train_params.txt', 'w') as f:
        f.write('model: {}(backbone: {})\n'.format(model.predictor.__class__.__name__, backbone))
        f.write('n_epoch: {}\n'.format(args.epoch))
        f.write('batch_size: {}\n'.format(args.batchsize))
        f.write('n_data_train: {}\n'.format(len(train_data)))
        f.write('n_data_val: {}\n'.format(len(valid_data)))
        f.write('seed: {}\n'.format(args.seed))
        if len(augmentations) > 0:
            f.write('[augmentation]\n')
            for process in augmentations:
                f.write('  {}: {}\n'.format(process, augmentations[process]))

    trainer.run()
Exemplo n.º 15
0
    def __init__(self, cfg, writer, logger):
        # super(CustomModel, self).__init__()
        self.cfg = cfg
        self.writer = writer
        self.class_numbers = 19
        self.logger = logger
        cfg_model = cfg['model']
        self.cfg_model = cfg_model
        self.best_iou = -100
        self.iter = 0
        self.nets = []
        self.split_gpu = 0
        self.default_gpu = cfg['model']['default_gpu']
        self.PredNet_Dir = None
        self.valid_classes = cfg['training']['valid_classes']
        self.G_train = True
        self.objective_vectors = np.zeros([19, 256])
        self.objective_vectors_num = np.zeros([19])
        self.objective_vectors_dis = np.zeros([19, 19])
        self.class_threshold = np.zeros(self.class_numbers)
        self.class_threshold = np.full([19], 0.95)
        self.metrics = CustomMetrics(self.class_numbers)
        self.cls_feature_weight = cfg['training']['cls_feature_weight']

        bn = cfg_model['bn']
        if bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        # elif bn == 'sync_abn':
        #     BatchNorm = InPlaceABNSync
        elif bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        # elif bn == 'abn':
        #     BatchNorm = InPlaceABN
        elif bn == 'gn':
            BatchNorm = nn.GroupNorm
        else:
            raise NotImplementedError(
                'batch norm choice {} is not implemented'.format(bn))
        self.PredNet = DeepLab(
            num_classes=19,
            backbone=cfg_model['basenet']['version'],
            output_stride=16,
            bn=cfg_model['bn'],
            freeze_bn=True,
        ).cuda()
        self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet)
        self.PredNet_DP = self.init_device(self.PredNet,
                                           gpu_id=self.default_gpu,
                                           whether_DP=True)
        self.PredNet.eval()
        self.PredNet_num = 0

        self.BaseNet = DeepLab(
            num_classes=19,
            backbone=cfg_model['basenet']['version'],
            output_stride=16,
            bn=cfg_model['bn'],
            freeze_bn=False,
        )

        logger.info('the backbone is {}'.format(
            cfg_model['basenet']['version']))

        self.BaseNet_DP = self.init_device(self.BaseNet,
                                           gpu_id=self.default_gpu,
                                           whether_DP=True)
        self.nets.extend([self.BaseNet])
        self.nets_DP = [self.BaseNet_DP]

        self.optimizers = []
        self.schedulers = []
        # optimizer_cls = get_optimizer(cfg)
        optimizer_cls = torch.optim.SGD
        optimizer_params = {
            k: v
            for k, v in cfg['training']['optimizer'].items() if k != 'name'
        }
        # optimizer_cls_D = torch.optim.SGD
        # optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items()
        #                     if k != 'name'}
        self.BaseOpti = optimizer_cls(self.BaseNet.parameters(),
                                      **optimizer_params)
        self.optimizers.extend([self.BaseOpti])

        self.BaseSchedule = get_scheduler(self.BaseOpti,
                                          cfg['training']['lr_schedule'])
        self.schedulers.extend([self.BaseSchedule])
        self.setup(cfg, writer, logger)

        self.adv_source_label = 0
        self.adv_target_label = 1
        self.bceloss = nn.BCEWithLogitsLoss(size_average=True)
        self.loss_fn = get_loss_function(cfg)
        self.mseloss = nn.MSELoss()
        self.l1loss = nn.L1Loss()
        self.smoothloss = nn.SmoothL1Loss()
        self.triplet_loss = nn.TripletMarginLoss()
Exemplo n.º 16
0
def main():

    # define and parse arguments
    parser = argparse.ArgumentParser()

    # general
    parser.add_argument('--experiment_name',
                        type=str,
                        default="experiment",
                        help="experiment name. will be used in the path names \
                             for log- and savefiles")
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='fixes random seed and sets model to \
                              the potentially faster cuDNN deterministic mode \
                              (default: non-deterministic mode)')
    parser.add_argument('--val_freq',
                        type=int,
                        default=1000,
                        help='validation will be run every val_freq \
                        batches/optimization steps during training')
    parser.add_argument('--save_freq',
                        type=int,
                        default=1000,
                        help='training state will be saved every save_freq \
                        batches/optimization steps during training')
    parser.add_argument('--log_freq',
                        type=int,
                        default=100,
                        help='tensorboard logs will be written every log_freq \
                              number of batches/optimization steps')

    # input/output
    parser.add_argument('--use_s2hr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 high-resolution (10 m) bands')
    parser.add_argument('--use_s2mr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 medium-resolution (20 m) bands')
    parser.add_argument('--use_s2lr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 low-resolution (60 m) bands')
    parser.add_argument('--use_s1',
                        action='store_true',
                        default=False,
                        help='use sentinel-1 data')
    parser.add_argument('--no_savanna',
                        action='store_true',
                        default=False,
                        help='ignore class savanna')

    # training hyperparameters
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum (default: 0.9), only used for deeplab')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight-decay (default: 5e-4)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='batch size for training and validation \
                              (default: 32)')
    parser.add_argument('--workers',
                        type=int,
                        default=4,
                        help='number of workers for dataloading (default: 4)')
    parser.add_argument('--max_epochs',
                        type=int,
                        default=100,
                        help='number of training epochs (default: 100)')

    # network
    parser.add_argument('--model',
                        type=str,
                        choices=['deeplab', 'unet'],
                        default='deeplab',
                        help="network architecture (default: deeplab)")

    # deeplab-specific
    parser.add_argument('--pretrained_backbone',
                        action='store_true',
                        default=False,
                        help='initialize ResNet-101 backbone with ImageNet \
                              pre-trained weights')
    parser.add_argument('--out_stride',
                        type=int,
                        choices=[8, 16],
                        default=16,
                        help='network output stride (default: 16)')

    # data
    parser.add_argument('--data_dir_train',
                        type=str,
                        default=None,
                        help='path to training dataset')
    parser.add_argument(
        '--dataset_val',
        type=str,
        default="sen12ms_holdout",
        choices=['sen12ms_holdout', 'dfc2020_val', 'dfc2020_test'],
        help='dataset to use for validation (default: \
                             sen12ms_holdout)')
    parser.add_argument('--data_dir_val',
                        type=str,
                        default=None,
                        help='path to validation dataset')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='path to dir for tensorboard logs \
                              (default runs/CURRENT_DATETIME_HOSTNAME)')

    args = parser.parse_args()
    print("=" * 20, "CONFIG", "=" * 20)
    for arg in vars(args):
        print('{0:20}  {1}'.format(arg, getattr(args, arg)))
    print()

    # fix seeds and set pytorch to deterministic mode
    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # set flags for GPU processing if available
    if torch.cuda.is_available():
        args.use_gpu = True
        if torch.cuda.device_count() > 1:
            raise NotImplementedError("multi-gpu training not implemented! " +
                                      "try to run script as: " +
                                      "CUDA_VISIBLE_DEVICES=0 train.py")
    else:
        args.use_gpu = False

    # load datasets
    train_set = SEN12MS(args.data_dir_train,
                        subset="train",
                        no_savanna=args.no_savanna,
                        use_s2hr=args.use_s2hr,
                        use_s2mr=args.use_s2mr,
                        use_s2lr=args.use_s2lr,
                        use_s1=args.use_s1)
    n_classes = train_set.n_classes
    n_inputs = train_set.n_inputs
    if args.dataset_val == "sen12ms_holdout":
        val_set = SEN12MS(args.data_dir_train,
                          subset="holdout",
                          no_savanna=args.no_savanna,
                          use_s2hr=args.use_s2hr,
                          use_s2mr=args.use_s2mr,
                          use_s2lr=args.use_s2lr,
                          use_s1=args.use_s1)
    else:
        dfc2020_subset = args.dataset_val.split("_")[-1]
        val_set = DFC2020(args.data_dir_val,
                          subset=dfc2020_subset,
                          no_savanna=args.no_savanna,
                          use_s2hr=args.use_s2hr,
                          use_s2mr=args.use_s2mr,
                          use_s2lr=args.use_s2lr,
                          use_s1=args.use_s1)

    # set up dataloaders
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=False)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            drop_last=False)

    # set up network
    if args.model == "deeplab":
        model = DeepLab(num_classes=n_classes,
                        backbone='resnet',
                        pretrained_backbone=args.pretrained_backbone,
                        output_stride=args.out_stride,
                        sync_bn=False,
                        freeze_bn=False,
                        n_in=n_inputs)
    else:
        model = UNet(n_classes=n_classes, n_channels=n_inputs)

    if args.use_gpu:
        model = model.cuda()

    # define loss function
    loss_fn = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    # set up optimizer
    if args.model == "deeplab":
        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)

    # set up tensorboard logging
    if args.log_dir is None:
        args.log_dir = "logs"
    writer = SummaryWriter(
        log_dir=os.path.join(args.log_dir, args.experiment_name))

    # create checkpoint dir
    args.checkpoint_dir = os.path.join(args.log_dir, args.experiment_name,
                                       "checkpoints")
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    # save config
    pkl.dump(args, open(os.path.join(args.checkpoint_dir, "args.pkl"), "wb"))

    # train network
    step = 0
    trainer = ModelTrainer(args)
    for epoch in range(args.max_epochs):
        print("=" * 20, "EPOCH", epoch + 1, "/", str(args.max_epochs),
              "=" * 20)

        # run training for one epoch
        model, step = trainer.train(model,
                                    train_loader,
                                    val_loader,
                                    loss_fn,
                                    optimizer,
                                    writer,
                                    step=step)

    # export final set of weights
    trainer.export_model(model, args.checkpoint_dir, name="final")
Exemplo n.º 17
0
import matplotlib as mpl
mpl.use('TkAgg')

#basic setting
num_epochs = 30
print_iter = 25
batch_size = 2
vis_result = True
validation_ratio = 0.05
startlr = 1.25e-3
boundary_flag = True

#model load
#model = model = PSPNet(n_classes=30, n_blocks=[3, 4, 6, 3], pyramids=[6, 3, 2, 1]).cuda()
if boundary_flag:
    model = DeepLab(backbone='xception', output_stride=16,
                    num_classes=24 + 2).cuda()
else:
    model = DeepLab(backbone='xception', output_stride=16,
                    num_classes=24).cuda()
model = torch.load("./ckpt/20.pth")
model_name = model.__class__.__name__

#dataset load
aug = Compose(
    [
        HorizontalFlip(0.5),
        OneOf([
            MotionBlur(p=0.2),
            MedianBlur(blur_limit=3, p=0.1),
            ISONoise(p=0.3),
            Blur(blur_limit=3, p=0.1),
Exemplo n.º 18
0
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))
Exemplo n.º 19
0
def trainingNetwork(images_folder_train, labels_folder_train, images_folder_val, labels_folder_val,
                    dictionary, target_classes, output_classes, save_network_as, classifier_name,
                    epochs, batch_sz, batch_mult, learning_rate, L2_penalty, validation_frequency, loss_to_use,
                    epochs_switch, epochs_transition, tversky_alpha, tversky_gamma, optimiz,
                    flag_shuffle, flag_training_accuracy, progress):

    ##### DATA #####

    # setup the training dataset
    datasetTrain = CoralsDataset(images_folder_train, labels_folder_train, dictionary, target_classes)

    print("Dataset setup..", end='')
    datasetTrain.computeAverage()
    datasetTrain.computeWeights()
    print(datasetTrain.dict_target)
    print(datasetTrain.weights)
    freq = 1.0 / datasetTrain.weights
    print(freq)
    print("done.")

    save_classifier_as = save_network_as.replace(".net", ".json")

    datasetTrain.enableAugumentation()

    datasetVal = CoralsDataset(images_folder_val, labels_folder_val, dictionary, target_classes)
    datasetVal.dataset_average = datasetTrain.dataset_average
    datasetVal.weights = datasetTrain.weights

    #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET
    datasetVal.disableAugumentation()

    # setup the data loader
    dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_sz, shuffle=flag_shuffle, num_workers=0, drop_last=True,
                                 pin_memory=True)

    validation_batch_size = 4
    dataloaderVal = DataLoader(datasetVal, batch_size=validation_batch_size, shuffle=False, num_workers=0, drop_last=True,
                                 pin_memory=True)

    training_images_number = len(datasetTrain.images_names)
    validation_images_number = len(datasetVal.images_names)

    print("NETWORK USED: DEEPLAB V3+")

    if os.path.exists(save_network_as):
        net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
        net.load_state_dict(torch.load(save_network_as))
        print("Checkpoint loaded.")
    else:
        ###### SETUP THE NETWORK #####
        net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
        state = torch.load("models/deeplab-resnet.pth.tar")
        # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER
        new_dictionary = state['state_dict']
        del new_dictionary['decoder.last_conv.8.weight']
        del new_dictionary['decoder.last_conv.8.bias']
        net.load_state_dict(state['state_dict'], strict=False)

    # OPTIMIZER
    if optimiz == "SGD":
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=L2_penalty, momentum=0.9)
    elif optimiz == "ADAM":
        optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=L2_penalty)

    USE_CUDA = torch.cuda.is_available()

    if USE_CUDA:
        device = torch.device("cuda")
        net.to(device)

    ##### TRAINING LOOP #####

    reduce_lr_patience = 2
    if loss_to_use == "DICE+BOUNDARY":
        reduce_lr_patience = 200
        print("patience increased !")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=reduce_lr_patience, verbose=True)

    best_accuracy = 0.0
    best_jaccard_score = 0.0

    # Crossentropy loss
    weights = datasetTrain.weights
    class_weights = torch.FloatTensor(weights).cuda()
    CEloss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    # weights for GENERALIZED DICE LOSS (GDL)
    freq = 1.0 / datasetTrain.weights[1:]
    w = 1.0 / (freq * freq)
    w = w / w.sum() + 0.00001
    w_for_GDL = torch.from_numpy(w)
    w_for_GDL = w_for_GDL.to(device)

    # Focal Tversky loss
    focal_tversky_gamma = torch.tensor(tversky_gamma)
    focal_tversky_gamma = focal_tversky_gamma.to(device)

    tversky_loss_alpha = torch.tensor(tversky_alpha)
    tversky_loss_beta = torch.tensor(1.0 - tversky_alpha)
    tversky_loss_alpha = tversky_loss_alpha.to(device)
    tversky_loss_beta = tversky_loss_beta.to(device)



    print("Training Network")
    num_iter = 0
    total_iter = epochs * int(len(datasetTrain) / dataloaderTrain.batch_size)
    for epoch in range(epochs):

        net.train()
        optimizer.zero_grad()

        loss_values = []
        for i, minibatch in enumerate(dataloaderTrain):

            txt = "Training - Iterations " + str(num_iter + 1) + "/" + str(total_iter)
            progress.setMessage(txt)
            progress.setProgress((100.0 * num_iter) / total_iter)
            QApplication.processEvents()
            num_iter += 1

            # get the inputs
            images_batch = minibatch['image']
            labels_batch = minibatch['labels']

            if USE_CUDA:
                images_batch = images_batch.to(device)
                labels_batch = labels_batch.to(device)

            # forward+loss+backward
            outputs = net(images_batch)

            loss = computeLoss(loss_to_use, CEloss, w_for_GDL, tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                               epoch, epochs_switch, epochs_transition, labels_batch, outputs)

            loss.backward()

            # TO AVOID MEMORY TROUBLE UPDATE WEIGHTS EVERY BATCH SIZE x BATCH MULT
            if (i+1)% batch_mult == 0:
                optimizer.step()
                optimizer.zero_grad()

            print(epoch, i, loss.item())
            loss_values.append(loss.item())

        mean_loss_train = sum(loss_values) / len(loss_values)
        print("Epoch: %d , Mean loss = %f" % (epoch, mean_loss_train))

        ### VALIDATION ###
        if epoch > 0 and (epoch+1) % validation_frequency == 0:

            print("RUNNING VALIDATION.. ", end='')

            metrics_val, mean_loss_val = evaluateNetwork(datasetVal, dataloaderVal, loss_to_use, CEloss, w_for_GDL,
                                                         tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                                                         epoch, epochs_switch, epochs_transition,
                                                         output_classes, net, flag_compute_mIoU=False)
            accuracy = metrics_val['Accuracy']
            jaccard_score = metrics_val['JaccardScore']

            scheduler.step(mean_loss_val)

            accuracy_training = 0.0
            jaccard_training = 0.0

            if flag_training_accuracy is True:
                metrics_train, mean_loss_train = evaluateNetwork(datasetTrain, dataloaderTrain, loss_to_use, CEloss, w_for_GDL,
                                                                 tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                                                                 epoch, epochs_switch, epochs_transition,
                                                                 output_classes, net, flag_compute_mIoU=False)
                accuracy_training = metrics_train['Accuracy']
                jaccard_training = metrics_train['JaccardScore']

            #if jaccard_score > best_jaccard_score:
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_jaccard_score = jaccard_score
                torch.save(net.state_dict(), save_network_as)
                # performance of the best accuracy network on the validation dataset
                metrics_filename = save_network_as[:len(save_network_as) - 4] + "-val-metrics.txt"
                saveMetrics(metrics_val, metrics_filename)


            print("-> CURRENT BEST ACCURACY ", best_accuracy)


    # main loop ended
    torch.cuda.empty_cache()
    del net
    net = None

    print("***** TRAINING FINISHED *****")
    print("BEST ACCURACY REACHED ON THE VALIDATION SET: %.3f " % best_accuracy)

    return datasetTrain