예제 #1
0
parser.add_argument('--test_every_step', type=bool, default=False,
                    help='test after every train step')
parser.add_argument('--pre_all', type=bool, default=True,
                    help='pretrain the whole model')
parser.add_argument('--pre_part', type=bool, default=True,
                    help='pretrain the pytorch pretrain_model, eg resnet50 and resnet18, used in resnet_50 and resnet_dense')


args = parser.parse_args()
ori_data = os.path.join(args.data_path,'img')
train_data = os.path.join(args.data_path,str(args.n_class),'train_data')
test_data = os.path.join(args.data_path,str(args.n_class),'test_data')
label_data = os.path.join(args.data_path,str(args.n_class),'train_label')

model_name = models_list[args.model_id]
model = get_cur_model(model_name, n_classes=args.n_class, pretrain=args.pre_part)
# save_model_path = os.path.join('./models',model_name)


softmax_2d = nn.Softmax2d()

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
if args.use_gpu:
    model.cuda()
if args.pre_all:
    model_path = os.path.join('models',model_name,args.best_model)
    model.load_state_dict(torch.load(model_path))

mean_loss = 100
예제 #2
0
    'pretrain the pytorch pretrain_model, eg resnet50 and resnet18, used in resnet_50 and resnet_dense'
)
parser.add_argument('--hard_example_train',
                    type=bool,
                    default=False,
                    help='only train the hard example')

args = parser.parse_args()
ori_data = os.path.join(args.data_path, 'img')
train_data = os.path.join(args.data_path, str(args.n_class), 'train_data')
test_data = os.path.join(args.data_path, str(args.n_class), 'test_data')
label_data = os.path.join(args.data_path, str(args.n_class), 'train_label')

model_name = models_list[args.model_id]
model = get_cur_model(model_name,
                      n_classes=args.n_class,
                      pretrain=args.pre_part)
small_model = get_cur_model('Small', n_classes=args.n_class)
# save_model_path = os.path.join('./models',model_name)
if 'resnet' in model_name:
    params_mask = list(map(id, model.resnet.parameters()))
    base_params = filter(lambda p: id(p) not in params_mask,
                         model.parameters())
    pre_params = filter(lambda p: id(p) in params_mask, model.parameters())
    optimizer = torch.optim.Adam([{
        'params': base_params
    }, {
        'params': pre_params,
        'lr': args.pre_lr
    }],
                                 lr=args.lr)
예제 #3
0
train_data = os.path.join(args.data_path,'4','train_data')
test_data = os.path.join(args.data_path,'4','test_data')
label_data = os.path.join(args.data_path,'4','train_label')

a_1 = args.a_1
lambda_2 = args.lambda_2
lambda_3 = args.lambda_3
e_ls = args.e_ls
option_ = 2
test_flag = True
n_class = args.n_class
gpu_num = args.gpu_num

from utils import get_cur_model
#a = TestHeavisideFunction()
model = get_cur_model('level_set',2)
LevelSetModel = LevelSet_CNN_RNN_STN()

RNNLevelSetModel = 1

Options={
'InnerAreaOption':1,
'UseLengthItemType':1,
'UseHigh_Hfuntion':0,
'isShownVisdom':0,
'lambda_1':a_1,
'lambda_2':lambda_2,
'lambda_3':lambda_3,
'lambda_shape':0.2,
'lambda_CNN':0.75,
'Lamda_RNN':1.0,
예제 #4
0
GRU_Dimention = conf.read('Options', 'GRU_Dimention', type=2)
n_epochs = conf.read('Options', 'n_epochs', type=2)
lr_decay = conf.read('Options', 'lr_decay', type=2)
batch_size = conf.read('Options', 'batch_size', type=2)
img_size = conf.read('Options', 'img_size', type=2)
random_seed = conf.read('Options', 'random_seed', type=2)

UseHigh_Hfuntion = conf.read('Options', 'UseHigh_Hfuntion', type=2)

ori_data = os.path.join(data_path, 'img')
train_data = os.path.join(data_path, '4', 'train_data')
test_data = os.path.join(data_path, '4', 'test_data')
label_data = os.path.join(data_path, '4', 'train_label')
test_flag = True

model = get_cur_model('level_set', n_class)
LevelSetModel = LevelSet_CNN_RNN_STN()

RNNLevelSetModel = 1

Options = {
    'InnerAreaOption': InnerAreaOption,
    'UseLengthItemType': UseLengthItemType,
    'UseHigh_Hfuntion': UseHigh_Hfuntion,
    'isShownVisdom': isShownVisdom,
    'lambda_1': lambda_1,
    'lambda_2': lambda_2,
    'lambda_3': lambda_3,
    'lambda_shape': lambda_shape,
    'lambda_CNN': lambda_CNN,
    'Lamda_RNN': Lamda_RNN,