Ejemplo n.º 1
0
def plot_train_forward(slices=None):
    with torch.no_grad():
        batch = next(val_gen)
        results_dict = net.train_forward(batch, is_validation=True) #seg preds are int preds already

        out_file = os.path.join(anal_dir, "straight_val_inference_fold_{}".format(str(cf.fold)))
        plg.view_batch(cf, batch, res_dict=results_dict, show_info=False, legend=True,
                                  out_file=out_file, slices=slices)
Ejemplo n.º 2
0
def plot_forward(pid, slices=None):
    with torch.no_grad():
        batch = batch_gen['test'].generate_train_batch(pid=pid)
        results_dict = net.test_forward(batch) #seg preds are only seg_logits! need to take argmax.

        if 'seg_preds' in results_dict.keys():
            results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis]

        out_file = os.path.join(anal_dir, "straight_inference_fold_{}_pid_{}".format(str(cf.fold), pid))
        plg.view_batch(cf, batch, res_dict=results_dict, show_info=False, legend=True, show_gt_labels=True,
                                  out_file=out_file, sample_picks=slices)
Ejemplo n.º 3
0
def plot_merged_boxes(results_list, pid, plot_mods=False, show_seg_ids="all", show_info=True, show_gt_boxes=True,
                      s_picks=None, vol_slice_picks=None, score_thres=None):
    """

    :param results_list: holds (results_dict, pid)
    :param pid:
    :return:
    """
    results_dict = [res_dict for (res_dict, pid_) in results_list if pid_==pid][0]
    #seg preds are discarded in predictor pipeline.
    #del results_dict['seg_preds']

    batch = batch_gen['test'].generate_train_batch(pid=pid)
    out_file = os.path.join(anal_dir, "merged_boxes_fold_{}_pid_{}_thres_{}.png".format(str(cf.fold), pid, str(score_thres).replace(".","_")))

    utils.save_obj({'res_dict':results_dict, 'batch':batch}, os.path.join(anal_dir, "bytes_merged_boxes_fold_{}_pid_{}".format(str(cf.fold), pid)))

    plg.view_batch(cf, batch, res_dict=results_dict, show_info=show_info, legend=False, sample_picks=s_picks,
                   show_seg_pred=True, show_seg_ids=show_seg_ids, show_gt_boxes=show_gt_boxes,
                   box_score_thres=score_thres, vol_slice_picks=vol_slice_picks, show_gt_labels=True,
                   plot_mods=plot_mods, out_file=out_file, has_colorchannels=cf.has_colorchannels, dpi=600)

    return
Ejemplo n.º 4
0
    # cf.data_dir = "experiments/dev_data"

    cf.exp_dir = "experiments/dev/"
    cf.plot_dir = cf.exp_dir + "plots"
    os.makedirs(cf.exp_dir, exist_ok=True)
    cf.fold = 0
    logger = utils.get_logger(cf.exp_dir)
    gens = get_train_generators(cf, logger)
    train_loader = gens['train']
    for i in range(0):
        stime = time.time()
        print("producing training batch nr ", i)
        ex_batch = next(train_loader)
        times["train_batch"] = time.time() - stime
        #experiments/dev/dev_exbatch_{}.png".format(i)
        plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False)


    val_loader = gens['val_sampling']
    stime = time.time()
    for i in range(1):
        ex_batch = next(val_loader)
        times["val_batch"] = time.time() - stime
        stime = time.time()
        #"experiments/dev/dev_exvalbatch_{}.png"
        plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True)
        times["val_plot"] = time.time() - stime
    #
    test_loader = get_test_generator(cf, logger)["test"]
    stime = time.time()
    ex_batch = test_loader.generate_train_batch(pid=None)
Ejemplo n.º 5
0
def train(cf, logger):
    """
    performs the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs. logs to file and tensorboard.
    """
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
        cf.dim, cf.fold, cf.exp_dir, cf.model))
    logger.time("train_val")

    # -------------- inits and settings -----------------
    net = model.net(cf, logger).cuda()
    if cf.optimizer == "ADAM":
        optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    elif cf.optimizer == "SGD":
        optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3)
    if cf.dynamic_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
                                                                    patience=cf.scheduling_patience)
    model_selector = utils.ModelSelector(cf, logger)

    starting_epoch = 1
    if cf.resume_from_checkpoint:
        starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer)
        logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch))

    # prepare monitoring
    monitor_metrics = utils.prepare_monitoring(cf)

    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)

    # -------------- training -----------------
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs))
        logger.time("train_epoch")

        net.train()

        train_results_list = []
        train_evaluator = Evaluator(cf, logger, mode='train')

        for i in range(cf.num_train_batches):
            logger.time("train_batch_loadfw")
            batch = next(batch_gen['train'])
            batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts']
            batch_gen['train'].generator.stats['empty_samples_count'] += batch['empty_samples_count']

            logger.time("train_batch_loadfw")
            logger.time("train_batch_netfw")
            results_dict = net.train_forward(batch)
            logger.time("train_batch_netfw")
            logger.time("train_batch_bw")
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            if cf.clip_norm:
                torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping
            optimizer.step()
            train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict
            if not cf.server_env:
                print("\rFinished training batch " +
                      "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches,
                                                                                             logger.get_time("train_batch_loadfw")+
                                                                                             logger.get_time("train_batch_netfw")
                                                                                             +logger.time("train_batch_bw"),
                                                                                             logger.get_time("train_batch_loadfw",reset=True),
                                                                                             logger.get_time("train_batch_netfw", reset=True),
                                                                                             logger.get_time("train_batch_bw", reset=True)), end="", flush=True)
        print()

        #--------------- train eval ----------------
        if (epoch-1)%cf.plot_frequency==0:
            # view an example batch
            logger.time("train_plot")
            plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True,
                           out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold)))
            logger.info("generated train-example plot in {:.2f}s".format(logger.time("train_plot")))


        logger.time("evals")
        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
        logger.time("evals")
        logger.time("train_epoch", toggle=False)
        del train_results_list

        #----------- validation ------------
        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        logger.time("val_epoch")
        with torch.no_grad():
            net.eval()
            val_results_list = []
            val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
            val_predictor = Predictor(cf, net, logger, mode='val')

            for i in range(batch_gen['n_val']):
                logger.time("val_batch")
                batch = next(batch_gen[cf.val_mode])
                if cf.val_mode == 'val_patient':
                    results_dict = val_predictor.predict_patient(batch)
                elif cf.val_mode == 'val_sampling':
                    results_dict = net.train_forward(batch, is_validation=True)
                val_results_list.append([results_dict, batch["pid"]])
                if not cf.server_env:
                    print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch',
                                                                              i + 1, batch_gen['n_val'],
                                                                              logger.time("val_batch")), end="", flush=True)
            print()

            #------------ val eval -------------
            if (epoch - 1) % cf.plot_frequency == 0:
                logger.time("val_plot")
                plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True,
                               out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold)))
                logger.info("generated val plot in {:.2f}s".format(logger.time("val_plot")))

            logger.time("evals")
            _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])

            model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
            del val_results_list
            #----------- monitoring -------------
            monitor_metrics.update({"lr": 
                {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
            logger.metrics2tboard(monitor_metrics, global_step=epoch)
            logger.time("evals")

            logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format(
                epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"),
                logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"),
                logger.get_time("val_epoch", reset=True)/batch_gen["n_val"]))
            logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True)))

        #-------------- scheduling -----------------
        if not cf.dynamic_lr_scheduling:
            for param_group in optimizer.param_groups:
                param_group['lr'] = cf.learning_rate[epoch-1]
        else:
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])

    logger.time("train_val")
    logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True)))
    batch_gen['train'].generator.print_stats(logger, plot=True)