예제 #1
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))

    net = model.net(cf, logger).cuda()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=cf.learning_rate[0],
                                 weight_decay=cf.weight_decay)
    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)

    starting_epoch = 1

    # prepare monitoring
    monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf)

    if cf.resume_to_checkpoint:
        starting_epoch, monitor_metrics = utils.load_checkpoint(
            cf.resume_to_checkpoint, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(
            cf.resume_to_checkpoint, starting_epoch))

    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)

    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cf.learning_rate[epoch - 1]

        start_time = time.time()

        net.train()
        train_results_list = []

        for bix in range(cf.num_train_batches):
            batch = next(batch_gen['train'])
            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            optimizer.step()
            logger.info(
                'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                        time.time() - tic_bw,
                        time.time() - tic_fw) + results_dict['logger_string'])
            train_results_list.append([results_dict['boxes'], batch['pid']])
            monitor_metrics['train']['monitor_values'][epoch].append(
                results_dict['monitor_values'])

        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(
            train_results_list, monitor_metrics['train'])
        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                for _ in range(batch_gen['n_val']):
                    batch = next(batch_gen[cf.val_mode])
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch,
                                                         is_validation=True)
                    val_results_list.append(
                        [results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(
                        results_dict['monitor_values'])

                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(
                    val_results_list, monitor_metrics['val'])
                model_selector.run_model_selection(net, optimizer,
                                                   monitor_metrics, epoch)

            # update monitoring and prediction plots
            TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info(
                'trained epoch {}: took {} sec. ({} train / {} val)'.format(
                    epoch, epoch_time, train_time, epoch_time - train_time))
            batch = next(batch_gen['val_sampling'])
            results_dict = net.train_forward(batch, is_validation=True)
            logger.info('plotting predictions from validation sampling.')
            plot_batch_prediction(batch, results_dict, cf)
예제 #2
0
def train(fold):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """

    # set up experiment dirs and copy config file
    utils.prep_exp(cf)
    logger = utils.get_logger(cf.exp_dir)
    logger.info(
        'performing training in {d}D over fold {f} on experiment {e}'.format(
            d=cf.dim, f=fold, e=cf.exp_dir))
    logger.info('intitializing tensorflow graph...')

    tf.reset_default_graph()
    x = tf.placeholder('float', shape=cf.network_input_shape)
    y = tf.placeholder('float', shape=cf.network_output_shape)
    learning_rate = tf.Variable(cf.learning_rate)
    logits = model.create_UNet(x,
                               cf.n_features_root,
                               cf.n_classes,
                               dim=cf.dim,
                               logger=logger)
    loss = utils._get_loss(logits, y, cf.n_classes, cf.loss_name,
                           cf.class_weights, cf.dim)
    predicter = tf.nn.softmax(logits)
    dice_per_class = utils.get_dice_per_class(logits, y, dim=cf.dim)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(loss)
    saver = tf.train.Saver()

    # prepare monitoring
    metrics = {}
    metrics['train'] = {
        'loss': [None],
        'dices': np.zeros(shape=(1, cf.n_classes))
    }  # CHECK IF THIS WORKS
    metrics['val'] = {
        'loss': [None],
        'dices': np.zeros(shape=(1, cf.n_classes))
    }
    best_metrics = {
        'loss': [10, 0],
        'dices': np.zeros(shape=(cf.n_classes + 1, 2))
    }
    file_name = cf.plot_dir + '/monitor_{}.png'.format(fold)
    TrainingPlot = TrainingPlot_2Panel(cf.n_epochs, file_name,
                                       cf.experiment_name, cf.class_dict)

    logger.info('initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, fold)

    logger.info('starting training...')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(cf.n_epochs):

            start_time = time.time()

            # perform tranining steps
            train_loss_running_mean = 0.
            train_dices_running_batch_mean = np.zeros(shape=(1, cf.n_classes))
            for _ in range(cf.n_train_batches):
                batch = next(batch_gen['train'])
                train_loss, train_dices, _ = sess.run(
                    (loss, dice_per_class, optimizer),
                    feed_dict={
                        x: batch['data'],
                        y: batch['seg']
                    })
                train_loss_running_mean += train_loss / cf.n_train_batches
                train_dices_running_batch_mean += train_dices / cf.n_train_batches
            metrics['train']['loss'].append(train_loss_running_mean)
            metrics['train']['dices'] = np.append(
                metrics['train']['dices'],
                train_dices_running_batch_mean,
                axis=0)

            # perform validation
            val_loss_running_mean = 0.
            val_dices_running_batch_mean = np.zeros(shape=(1, cf.n_classes))
            for _ in range(cf.n_val_batches):
                batch = next(batch_gen['val'])
                val_loss, val_dices = sess.run((loss, dice_per_class),
                                               feed_dict={
                                                   x: batch['data'],
                                                   y: batch['seg']
                                               })
                val_loss_running_mean += val_loss / cf.n_val_batches
                val_dices_running_batch_mean[0] += val_dices / cf.n_val_batches
            metrics['val']['loss'].append(val_loss_running_mean)
            metrics['val']['dices'] = np.append(metrics['val']['dices'],
                                                val_dices_running_batch_mean,
                                                axis=0)

            # evaluate epoch
            val_loss = metrics['val']['loss'][-1]
            val_dices = metrics['val']['dices'][-1]
            if val_loss < best_metrics['loss'][0]:
                best_metrics['loss'][0] = val_loss
                best_metrics['loss'][1] = epoch
            for cl in range(cf.n_classes):
                if val_dices[cl] > best_metrics['dices'][cl][0]:
                    best_metrics['dices'][cl][0] = val_dices[cl]
                    best_metrics['dices'][cl][1] = epoch

            # selection criterion is the averaged dice over both foreground classes
            fg_dice = np.mean(val_dices[1:])
            if fg_dice > best_metrics['dices'][cf.n_classes][0]:
                best_metrics['dices'][cf.n_classes][0] = fg_dice
                best_metrics['dices'][cf.n_classes][1] = epoch
                saver.save(sess,
                           os.path.join(cf.exp_dir, 'params_{}'.format(fold)))

            # update monitoring and prediction plots
            TrainingPlot.update_and_save(metrics, best_metrics)
            batch = next(batch_gen['val'])
            correct_prediction = np.argmax(sess.run(
                (predicter), feed_dict={x: batch['data']}),
                                           axis=-1)
            outfile = cf.plot_dir + '/pred_example_{}.png'.format(
                fold)  #set fold -> epoch to keep plots from all epochs
            plot_batch_prediction(batch['data'],
                                  batch['seg'],
                                  correct_prediction,
                                  cf.n_classes,
                                  outfile,
                                  dim=cf.dim)
            logger.info(
                'trained epoch {e}: val_loss {l}, val_dices: {d}, took {t} sec.'
                .format(e=epoch,
                        l=np.round(val_loss, 3),
                        d=val_dices,
                        t=np.round(time.time() - start_time, 0)))
예제 #3
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))

    writer = SummaryWriter(os.path.join(cf.exp_dir, 'tensorboard'))

    net = model.net(cf, logger).cuda()

    #optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=cf.initial_learning_rate,
                                 weight_decay=cf.weight_decay)

    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)  #val_sampling

    starting_epoch = 1

    # prepare monitoring
    if cf.resume_to_checkpoint:  #default: False
        lastepochpth = cf.resume_to_checkpoint + 'last_checkpoint/'
        best_epoch = np.load(lastepochpth + 'epoch_ranking.npy')[0]
        df = open(lastepochpth + 'monitor_metrics.pickle', 'rb')
        monitor_metrics = pickle.load(df)
        df.close()
        starting_epoch = utils.load_checkpoint(lastepochpth, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(
            cf.resume_to_checkpoint, starting_epoch))
        num_batch = starting_epoch * cf.num_train_batches + 1
        num_val = starting_epoch * cf.num_val_batches + 1
    else:
        monitor_metrics = utils.prepare_monitoring(cf)
        num_batch = 0  #for show loss
        num_val = 0
    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)
    best_train_recall, best_val_recall = 0, 0
    lr_now = cf.initial_learning_rate
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            #param_group['lr'] = cf.learning_rate[epoch - 1]
            print('lr_now', lr_now)
            lr_next = utils.learning_rate_decreasing(
                cf, epoch, lr_now, mode='step')  #cf.learning_rate[epoch - 1]
            print('lr_next', lr_next)
            param_group[
                'lr'] = lr_next  #learning_rate_decreasing(cf,epoch,lr_now,mode='step')#cf.learning_rate[epoch - 1]
            lr_now = lr_next

        start_time = time.time()

        net.train()
        train_results_list = []  #this batch
        train_results_list_seg = []

        for bix in range(cf.num_train_batches):  #200
            num_batch += 1
            batch = next(
                batch_gen['train']
            )  #data,seg,pid,class_target,bb_target,roi_masks,roi_labels
            for ii, i in enumerate(batch['roi_labels']):
                if i[0] > 0:
                    batch['roi_labels'][ii] = [1]
                else:
                    batch['roi_labels'][ii] = [-1]

            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()

            optimizer.zero_grad()
            results_dict['torch_loss'].backward()  #total loss
            optimizer.step()

            if (num_batch) % cf.show_train_images == 0:
                fig = plot_batch_prediction(batch, results_dict, cf, 'train')
                writer.add_figure('/Train/results', fig, num_batch)
                fig.clear()
            print('model', cf.exp_dir.split('/')[-2])
            logger.info(
                'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                        time.time() - tic_bw,
                        time.time() - tic_fw))

            #writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch)
            #writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch)
            #writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch)
            #writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch)
            #writer.add_scalar('Train/fusion_dice_loss',results_dict['monitor_losses']['fusion_loss_dice'],num_batch)

            train_results_list.append([results_dict['boxes'],
                                       batch['pid']])  #just gt and det
            monitor_metrics['train']['monitor_values'][epoch].append(
                results_dict['monitor_losses'])

        count_train = train_evaluator.evaluate_predictions(train_results_list,
                                                           epoch,
                                                           cf,
                                                           flag='train')
        precision = count_train[0] / (count_train[0] + count_train[2] + 0.01)
        recall = count_train[0] / (count_train[3])
        print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(
            count_train[0], count_train[1], count_train[2], count_train[3]))
        print('precision:{}, recall:{}'.format(precision, recall))
        monitor_metrics['train']['train_recall'].append(recall)
        monitor_metrics['train']['train_percision'].append(precision)
        writer.add_scalar('Train/train_precision', precision, epoch)
        writer.add_scalar('Train/train_recall', recall, epoch)
        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                dice_val_seg, dice_val_mask, dice_val_fusion = [], [], []
                for _ in range(batch_gen['n_val']):  #50
                    num_val += 1
                    batch = next(batch_gen[cf.val_mode])
                    print('eval', batch['pid'])
                    for ii, i in enumerate(batch['roi_labels']):
                        if i[0] > 0:
                            batch['roi_labels'][ii] = [1]
                        else:
                            batch['roi_labels'][ii] = [-1]
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(
                            batch)  #result of one patient
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch,
                                                         is_validation=True)
                    if (num_val) % cf.show_val_images == 0:
                        fig = plot_batch_prediction(batch, results_dict, cf,
                                                    cf.val_mode)
                        writer.add_figure('Val/results', fig, num_val)
                        fig.clear()

                    # compute dice for vnet
                    this_batch_seg_label = torch.FloatTensor(
                        mutils.get_one_hot_encoding(
                            batch['seg'], cf.num_seg_classes + 1)).cuda()
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_seg = mutils.dice_val(
                            results_dict['seg_logits'], this_batch_seg_label)
                    else:
                        this_batch_dice_seg = mutils.dice_val(
                            F.softmax(results_dict['seg_logits'], dim=1),
                            this_batch_seg_label)
                    dice_val_seg.append(this_batch_dice_seg)
                    # compute dice for mask
                    #mask_map = torch.from_numpy(results_dict['seg_preds']).cuda()
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_mask = mutils.dice_val(
                            results_dict['seg_preds'], this_batch_seg_label)
                    else:
                        this_batch_dice_mask = mutils.dice_val(
                            F.softmax(results_dict['seg_preds'], dim=1),
                            this_batch_seg_label)
                    dice_val_mask.append(this_batch_dice_mask)
                    # compute dice for fusion
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_fusion = mutils.dice_val(
                            results_dict['fusion_map'], this_batch_seg_label)
                    else:
                        this_batch_dice_fusion = mutils.dice_val(
                            F.softmax(results_dict['fusion_map'], dim=1),
                            this_batch_seg_label)
                    dice_val_fusion.append(this_batch_dice_fusion)

                    val_results_list.append(
                        [results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(
                        results_dict['monitor_values'])

                count_val = val_evaluator.evaluate_predictions(
                    val_results_list, epoch, cf, flag='val')
                precision = count_val[0] / (count_val[0] + count_val[2] + 0.01)
                recall = count_val[0] / (count_val[3])
                print(
                    'tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(
                        count_val[0], count_val[1], count_val[2],
                        count_val[3]))
                print('precision:{}, recall:{}'.format(precision, recall))
                val_dice_seg = sum(dice_val_seg) / float(len(dice_val_seg))
                val_dice_mask = sum(dice_val_mask) / float(len(dice_val_mask))
                val_dice_fusion = sum(dice_val_fusion) / float(
                    len(dice_val_fusion))
                monitor_metrics['val']['val_recall'].append(recall)
                monitor_metrics['val']['val_precision'].append(precision)
                monitor_metrics['val']['val_dice_seg'].append(val_dice_seg)
                monitor_metrics['val']['val_dice_mask'].append(val_dice_mask)
                monitor_metrics['val']['val_dice_fusion'].append(
                    val_dice_fusion)

                writer.add_scalar('Val/val_precision', precision, epoch)
                writer.add_scalar('Val/val_recall', recall, epoch)
                writer.add_scalar('Val/val_dice_seg', val_dice_seg, epoch)
                writer.add_scalar('Val/val_dice_mask', val_dice_mask, epoch)
                writer.add_scalar('Val/val_dice_fusion', val_dice_fusion,
                                  epoch)
                model_selector.run_model_selection(net, optimizer,
                                                   monitor_metrics, epoch)

            # update monitoring and prediction plots
            #TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info(
                'trained epoch {}: took {} sec. ({} train / {} val)'.format(
                    epoch, epoch_time, train_time, epoch_time - train_time))
    writer.close()
예제 #4
0
def test(folds):
    """
    create and evaluate predictions for the held-out test set. Predictions can be averaged over several models
    trained during cross validation (specified by "folds"). Predictions are further averaged over 4 different input
    orientations obtained by mirroring (test-time data augmentation). The averaged softmax predictions for each patient
    are saved out for further statistics and plotted to the cf.test_dir and the final dice scores per class are printed.
    """
    logger = utils.get_logger(cf.exp_dir)
    logger.info(
        'performing testing in {d}D over fold(s) {f} on experiment {e}'.format(
            d=cf.dim, f=folds, e=cf.exp_dir))
    logger.info('intitializing tensorflow graph...')
    tf.reset_default_graph()
    x = tf.placeholder('float', shape=cf.network_input_shape)
    logits = model.create_UNet(x,
                               cf.n_features_root,
                               cf.n_classes,
                               dim=cf.dim,
                               logger=logger)
    predicter = tf.nn.softmax(logits)
    saver = tf.train.Saver()
    logger.info('intitializing test generator...')
    test_data_dict = data_loader.get_test_generator(cf)
    pred_dict = {key: [] for key in test_data_dict.keys()}

    logger.info('starting testing...')
    with tf.Session() as sess:
        for fold in folds:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess,
                          os.path.join(cf.exp_dir, 'params_{}'.format(fold)))

            for ix, pid in enumerate(test_data_dict.keys()):
                patient_fold_prediction = []
                test_arr = test_data_dict[pid]['data']
                patient_fold_prediction.append(
                    sess.run(predicter, feed_dict={x: test_arr}))

                test_arr = np.flip(test_data_dict[pid]['data'],
                                   axis=cf.dim - 1)
                patient_fold_prediction.append(
                    np.flip(sess.run(predicter, feed_dict={x: test_arr}),
                            axis=cf.dim - 1))

                test_arr = np.flip(test_data_dict[pid]['data'], axis=cf.dim)
                patient_fold_prediction.append(
                    np.flip(sess.run(predicter, feed_dict={x: test_arr}),
                            axis=cf.dim))

                test_arr = np.flip(np.flip(test_data_dict[pid]['data'],
                                           axis=cf.dim - 1),
                                   axis=cf.dim)
                patient_fold_prediction.append(
                    np.flip(np.flip(sess.run(predicter,
                                             feed_dict={x: test_arr}),
                                    axis=cf.dim - 1),
                            axis=cf.dim))
                pred_dict[pid].append(
                    np.mean(np.array(patient_fold_prediction), axis=0))

    logger.info('evaluating averaged predictions...')
    final_dices = []
    for ix, pid in enumerate(test_data_dict.keys()):
        final_pred_soft = np.mean(np.array(pred_dict[pid]), axis=0)
        final_pred_correct = np.argmax(final_pred_soft, axis=-1)
        seg = test_data_dict[pid]['seg']
        avg_dices = utils.numpy_volume_dice_per_class(
            utils.get_one_hot_prediction(final_pred_correct, cf.n_classes),
            seg)
        final_dices.append(avg_dices)
        logger.info('avg dices for patient {p} over {a} preds: {d}'.format(
            p=pid, a=len(pred_dict[pid]), d=avg_dices))
        np.save(os.path.join(cf.test_dir, '{}_pred_final.npy'.format(pid)),
                np.concatenate((final_pred_soft[np.newaxis], seg[np.newaxis])))
        plot_batch_prediction(test_data_dict[pid]['data'],
                              seg,
                              final_pred_correct,
                              cf.n_classes,
                              os.path.join(cf.test_dir,
                                           '{}_pred_final.png'.format(pid)),
                              dim=cf.dim)

    logger.info('final dices mean: {}'.format(np.mean(final_dices, axis=0)))
    logger.info('final dices std: {}'.format(np.std(final_dices, axis=0)))
예제 #5
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
        cf.dim, cf.fold, cf.exp_dir, cf.model))
    
    writer = SummaryWriter(os.path.join(cf.exp_dir,'tensorboard'))

    net = model.net(cf, logger).cuda()
    #print('finish initial network')
    optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    #print('finish initial optimizer')
    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)#val_sampling

    starting_epoch = 1

    # prepare monitoring
    #monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf)
    #print('monitor_metrics',monitor_metrics)
    if cf.resume_to_checkpoint:#default: False
        best_epoch = np.load(cf.resume_to_checkpoint + 'epoch_ranking.npy')[0] 
        df = open(cf.resume_to_checkpoint+'monitor_metrics.pickle','rb')
        monitor_metrics = pickle.load(df)
        df.close()
        starting_epoch = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch))
        num_batch = starting_epoch * cf.num_train_batches+1
        num_val = starting_epoch * cf.num_val_batches+1
    else:
        monitor_metrics = utils.prepare_monitoring(cf)
        num_batch = 0#for show loss
        num_val = 0
    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)
    #for k in batch_gen.keys():
    #    print('k in batch_gen are {}'.format(k))
    best_train_recall,best_val_recall = 0,0
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cf.learning_rate[epoch - 1]

        start_time = time.time()

        net.train()
        train_results_list = []#this batch

        #print('net.train()')
        for bix in range(cf.num_train_batches):#200
            num_batch += 1
            batch = next(batch_gen['train'])#data,seg,pid,class_target,bb_target,roi_masks,roi_labels
            #print('training',batch['pid'])
            for ii,i in enumerate(batch['roi_labels']):
                if i[0] > 0:
                    batch['roi_labels'][ii] = [1]
                else:
                    batch['roi_labels'][ii] = [-1]
            #for k in batch.keys():
            #    print('k',k)

            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()

            optimizer.zero_grad()
            results_dict['torch_loss'].backward()#total loss
            optimizer.step()
            
            if (num_batch) % cf.show_train_images == 0:
                fig = plot_batch_prediction(batch, results_dict, cf,'train')
                writer.add_figure('/Train/results',fig,num_batch)
                fig.clear()
            logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                        .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                                time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'])
            writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch)
            writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch)
            writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch)
            writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch)
            writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch)
            if 'mrcnn' in cf.model_path:
                writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch)
            if 'ufrcnn' in cf.model_path:
                writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch)
            train_results_list.append([results_dict['boxes'], batch['pid']])#just gt and det
            monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values'])

        count_train = train_evaluator.evaluate_predictions(train_results_list,epoch,cf,flag = 'train')
        print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_train[0],count_train[1],count_train[2],count_train[3]))

        precision = count_train[0]/ (count_train[0]+count_train[2]+0.01)
        recall = count_train[0]/ (count_train[3])
        print('precision:{}, recall:{}'.format(precision,recall))
        monitor_metrics['train']['train_recall'].append(recall)
        monitor_metrics['train']['train_percision'].append(precision)
        writer.add_scalar('Train/train_precision',precision,epoch)
        writer.add_scalar('Train/train_recall',recall,epoch)

        train_time = time.time() - start_time
        print('*'*50 + 'finish epoch {}'.format(epoch))

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                dice_val = [] 
                for _ in range(batch_gen['n_val']):#50
                    num_val += 1
                    batch = next(batch_gen[cf.val_mode])
                    #print('valing',batch['pid'])
                    for ii,i in enumerate(batch['roi_labels']):
                        if i[0] > 0:
                            batch['roi_labels'][ii] = [1]
                        else:
                            batch['roi_labels'][ii] = [-1]
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch, is_validation=True)
                        if (num_val) % cf.show_val_images == 0:
                            fig = plot_batch_prediction(batch, results_dict, cf,'val')
                            writer.add_figure('Val/results',fig,num_val)
                            fig.clear()

                    this_batch_seg_label = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], cf.num_seg_classes)).cuda()
                    this_batch_dice = DiceLoss()
                    dice = 1- this_batch_dice(F.softmax(results_dict['seg_logits'],dim=1),this_batch_seg_label)
                    #this_batch_dice = batch_dice(F.softmax(results_dict['seg_logits'],dim = 1),this_batch_seg_label,showdice = True)
                    dice_val.append(dice)
                    val_results_list.append([results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values'])

                count_val = val_evaluator.evaluate_predictions(val_results_list,epoch,cf,flag = 'val')
                print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_val[0],count_val[1],count_val[2],count_val[3]))
                precision = count_val[0]/ (count_val[0]+count_val[2]+0.01)
                recall = count_val[0]/ (count_val[3])
                print('precision:{}, recall:{}'.format(precision,recall))
                monitor_metrics['val']['val_recall'].append(recall)
                monitor_metrics['val']['val_percision'].append(precision) 
                writer.add_scalar('Val/val_precision',precision,epoch)
                writer.add_scalar('Val/val_recall',recall,epoch)
                writer.add_scalar('Val/val_dice',sum(dice_val)/float(len(dice_val)),epoch)
                model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)

            # update monitoring and prediction plots
            #TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format(
                epoch, epoch_time, train_time, epoch_time-train_time))
    writer.close()