예제 #1
0
def main(config, resume, epoch_logs):
    torch.manual_seed(42)
    val_logger = Logger()
    train_logger = Logger()

    if epoch_logs:
        epoch_logs = json.load(open(epoch_logs))
        train_logger.entries = epoch_logs["train_results"]
        val_logger.entries = epoch_logs["val_results"]

    del epoch_logs

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.VOC(config['train_supervised'])
    unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    val_loader = dataloaders.VOC(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       cutmix_conf=config["cutmix"],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      val_logger=val_logger,
                      train_logger=train_logger)

    trainer.train()
예제 #2
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config[
        'n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.TISSUE(config['train_supervised'])
    unsupervised_loader = dataloaders.TISSUE(config['train_unsupervised'])
    val_loader = dataloaders.TISSUE(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    elif config['model']['sup_loss'] == 'FL':
        alpha = get_alpha(supervised_loader)  # calculare class occurences
        sup_loss = FocalLoss(apply_nonlin=softmax_helper,
                             alpha=alpha,
                             gamma=2,
                             smooth=1e-5)
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      train_logger=train_logger)

    trainer.train()
예제 #3
0
파일: train.py 프로젝트: saramsv/TCT
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    #config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
    #config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
    #config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    #supervised_loader = dataloaders.VOC(config['train_supervised'])
    #unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    #val_loader = dataloaders.VOC(config['val_loader'])

    supervised_loader = dataloaders.CUS_loader(config['train_supervised'])
    unsupervised_loader = dataloaders.CUS_loader(config['train_unsupervised'])
    sequence_loader = dataloaders.CUS_loader(
        config['train_unsupervised_sequence'])
    val_loader = dataloaders.CUS_loader(config['val_loader'])

    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'],
                       ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      sequence_loader=sequence_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      train_logger=train_logger)

    trainer.train()
예제 #4
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
    config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.VOC(config['train_supervised'])
    unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
    val_loader = dataloaders.VOC(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    elif config['model']['sup_loss'] == 'FL':
        # pixelcount = [list of pixelcount per object class]
        # pixelcount = [44502000, 49407, 1279000, 969250]
        # need to write a function to count pixels
        sup_loss = FocalLoss(apply_nonlin = softmax_helper, alpha = pixelcount, gamma = 2, smooth = 1e-5)
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'],
                                num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                        rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes, conf=config['model'],
    						sup_loss=sup_loss, cons_w_unsup=cons_w_unsup,
    						weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables'],
                            ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(
        model=model,
        resume=resume,
        config=config,
        supervised_loader=supervised_loader,
        unsupervised_loader=unsupervised_loader,
        val_loader=val_loader,
        iter_per_epoch=iter_per_epoch,
        train_logger=train_logger)

    trainer.train()
예제 #5
0
def main(config, resume, site):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.Prostate(site, config['train_supervised'])
    unsupervised_loader = dataloaders.Prostate(site, config['train_unsupervised'])
    val_loader = dataloaders.Prostate(site, config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)
    l = iter(supervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)
    # Models
    model = models.CCT(encoder, num_classes=Num_classes, conf=config['model'],
                       sup_loss=sup_loss, cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables'])
    model.float()
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(
        model=model,
        resume=resume,
        config=config,
        supervised_loader=supervised_loader,
        unsupervised_loader=unsupervised_loader,
        val_loader=val_loader,
        iter_per_epoch=iter_per_epoch,
        train_logger=train_logger)

    trainer.train()
예제 #6
0
def create_model(config, encoder, Num_classes, iter_per_epoch):
    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=Num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=iter_per_epoch,
                                      rampup_ends=rampup_ends)

    model = models.CCT(encoder,
                       num_classes=Num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'])

    return model.float()