示例#1
0
def train_test(tester, arch, enable_gpu):
    pin_memory = enable_gpu
    dataloaders, class_to_idx = model_helper.get_dataloders(
        data_dir, enable_gpu, num_workers, pin_memory)

    model, optimizer, criterion = model_helper.create_model(
        arch, learning_rate, hidden_units, class_to_idx)
    if enable_gpu:
        model.cuda()
    else:
        torch.set_num_threads(num_cpu_threads)

    epochs = gpu_epochs if enable_gpu else cpu_epochs

    model_helper.train(model, criterion, optimizer, epochs,
                       dataloaders['training'], dataloaders['validation'],
                       enable_gpu)

    checkpoint_dir = testing_dir + '/gpu' if enable_gpu else '/cpu'

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    checkpoint = checkpoint_dir + '/' + arch + '_checkpoint.pth'

    model_helper.save_checkpoint(checkpoint, model, optimizer, arch,
                                 learning_rate, hidden_units, epochs)
示例#2
0
def main():
    start_time = time()
    in_args = get_input_args()
    use_gpu = torch.cuda.is_available() and in_args.gpu

    print("Training on {} using {}".format("GPU" if use_gpu else "CPU",
                                           in_args.arch))

    print(
        "Architecture:{}, Learning rate:{}, Hidden Units:{}, Epochs:{}".format(
            in_args.arch, in_args.learning_rate, in_args.hidden_units,
            in_args.epochs))

    dataloaders, class_to_idx = model_helper.get_dataloders(in_args.data_dir)

    model, optimizer, criterion = model_helper.create_model(
        in_args.arch, in_args.learning_rate, in_args.hidden_units,
        class_to_idx)

    if use_gpu:
        model.cuda()
        criterion.cuda()
    else:
        torch.set_num_threads(in_args.num_threads)

    model_helper.train(model, criterion, optimizer, in_args.epochs,
                       dataloaders['training'], dataloaders['validation'],
                       use_gpu)

    if in_args.save_dir:
        if not os.path.exists(in_args.save_dir):
            os.makedirs(in_args.save_dir)

        file_path = in_args.save_dir + '/' + in_args.arch + '_checkpoint.pth'
    else:
        file_path = in_args.arch + '_checkpoint.pth'

    model_helper.save_checkpoint(file_path, model, optimizer, in_args.arch,
                                 in_args.learning_rate, in_args.hidden_units,
                                 in_args.epochs)

    test_loss, accuracy = model_helper.validate(model, criterion,
                                                dataloaders['testing'],
                                                use_gpu)
    print("Test Accuracy: {:.3f}".format(accuracy))

    end_time = time()
    utility.print_elapsed_time(end_time - start_time)
示例#3
0
def inference(ckpt, input_file, output_file, hparams):
    """
    Generate predictions.
    Args:
        ckpt        - the model checkpoint file to use
        input_file  - the input file with the inference data
        output_file - the name of the file to save the predictions to
        hparams     - hyperparameters for this model
    """

    infer_tuple = create_model(hparams, tf.contrib.learn.ModeKeys.INFER)
    infer_tuple.model.saver.restore(infer_tuple.session, ckpt)
    infer_tuple.session.run([infer_tuple.iterator.initializer])

    outputs = []
    while True:
        try:
            sample, output = infer_tuple.model.infer(infer_tuple.session)
            outputs.append((sample, output))
        except tf.errors.OutOfRangeError:
            print(" - Done -")
            break

    return outputs
示例#4
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip),\
        'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)
    except IndexError:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(
        sitk.ReadImage(join(args.data_root_dir, 'imgs',
                            train_list[0][0]))).shape
    net_input_shape = (img_shape[0], img_shape[1], args.slices)
    print(net_input_shape)
    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = args.save_prefix +\
        'split-' + str(args.split_num) + \
        '-batch-' + str(args.batch_size) + \
        '_shuff-' + str(args.shuffle_data) + \
        '_aug-' + str(args.aug_data) + \
        '_loss-' + str(args.loss) + \
        '_strid-' + str(args.stride) + \
        '_lr-' + str(args.initial_lr) + \
        '_recon-' + str(args.recon_wei)

    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except FileExistsError:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except FileExistsError:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except FileExistsError:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except FileExistsError:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)
示例#5
0
def main():

    # get the input arguments
    in_arg = get_input_args()
    print_input_arguments(in_arg)

    # load the train dataset (to get the class to idx mapping) and the dataloaders for training validation and test
    train_data, trainloader, validationloader, testloader = load_data(
        in_arg.data)

    # Load the categry label to category name mapping
    with open('cat_to_name.json', 'r') as f:
        cat_to_name = json.load(f)

    # Load the model
    model = create_model(in_arg.arch)

    # Freeze parameters so we don't backprop through them
    for param in model.parameters():
        param.requires_grad = False

    # Get the updated classifier based on the hidden layers parsed
    classifier, input_size = create_classifier(model, in_arg.units)
    model.classifier = classifier
    print_model(model)

    # Chose the model based on the "GPU" input
    if in_arg.gpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = 'cpu'
    print(f"Device used: {device}")

    # Start training the model
    model, criterion, optimizer, train_loss_toprint, valid_loss_toprint, valid_accuracy_toprint = train_model(
        model, device, in_arg.epochs, trainloader, in_arg.lr, validationloader)

    if in_arg.print_graph:
        model_validation(device, model, testloader, criterion,
                         train_loss_toprint, valid_loss_toprint,
                         valid_accuracy_toprint)

    checkpoint = {
        'input_size': input_size,
        'output_size': 102,
        'arch': in_arg.arch,
        'learning_rate': in_arg.lr,
        'batch_size': 64,
        'classifier': classifier,
        'epochs': in_arg.epochs,
        'optimizer': optimizer.state_dict(),
        'state_dict': model.state_dict(),
        'class_to_idx': train_data.class_to_idx,
        'train_loss_toprint': train_loss_toprint,
        'test_loss_toprint': valid_loss_toprint,
        'test_accuracy_toprint': valid_accuracy_toprint,
        'criterion': criterion
    }

    if in_arg.save_dir:
        os.mkdir(in_arg.save_dir)
        torch.save(checkpoint, in_arg.save_dir + '/checkpoint.pth')
    else:
        torch.save(checkpoint, 'checkpoint.pth')
def main():
    Input_aruguments = argument_parser()
    print("Chosen Learning rate is {}, Hidden Units is {} and Epochs are {}".
          format(Input_aruguments.learning_rate, Input_aruguments.hidden_units,
                 Input_aruguments.epochs))

    batch_size = 64

    gpu_check = torch.cuda.is_available() and Input_aruguments.gpu
    if gpu_check:
        print("GPU Device available.")
    else:
        warnings.warn(
            "No GPU found. Please use a GPU to train your neural network.")

    print("Data loading started.")
    train_dir = Input_aruguments.data_dir + '/train'
    valid_dir = Input_aruguments.data_dir + '/valid'
    test_dir = Input_aruguments.data_dir + '/test'

    data_transforms = {
        'training_sets':
        transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'validation_sets':
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'testing_sets':
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    # Load the datasets with ImageFolder
    image_datasets = {
        'training_sets':
        datasets.ImageFolder(train_dir,
                             transform=data_transforms['training_sets']),
        'validation_sets':
        datasets.ImageFolder(valid_dir,
                             transform=data_transforms['validation_sets']),
        'testing_sets':
        datasets.ImageFolder(test_dir,
                             transform=data_transforms['testing_sets'])
    }

    # Using the image datasets and the transforms, define the dataloaders
    dataloaders = {
        'training_sets':
        torch.utils.data.DataLoader(image_datasets['training_sets'],
                                    batch_size,
                                    shuffle=True),
        'validation_sets':
        torch.utils.data.DataLoader(image_datasets['validation_sets'],
                                    batch_size,
                                    shuffle=True),
        'testing_sets':
        torch.utils.data.DataLoader(image_datasets['testing_sets'],
                                    batch_size,
                                    shuffle=True)
    }
    print("Data loading completed. Model creation in-progress, please wait.")
    model, optimizer, criterion = model_helper.create_model(
        Input_aruguments.arch, Input_aruguments.hidden_units,
        Input_aruguments.learning_rate,
        image_datasets['training_sets'].class_to_idx)

    print("Model creation completed. Moving to GPU if available, please wait.")
    if gpu_check:
        model.cuda()
        criterion.cuda()

    print("Training started, please wait it might take upto 5 mins.")
    model_helper.train(model, criterion, optimizer, Input_aruguments.epochs,
                       dataloaders['training_sets'],
                       dataloaders['validation_sets'], gpu_check)
    print("Training completed. Saving checkpoints, please wait.")
    model_helper.save_checkpoint(model, optimizer, batch_size,
                                 Input_aruguments.learning_rate,
                                 Input_aruguments.arch,
                                 Input_aruguments.hidden_units,
                                 Input_aruguments.epochs)
    print("Saving checkpoints complete. Validating model, please wait.")
    test_loss, accuracy = model_helper.validate(model, criterion,
                                                dataloaders['testing_sets'],
                                                gpu_check)
    print("Validation Accuracy: {:.3f}".format(accuracy))
    image_path = 'flower_data/test/66/image_05582.jpg'
    print("Predication for: {}".format(image_path))
    probs, classes = model_helper.predict(image_path, model, gpu_check)
    print(probs)
    print(classes)
示例#7
0
def main(args):
    # Directory to save images for using flow_from_directory
    args.output_name = 'split-' + str(args.split_num) + '_nclass-' + str(args.num_classes) + \
                       '_batch-' + str(args.batch_size) +  '_shuff-' + str(args.shuffle_data) + \
                       '_aug-' + str(args.aug_data) + '_loss-' + str(args.loss) + '_lr-' + str(args.initial_lr) + \
                       '_reconwei-' + str(args.recon_wei) + '_attrwei-' + str(args.attr_wei) + \
                       '_r1-' + str(args.routings1) + '_r2-' + str(args.routings2)
    args.time = time

    # Create all the output directories
    args.check_dir = os.path.join(args.data_root_dir, 'saved_models',
                                  args.exp_name, args.net)
    safe_mkdir(args.check_dir)

    args.log_dir = os.path.join(args.data_root_dir, 'logs', args.exp_name,
                                args.net)
    safe_mkdir(args.log_dir)

    args.tf_log_dir = os.path.join(args.log_dir, 'tf_logs', args.time)
    safe_mkdir(args.tf_log_dir)

    args.output_dir = os.path.join(args.data_root_dir, 'plots', args.exp_name,
                                   args.net)
    safe_mkdir(args.output_dir)

    args.img_aug_dir = os.path.join(args.data_root_dir, 'logs', 'aug_imgs')
    safe_mkdir(args.img_aug_dir)

    # Load the training, validation, and testing data
    train_imgs, train_masks, train_labels, val_imgs, val_masks, val_labels, test_imgs, test_masks, test_labels = \
        load_data(root=args.data_root_dir, split=args.split_num,
                  k_folds=args.k_fold_splits, val_split=args.val_split)
    print(
        'Found {} 3D nodule images for training, {} for validation, and {} for testing.'
        ''.format(len(train_imgs), len(val_imgs), len(test_imgs)))

    # Resize images to args.resize_shape
    print('Resizing training images to {}.'.format(args.resize_shape))
    train_imgs, train_masks, train_labels = resize_data(
        train_imgs, train_masks, train_labels, args.resize_shape)
    print('Resizing validation images to {}.'.format(args.resize_shape))
    val_imgs, val_masks, val_labels = resize_data(val_imgs, val_masks,
                                                  val_labels,
                                                  args.resize_shape)
    print('Resizing testing images to {}.'.format(args.resize_shape))
    test_imgs, test_masks, test_labels = resize_data(test_imgs, test_masks,
                                                     test_labels,
                                                     args.resize_shape)

    # Create the model
    model_list = create_model(args=args, input_shape=args.resize_shape + [1])
    model_list[0].summary()

    # Run the chosen functions
    if args.train:
        from train import train
        print('-' * 98, '\nRunning Training\n', '-' * 98)
        train(args=args,
              u_model=model_list[0],
              train_samples=(train_imgs, train_masks, train_labels),
              val_samples=(val_imgs, val_masks, val_labels))

    if args.test:
        from test import test
        print('-' * 98, '\nRunning Testing\n', '-' * 98)
        if args.net.find('caps') != -1:
            test(args=args,
                 u_model=model_list[1],
                 test_samples=(test_imgs, test_masks, test_labels))
        else:
            test(args=args,
                 u_model=model_list[0],
                 test_samples=(test_imgs, test_masks, test_labels))

    if args.manip and args.net.find('caps') != -1:
        from manip import manip
        print('-' * 98, '\nRunning Manipulate\n', '-' * 98)
        manip(args=args,
              u_model=model_list[2],
              test_samples=(test_imgs, test_masks, test_labels))

    print('Done.')
示例#8
0
def main():
    start_time = time()

    in_args = get_input_args()

    # Check for GPU
    use_gpu = torch.cuda.is_available() and in_args.gpu

    # Print parameter information
    if use_gpu:
        print("Training on GPU{}".format(
            " with pinned memory" if in_args.pin_memory else "."))
    else:
        print("Training on CPU using {} threads.".format(in_args.num_threads))

    print("Architecture:{}, Learning rate:{}, Hidden Units:{}, Epochs:{}".format(
        in_args.arch, in_args.learning_rate, in_args.hidden_units, in_args.epochs))

    # Get dataloaders for training
    dataloaders, class_to_idx = model_helper.get_dataloders(in_args.data_dir,
                                                            use_gpu,
                                                            in_args.num_workers,
                                                            in_args.pin_memory)

    # Create model
    model, optimizer, criterion = model_helper.create_model(in_args.arch,
                                                            in_args.learning_rate,
                                                            in_args.hidden_units,
                                                            class_to_idx)

    # Move tensors to GPU if available
    if use_gpu:
        model.cuda()
        criterion.cuda()
    else:
        torch.set_num_threads(in_args.num_threads)

    # Train the network
    model_helper.train(model,
                       criterion,
                       optimizer,
                       in_args.epochs,
                       dataloaders['training'],
                       dataloaders['validation'],
                       use_gpu)

    # Save trained model
    if in_args.save_dir:

        # Create save directory if required
        if not os.path.exists(in_args.save_dir):
            os.makedirs(in_args.save_dir)

         # Save checkpoint in save directory
        file_path = in_args.save_dir + '/' + in_args.arch + '_checkpoint.pth'
    else:
        # Save checkpoint in current directory
        file_path = in_args.arch + '_checkpoint.pth'

    model_helper.save_checkpoint(file_path,
                                 model,
                                 optimizer,
                                 in_args.arch,
                                 in_args.learning_rate,
                                 in_args.hidden_units,
                                 in_args.epochs)

    # Get prediction accuracy using test dataset
    test_loss, accuracy = model_helper.validate(
        model, criterion, dataloaders['testing'], use_gpu)
    print("Testing Accuracy: {:.3f}".format(accuracy))

    # Computes overall runtime in seconds & prints it in hh:mm:ss format
    end_time = time()
    utility.print_elapsed_time(end_time - start_time)
示例#9
0
def train(hparams):
    """Build and train the model as specified in hparams"""

    ckptsdir = str(Path(hparams.modeldir, "ckpts"))

    # build training and eval graphs
    train_tuple = create_model(hparams, tf.contrib.learn.ModeKeys.TRAIN)
    eval_tuple = create_model(hparams, tf.contrib.learn.ModeKeys.EVAL)

    with train_tuple.graph.as_default():
        initializer = tf.global_variables_initializer()
        train_tables_initializer = tf.tables_initializer()

    with eval_tuple.graph.as_default():
        local_initializer = tf.local_variables_initializer()
        eval_tables_initializer = tf.tables_initializer()

    # Summary writers
    summary_writer = tf.summary.FileWriter(hparams.modeldir,
                                           train_tuple.graph,
                                           max_queue=25,
                                           flush_secs=30)

    if hparams.saved is not None:
        # load checkpoint
        train_tuple.model.saver.restore(train_tuple.session, hparams.saved)
    else:
        train_tuple.session.run([initializer])

    start_time = process_time()
    # initialize the training dataset
    train_tuple.session.run([train_tables_initializer])
    train_tuple.session.run([train_tuple.iterator.initializer])
    # initialize the eval table only once
    eval_tuple.session.run([eval_tables_initializer])
    # finalize the graph
    train_tuple.graph.finalize()

    profile_next_step = False
    profiled = False
    # Train until the dataset throws an error (at the end of num_epochs)
    while True:
        step_time = []
        try:
            curr_time = process_time()
            if False:
                #if not profiled and profile_next_step:
                print("Running training step with profiling")
                # run profiling
                _, train_loss, global_step, _, summary, metadata = train_tuple.model.\
                        train_with_profile(train_tuple.session, summary_writer)
                # write the metadata out to a chrome trace file
                trace = timeline.Timeline(step_stats=metadata.step_stats)
                with open(hparams.modeldir + "/timeline.ctf.json",
                          "w") as tracefile:
                    tracefile.write(trace.generate_chrome_trace_format())
                profile_next_step = False
                profiled = True
            else:
                _, train_loss, global_step, _, summary = train_tuple.model.train(
                    train_tuple.session)
            step_time.append(process_time() - curr_time)

            # write train summaries
            if global_step == 1:
                summary_writer.add_summary(summary, global_step)
            if global_step % 15 == 0:
                summary_writer.add_summary(summary, global_step)
                print("Step: %d, Training Loss: %f, Avg Sec/Step: %2.2f" %
                      (global_step, train_loss, np.mean(step_time)))

            if global_step % 100 == 0:
                step_time = []
                profile_next_step = True
                # Do one evaluation
                checkpoint_path = train_tuple.model.saver.save(
                    train_tuple.session,
                    ckptsdir + "/ckpt",
                    global_step=global_step)
                print(checkpoint_path)
                eval_tuple.model.saver.restore(eval_tuple.session,
                                               checkpoint_path)
                eval_tuple.session.run(
                    [eval_tuple.iterator.initializer, local_initializer])
                while True:
                    try:
                        eval_loss, eval_acc, eval_summary, _ = eval_tuple.model.eval(
                            eval_tuple.session)
                        # summary_writer.add_summary(summary, global_step)
                    except tf.errors.OutOfRangeError:
                        print("Step: %d, Eval Loss: %f, Eval Accuracy: %f" %
                              (global_step, eval_loss, eval_acc))
                        summary_writer.add_summary(eval_summary, global_step)
                        break

        except tf.errors.OutOfRangeError:
            print("- End of Trainig -")
            break

    # End of training
    summary_writer.close()
    print("Total Training Time: %4.2f" % (process_time() - start_time))
示例#10
0
def main(args):
    #tf.enable_eager_execution()

    #sess = K.get_session()
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    #K.set_session(sess)

    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    if args.pytorch:
        from capsnet_pytorch import CapsNetBasic, CapsNetR3
        model = CapsNetR3() #CapsNetBasic()
    else:
        # Create the model for training/testing/manipulation
        model_list = create_model(args=args, input_shape=net_input_shape)
        print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.pytorch:
        if args.train:
            from train_pytorch import train
            # Run training
            train(args, model, train_list, net_input_shape)
            #train(args, train_list, val_list, model_list[0], net_input_shape)
        
        
    else:
        if args.train:
            from train import train
            # Run training
            train(args, train_list, val_list, model_list[0], net_input_shape)
    
        if args.test:
            from test import test
            # Run testing
            test(args, test_list, model_list, net_input_shape)
    
        if args.manip:
            from manip import manip
            # Run manipulation of segcaps
            manip(args, test_list, model_list, net_input_shape)
示例#11
0
def main():
    start_time = time()

    in_args = get_input_args()

    use_gpu = torch.cuda.is_available() and in_args.gpu

    print("Training on {} using {}".format("GPU" if use_gpu else "CPU",
                                           in_args.arch))

    print("Learning rate:{}, Hidden Units:{}, Epochs:{}".format(
        in_args.learning_rate, in_args.hidden_units, in_args.epochs))

    if not os.path.exists(in_args.save_dir):
        os.makedirs(in_args.save_dir)

    training_dir = in_args.data_dir + '/train'
    validation_dir = in_args.data_dir + '/valid'
    testing_dir = in_args.data_dir + '/test'

    data_transforms = {
        'training':
        transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'validation':
        transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'testing':
        transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    dirs = {
        'training': training_dir,
        'validation': validation_dir,
        'testing': testing_dir
    }

    image_datasets = {
        x: datasets.ImageFolder(dirs[x], transform=data_transforms[x])
        for x in ['training', 'validation', 'testing']
    }

    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=64,
                                       shuffle=True)
        for x in ['training', 'validation', 'testing']
    }

    model, optimizer, criterion = model_helper.create_model(
        in_args.arch, in_args.hidden_units, in_args.learning_rate,
        image_datasets['training'].class_to_idx)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    model_helper.train(model, criterion, optimizer, in_args.epochs,
                       dataloaders['training'], dataloaders['validation'],
                       use_gpu)

    file_path = in_args.save_dir + '/' + in_args.arch + \
        '_epoch' + str(in_args.epochs) + '.pth'

    model_helper.save_checkpoint(file_path, model, optimizer, in_args.arch,
                                 in_args.hidden_units, in_args.epochs)

    test_loss, accuracy = model_helper.validate(model, criterion,
                                                dataloaders['testing'],
                                                use_gpu)
    print("Post load Validation Accuracy: {:.3f}".format(accuracy))
    image_path = 'flowers/test/28/image_05230.jpg'
    print("Predication for: {}".format(image_path))
    probs, classes = model_helper.predict(image_path, model, use_gpu)
    print(probs)
    print(classes)

    end_time = time()
    utility.print_elapsed_time(end_time - start_time)
示例#12
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (
        args.train or args.test or args.manip
    ), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    experiment = Experiment(project_name="general", workspace="simek13")

    # Load the training, validation, and testing data
    dataset = Dataset(args.data_root_dir)
    if any(x in args.data_root_dir for x in ['Spectralis', 'Cirrus']):
        train_list, val_list, test_list = dataset.load_volumes()
        # Get image properties from first image. Assume they are all the same.
        img_shape = sitk.GetArrayFromImage(
            sitk.ReadImage(
                join(args.data_root_dir, 'train', train_list[0][0], 'images',
                     train_list[0][1]))).shape
    else:
        data = dataset.load_data()
        train_list, val_list = train_test_split(data, test_size=0.2)
        img_shape = sitk.GetArrayFromImage(
            sitk.ReadImage(join(args.data_root_dir, 'images',
                                data[0][0]))).shape

    if any(x in args.data_root_dir for x in ['Spectralis', 'Cirrus']):
        net_input_shape = (int(img_shape[0] / 4), int(img_shape[1] / 4), 1)
    else:
        net_input_shape = (int(img_shape[0] / 8), int(img_shape[1] / 8), 1)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    model_list[0].summary()

    args.output_name = 'batch-' + str(args.batch_size) + '_shuff-' + str(args.shuffle_data) + \
                       '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr)
    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)
    if args.test:
        from test import test
        # Run testing
        if any(x in args.data_root_dir for x in ['Spectralis', 'Cirrus']):
            test(args, test_list, model_list, net_input_shape)
        else:
            test(args, val_list, model_list, net_input_shape)
示例#13
0
def main(args):
    # Ensure training and testing are not all turned off
    assert (args.train or args.test), 'Cannot have train or tes as 0'

    # Load the data
    images, ground_truth = read_and_process_data(args.data_root_dir, args.size)

    images_train_val, images_test, g_t_train_val, g_t_test = generate_train_test(
        images, ground_truth, random_num=random.randint(1, 1001))
    images_train, images_val, g_t_train, g_t_val = generate_train_val(
        images_train_val, g_t_train_val, random_num=random.randint(1, 1001))

    show = 0
    if show:
        show_image(images_train, g_t_train)
        show_image(images_val, g_t_val)

    input_shape = (images[0].shape[0], images[0].shape[1], 1)

    analyze = 1
    if analyze:
        analyze_data(g_t_train, g_t_val, g_t_test, ground_truth)
    if analyze:
        analyze_img_data(images_train, images_val, images_test, images)

    if args.num_pat != 0:
        images_train_new = []
        g_t_train_new = []
        for i in range(0, args.num_pat, 1):
            images_train_new.append(images_train[i])
            g_t_train_new.append(g_t_train[i])

        images_train = images_train_new
        g_t_train = g_t_train_new

    # Create the model for training and testing
    # model_list = [0] train_model, [1] eval_model
    model_list = create_model(args=args, input_shape=input_shape)

    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_--slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(
        'D:\Engineering Physics\Skripsi\Program\Ischemic Stroke Segmentation',
        'saved_models')
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(
        'D:\Engineering Physics\Skripsi\Program\Ischemic Stroke Segmentation',
        'logs')
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(
        'D:\Engineering Physics\Skripsi\Program\Ischemic Stroke Segmentation',
        'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(
        'D:\Engineering Physics\Skripsi\Program\Ischemic Stroke Segmentation',
        'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, images_train, images_val, g_t_train, g_t_val,
              model_list[0], input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, [images[6]], [ground_truth[6]], model_list, input_shape)
示例#14
0
def main(args):
    image_shape = (128, 128)
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    im = cv2.imread(join(args.data_root_dir, 'imgs', train_list[0][0]))
    im = cv2.resize(im, image_shape)
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.reshape(gray, (1, gray.shape[0], gray.shape[1]))
    img_shape = gray.shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape, class_num=6)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])
    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        print('training')
        from train import train
        # Run training
        while ([] in train_list):
            train_list.remove([])
        CUDA_VISIBLE_DEVICES = 1
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        while ([] in test_list):
            test_list.remove([])
        for i in range(39,40):
            args.weights_path = f'saved_models/{args.net}/model' + str(i) + '.hdf5'
            test(args, test_list, model_list, net_input_shape, i)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
示例#15
0
文件: main.py 项目: legendhua/SegCaps
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
示例#16
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (
        args.train or args.test or args.manip
    ), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)
    except:
        # Create the training and test splits if not found
        logging.info(
            'No existing training, validate, test files...System will generate it.'
        )
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)

    # Get image properties from first image. Assume they are all the same.
    logging.info('Read image files...%s' %
                 (join(args.data_root_dir, 'imgs', train_list[0][0])))
    # Get image shape from the first image.
    image = sitk.GetArrayFromImage(
        sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0])))
    img_shape = image.shape  #(500,500,4)
    if args.dataset == 'luna16':
        net_input_shape = (img_shape[1], img_shape[2], args.slices)
    else:
        args.slices = 1
        img_shape = (RESOLUTION, RESOLUTION, img_shape[2])
        net_input_shape = (img_shape[0], img_shape[1], args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)

    #     args.output_name = 'sh-' + str(args.shuffle_data) + '_a-' + str(args.aug_data)

    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
示例#17
0
def main(args):
    # Set the dictionary of possible experiments
    if args.experiment == 0:
        args.exp_name = 'HPvsA'
    elif args.experiment == 1:
        args.exp_name = 'HPvsA_SSA'
    elif args.experiment == 2:
        args.exp_name = 'HPvsSSA'
    else:
        raise Exception('Experiment number undefined.')

    # Directory to save images for using flow_from_directory
    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei) + \
                       '_cpre-' + str(args.use_custom_pretrained) + \
                       '_r1-' + str(args.routings1) + '_r2-' + str(args.routings2)
    args.time = time

    args.img_dir = os.path.join(args.root_dir, 'experiment_splits',
                                args.exp_name,
                                'split_{}'.format(args.split_num))
    safe_mkdir(args.img_dir)

    # Create all the output directories
    args.check_dir = os.path.join(args.root_dir, 'saved_models', args.exp_name,
                                  args.net)
    safe_mkdir(args.check_dir)

    args.log_dir = os.path.join(args.root_dir, 'logs', args.exp_name, args.net)
    safe_mkdir(args.log_dir)

    args.img_aug_dir = os.path.join(args.root_dir, 'logs', args.exp_name,
                                    args.net, 'aug_imgs')
    safe_mkdir(args.img_aug_dir)

    args.tf_log_dir = os.path.join(args.log_dir, 'tf_logs', args.time)
    safe_mkdir(args.tf_log_dir)

    args.output_dir = os.path.join(args.root_dir, 'plots', args.exp_name,
                                   args.net)
    safe_mkdir(args.output_dir)

    # Set net input to (None, None, 3) to allow for variable size color inputs
    net_input_shape = [None, None, 3]
    args.crop_shape = [args.crop_hei, args.crop_wid]
    args.resize_shape = [args.resize_hei, args.resize_wid]

    if args.create_images or args.only_create_images:
        # Load the training, validation, and testing data
        train_list, val_list, test_list = load_data(root=args.root_dir,
                                                    exp_name=args.exp_name,
                                                    exp=args.experiment,
                                                    split=args.split_num,
                                                    k_folds=args.k_fold_splits,
                                                    val_split=args.val_split)
        print(
            'Found {} patients for training, {} for validation, and {} for testing. Note: For patients with more '
            'than one polyp of the same type, all images for the type are placed into either the training or testing '
            'set together.'.format(len(train_list), len(val_list),
                                   len(test_list)))

        # Split data for flow_from_directory
        train_samples, train_shape, val_samples, val_shape, test_samples, test_shape = \
            split_data_for_flow(root=args.root_dir, out_dir=args.img_dir, exp_name=args.exp_name,
                                resize_option=args.form_batches, resize_shape=args.resize_shape,
                                train_list=train_list, val_list=val_list, test_list=test_list)
    else:
        train_imgs = glob(os.path.join(args.img_dir, 'train', '*', '*.jpg'))
        assert train_imgs, 'No images found. Please set --create_images to 1 to check your --data_root_path.'
        train_shape = list(cv2.imread(train_imgs[0]).shape[:2])
        train_samples = len(train_imgs)
        val_samples = len(glob(os.path.join(args.img_dir, 'val', '*',
                                            '*.jpg')))
        test_samples = len(
            glob(os.path.join(args.img_dir, 'test', '*', '*.jpg')))

    if args.only_create_images:
        print('Finished creating images, exiting.')
        exit(0)

    if args.resize_shape[0] is not None:
        train_shape[0] = args.resize_shape[0]
    if args.resize_shape[1] is not None:
        train_shape[1] = args.resize_shape[1]

    train_shape = val_shape = test_shape = (train_shape[0] // (2**6) * (2**6),
                                            train_shape[1] // (2**6) * (2**6)
                                            )  # Assume 6 downsamples
    net_input_shape = (train_shape[0], train_shape[1], net_input_shape[2])
    model_list = create_model(args=args, input_shape=net_input_shape)
    model_list[0].summary()

    # Run the chosen functions
    if args.train:
        from train import train
        # Run training
        train(args=args,
              u_model=model_list[0],
              train_samples=train_samples,
              val_samples=val_samples,
              train_shape=train_shape,
              val_shape=val_shape)

    if args.test:
        from test import test
        # Run testing
        test_model = (model_list[1]
                      if args.net.find('caps') != -1 else model_list[0])
        test(args=args,
             u_model=test_model,
             val_samples=val_samples,
             val_shape=val_shape,
             test_samples=test_samples,
             test_shape=test_shape)

    if args.manip and args.net.find('caps') != -1:
        from manip import manip
        # Run manipulation of d-caps
        manip(args, test_list, model_list[2])

    if args.pred:
        try:
            with open(
                    os.path.join(args.root_dir, 'split_lists',
                                 'pred_split_' + str(args.split_num) + '.csv'),
                    'r') as f:
                reader = csv.reader(f)
                pred_list = list(reader)
            from predict import predict
            predict(args, pred_list, model_list, net_input_shape)
        except Exception as e:
            print(e)
            print(
                'Unable to load prediction list inside main.py, skipping Predict.'
            )
示例#18
0
# Script for training the model. Training and validation images should be be in subfolders of directory specified in paths below.
# Each subfolder name represents the name of the class.
train_dir = "D:\Data\TRAIN"
validation_dir = "D:\Data\TEST"

# saving paths
trained_model_path = "model/test2-L2/model_arh.json"
trained_model_weights_path = "model/test2-L2/"

train_batchsize = 20
val_batchsize = 5
target_image_size = 128
num_of_epochs = 16

# change architecture in model helper if needed
model = model_helper.create_model(n_classes=25)

# TODO load all accuracy and loss infos from previous training if loading weights
#model.load_weights('model_w_102.h5')

# Show a summary of the model. Check the number of trainable parameters
#model.summary()

train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   rotation_range=20,
                                   width_shift_range=0.25,
                                   height_shift_range=0.25,
                                   horizontal_flip=True,
                                   fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
示例#19
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (
        args.train or args.test or args.manip
    ), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=args.num_splits)
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)

    # Get image properties from first image. Assume they are all the same.
    print(train_list)
    img_shape = sitk.GetArrayFromImage(
        sitk.ReadImage(join(args.data_root_dir, 'imgs',
                            train_list[0][0]))).shape
    print("Shape: " + str(img_shape))

    args.modalities = 1
    net_input_shape = (img_shape[1], img_shape[2],
                       args.slices * args.modalities)
    if args.dataset == 'brats':
        args.modalities = 4
        net_input_shape = (256, 256, args.slices * args.modalities)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args,
              train_list,
              val_list,
              model_list[0],
              net_input_shape,
              num_output_classes=args.out_classes)
        args.weights_path = ''

    if args.test:
        if args.dataset == 'luna':
            from test import test
            test(args, test_list, model_list, net_input_shape)
        else:
            from test_multiclass import test
            test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
示例#20
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    #img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    im=cv2.imread(join(args.data_root_dir, 'imgs', train_list[0][0]))
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    #im = cv2.imread(join(img_path, img_name), -1)
    #gray = im
    gray=np.reshape(gray, (1,gray.shape[0],gray.shape[1]))
    img_shape = gray.shape
    net_input_shape = (512, 512, args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])
    model_list[0].load_weights('liver_data3/saved_models_original/segcapsr3/model5.hdf5')
    #print('model20')
    #model_list[0].load_weights('C:/Users/212673708/Documents/TAU/semesterB/DeepLearningInMedicalImaging/modelLiver.hdf5')
    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        print('training')
        from train import train
        # Run training
        while ([] in train_list):
            train_list.remove([])
        CUDA_VISIBLE_DEVICES = 1
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        while ([] in test_list):
            test_list.remove([])
        i=5
        args.weights_path = 'saved_models_original/segcapsr3/model'+str(i)+'.hdf5'
        test(args, test_list, model_list, net_input_shape,i)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.pred), \
        'Cannot have train, test, and pred all set to 0, Nothing to do.'

    # Set the output name to the experiment name, testing split, and all relevant hyperparameters for easy reference.
    args.output_name = 'exp-{}_split-{}_batch-{}_shuff-{}_aug-{}_loss-{}_slic-{}_sub-{}_strid-{}_lr-{}_' \
                       'dpre-{}_fbase-{}_dis-{}'.format(
                        args.exp_name, args.split_num, args.batch_size, args.shuffle_data, args.aug_data, args.loss,
                        args.slices, args.subsamp, args.stride, args.initial_lr, args.use_default_pretrained,
                        args.freeze_base_weights, args.disentangled)
    args.time = time

    # Set the name for the fusion strategy
    if args.disentangled:
        fusion_type = 'disen'
    else:
        fusion_type = 'early'

    # Create the directory for saving model checkpoints
    args.check_dir = os.path.join(args.data_root_dir, 'saved_models', args.exp_name, args.net + '_' + fusion_type)
    try:
        os.makedirs(args.check_dir)
    except:
        pass

    # Create the directory for saving csv logs
    args.log_dir = os.path.join(args.data_root_dir, 'logs', args.exp_name, args.net + '_' + fusion_type)
    try:
        os.makedirs(args.log_dir)
    except:
        pass

    # Create the directory for saving TF logs
    args.tf_log_dir = os.path.join(args.log_dir, 'tf_logs')
    try:
        os.makedirs(args.tf_log_dir)
    except:
        pass

    # Create the directory for saving train/test plots
    args.output_dir = os.path.join(args.data_root_dir, 'plots', args.exp_name, args.net + '_' + fusion_type, args.time)
    try:
        os.makedirs(args.output_dir)
    except:
        pass

    # Load images for this split
    all_imgs_list = []
    if (args.train or args.test):
        train_list, val_list, test_list = load_data(root=args.data_root_dir, mod_dirs=args.modality_dir_list,
                                                    exp_name=args.exp_name, split=args.split_num,
                                                    k_folds=args.k_fold_splits, val_split=args.val_split,
                                                    rand_seed=args.rand_seed)

        # Print the images selected for validation
        all_imgs_list = all_imgs_list + list(train_list) + list(val_list) + list(test_list)
        print('\nFound a total of {} images with lables.'.format(len(all_imgs_list)))
        print('\t{} images for training.'.format(len(train_list)))
        print('\t{} images for validation.'.format(len(val_list)))
        print('\t{} images for testing.'.format(len(test_list)))
        print('\nRandomly selected validation images:')
        print(val_list)
        print('\n')

    # If the user selected to do predictions (no GT), load those prediction images
    if args.pred:
        with open(os.path.join(args.data_root_dir, 'split_lists', args.exp_name,
                               'pred_split_{}.csv'.format(args.split_num)), 'r') as f:
            reader = csv.reader(f)
            pred_list = list(reader)
        all_imgs_list = all_imgs_list + list(pred_list)

    # This creates all images up front instead of dynamically during training. Beneficial if paying for GPU hours.
    if args.create_all_imgs:
        print('-' * 98, '\nCreating all images... This will take some time.\n', '-' * 98)
        for img_pairs in tqdm(all_imgs_list):
            _, _ = convert_data_to_numpy(root_path=args.data_root_dir, img_names=img_pairs,
                                         mod_dirs=args.modality_dir_list, exp_name=args.exp_name,
                                         no_masks=False, overwrite=True)

    # Set the network input shape depending on using a 2D or 3D network.
    if args.net.find('3d') != -1 or args.net.find('inflated') != -1:
        net_input_shape = [None, None, args.slices, 3 * (len(args.modality_dir_list.split(',')) - 1)]
    else:
        net_input_shape = [None, None, 3 * (len(args.modality_dir_list.split(',')) - 1)]

    # Create the model for training/testing/manipulation
    train_shape = sitk.ReadImage(os.path.join(args.data_root_dir, all_imgs_list[0][0])).GetSize()[:-1]
    args.resize_shape = [args.resize_hei, args.resize_wid]
    if args.resize_shape[0] is not None:
        train_shape[0] = args.resize_shape[0]
    if args.resize_shape[1] is not None:
        train_shape[1] = args.resize_shape[1]

    # Update the network input shape based on our data size
    train_shape = (train_shape[0] // (2 ** 6) * (2 ** 6), train_shape[1] // (2 ** 6) * (2 ** 6))  # Assume 6 downsamples
    net_input_shape[0] = train_shape[0]
    net_input_shape[1] = train_shape[1]
    model = create_model(args=args, input_shape=net_input_shape)

    # Print the model summary (try except is for Keras version compatibility)
    try:
        from keras.utils import print_summary
        print_summary(model=model, positions=[.38, .69, .8, 1.])
    except:
        model.summary()

    if args.train:
        # Run training
        print('-'*98,'\nRunning Training\n','-'*98)
        from train_class import train
        train(args, train_list, val_list, model, net_input_shape)

    if args.test:
        # Run testing
        print('-'*98,'\nRunning Testing\n','-'*98)
        from test_class import test
        test(args, test_list, model, net_input_shape)

    if args.pred:
        # Run prediction on new data
        print('-'*98,'\nRunning Prediction\n','-'*98)
        from predict_class import predict
        predict(args, pred_list, model, net_input_shape)