def check_wso(): from rsna19.data import dataset import albumentations.pytorch from torch.utils.data import DataLoader import matplotlib.pyplot as plt wso = WSO() dataset_valid = dataset.IntracranialDataset( csv_file='5fold.csv', folds=[0], preprocess_func=albumentations.pytorch.ToTensorV2(), ) batch_size = 2 data_loader = DataLoader(dataset_valid, shuffle=False, num_workers=16, batch_size=batch_size) for data in data_loader: img = data['image'].float().cpu() windowed_img = wso(img).detach().numpy() fig, ax = plt.subplots(4, 1) for batch in range(batch_size): for j in range(4): # for k in range(4): ax[j].imshow(windowed_img[batch, j], cmap='gray') plt.show()
def check_heatmap(model_name, fold, epoch, run=None): model_str = build_model_str(model_name, fold, run) model_info = MODELS[model_name] checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}' print('\n', model_name, '\n') model = model_info.factory(**model_info.args) model = model.cpu() dataset_valid = dataset.IntracranialDataset( csv_file='5fold.csv', folds=[fold], preprocess_func=albumentations.pytorch.ToTensorV2(), **model_info.dataset_args) model.eval() checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt') model.load_state_dict(checkpoint['model_state_dict']) model = model.cpu() batch_size = 1 data_loader = DataLoader(dataset_valid, shuffle=False, num_workers=16, batch_size=batch_size) data_iter = tqdm(enumerate(data_loader), total=len(data_loader)) for iter_num, data in data_iter: img = data['image'].float().cpu() labels = data['labels'].detach().numpy() with torch.set_grad_enabled(False): pred2d, heatmap, pred = model(img, output_heatmap=True, output_per_pixel=True) heatmap *= np.prod(heatmap.shape[1:]) pred2d = (pred2d[0]).detach().cpu().numpy() * 0.1 fig, ax = plt.subplots(2, 4) for i in range(batch_size): print(labels[i], torch.sigmoid(pred[i])) ax[0, 0].imshow(img[i, 0].cpu().detach().numpy(), cmap='gray') ax[0, 1].imshow(heatmap[i, 0].cpu().detach().numpy(), cmap='gray') ax[0, 2].imshow(pred2d[0], cmap='gray', vmin=0, vmax=1) ax[0, 3].imshow(pred2d[1], cmap='gray', vmin=0, vmax=1) ax[1, 0].imshow(pred2d[2], cmap='gray', vmin=0, vmax=1) ax[1, 1].imshow(pred2d[3], cmap='gray', vmin=0, vmax=1) ax[1, 2].imshow(pred2d[4], cmap='gray', vmin=0, vmax=1) ax[1, 3].imshow(pred2d[5], cmap='gray', vmin=0, vmax=1) plt.show()
def check_windows(model_name, fold, epoch, run=None): model_str = build_model_str(model_name, fold, run) model_info = MODELS[model_name] checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}' print('\n', model_name, '\n') model = model_info.factory(**model_info.args) model = model.cpu() dataset_valid = dataset.IntracranialDataset( csv_file='5fold.csv', folds=[fold], preprocess_func=albumentations.pytorch.ToTensorV2(), **model_info.dataset_args) model.eval() checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt') model.load_state_dict(checkpoint['model_state_dict']) model = model.cpu() w = model.windows_conv.weight.detach().cpu().numpy().flatten() b = model.windows_conv.bias.detach().cpu().numpy() print(w, b) for wi, bi in zip(w, b): print(f'{-int(bi/wi*1000)} +- {int(abs(1000/wi))}') batch_size = 1 data_loader = DataLoader(dataset_valid, shuffle=False, num_workers=16, batch_size=batch_size) data_iter = tqdm(enumerate(data_loader), total=len(data_loader)) for iter_num, data in data_iter: img = data['image'].float().cpu() labels = data['labels'].detach().numpy() with torch.set_grad_enabled(False): windowed_img = model.windows_conv(img) windowed_img = F.relu6(windowed_img).cpu().numpy() fig, ax = plt.subplots(4, 4) for batch in range(batch_size): print(labels[batch], data['path'][batch]) for j in range(4): for k in range(4): ax[j, k].imshow(windowed_img[batch, j * 4 + k], cmap='gray') plt.show()
def predict(model_name, fold, epoch, is_test, df_out_path, mode='normal', run=None): model_str = build_model_str(model_name, fold, run) model_info = MODELS[model_name] checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}' print('\n', model_name, '\n') model = model_info.factory(**model_info.args) model.output_segmentation = False preprocess_func = [] if 'h_flip' in mode: preprocess_func.append( albumentations.HorizontalFlip(always_apply=True)) if 'v_flip' in mode: preprocess_func.append(albumentations.VerticalFlip(always_apply=True)) if 'rot90' in mode: preprocess_func.append(Rotate90(always_apply=True)) dataset_valid = dataset.IntracranialDataset( csv_file='test2.csv' if is_test else '5fold.csv', folds=[fold], preprocess_func=albumentations.Compose(preprocess_func), return_labels=not is_test, is_test=is_test, **{ **model_info.dataset_args, "add_segmentation_masks": False, "segmentation_oversample": 1 }) model.eval() print(f'load {checkpoints_dir}/{epoch:03}.pt') checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt') model.load_state_dict(checkpoint['model_state_dict']) model = model.cuda() data_loader = DataLoader(dataset_valid, shuffle=False, num_workers=8, batch_size=model_info.batch_size * 2) all_paths = [] all_study_id = [] all_slice_num = [] all_gt = [] all_pred = [] data_iter = tqdm(enumerate(data_loader), total=len(data_loader)) for iter_num, batch in data_iter: with torch.set_grad_enabled(False): y_hat = torch.sigmoid(model(batch['image'].float().cuda())) all_pred.append(y_hat.cpu().numpy()) all_paths.extend(batch['path']) all_study_id.extend(batch['study_id']) all_slice_num.extend(batch['slice_num'].cpu().numpy()) if not is_test: y = batch['labels'] all_gt.append(y.numpy()) pred_columns = [ 'pred_epidural', 'pred_intraparenchymal', 'pred_intraventricular', 'pred_subarachnoid', 'pred_subdural', 'pred_any' ] gt_columns = [ 'gt_epidural', 'gt_intraparenchymal', 'gt_intraventricular', 'gt_subarachnoid', 'gt_subdural', 'gt_any' ] if is_test: all_pred = np.concatenate(all_pred) df = pd.DataFrame(all_pred, columns=pred_columns) else: all_pred = np.concatenate(all_pred) all_gt = np.concatenate(all_gt) df = pd.DataFrame(np.hstack((all_gt, all_pred)), columns=gt_columns + pred_columns) df = pd.concat((df, pd.DataFrame({ 'path': all_paths, 'study_id': all_study_id, 'slice_num': all_slice_num })), axis=1) df.to_csv(df_out_path, index=False)
def train(model_name, fold, run=None, resume_epoch=-1, use_apex=False): model_str = build_model_str(model_name, fold, run) model_info = MODELS[model_name] checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}' tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}' oof_dir = f'{BaseConfig.oof_dir}/{model_str}' os.makedirs(checkpoints_dir, exist_ok=True) os.makedirs(tensorboard_dir, exist_ok=True) os.makedirs(oof_dir, exist_ok=True) print('\n', model_name, '\n') logger = SummaryWriter(log_dir=tensorboard_dir) model = model_info.factory(**model_info.args) model = model.cuda() # try: # torchsummary.summary(model, (4, 512, 512)) # print('\n', model_name, '\n') # except: # raise # pass # model = torch.nn.DataParallel(model).cuda() model = model.cuda() augmentations = [ albumentations.ShiftScaleRotate(shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE, p=0.80), ] if model_info.use_vflip: augmentations += [ albumentations.Flip(), albumentations.RandomRotate90() ] else: augmentations += [albumentations.HorizontalFlip()] dataset_train = dataset.IntracranialDataset( csv_file='5fold-test-rev3.csv', folds=[f for f in range(BaseConfig.nb_folds) if f != fold], preprocess_func=albumentations.Compose(augmentations), **model_info.dataset_args) dataset_valid = dataset.IntracranialDataset(csv_file='5fold-test-rev3.csv', folds=[fold], preprocess_func=None, **model_info.dataset_args) data_loaders = { 'train': DataLoader(dataset_train, num_workers=8, shuffle=True, batch_size=model_info.batch_size), 'val': DataLoader(dataset_valid, shuffle=False, num_workers=8, batch_size=model_info.batch_size) } if model_info.single_slice_steps > 0: augmentations = [ albumentations.ShiftScaleRotate(shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE, p=0.80), ] if model_info.use_vflip: augmentations += [ albumentations.Flip(), albumentations.RandomRotate90() ] else: augmentations += [albumentations.HorizontalFlip()] dataset_train_1_slice = dataset.IntracranialDataset( csv_file='5fold-test-rev3.csv', folds=[f for f in range(BaseConfig.nb_folds) if f != fold], preprocess_func=albumentations.Compose(augmentations), **{ **model_info.dataset_args, "num_slices": 1 }) dataset_valid_1_slice = dataset.IntracranialDataset( csv_file='5fold-test-rev3.csv', folds=[fold], preprocess_func=None, **{ **model_info.dataset_args, "num_slices": 1 }) data_loaders['train_1_slice'] = DataLoader( dataset_train_1_slice, num_workers=8, shuffle=True, batch_size=model_info.batch_size * 2) data_loaders['val_1_slice'] = DataLoader( dataset_valid_1_slice, shuffle=False, num_workers=8, batch_size=model_info.batch_size * 2) model.train() class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda() def criterium(y_pred, y_true): return F.binary_cross_entropy_with_logits( y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1)) # fit the new layers first: if resume_epoch == -1 and model_info.is_pretrained: model.train() model.freeze_encoder() data_loader = data_loaders.get('train_1_slice', data_loaders['train']) pre_fit_steps = 40000 // model_info.batch_size data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps) epoch_loss = [] initial_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for iter_num, data in data_iter: if iter_num > pre_fit_steps: break with torch.set_grad_enabled(True): img = data['image'].float().cuda() labels = data['labels'].cuda() pred = model(img) loss = criterium(pred, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0) initial_optimizer.step() initial_optimizer.zero_grad() epoch_loss.append(float(loss)) data_iter.set_description( f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' ) model.unfreeze_encoder() optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr) if use_apex: model, optimizer = amp.initialize(model, optimizer, opt_level='O2') milestones = [5, 10, 16] if model_info.optimiser_milestones: milestones = model_info.optimiser_milestones scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2) print( f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}' ) if resume_epoch > -1: checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt') print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'amp' in checkpoint: amp.load_state_dict(checkpoint['amp']) for epoch_num in range(resume_epoch + 1, 7): for phase in ['train', 'val']: model.train(phase == 'train') epoch_loss = [] epoch_labels = [] epoch_predictions = [] epoch_sample_paths = [] if 'on_epoch' in model.__dir__(): model.on_epoch(epoch_num) if epoch_num < model_info.single_slice_steps: data_loader = data_loaders[phase + '_1_slice'] print("use 1 slice input") else: data_loader = data_loaders[phase] print("use N slices input") # if epoch_num == model_info.single_slice_steps: # print("train only conv slices/fn layers") # model.module.freeze_encoder_full() # # if epoch_num == model_info.single_slice_steps+1: # print("train all") # model.module.unfreeze_encoder() # # if -1 < model_info.freeze_bn_step <= epoch_num: # print("freeze bn") # model.module.freeze_bn() data_iter = tqdm(enumerate(data_loader), total=len(data_loader), ncols=200) for iter_num, data in data_iter: img = data['image'].float().cuda() labels = data['labels'].float().cuda() with torch.set_grad_enabled(phase == 'train'): # if epoch_num == model_info.single_slice_steps and phase == 'train': # with torch.set_grad_enabled(False): # model_x = model(img, output_before_combine_slices=True) # with torch.set_grad_enabled(True): # pred = model(model_x.detach(), train_last_layers_only=True) # else: pred = model(img) loss = criterium(pred, labels) if phase == 'train': if use_apex: with amp.scale_loss( loss / model_info.accumulation_steps, optimizer) as scaled_loss: scaled_loss.backward() else: (loss / model_info.accumulation_steps).backward() if (iter_num + 1) % model_info.accumulation_steps == 0: # if not use_apex: # torch.nn.utils.clip_grad_norm_(model.parameters(), 32.0) optimizer.step() optimizer.zero_grad() epoch_loss.append(float(loss)) epoch_labels.append(labels.detach().cpu().numpy()) epoch_predictions.append( torch.sigmoid(pred).detach().cpu().numpy()) epoch_sample_paths += data['path'] data_iter.set_description( f'{epoch_num} Loss: Running {np.mean(epoch_loss[-1000:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' ) logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num) logger.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch_num) # scheduler.get_lr()[0] try: epoch_labels = np.row_stack(epoch_labels) epoch_predictions = np.row_stack(epoch_predictions) print(epoch_labels.shape, epoch_predictions.shape) log_metrics(logger=logger, phase=phase, epoch_num=epoch_num, y=epoch_labels, y_hat=epoch_predictions) except Exception: pass logger.flush() if phase == 'val': scheduler.step(epoch=epoch_num) torch.save( { 'epoch': epoch_num, 'sample_paths': epoch_sample_paths, 'epoch_labels': epoch_labels, 'epoch_predictions': epoch_predictions, }, f'{oof_dir}/{epoch_num:03}.pt') else: # print(f'{checkpoints_dir}/{epoch_num:03}.pt') torch.save( { 'epoch': epoch_num, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'amp': amp.state_dict() }, f'{checkpoints_dir}/{epoch_num:03}.pt')
def train(model_name, fold, run=None, resume_epoch=-1): model_str = build_model_str(model_name, fold, run) model_info = MODELS[model_name] checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}' tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}' oof_dir = f'{BaseConfig.oof_dir}/{model_str}' os.makedirs(checkpoints_dir, exist_ok=True) os.makedirs(tensorboard_dir, exist_ok=True) os.makedirs(oof_dir, exist_ok=True) print('\n', model_name, '\n') logger = SummaryWriter(log_dir=tensorboard_dir) model = model_info.factory(**model_info.args) model = model.cuda() # try: # torchsummary.summary(model, (4, 512, 512)) # print('\n', model_name, '\n') # except: # raise # pass model = torch.nn.DataParallel(model).cuda() model = model.cuda() dataset_train = dataset.IntracranialDataset( csv_file='5fold-rev3.csv', folds=[f for f in range(BaseConfig.nb_folds) if f != fold], preprocess_func=albumentations.Compose([ albumentations.ShiftScaleRotate(shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE, p=0.7), albumentations.Flip(), albumentations.RandomRotate90(), ]), **{ **model_info.dataset_args, "segmentation_oversample": 1 }) dataset_valid = dataset.IntracranialDataset( csv_file='5fold.csv', folds=[fold], preprocess_func=None, **{ **model_info.dataset_args, "segmentation_oversample": 1 }) data_loaders = { 'train': DataLoader(dataset_train, num_workers=8, shuffle=True, batch_size=model_info.batch_size), 'val': DataLoader(dataset_valid, shuffle=False, num_workers=8, batch_size=model_info.batch_size) } dataset_train_1_slice = None if model_info.single_slice_steps > 0: dataset_train_1_slice = dataset.IntracranialDataset( csv_file='5fold-rev3.csv', folds=[f for f in range(BaseConfig.nb_folds) if f != fold], preprocess_func=albumentations.Compose([ albumentations.ShiftScaleRotate( shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE, p=0.75), albumentations.Flip(), albumentations.RandomRotate90() ]), **{ **model_info.dataset_args, "num_slices": 1 }) dataset_valid_1_slice = dataset.IntracranialDataset( csv_file='5fold.csv', folds=[fold], preprocess_func=None, **{ **model_info.dataset_args, "num_slices": 1, "segmentation_oversample": 1 }) data_loaders['train_1_slice'] = DataLoader( dataset_train_1_slice, num_workers=8, shuffle=True, batch_size=model_info.batch_size * 2) data_loaders['val_1_slice'] = DataLoader( dataset_valid_1_slice, shuffle=False, num_workers=8, batch_size=model_info.batch_size * 2) model.train() optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr) milestones = [5, 10, 16] if model_info.optimiser_milestones: milestones = model_info.optimiser_milestones scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2) print( f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}' ) if resume_epoch > -1: checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt') print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt') model.module.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda() def criterium(y_pred, y_true): return F.binary_cross_entropy_with_logits( y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1)) def criterium_mask(y_pred, y_true, have_segmentation): if not max(have_segmentation): return 0 return F.binary_cross_entropy(y_pred[have_segmentation], y_true[have_segmentation]) * 10 # criterium = nn.BCEWithLogitsLoss() # fit new layers first: if resume_epoch == -1 and model_info.is_pretrained: model.train() model.module.freeze_encoder() data_loader = data_loaders.get('train_1_slice', data_loaders['train']) pre_fit_steps = 50000 // model_info.batch_size data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps) epoch_loss = [] epoch_loss_mask = [] initial_optimizer = radam.RAdam(model.parameters(), lr=1e-4) for iter_num, data in data_iter: if iter_num > pre_fit_steps: break with torch.set_grad_enabled(True): img = data['image'].float().cuda() labels = data['labels'].cuda() segmentation_labels = data['seg'].cuda() have_segmentation = data['have_segmentation'] have_any_segmentation = max(have_segmentation) pred, segmentation = model(img) loss_cls = criterium(pred, labels) loss_mask = criterium_mask( segmentation, F.max_pool2d(segmentation_labels, 4), have_segmentation) (loss_cls + loss_mask).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0) initial_optimizer.step() initial_optimizer.zero_grad() epoch_loss.append(float(loss_cls)) if have_any_segmentation: epoch_loss_mask.append(float(loss_mask)) data_iter.set_description( f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' + f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}' ) model.module.unfreeze_encoder() for epoch_num in range(resume_epoch + 1, 8): if epoch_num > 3 and dataset_train_1_slice is not None: dataset_train_1_slice.segmentation_oversample = 1 for phase in ['train', 'val']: model.train(phase == 'train') epoch_loss = [] epoch_loss_mask = [] epoch_labels = [] epoch_predictions = [] epoch_sample_paths = [] if 'on_epoch' in model.module.__dir__(): model.module.on_epoch(epoch_num) if epoch_num < model_info.single_slice_steps: data_loader = data_loaders[phase + '_1_slice'] print("use 1 slice input") else: data_loader = data_loaders[phase] print("use N slices input") # if epoch_num == model_info.single_slice_steps: # print("train only conv slices/fn layers") # model.module.freeze_encoder_full() # # if epoch_num == model_info.single_slice_steps+1: # print("train all") # model.module.unfreeze_encoder() # # if -1 < model_info.freeze_bn_step <= epoch_num: # print("freeze bn") # model.module.freeze_bn() data_iter = tqdm(enumerate(data_loader), total=len(data_loader)) for iter_num, data in data_iter: img = data['image'].float().cuda() labels = data['labels'].float().cuda() segmentation_labels = data['seg'].cuda() have_segmentation = data['have_segmentation'] have_any_segmentation = max(have_segmentation) with torch.set_grad_enabled(phase == 'train'): pred, segmentation = model(img) loss_cls = criterium(pred, labels) loss_mask = criterium_mask( segmentation, F.max_pool2d(segmentation_labels, 4), have_segmentation) if phase == 'train': ((loss_cls + loss_mask) / model_info.accumulation_steps).backward() if (iter_num + 1) % model_info.accumulation_steps == 0: torch.nn.utils.clip_grad_norm_( model.parameters(), 16.0) optimizer.step() optimizer.zero_grad() epoch_loss.append(float(loss_cls)) if have_any_segmentation: epoch_loss_mask.append(float(loss_mask)) epoch_labels.append(labels.detach().cpu().numpy()) epoch_predictions.append( torch.sigmoid(pred).detach().cpu().numpy()) epoch_sample_paths += data['path'] data_iter.set_description( f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' + f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}' ) epoch_labels = np.row_stack(epoch_labels) epoch_predictions = np.row_stack(epoch_predictions) logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num) logger.add_scalar(f'loss_mask_{phase}', np.mean(epoch_loss_mask), epoch_num) logger.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch_num) # scheduler.get_lr()[0] try: log_metrics(logger=logger, phase=phase, epoch_num=epoch_num, y=epoch_labels, y_hat=epoch_predictions) except Exception: pass logger.flush() if phase == 'val': scheduler.step(epoch=epoch_num) torch.save( { 'epoch': epoch_num, 'sample_paths': epoch_sample_paths, 'epoch_labels': epoch_labels, 'epoch_predictions': epoch_predictions, }, f'{oof_dir}/{epoch_num:03}.pt') else: # print(f'{checkpoints_dir}/{epoch_num:03}.pt') torch.save( { 'epoch': epoch_num, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'{checkpoints_dir}/{epoch_num:03}.pt')