def test_samples(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") keys = "image" xforms = Compose( [ LoadImaged(keys=keys), AddChanneld(keys=keys), ScaleIntensityd(keys=keys), RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10), ] ) image_path = os.path.join(testing_dir, "anatomical.nii") xforms.set_random_state(0) ims = xforms({keys: image_path}) fig, mat = matshow3d( [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False ) self.assertTrue(mat.dtype == np.float32) with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/matshow3d_patch_test.png" fig.savefig(tempimg) comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True) if comp: print("not none comp: ", comp) # matplotlib 3.2.2 np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3) else: self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed
class Lungs(Dataset, Randomizable): def __init__(self, dicom_folders): self.dicom_folders = dicom_folders self.transforms = get_validation_augmentation() self.preprocessing = get_preprocessing( functools.partial(preprocess_input, **formatted_settings)) self.transform3d = Compose( [ScaleIntensity(), Resize((160, 160, 160)), ToTensor()]) def __len__(self): return len(self.dicom_folders) def randomize(self) -> None: MAX_SEED = np.iinfo(np.uint32).max + 1 self._seed = self.R.randint(MAX_SEED, dtype="uint32") def get(self, i): s = time.time() data = load_dicom_array(self.dicom_folders[i]) image, files = data image_lung = np.expand_dims(window(image, WL=-600, WW=1500), axis=3) image_mediastinal = np.expand_dims(window(image, WL=40, WW=400), axis=3) image_pe_specific = np.expand_dims(window(image, WL=100, WW=700), axis=3) image = np.concatenate( [image_mediastinal, image_pe_specific, image_lung], axis=3) rat = MAX_LENGTH / np.max(image.shape[1:]) names = [row.split(".dcm")[0].split("/")[-3:] for row in files] images = [] for img in image: if self.transforms: img = self.transforms(image=img)['image'] if self.preprocessing: img = self.preprocessing(image=img)['image'] images.append(img) images = np.array(images) img = images[:, ::-1].transpose(1, 2, 3, 0) if self.transform3d is not None: if isinstance(self.transform3d, Randomizable): self.transform3d.set_random_state(seed=self._seed) img = apply_transform(self.transform3d, img) return torch.from_numpy(images), names, img def __getitem__(self, i): self.randomize() try: return self.get(i) except Exception as e: print(e) return None, None, None
def test_random_compose(self): class _Acc(Randomizable): self.rand = 0.0 def randomize(self, data=None): self.rand = self.R.rand() def __call__(self, data): self.randomize() return self.rand + data c = Compose([_Acc(), _Acc()]) self.assertNotAlmostEqual(c(0), c(0)) c.set_random_state(123) self.assertAlmostEqual(c(1), 1.61381597) c.set_random_state(223) c.randomize() self.assertAlmostEqual(c(1), 1.90734751)
def test_random_compose(self): class _Acc(Randomizable): self.rand = 0.0 def randomize(self): self.rand = self.R.rand() def __call__(self, data): self.randomize() return self.rand + data c = Compose([_Acc(), _Acc()]) self.assertNotAlmostEqual(c(0), c(0)) c.set_random_state(123) self.assertAlmostEqual(c(1), 2.39293837) c.set_random_state(223) c.randomize() self.assertAlmostEqual(c(1), 2.57673391)
def test_data_loader(self): xform_1 = Compose([_RandXform()]) train_ds = Dataset([1], transform=xform_1) xform_1.set_random_state(123) out_1 = train_ds[0] self.assertAlmostEqual(out_1, 0.2045649) set_determinism(seed=123) train_loader = DataLoader(train_ds, num_workers=0) out_1 = next(iter(train_loader)) self.assertAlmostEqual(out_1.cpu().item(), 0.84291356) if sys.platform != "win32": # skip multi-worker tests on win32 train_loader = DataLoader(train_ds, num_workers=1) out_1 = next(iter(train_loader)) self.assertAlmostEqual(out_1.cpu().item(), 0.180814653) train_loader = DataLoader(train_ds, num_workers=2) out_1 = next(iter(train_loader)) self.assertAlmostEqual(out_1.cpu().item(), 0.04293707) set_determinism(None)
def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", num_workers=10): monai.config.print_config() # define transforms for image and classification train_transforms = Compose([ LoadPNG(image_only=True), AddChannel(), ScaleIntensity(), RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensor(), ]) train_transforms.set_random_state(1234) val_transforms = Compose( [LoadPNG(image_only=True), AddChannel(), ScaleIntensity(), ToTensor()]) # create train, val data loaders train_ds = MedNISTDataset(train_x, train_y, train_transforms) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers) val_ds = MedNISTDataset(val_x, val_y, val_transforms) val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers) model = densenet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) epoch_num = 4 val_interval = 1 # start training validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(epoch_num): print("-" * 10) print(f"Epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) metric_values.append(auc_metric) acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) if auc_metric > best_metric: best_metric = auc_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current AUC: {auc_metric:0.4f} " f"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}" ) print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) return epoch_loss_values, best_metric, best_metric_epoch
def run_training_test(root_dir, device="cuda:0", cachedataset=0): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ] ) train_transforms.set_random_state(1234) val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) dice_metric = DiceMetric(include_background=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model)) value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += not_nans.item() metric_sum += value.item() * not_nans.item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}") writer.close() return epoch_loss_values, best_metric, best_metric_epoch
'rv_lv_ratio_gte_1', # exam level "central_pe", "leftsided_pe", "rightsided_pe", "acute_and_chronic_pe", "chronic_pe" ] out_dim = len(target_cols) image_size = 100 val_transforms = Compose([ ScaleIntensity(), Resize((image_size, image_size, image_size)), ToTensor() ]) val_transforms.set_random_state(seed=42) def monai_preprocess(imgs512): imgs = imgs512[:, :, 43:-55, 43:-55] img_monai = imgs[int(imgs.shape[0] * 0.25):int(imgs.shape[0] * 0.75)] img_monai = np.transpose(img_monai, (1, 2, 3, 0)) img_monai = apply_transform(val_transforms, img_monai) img_monai = np.expand_dims(img_monai, axis=0) img_monai = torch.from_numpy(img_monai).cuda() return img_monai class MonaiModelTest(): def __init__(self, monai_model_file): self.monai_model = monai.networks.nets.densenet.densenet121(
def transform_and_copy(data, cahce_dir): copy_dir = os.path.join(cahce_dir, 'copied_images') if not os.path.exists(copy_dir): os.mkdir(copy_dir) copy_list_path = os.path.join(copy_dir, 'copied_images.npy') if not os.path.exists(copy_list_path): print("transforming and copying images...") imageLoader = LoadImage() to_copy_list = [x for x in data if int(x['_label']) == 1] mul = 1 #int(len(data)/len(to_copy_list) - 1) rand_x_flip = RandFlip(spatial_axis=0, prob=0.50) rand_y_flip = RandFlip(spatial_axis=1, prob=0.50) rand_z_flip = RandFlip(spatial_axis=2, prob=0.50) rand_affine = RandAffine(prob=1.0, rotate_range=(0, 0, np.pi / 10), shear_range=(0.12, 0.12, 0.0), translate_range=(0, 0, 0), scale_range=(0.12, 0.12, 0.0), padding_mode="zeros") rand_gaussian_noise = RandGaussianNoise(prob=0.5, mean=0.0, std=0.05) transform = Compose([ AddChannel(), rand_x_flip, rand_y_flip, rand_z_flip, rand_affine, SqueezeDim(), ]) copy_list = [] n = len(to_copy_list) for i in range(len(to_copy_list)): print(f'Copying image {i+1}/{n}', end="\r") to_copy = to_copy_list[i] image_file = to_copy['image'] _image_file = replace_suffix(image_file, '.nii.gz', '') label = to_copy['label'] _label = to_copy['_label'] image_data, _ = imageLoader(image_file) seg_file = to_copy['seg'] seg_data, _ = nrrd.read(seg_file) for i in range(mul): rand_seed = np.random.randint(1e8) transform.set_random_state(seed=rand_seed) new_image_data = rand_gaussian_noise( np.array(transform(image_data))) transform.set_random_state(seed=rand_seed) new_seg_data = np.array(transform(seg_data)) #multi_slice_viewer(image_data, image_file) #multi_slice_viewer(seg_data, seg_file) #seg_image = MaskIntensity(seg_data)(image_data) #multi_slice_viewer(seg_image, seg_file) image_basename = os.path.basename(_image_file) seg_basename = image_basename + f'_seg_{i}.nrrd' image_basename = image_basename + f'_{i}.nii.gz' new_image_file = os.path.join(copy_dir, image_basename) write_nifti(new_image_data, new_image_file, resample=False) new_seg_file = os.path.join(copy_dir, seg_basename) nrrd.write(new_seg_file, new_seg_data) copy_list.append({ 'image': new_image_file, 'seg': new_seg_file, 'label': label, '_label': _label }) np.save(copy_list_path, copy_list) print("done transforming and copying!") copy_list = np.load(copy_list_path, allow_pickle=True) return copy_list
def large_image_splitter(data, cache_dir, num_splits, only_label_one=False): print("Splitting large images...") len_old = len(data) print("original data len:", len_old) split_images_dir = os.path.join(cache_dir, 'split_images') split_images = os.path.join(split_images_dir, 'split_images.npy') def _replace_in_data(split_images, num_splits): new_images = [] for image in data: new_images.append(image) for s in split_images: source_image = s['source'] if image['_label'] == 0 and only_label_one is True: break if image['image'] == source_image: #new_images.pop() for i in range(min(num_splits, len(s["splits"]))): new_images.append(s["splits"][i]) break return new_images if os.path.exists(split_images): new_images = np.load(split_images, allow_pickle=True) """for s in new_images: print("split image:", s["source"], end='\r')""" out_data = _replace_in_data(new_images, num_splits) else: if not os.path.exists(split_images_dir): os.mkdir(split_images_dir) new_images = [] imageLoader = LoadImage() for image in data: image_data, _ = imageLoader(image["image"]) seg_data, _ = nrrd.read(image['seg']) label = image['_label'] z_len = image_data.shape[2] if z_len > 200: count = z_len // 80 print("splitting image:", image["image"], f"into {count} parts", "shape:", image_data.shape, end='\r') split_image_list = [ image_data[:, :, idz::count] for idz in range(count) ] split_seg_list = [ seg_data[:, :, idz::count] for idz in range(count) ] new_image = {'source': image["image"], 'splits': []} for i in range(count): image_file = os.path.basename( replace_suffix(image["image"], '.nii.gz', '')) image_file = os.path.join(split_images_dir, image_file + f'_{i}.nii.gz') seg_file = os.path.basename( replace_suffix(image["seg"], '.nrrd', '')) seg_file = os.path.join(split_images_dir, seg_file + f'_seg_{i}.nrrd') split_image = np.array(split_image_list[i]) split_seg = np.array(split_seg_list[i], dtype=np.uint8) rand_affine = RandAffine(prob=1.0, rotate_range=(0, 0, np.pi / 16), shear_range=(0.07, 0.07, 0.0), translate_range=(0, 0, 0), scale_range=(0.07, 0.07, 0.0), padding_mode="zeros") transform = Compose([ AddChannel(), rand_affine, SqueezeDim(), ]) rand_seed = np.random.randint(1e8) transform.set_random_state(seed=rand_seed) split_image = transform(split_image).detach().cpu().numpy() transform.set_random_state(seed=rand_seed) split_seg = transform(split_seg).detach().cpu().numpy() write_nifti(split_image, image_file, resample=False) nrrd.write(seg_file, split_seg) new_image['splits'].append({ 'image': image_file, 'label': image['label'], '_label': image['_label'], 'seg': seg_file, 'w': False }) new_images.append(new_image) np.save(split_images, new_images) out_data = _replace_in_data(new_images, num_splits) print("new data len:", len(out_data)) return out_data
def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=False): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(root_dir, 'seg*.nii.gz'))) train_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=['img', 'seg']) ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), ToTensord(keys=['img', 'seg']) ]) # create a training data loader if cachedataset: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, 'runs')) model_filename = os.path.join(root_dir, 'best_metric_model.pth') for epoch in range(6): print('-' * 10) print('Epoch {}/{}'.format(epoch + 1, 6)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to( device), batch_data['seg'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print('saved new best metric model') print( "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) writer.close() return epoch_loss_values, best_metric, best_metric_epoch