def setup_arch_flags(cls): super().setup_arch_flags() flags.DEFINE_integer('max_epoch', default=200, help='number of max_epoch') flags.DEFINE_multi_integer('milestones', default=[ 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 140, 160 ], help='miletones for lr_decay') flags.DEFINE_float('gamma', default=0.85, help='gamma for lr_decay') flags.DEFINE_string('device', default='cpu', help='cpu or cuda?') flags.DEFINE_integer('printfreq', default=5, help='how many output for an epoch') flags.DEFINE_integer('num_admm_innerloop', default=2, help='how many output for an epoch') flags.DEFINE_integer('num_workers', default=1, help='how many output for an epoch') flags.DEFINE_integer('batch_size', default=1, help='how many output for an epoch') flags.DEFINE_boolean('vis_during_training', default=False, help='matplotlib plot image during training') flags.DEFINE_boolean('group', default=False, help='group patient to perform 3D dice')
def setup_arch_flags(cls): """ Setup the arch_hparams """ flags.DEFINE_float('weight_decay', default=0, help='decay of learning rate schedule') flags.DEFINE_float('lr', default=0.001, help='learning rate') flags.DEFINE_boolean('amsgrad', default=False, help='amsgrad') flags.DEFINE_integer('optim_inner_loop_num', default=5, help='optim_inner_loop_num') flags.DEFINE_string('arch', default='enet', help='arch_name') flags.DEFINE_integer('num_classes', default=2, help='num of classes') flags.DEFINE_string('method', default='admm_gc_size', help='arch_name') flags.DEFINE_boolean('ignore_negative', default=False, help='ignore negative examples in the training')
def setup_arch_flags(cls): super().setup_arch_flags() flags.DEFINE_boolean( 'individual_size_constraint', default=True, help='Individual size constraint for each input image') flags.DEFINE_float('eps', default=0.001, help='Individual size eps') flags.DEFINE_integer( 'global_upbound', default=2000, help='global upper bound if individual_size_constraint is False') flags.DEFINE_integer( 'global_lowbound', default=20, help='global lower bound if individual_size_constraint is False') SizeConstraint.setup_arch_flag()
assert hparams['batch_size']==1,hparams['batch_size'] def run(argv): del argv hparams = flags.FLAGS.flag_values_dict() check_consistance(hparams) train_dataset, val_dataset = build_datasets(hparams) arch_hparams = extract_from_big_dict(hparams, AdmmGCSize.arch_hparam_keys) torchnet = get_arch(arch_hparams['arch'], arch_hparams) admm = get_method(hparams['method'], torchnet, **hparams) criterion = get_loss_fn(hparams['loss']) trainer = ADMM_Trainer(admm, [train_dataset, val_dataset], criterion, hparams) trainer.start_training() if __name__ == '__main__': torch.manual_seed(41) flags.DEFINE_string('dataroot', default='cardiac', help='the name of the dataset') flags.DEFINE_boolean('data_aug', default=False, help='data_augmentation') flags.DEFINE_string('loss',default='partial_ce',help='loss used in admm loop') flags.DEFINE_boolean('data_equ',default=False, help='data equalization') # AdmmSize.setup_arch_flags() # AdmmGCSize.setup_arch_flags() # ADMM_size_inequality.setup_arch_flags() ADMM_reg_size_inequality.setup_arch_flags() ADMM_Trainer.setup_arch_flags() app.run(run)