def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model, loss_weights = compile_model(args=args, train_list=train_list, net_input_shape=net_input_shape,
                                        uncomp_model=u_model)

    # Load pre-trained weights
    if args.custom_weights_path != '':
        try:
            model.load_weights(args.custom_weights_path)
        except Exception as e:
            print(e)
            print('!!! Failed to load weights file. Training without pre-training weights. !!!')

    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit_generator(
        generate_train_batches(root_path=args.data_root_dir, train_list=train_list, net_shape=net_input_shape,
                               mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net,
                               MIP_choices=args.MIP_choices, n_class=args.num_classes, batchSize=args.batch_size,
                               numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride,
                               shuff=args.shuffle_data, aug_data=args.aug_data),
        max_queue_size=40, workers=4, use_multiprocessing=False,
        steps_per_epoch=int(np.ceil(len(train_list)/args.batch_size*12)), # 12 avg. num of loops in train generator
        validation_data=generate_val_batches(root_path=args.data_root_dir, val_list=val_list, net_shape=net_input_shape,
                                             mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net,
                                             MIP_choices=args.MIP_choices, n_class=args.num_classes,
                                             batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0,
                                             stride=args.stride, shuff=args.shuffle_data),
        validation_steps=int(np.ceil(len(val_list)/args.batch_size)),
        epochs=args.epochs, class_weight=loss_weights, callbacks=callbacks, verbose=args.verbose)

    # Plot the training data collected
    plot_training(history, args.net, args.num_classes, args.output_dir, args.output_name, args.time)
Example #2
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    num_epoch = 400
    model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model)
    loss_vec = np.zeros((num_epoch, 1))
    dice_hard_vec = np.zeros((num_epoch, 1))
    val_loss_vec = np.zeros((num_epoch, 1))
    val_dice_hard_vec = np.zeros((num_epoch, 1))
    val_out_seg_loss_vec = np.zeros((num_epoch, 1))
    out_seg_loss_vec = np.zeros((num_epoch, 1))
    val_out_recon_loss_vec = np.zeros((num_epoch, 1))
    out_recon_loss_vec = np.zeros((num_epoch, 1))

    # Set the callbacks
    for i in range(num_epoch):
        print(i)
        callbacks = get_callbacks(args, i)

        # Training the network
        train_batches = generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, image_shape=net_input_shape,
                                               batchSize=args.batch_size, numSlices=args.slices,
                                               subSampAmt=args.subsamp,
                                               stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data)
        val_data = generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net,
                                        batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0,
                                        stride=20, shuff=args.shuffle_data)
        history = model.fit_generator(train_batches.it, max_queue_size=40, workers=4, use_multiprocessing=False,
                                      steps_per_epoch=247, validation_data=val_data.it, validation_steps=200,
                                      callbacks=callbacks, verbose=1)

        # Plot the training data collected
        if args.net == 'segcapsr3':
            # Plot the training data collected
            print(history.history.keys())
            loss_vec[i] = history.history['loss'][0]
            dice_hard_vec[i] = history.history['out_seg_dice_hard'][0]
            val_loss_vec[i] = history.history['val_loss'][0]
            val_dice_hard_vec[i] = history.history['val_out_seg_dice_hard'][0]
            val_out_seg_loss_vec[i] = history.history['val_out_seg_loss'][0]
            out_seg_loss_vec[i] = history.history['out_seg_loss'][0]
            val_out_recon_loss_vec[i] = history.history['val_out_recon_loss'][0]
            out_recon_loss_vec[i] = history.history['out_recon_loss'][0]
            file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+")
            file2.writelines(
                [str(loss_vec), str(dice_hard_vec), str(out_seg_loss_vec), str(out_recon_loss_vec), str(val_loss_vec),
                 str(val_dice_hard_vec), str(val_out_seg_loss_vec), str(val_out_recon_loss_vec)])
            file2.close()
        else:
            loss_vec[i] = history.history['loss'][0]
            dice_hard_vec[i] = history.history['dice_hard'][0]
            val_loss_vec[i] = history.history['val_loss'][0]
            val_dice_hard_vec[i] = history.history['val_dice_hard'][0]
            file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+")
            file2.writelines([str(loss_vec), str(dice_hard_vec), str(val_loss_vec), str(val_dice_hard_vec)])
            file2.close()
    plot_training(loss_vec, dice_hard_vec, val_loss_vec, val_dice_hard_vec, args, num_epoch)
Example #3
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    weights_path = args.weights_path
    if args.retrain == 1:
        print('\nRetrain model from weights_path=%s' % (weights_path))
        model.load_weights(weights_path)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit(
        generate_train_batches(args.data_root_dir,
                               train_list,
                               net_input_shape,
                               net=args.net,
                               batchSize=args.batch_size,
                               numSlices=args.slices,
                               subSampAmt=args.subsamp,
                               stride=args.stride,
                               shuff=args.shuffle_data,
                               aug_data=args.aug_data),
        max_queue_size=40,
        workers=4,
        use_multiprocessing=False,
        steps_per_epoch=args.steps,
        validation_data=generate_val_batches(args.data_root_dir,
                                             val_list,
                                             net_input_shape,
                                             net=args.net,
                                             batchSize=args.batch_size,
                                             numSlices=args.slices,
                                             subSampAmt=0,
                                             stride=20,
                                             shuff=args.shuffle_data),
        validation_steps=
        250,  # Set validation stride larger to see more of the data.
        epochs=args.epochs,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
Example #4
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)
    # print(callbacks)
    train_batches = generate_train_batches(args.data_root_dir,
                                           train_list,
                                           net_input_shape,
                                           net=args.net,
                                           batchSize=args.batch_size,
                                           numSlices=args.slices,
                                           subSampAmt=args.subsamp,
                                           stride=args.stride,
                                           shuff=args.shuffle_data,
                                           aug_data=args.aug_data)

    val_batches = generate_val_batches(args.data_root_dir,
                                       val_list,
                                       net_input_shape,
                                       net=args.net,
                                       batchSize=args.batch_size,
                                       numSlices=args.slices,
                                       subSampAmt=0,
                                       stride=20,
                                       shuff=args.shuffle_data)
    print("train_batches..", train_batches)
    # Training the network
    history = model.fit_generator(
        train_batches,
        max_queue_size=40,
        workers=4,
        use_multiprocessing=False,
        steps_per_epoch=10000,
        validation_data=val_batches,
        validation_steps=
        500,  # Set validation stride larger to see more of the data.
        epochs=200,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
Example #5
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit_generator(
        generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net,
                               batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp,
                               stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data),
        max_queue_size=40, workers=4, use_multiprocessing=False,
        steps_per_epoch=10000,
        validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net,
                                             batchSize=args.batch_size,  numSlices=args.slices, subSampAmt=0,
                                             stride=20, shuff=args.shuffle_data),
        validation_steps=500, # Set validation stride larger to see more of the data.
        epochs=200,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit(
        generate_train_batches(args.data_root_dir,
                               train_list,
                               net_input_shape,
                               net=args.net,
                               batch_size=args.batch_size,
                               shuff=args.shuffle_data,
                               aug_data=args.aug_data),
        max_queue_size=40,
        workers=0,
        use_multiprocessing=False,
        steps_per_epoch=ceil(len(train_list) / args.batch_size),
        validation_data=generate_val_batches(args.data_root_dir,
                                             val_list,
                                             net_input_shape,
                                             net=args.net,
                                             batch_size=args.batch_size,
                                             shuff=args.shuffle_data),
        validation_steps=ceil(
            len(val_list) / args.batch_size
        ),  # Set validation stride larger to see more of the data.
        epochs=50,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
Example #7
0
def train(args, model, train_list, net_input_shape):
    """
    # Compile the loaded model
    model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit_generator(
        generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net,
                               batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp,
                               stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data),
        max_queue_size=40, workers=4, use_multiprocessing=False,
        steps_per_epoch=10000,
        validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net,
                                             batchSize=args.batch_size,  numSlices=args.slices, subSampAmt=0,
                                             stride=20, shuff=args.shuffle_data),
        validation_steps=500, # Set validation stride larger to see more of the data.
        epochs=200,
        callbacks=callbacks,
        verbose=1)
    """

    model.cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.initial_lr,
                                 betas=(0.99, 0.999))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           factor=0.05,
                                                           patience=5,
                                                           verbose=True)

    loss, loss_weighting = get_loss(root=args.data_root_dir,
                                    split=args.split_num,
                                    net=args.net,
                                    recon_wei=args.recon_wei,
                                    choice=args.loss)

    recon_loss = nn.MSELoss(reduction='sum')

    fit_generator = generate_train_batches(args.data_root_dir,
                                           train_list,
                                           net_input_shape,
                                           net=args.net,
                                           batchSize=args.batch_size,
                                           numSlices=args.slices,
                                           subSampAmt=args.subsamp,
                                           stride=args.stride,
                                           shuff=args.shuffle_data,
                                           aug_data=args.aug_data)

    factor = 1.

    for i, batch in enumerate(fit_generator):
        x = torch.from_numpy(batch[0][0]).float().permute(
            0, 3, 1, 2)  # -> (1,512,512,1)
        x = Variable(x).cuda()
        x1 = torch.from_numpy(batch[0][1]).float().permute(
            0, 3, 1, 2)  # -> (1,512,512,1)
        x1 = Variable(x1).cuda()
        y = torch.from_numpy(batch[1][0]).float().permute(
            0, 3, 1, 2)  # -> (1,512,512,1)
        y = y.cuda()
        label_recon = torch.from_numpy(batch[1][1]).float().permute(
            0, 3, 1, 2)  # -> (1,512,512,1)
        label_recon = label_recon.cuda()

        optimizer.zero_grad()

        out_seg, out_recon = model(x, x1)

        loss_a = loss['out_seg'](y, out_seg)

        if (i % 100) == 0:
            save_image(x, 'x.jpg')
            save_image(y, 'y.jpg')
            save_image(label_recon, 'label_recon.jpg')
            save_image(out_seg, 'out_seg.jpg')
            save_image(loss_a, 'loss_a.jpg')
            save_image(out_recon, 'out_recon.jpg')

        loss_a = loss_a.sum()
        loss_b = recon_loss(out_recon, label_recon)
        total_loss = loss_a + factor * loss_b

        total_loss.backward()
        optimizer.step()

        print("batch:", i, "loss a:", loss_a.data.item(), "loss b:",
              loss_b.data.item(), "check:",
              out_recon.sum().data.item())
        """