Exemple #1
0
def main(config, resume, epoch_logs):
    torch.manual_seed(42)
    val_logger = Logger()
    train_logger = Logger()

    if epoch_logs:
        epoch_logs = json.load(open(epoch_logs))
        train_logger.entries = epoch_logs["train_results"]
        val_logger.entries = epoch_logs["val_results"]

    del epoch_logs

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.VOC(config['train_supervised'])
    unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    val_loader = dataloaders.VOC(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       cutmix_conf=config["cutmix"],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      val_logger=val_logger,
                      train_logger=train_logger)

    trainer.train()
Exemple #2
0
def main():
    args = parse_arguments()

    # CONFIG
    assert args.config
    config = json.load(open(args.config))
    scales = [0.5, 0.75, 1.0, 1.25, 1.5]

    # DATA
    testdataset = testDataset(args.images)
    loader = DataLoader(testdataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1)
    num_classes = config['num_classes']
    palette = get_voc_pallete(num_classes)

    # MODEL
    config['model']['supervised'] = True
    config['model']['semi'] = False
    model = models.CCT(num_classes=num_classes,
                       conf=config['model'],
                       testing=True)
    checkpoint = torch.load(args.model)
    model = torch.nn.DataParallel(model)
    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    model.cuda()

    #if args.save and not os.path.exists('outputs'):
    #    os.makedirs('outputs')
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    # LOOP OVER THE DATA
    tbar = tqdm(loader, ncols=100)
    total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
    labels, predictions = [], []

    for index, data in enumerate(tbar):
        image, image_id = data
        image = image.cuda()

        # PREDICT
        with torch.no_grad():
            output = multi_scale_predict(model, image, scales, num_classes)
        prediction = np.asarray(np.argmax(output, axis=0), dtype=np.uint8)

        # SAVE RESULTS
        prediction_im = colorize_mask(prediction, palette)
        prediction_im.save(args.save + '/' + image_id[0] + '.png')
Exemple #3
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.TISSUE(config['train_supervised'])
    unsupervised_loader = dataloaders.TISSUE(config['train_unsupervised'])
    val_loader = dataloaders.TISSUE(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    elif config['model']['sup_loss'] == 'FL':
        alpha = get_alpha(supervised_loader)  # calculare class occurences
        sup_loss = FocalLoss(apply_nonlin=softmax_helper,
                             alpha=alpha,
                             gamma=2,
                             smooth=1e-5)
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      train_logger=train_logger)

    trainer.train()
Exemple #4
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    #config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
    #config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
    #config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    #supervised_loader = dataloaders.VOC(config['train_supervised'])
    #unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    #val_loader = dataloaders.VOC(config['val_loader'])

    supervised_loader = dataloaders.CUS_loader(config['train_supervised'])
    unsupervised_loader = dataloaders.CUS_loader(config['train_unsupervised'])
    sequence_loader = dataloaders.CUS_loader(
        config['train_unsupervised_sequence'])
    val_loader = dataloaders.CUS_loader(config['val_loader'])

    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      sequence_loader=sequence_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      train_logger=train_logger)

    trainer.train()
Exemple #5
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.VOC(config['train_supervised'])
    unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    val_loader = dataloaders.VOC(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    elif config['model']['sup_loss'] == 'FL':
        # pixelcount = [list of pixelcount per object class]
        # pixelcount = [44502000, 49407, 1279000, 969250]
        # need to write a function to count pixels
        sup_loss = FocalLoss(apply_nonlin = softmax_helper, alpha = pixelcount, gamma = 2, smooth = 1e-5)
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'],
                                num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                        rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes, conf=config['model'],
    						sup_loss=sup_loss, cons_w_unsup=cons_w_unsup,
    						weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables'],
                            ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(
        model=model,
        resume=resume,
        config=config,
        supervised_loader=supervised_loader,
        unsupervised_loader=unsupervised_loader,
        val_loader=val_loader,
        iter_per_epoch=iter_per_epoch,
        train_logger=train_logger)

    trainer.train()
Exemple #6
0
def main(config, resume, site):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.Prostate(site, config['train_supervised'])
    unsupervised_loader = dataloaders.Prostate(site, config['train_unsupervised'])
    val_loader = dataloaders.Prostate(site, config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)
    l = iter(supervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)
    # Models
    model = models.CCT(encoder, num_classes=Num_classes, conf=config['model'],
                       sup_loss=sup_loss, cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables'])
    model.float()
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(
        model=model,
        resume=resume,
        config=config,
        supervised_loader=supervised_loader,
        unsupervised_loader=unsupervised_loader,
        val_loader=val_loader,
        iter_per_epoch=iter_per_epoch,
        train_logger=train_logger)

    trainer.train()
Exemple #7
0
def create_model(config, encoder, Num_classes, iter_per_epoch):
    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=Num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=iter_per_epoch,
                                      rampup_ends=rampup_ends)

    model = models.CCT(encoder,
                       num_classes=Num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'])

    return model.float()