sampling_strategy=None) valid_evaluator = Evaluator( dis_model, valid_studies, '../data/lumbar_train51_annotation.json', num_rep=5, max_dist=6, ) step_per_batch = len(train_dataloader) optimizer = torch.optim.AdamW(dis_model.parameters(), lr=1e-5) max_step = 40 * step_per_batch fit_result = torch_utils.fit( dis_model, train_data=train_dataloader, valid_data=None, optimizer=optimizer, max_step=max_step, loss=NullLoss(), metrics=[valid_evaluator.metric], is_higher_better=True, evaluate_per_steps=step_per_batch, evaluate_fn=valid_evaluator, ) dis_model.kp_model = None torch.save(dis_model.cpu().state_dict(), '../models/pretrained.dis_model_v3') print('task completed, {} seconds used'.format(time.time() - start_time))
prob_rotate=1, max_angel=180, num_rep=20, size=[512, 512], pin_memory=False) valid_dataloader = KeyPointDataLoader(valid_images, valid_spacings, valid_annotation, batch_size=1, num_workers=5, num_rep=20, size=[512, 512], pin_memory=False) optimizer = torch.optim.AdamW(kp_model_v2.parameters(), lr=1e-5) max_step = 50 * len(train_dataloader) result_v2 = torch_utils.fit( kp_model_v2, train_dataloader, valid_dataloader, optimizer, max_step, NullLoss(), [KeyPointAcc(6)], is_higher_better=True, evaluate_per_steps=len(train_dataloader), # checkpoint_dir='models', ) torch.save(kp_model_v2.cpu().state_dict(), 'models/pretrained.kp_model_v2') print('task completed, {} seconds used'.format(time.time() - start_time))