Пример #1
0
    def test(self):
        self.batch_size = 1
        test_images = tf.placeholder(dtype=tf.float32,
                                     shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])

        test_gt = tf.placeholder(dtype=tf.float32,
                                 shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])

        u_test, _, _, _, _, _, _, _, _ = self.PD_v2(test_images, reuse=tf.AUTO_REUSE, training=False)

        saver = tf.train.Saver()

        with tf.name_scope('All_Metrics'):

            dice_coe_test = dice_coe(output=u_test,target=test_gt)
            iou_coe_test = iou_coe(output=u_test,target=test_gt)

        with tf.Session() as sess:

            sess.run(tf.global_variables_initializer())

            if self.restore:
                saver.restore(sess, self.checkpoint)

            t_f, v_f = self.create_list(data_path=self.data_dir)
            data_test = self.load_data_valid(data_path=self.data_dir, list=v_f)


            dice_test_list = []
            iou_test_list = []
            for test_iter in tqdm(range(self.num_test_samples)):
                test_i, test_m = next(data_test)

                dice_coe_test_value, iou_coe_test_value,  = sess.run([dice_coe_test, iou_coe_test],feed_dict={test_images: test_i, test_gt: test_m})
                dice_test_list.append(dice_coe_test_value)
                iou_test_list.append(iou_coe_test_value)

            if not os.path.exists('Test'):
                os.makedirs('Test')

            np.savez('Test/Test_'+self.model_name+'.npz', dice=dice_test_list, iou=iou_test_list, dice_avg=np.average(dice_test_list), iou_avg=np.average(iou_test_list))


            data = np.load('Test/Test_'+self.model_name+'.npz')
            print(data['dice'])
            print(data['iou'])
            print(data['dice_avg'])
            print(data['iou_avg'])
Пример #2
0
def validation(trained_net, val_set, criterion, device, batch_size):
    '''
    used for valuation during training phase

    params trained_net: trained U-net
    params val_set: validation dataset 
    params criterion: loss function
    params device: cpu or gpu
    '''

    n_val = len(val_set)
    val_loader = val_set.load()

    tot = 0
    acc = 0

    with tqdm(total=n_val, desc='Validation round', unit='patch',
              leave=False) as pbar:
        with torch.no_grad():
            for i, sample in enumerate(val_loader):
                images, segs = sample['image'].to(
                    device=device), sample['seg'].to(device=device)

                preds = trained_net(images)
                val_loss = criterion(preds, segs)
                dice_score = dice_coe(preds.detach().cpu(),
                                      segs.detach().cpu())

                tot += val_loss.detach().item()
                acc += dice_score['avg']

                pbar.set_postfix(
                    **{
                        'validation loss (images)': val_loss.detach().item(),
                        'val_acc': dice_score['avg']
                    })
                pbar.update(images.shape[0])

    return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil(
        n_val / batch_size))
Пример #3
0
def predict(trained_net,
            model_name,
            test_patient,
            root,
            crop_size=98,
            overlap_size=0,
            save_mask=True):
    ''' used for predicting image segmentation after training '''

    # load test image
    patient_name = os.path.basename(test_patient)
    modality_dir = os.listdir(test_patient)
    image = []
    for modality in modality_dir:
        if modality != patient_name + '_seg.nii.gz':
            path = os.path.join(test_patient, modality)
            img = sitk.GetArrayFromImage(sitk.ReadImage(path))
            image.append(img)
        else:
            path = os.path.join(test_patient, modality)
            seg = sitk.ReadImage(path)
            seg_arr = sitk.GetArrayFromImage(seg)

    image = np.stack(image)  # C*D*H*W

    # model inference
    trained_net.eval()
    image_shape = image.shape[-3:]
    crop_info = crop_index_gen(image_shape=image_shape,
                               crop_size=crop_size,
                               overlap_size=overlap_size)

    image_patches = image_crop(image, crop_info, norm=True, ToTensor=True)

    cropped_image_list = np.zeros_like(image_patches.cpu().numpy())

    with torch.no_grad():
        with tqdm(total=len(cropped_image_list),
                  desc='inference test image',
                  unit='patch') as pbar:
            for i, image in enumerate(image_patches):
                image = image.unsqueeze(dim=0)
                preds = trained_net(image)

                cropped_image_list[i, ...] = preds.squeeze(
                    0).detach().cpu().numpy()
                pbar.update(1)

    crop_index = crop_info['index_array']
    rebuild_four_channels = image_rebuild(crop_index, cropped_image_list)
    inferenced_mask = inference_output(rebuild_four_channels)

    # calcualte DSC
    target = torch.from_numpy(seg_arr).unsqueeze(0)
    pred = torch.from_numpy(inferenced_mask).unsqueeze(0)
    pred = to_one_hot(pred)
    dsc = dice_coe(pred, target)
    print('DSC by label of this image is: ', dsc)

    # plot predicted segmentation
    plt.figure(figsize=(20, 10))
    ground_truth = seg_arr[image_shape[0] // 2]
    predicted = inferenced_mask[image_shape[0] // 2]
    image_list = [ground_truth, predicted]

    subtitles = ['ground truth', 'predicted']
    plt.subplots_adjust(wspace=0.3)

    for i in range(1, 3):
        ax = plt.subplot(1, 2, i)
        ax.set_title(subtitles[i - 1])
        sns.heatmap(image_list[i - 1],
                    vmin=0,
                    vmax=4,
                    xticklabels=False,
                    yticklabels=False,
                    square=True,
                    cmap='coolwarm',
                    cbar=True)

    # save prediction
    save_path = os.path.join(root, 'prediction_results', patient_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    mask = sitk.GetImageFromArray(inferenced_mask.astype(np.int16))
    if save_mask:
        plt.savefig(
            os.path.join(
                save_path,
                '{}_2D_prediction_{}.png'.format(patient_name, model_name)))

        mask.CopyInformation(seg)
        sitk.WriteImage(
            mask,
            os.path.join(
                save_path,
                '{}_2D_prediction_{}.nii.gz'.format(patient_name, model_name)))
Пример #4
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_loss_fn = args.loss_fn

    config_file = 'config.yaml'

    config = load_config(config_file)
    data_root = config['PATH']['data_root']
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['root']
    model_dir = config['PATH']['model_path']
    best_dir = config['PATH']['best_model_path']

    data_class = config['PATH']['data_class']
    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    base_channel = int(config['PARAMETERS']['base_channels'])
    crop_size = int(config['PARAMETERS']['crop_size'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    epochs = int(config['PARAMETERS']['epoch'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    ignore_idx = int(config['PARAMETERS']['ignore_index'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_data', model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    # split dataset
    dir_ = os.path.join(data_root, data_class)
    data_content = train_split(dir_)

    # load training set and validation set
    train_set = data_loader(data_content=data_content,
                            key='train',
                            form='LGG',
                            crop_size=crop_size,
                            batch_size=batch_size,
                            num_works=8)
    n_train = len(train_set)
    train_loader = train_set.load()

    val_set = data_loader(data_content=data_content,
                          key='val',
                          form='LGG',
                          crop_size=crop_size,
                          batch_size=batch_size,
                          num_works=8)

    logger.info('Dataset loading finished!')

    n_val = len(val_set)
    nb_batches = np.ceil(n_train / batch_size)
    n_total = n_train + n_val
    logger.info(
        '{} images will be used in total, {} for trainning and {} for validation'
        .format(n_total, n_train, n_val))

    net = init_U_Net(input_modalites, output_channels, base_channel)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs available.'.format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    net.to(device)

    if model_loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'CrossEntropy':
        criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'FocalLoss':
        criterion = FocalLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_CE':
        criterion = Dice_CE(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_FL':
        criterion = Dice_FL(labels=labels, ignore_index=ignore_idx)
    else:
        raise NotImplementedError()

    optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'BG_acc': [],
        'NET_acc': [],
        'ED_acc': [],
        'ET_acc': []
    }

    if is_resume:
        try:
            ckp_path = os.path.join(model_dir,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # setup to train mode
        net.train()
        running_loss = 0
        dice_coeff_bg = 0
        dice_coeff_net = 0
        dice_coeff_ed = 0
        dice_coeff_et = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = net(images)

                loss = criterion(outputs, segs)
                # loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

                # save the output at the begining of each epoch to visulize it
                if i == 0:
                    in_images = images.detach().cpu().numpy()[:, 0, ...]
                    in_segs = segs.detach().cpu().numpy()
                    in_pred = outputs.detach().cpu().numpy()
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=model_name,
                                 epoch=epoch + 1)

                running_loss += loss.detach().item()
                dice_score = dice_coe(outputs.detach().cpu(),
                                      segs.detach().cpu())
                dice_coeff_bg += dice_score['BG']
                dice_coeff_ed += dice_score['ED']
                dice_coeff_et += dice_score['ET']
                dice_coeff_net += dice_score['NET']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training (avg) accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    # validate
                    net.eval()
                    val_loss, val_acc = validation(net, val_set, criterion,
                                                   device, batch_size)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['BG_acc'].append(dice_coeff_bg / nb_batches)
        train_info['NET_acc'].append(dice_coeff_net / nb_batches)
        train_info['ED_acc'].append(dice_coeff_ed / nb_batches)
        train_info['ET_acc'].append(dice_coeff_et / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)

        if min_loss > val_loss:
            min_loss = val_loss
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / (np.ceil(n_train / batch_size))))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy: {}'.format(
                val_loss, val_acc))
        logger.info('The best validation loss till now is {}'.format(min_loss))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)

        loss_plot(train_info_file, name=model_name)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    logger.info('finish training!')
Пример #5
0
    def train(self):
        is_random = tf.placeholder(dtype=tf.bool,name='is_random')

        train_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])
        validation_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])
        test_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])

        train_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])
        validation_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])
        test_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH])

        train_pack = tf.concat([train_images,train_gt],axis=-1)
        train_pack = tf.cond(is_random,lambda:tf.image.rot90(train_pack), lambda: train_pack)
        train_pack = tf.image.random_flip_up_down(train_pack)
        train_pack = tf.image.random_flip_left_right(train_pack)
        # train_pack = tf.image.random_crop(train_pack)
        train_images_da = tf.expand_dims(train_pack[:,:,:,0],-1)
        train_gt_da = tf.expand_dims(train_pack[:,:,:,1],-1)

        u_train, td, tp, lda, alpha, delta, sigma, f, th = self.PD_v2(train_images_da,  training = True)
        u_validation, _, _, _, _, delta_v, _, _, th_v= self.PD_v2(validation_images, reuse=True, training= False)
        u_test, _, _, _, _, _, _, _, _= self.PD_v2(test_images, reuse=True, training= False)

        with tf.variable_scope('Regularization'):
            TV2D = tf.reduce_mean(self.TV2D(th))

        with tf.name_scope('Bayesian_Weights'):
            s1 = tf.Variable(tf.sqrt(0.5),name='s1')
            s2 = tf.Variable(0.1,name='s2')

            c1 = tf.div(1.,2.*(tf.square(s1)),name='c1')
            c2 = tf.div(1.,2.*(tf.square(s2)),name='c2')

        # train_gt_flat = tf.reshape(train_gt,[self.batch_size,-1])
        # u_train_flat = tf.reshape(u_train,[self.batch_size,-1])
        # train_gt_flat = tf.concat([train_gt_flat,1.0-train_gt_flat],axis=-1)
        # u_train_flat = tf.concat([u_train_flat,1.0-u_train_flat],axis=-1)
        #train_loss = tf.losses.softmax_cross_entropy(onehot_labels=train_gt_flat,logits=u_train_flat)
        # loss_l2 = tf.losses.mean_squared_error(labels=train_gt,predictions=u_train)
        # u_train = delta
        loss_l2 = tf.losses.mean_squared_error(labels=train_gt_da,predictions=u_train)
        U2 = tf.reduce_mean(tf.square(delta))
        train_loss =  c1*tf.reduce_mean(1.-dice_coe(output=u_train,target=train_gt_da)) + s1 + 0.001*U2
        # train_loss = c1*loss_l2 + s1 + 0.001*U2
        # u_train = tf.maximum(0.0, tf.minimum(1., u_train))


        # validation_loss = tf.losses.mean_squared_error(labels=validation_gt,predictions=u_validation)
        # u_validation = tf.maximum(0.0, tf.minimum(1., u_validation))
        validation_loss = 1.-dice_coe(output=u_validation,target=validation_gt)

        test_loss = 1.-dice_coe(output=u_test,target=test_gt)

        global_step = tf.Variable(0, trainable=False)
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        optimizer = optimizer_opt(self,loss=train_loss,var_list=var_list,global_step=global_step,show_gradients=True)

        saver = tf.train.Saver()

        with tf.name_scope('All_Metrics'):
            dice_coe_t = dice_coe(output=u_train,target=train_gt)
            iou_coe_t = iou_coe(output=u_train,target=train_gt)

            dice_coe_v = dice_coe(output=u_validation,target=validation_gt)
            iou_coe_v = iou_coe(output=u_validation,target=validation_gt)

            dice_coe_test = dice_coe(output=u_test,target=test_gt)
            iou_coe_test = iou_coe(output=u_test,target=test_gt)
        
        # Summaries
        random_slice = tf.placeholder(dtype=tf.int32)

        with tf.name_scope('Training'):
            with tf.name_scope('Bayesian_weights'):
                tf.summary.scalar('c1',c1)
                # tf.summary.scalar('c2',c2)

            with tf.name_scope('Metrics'):
                tf.summary.scalar('Dice_coe',dice_coe_t)
                tf.summary.scalar('Iou_coe',iou_coe_t)

            with tf.name_scope('Images'):
                tf.summary.image('Input', train_images_da)

                out = u_train
                out = tf.cast(tf.reshape(out, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32)
                tf.summary.image('Output', out)

                sum_mask = train_gt_da
                sum_mask = tf.cast(tf.reshape(sum_mask, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32)
                tf.summary.image('GT', sum_mask)

                tf.summary.image('delta', delta)
                tf.summary.image('th', th)

            with tf.name_scope('Losses'):
                tf.summary.scalar('Loss',train_loss)
                tf.summary.scalar('TV2D_Reg',TV2D)
                tf.summary.scalar('U2_Reg',U2)

            with tf.name_scope('Parameters'):
                tf.summary.scalar('tp', tp)
                tf.summary.scalar('td', td)
                tf.summary.scalar('Lambda', lda[0,0,0,0])
                tf.summary.scalar('alpha', alpha[0,0,0,0])
                tf.summary.scalar('sigma', sigma)

        summary_train = tf.summary.merge_all()

        with tf.name_scope('Validation'):
            with tf.name_scope('Losses'):
                vl = tf.summary.scalar('Loss',validation_loss)
            with tf.name_scope('Metrics'):
                dv = tf.summary.scalar('Dice_coe',dice_coe_v)
                iv = tf.summary.scalar('Iou_coe',iou_coe_v)

            with tf.name_scope('Images'):
                inv = tf.summary.image('Input', validation_images)

                out = u_validation
                out = tf.cast(tf.reshape(out, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32)
                outv = tf.summary.image('Output', out)

                sum_mask = validation_gt
                sum_mask = tf.cast(tf.reshape(sum_mask, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32)
                gtv = tf.summary.image('GT', sum_mask)

                delv = tf.summary.image('delta', delta_v)
                thv = tf.summary.image('th', th_v)

        summary_validation = tf.summary.merge(inputs=[vl, iv, dv, outv, gtv, inv, delv, thv])

        with tf.name_scope('Test'):
            dice_test = tf.placeholder(dtype=tf.float32)
            iou_test = tf.placeholder(dtype=tf.float32)
            with tf.name_scope('Metrics'):
                d_test = tf.summary.scalar('Dice_coe',dice_test)
                i_test = tf.summary.scalar('Iou_coe',iou_test)


        summary_test = tf.summary.merge(inputs=[d_test, i_test])

        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
        with tf.Session() as sess:
            summary_writer_t = tf.summary.FileWriter(self.log_dir+'/train', sess.graph,filename_suffix='train')
            summary_writer_v = tf.summary.FileWriter(self.log_dir+'/validation', sess.graph,filename_suffix='valid')
            summary_writer_test = tf.summary.FileWriter(self.log_dir+'/test', sess.graph,filename_suffix='test')

            sess.run(tf.global_variables_initializer())

            if self.restore:
                saver.restore(sess, self.checkpoint)
                counter = np.load('counter.npy')
                epocas = np.load('epocas.npy')
            else:
                counter = 0
                epocas = 0

            t_f, v_f = self.create_list(data_path=self.data_dir)

            data_train = self.load_data_train(data_path=self.data_dir, list=t_f)
            data_valid = self.load_data_valid(data_path=self.data_dir, list=v_f)
            # data_test = self.load_data_test(data_path=self.data_dir, list=v_f)
            for i in tqdm(range(self.training_iters-counter)):

                if (counter % self.epoch_iteration)==0 or counter==0:
                    print('Shuffle Data')
                    t_f = self.shuffle_data(t_f)
                    v_f = self.shuffle_data(v_f)
                    data_train = self.load_data_train(data_path=self.data_dir, list=t_f)
                    data_valid = self.load_data_valid(data_path=self.data_dir, list=v_f)

                if not((counter % (self.epoch_iteration/10))==0):
                    ti, tm = next(data_train)

                    _= sess.run([optimizer],feed_dict={train_images: ti, train_gt: tm, is_random: np.random.randint(0,2)})
                else:

                    ti, tm = next(data_train)
                    vi, vm = next(data_valid)

                    _, summary_str_t, summary_str_v, global_step_value = sess.run([optimizer, summary_train, summary_validation, global_step],feed_dict={train_images: ti, train_gt: tm,validation_images: vi, validation_gt: vm , is_random: np.random.randint(0,2)}, options = run_options, run_metadata = run_metadata)
                    summary_writer_t.add_run_metadata(run_metadata, 'step%d' % counter)
                    summary_writer_t.add_summary(summary_str_t, global_step_value)
                    summary_writer_t.flush()
                    summary_writer_v.add_run_metadata(run_metadata, 'step%d' % counter)
                    summary_writer_v.add_summary(summary_str_v, global_step_value)
                    summary_writer_v.flush()


                if (counter % self.epoch_iteration)==0 and not(counter==0):
                    epocas = epocas + 1
                    saver.save(sess, self.checkpoint)
                    np.save('counter.npy',counter)
                    np.save('epocas.npy',epocas)

                if self.is_test:
                    if (counter % (self.test_each_epoch*self.epoch_iteration))==0: # and not(counter==0):
                        dice_test_list = []
                        iou_test_list = []
                        test_list = np.load('./test_list.npy')
                        num_images = len(test_list[0])
                        list_flair = test_list[0]
                        list_gt = test_list[1]
                        for subject in tqdm(list_flair):
                            # vol_value = []
                            # vol_gt = []
                            print ('un volumen')
                            for test_iter in range(155):
                                test_i = np.load(self.data_dir+'validation'+subject+str(test_iter)+'.npy')
                                test_i = np.reshape(test_i,newshape=[1,self.IM_ROWS,self.IM_COLS,1])
                                test_m = np.load(self.data_dir+'validation'+'/gt'+subject[3:]+str(test_iter)+'.npy')
                                test_m = np.reshape(test_m,newshape=[1,self.IM_ROWS,self.IM_COLS,1])

                                slice_value = sess.run([u_test],
                                    feed_dict={test_images: test_i, test_gt: test_m}, options=run_options, run_metadata=run_metadata)

                                if test_iter == 0:
                                    vol_value = slice_value
                                    vol_gt = test_m
                                else:
                                    vol_value = np.concatenate([vol_value,slice_value],axis=-1)
                                    vol_gt = np.concatenate([vol_gt,test_m],axis=-1)

                                # vol_value.append(slice_value[0][0,:,:,:])
                                # vol_gt.append(test_m[0,:,:,:])

                            vol_value = np.reshape(vol_value,[self.IM_ROWS,self.IM_COLS,155])
                            vol_gt = np.reshape(vol_gt,[self.IM_ROWS,self.IM_COLS,155])
                            dice_coe_test_value = self.np_dice(vol_value,vol_gt)
                            iou_coe_test_value = self.np_iou(vol_value,vol_gt)
                            dice_test_list.append(dice_coe_test_value)
                            iou_test_list.append(iou_coe_test_value)

                        summary_str_test, global_step_value = sess.run([summary_test, global_step],feed_dict={dice_test: np.mean(dice_test_list), iou_test: np.mean(iou_test_list)})
                        summary_writer_test.add_run_metadata(run_metadata, 'step%d' % epocas)
                        summary_writer_test.add_summary(summary_str_test, global_step_value)
                        summary_writer_test.flush()

                counter += 1
            sess.close()