Пример #1
0
 def setup_arch_flags(cls):
     super().setup_arch_flags()
     flags.DEFINE_integer('max_epoch',
                          default=200,
                          help='number of max_epoch')
     flags.DEFINE_multi_integer('milestones',
                                default=[
                                    10, 20, 30, 40, 50, 60, 70, 80, 90, 100,
                                    110, 120, 140, 160
                                ],
                                help='miletones for lr_decay')
     flags.DEFINE_float('gamma', default=0.85, help='gamma for lr_decay')
     flags.DEFINE_string('device', default='cpu', help='cpu or cuda?')
     flags.DEFINE_integer('printfreq',
                          default=5,
                          help='how many output for an epoch')
     flags.DEFINE_integer('num_admm_innerloop',
                          default=2,
                          help='how many output for an epoch')
     flags.DEFINE_integer('num_workers',
                          default=1,
                          help='how many output for an epoch')
     flags.DEFINE_integer('batch_size',
                          default=1,
                          help='how many output for an epoch')
     flags.DEFINE_boolean('vis_during_training',
                          default=False,
                          help='matplotlib plot image during training')
     flags.DEFINE_boolean('group',
                          default=False,
                          help='group patient to perform 3D dice')
Пример #2
0
 def setup_arch_flags(cls):
     """ Setup the arch_hparams """
     flags.DEFINE_float('weight_decay',
                        default=0,
                        help='decay of learning rate schedule')
     flags.DEFINE_float('lr', default=0.001, help='learning rate')
     flags.DEFINE_boolean('amsgrad', default=False, help='amsgrad')
     flags.DEFINE_integer('optim_inner_loop_num',
                          default=5,
                          help='optim_inner_loop_num')
     flags.DEFINE_string('arch', default='enet', help='arch_name')
     flags.DEFINE_integer('num_classes', default=2, help='num of classes')
     flags.DEFINE_string('method', default='admm_gc_size', help='arch_name')
     flags.DEFINE_boolean('ignore_negative',
                          default=False,
                          help='ignore negative examples in the training')
Пример #3
0
 def setup_arch_flags(cls):
     super().setup_arch_flags()
     flags.DEFINE_boolean(
         'individual_size_constraint',
         default=True,
         help='Individual size constraint for each input image')
     flags.DEFINE_float('eps', default=0.001, help='Individual size eps')
     flags.DEFINE_integer(
         'global_upbound',
         default=2000,
         help='global upper bound if individual_size_constraint is False')
     flags.DEFINE_integer(
         'global_lowbound',
         default=20,
         help='global lower bound if individual_size_constraint is False')
     SizeConstraint.setup_arch_flag()
Пример #4
0
        assert hparams['batch_size']==1,hparams['batch_size']

def run(argv):
    del argv

    hparams = flags.FLAGS.flag_values_dict()
    check_consistance(hparams)
    train_dataset, val_dataset = build_datasets(hparams)

    arch_hparams = extract_from_big_dict(hparams, AdmmGCSize.arch_hparam_keys)
    torchnet = get_arch(arch_hparams['arch'], arch_hparams)

    admm = get_method(hparams['method'], torchnet, **hparams)
    criterion = get_loss_fn(hparams['loss'])
    trainer = ADMM_Trainer(admm, [train_dataset, val_dataset], criterion, hparams)
    trainer.start_training()


if __name__ == '__main__':
    torch.manual_seed(41)
    flags.DEFINE_string('dataroot', default='cardiac', help='the name of the dataset')
    flags.DEFINE_boolean('data_aug', default=False, help='data_augmentation')
    flags.DEFINE_string('loss',default='partial_ce',help='loss used in admm loop')
    flags.DEFINE_boolean('data_equ',default=False, help='data equalization')
    # AdmmSize.setup_arch_flags()
    # AdmmGCSize.setup_arch_flags()
    # ADMM_size_inequality.setup_arch_flags()
    ADMM_reg_size_inequality.setup_arch_flags()
    ADMM_Trainer.setup_arch_flags()
    app.run(run)