def main(data_path, batch_size, lr, n_epoch, kernel_size, filter_start, sh_degree, depth, n_side, rf_name, wm, gm, csf, loss_fn_intensity, loss_fn_non_negativity, loss_fn_sparsity, sigma_sparsity, intensity_weight, nn_fodf_weight, sparsity_weight, save_path, save_every, normalize, load_state): """Train a model Args: data_path (str): Data path batch_size (int): Batch size lr (float): Learning rate n_epoch (int): Number of training epoch kernel_size (int): Kernel Size filter_start (int): Number of output features of the first convolution layer sh_degree (int): Spherical harmonic degree of the fODF depth (int): Graph subsample depth n_side (int): Resolution of the Healpix map rf_name (str): Response function algorithm name wm (float): Use white matter gm (float): Use gray matter csf (float): Use CSF loss_fn_intensity (str): Name of the intensity loss loss_fn_non_negativity (str): Name of the nn loss loss_fn_sparsity (str): Name of the sparsity loss intensity_weight (float): Weight of the intensity loss nn_fodf_weight (float): Weight of the nn loss sparsity_weight (float): Weight of the sparsity loss save_path (str): Save path save_every (int): Frequency to save the model normalize (bool): Normalize the fODFs load_state (str): Load pre trained network """ # Load the shell and the graph samplings shellSampling = ShellSampling(f'{data_path}/bvecs.bvecs', f'{data_path}/bvals.bvals', sh_degree=sh_degree, max_sh_degree=8) graphSampling = HealpixSampling(n_side, depth, sh_degree=sh_degree) # Load the image and the mask dataset = DMRIDataset(f'{data_path}/features.nii', f'{data_path}/mask.nii') dataloader_train = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) n_batch = len(dataloader_train) # Load the Polar filter used for the deconvolution polar_filter_equi, polar_filter_inva = load_response_function( f'{data_path}/response_functions/{rf_name}', wm=wm, gm=gm, csf=csf, max_degree=sh_degree, n_shell=len(shellSampling.shell_values)) # Create the deconvolution model model = Model(polar_filter_equi, polar_filter_inva, shellSampling, graphSampling, filter_start, kernel_size, normalize) if load_state: print(load_state) model.load_state_dict(torch.load(load_state), strict=False) # Load model in GPU model = model.to(DEVICE) torch.save(model.state_dict(), os.path.join(save_path, 'history', 'epoch_0.pth')) # Loss intensity_criterion = Loss(loss_fn_intensity) non_negativity_criterion = Loss(loss_fn_non_negativity) sparsity_criterion = Loss(loss_fn_sparsity, sigma_sparsity) # Create dense interpolation used for the non-negativity and the sparsity losses denseGrid_interpolate = ComputeSignal( torch.Tensor(graphSampling.sampling.SH2S)) denseGrid_interpolate = denseGrid_interpolate.to(DEVICE) # Optimizer/Scheduler optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = ReduceLROnPlateau(optimizer, threshold=0.01, factor=0.1, patience=3, verbose=True) save_loss = {} save_loss['train'] = {} writer = SummaryWriter(log_dir=os.path.join(data_path, 'result', 'run', save_path.split('/')[-1])) tb_j = 0 # Training loop for epoch in range(n_epoch): # TRAIN model.train() # Initialize loss to save and plot. loss_intensity_ = 0 loss_sparsity_ = 0 loss_non_negativity_fodf_ = 0 # Train on batch. for batch, data in enumerate(dataloader_train): # Delete all previous gradients optimizer.zero_grad() to_print = '' # Load the data in the DEVICE input = data['input'].to(DEVICE) output = data['output'].to(DEVICE) mask = data['mask'].to(DEVICE) x_reconstructed, x_deconvolved_equi_shc, x_deconvolved_inva_shc = model( input) ############################################################################################### ############################################################################################### # Loss ############################################################################################### ############################################################################################### # Intensity loss loss_intensity = intensity_criterion(x_reconstructed, output, mask) loss_intensity_ += loss_intensity.item() loss = intensity_weight * loss_intensity to_print += ', Intensity: {0:.10f}'.format(loss_intensity.item()) if not x_deconvolved_equi_shc is None: x_deconvolved_equi = denseGrid_interpolate( x_deconvolved_equi_shc) ############################################################################################### # Sparsity loss equi_sparse = torch.zeros(x_deconvolved_equi.shape).to(DEVICE) loss_sparsity = sparsity_criterion(x_deconvolved_equi, equi_sparse, mask) loss_sparsity_ += loss_sparsity.item() loss += sparsity_weight * loss_sparsity to_print += ', Equi Sparsity: {0:.10f}'.format( loss_sparsity.item()) ############################################################################################### # Non negativity loss fodf_neg = torch.min(x_deconvolved_equi, torch.zeros_like(x_deconvolved_equi)) fodf_neg_zeros = torch.zeros(fodf_neg.shape).to(DEVICE) loss_non_negativity_fodf = non_negativity_criterion( fodf_neg, fodf_neg_zeros, mask) loss_non_negativity_fodf_ += loss_non_negativity_fodf.item() loss += nn_fodf_weight * loss_non_negativity_fodf to_print += ', Equi NN: {0:.10f}'.format( loss_non_negativity_fodf.item()) ############################################################################################### # Partial volume regularizer regularizer_equi = 0.00001 * 1 / torch.mean( x_deconvolved_equi_shc[mask == 1][:, :, 0]) * np.sqrt( 4 * np.pi) loss += regularizer_equi to_print += ', Equi regularizer: {0:.10f}'.format( regularizer_equi.item()) if not x_deconvolved_inva_shc is None: ############################################################################################### # Partial volume regularizer regularizer_inva = 0.00001 * 1 / torch.mean( x_deconvolved_inva_shc[mask == 1][:, :, 0]) * np.sqrt( 4 * np.pi) loss += regularizer_inva to_print += ', Inva regularizer: {0:.10f}'.format( regularizer_inva.item()) ############################################################################################### # Tensorboard tb_j += 1 writer.add_scalar('Batch/train_intensity', loss_intensity.item(), tb_j) writer.add_scalar('Batch/train_sparsity', loss_sparsity.item(), tb_j) writer.add_scalar('Batch/train_nn', loss_non_negativity_fodf.item(), tb_j) writer.add_scalar('Batch/train_total', loss.item(), tb_j) ############################################################################################### # To print loss to_print = 'Epoch [{0}/{1}], Iter [{2}/{3}]: Loss: {4:.10f}'.format(epoch + 1, n_epoch, batch + 1, n_batch, loss.item()) \ + to_print print(to_print, end="\r") ############################################################################################### # Loss backward loss = loss loss.backward() optimizer.step() if (batch + 1) % 500 == 0: torch.save( model.state_dict(), os.path.join(save_path, 'history', 'epoch_{0}.pth'.format(epoch + 1))) ############################################################################################### # Save and print mean loss for the epoch print("") to_print = '' loss_ = 0 # Mean results of the last epoch save_loss['train'][epoch] = {} save_loss['train'][epoch]['loss_intensity'] = loss_intensity_ / n_batch save_loss['train'][epoch]['weight_loss_intensity'] = intensity_weight loss_ += intensity_weight * loss_intensity_ to_print += ', Intensity: {0:.10f}'.format(loss_intensity_ / n_batch) save_loss['train'][epoch]['loss_sparsity'] = loss_sparsity_ / n_batch save_loss['train'][epoch]['weight_loss_sparsity'] = sparsity_weight loss_ += sparsity_weight * loss_sparsity_ to_print += ', Sparsity: {0:.10f}'.format(loss_sparsity_ / n_batch) save_loss['train'][epoch][ 'loss_non_negativity_fodf'] = loss_non_negativity_fodf_ / n_batch save_loss['train'][epoch][ 'weight_loss_non_negativity_fodf'] = nn_fodf_weight loss_ += nn_fodf_weight * loss_non_negativity_fodf_ to_print += ', WM fODF NN: {0:.10f}'.format(loss_non_negativity_fodf_ / n_batch) save_loss['train'][epoch]['loss'] = loss_ / n_batch to_print = 'Epoch [{0}/{1}], Train Loss: {2:.10f}'.format( epoch + 1, n_epoch, loss_ / n_batch) + to_print print(to_print) writer.add_scalar('Epoch/train_intensity', loss_intensity_ / n_batch, epoch) writer.add_scalar('Epoch/train_sparsity', loss_sparsity_ / n_batch, epoch) writer.add_scalar('Epoch/train_nn', loss_non_negativity_fodf_ / n_batch, epoch) writer.add_scalar('Epoch/train_total', loss_ / n_batch, epoch) ############################################################################################### # VALIDATION scheduler.step(loss_ / n_batch) if epoch == 0: min_loss = loss_ epochs_no_improve = 0 n_epochs_stop = 5 early_stop = False elif loss_ < min_loss * 0.999: epochs_no_improve = 0 min_loss = loss_ else: epochs_no_improve += 1 if epoch > 5 and epochs_no_improve == n_epochs_stop: print('Early stopping!') early_stop = True ############################################################################################### # Save the loss and model with open(os.path.join(save_path, 'history', 'loss.pkl'), 'wb') as f: pickle.dump(save_loss, f) if (epoch + 1) % save_every == 0: torch.save( model.state_dict(), os.path.join(save_path, 'history', 'epoch_{0}.pth'.format(epoch + 1))) if early_stop: print("Stopped") break
def __init__(self, batch_size=32, optimizer_name="Adam", lr=1e-3, weight_decay=1e-5, epochs=200, model_name="model01", gpu_ids=None, resume=None, tqdm=None): """ args: batch_size = (int) batch_size of training and validation lr = (float) learning rate of optimization weight_decay = (float) weight decay of optimization epochs = (int) The number of epochs of training model_name = (string) The name of training model. Will be folder name. gpu_ids = (List) List of gpu_ids. (e.g. gpu_ids = [0, 1]). Use CPU, if it is None. resume = (Dict) Dict of some settings. (resume = {"checkpoint_path":PATH_of_checkpoint, "fine_tuning":True or False}). Learn from scratch, if it is None. tqdm = (tqdm Object) progress bar object. Set your tqdm please. Don't view progress bar, if it is None. """ # Set params self.batch_size = batch_size self.epochs = epochs self.start_epoch = 0 self.use_cuda = (gpu_ids is not None) and torch.cuda.is_available self.tqdm = tqdm self.use_tqdm = tqdm is not None # ------------------------- # # Define Utils. (No need to Change.) """ These are Project Modules. You may not have to change these. Saver: Save model weight. / <utils.saver.Saver()> TensorboardSummary: Write tensorboard file. / <utils.summaries.TensorboardSummary()> Evaluator: Calculate some metrics (e.g. Accuracy). / <utils.metrics.Evaluator()> """ ## ***Define Saver*** self.saver = Saver(model_name, lr, epochs) self.saver.save_experiment_config() ## ***Define Tensorboard Summary*** self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # ------------------------- # # Define Training components. (You have to Change!) """ These are important setting for training. You have to change these. make_data_loader: This creates some <Dataloader>s. / <dataloader.__init__> Modeling: You have to define your Model. / <modeling.modeling.Modeling()> Evaluator: You have to define Evaluator. / <utils.metrics.Evaluator()> Optimizer: You have to define Optimizer. / <utils.optimizer.Optimizer()> Loss: You have to define Loss function. / <utils.loss.Loss()> """ ## ***Define Dataloader*** self.train_loader, self.val_loader, self.test_loader, self.num_classes = make_data_loader(batch_size) ## ***Define Your Model*** self.model = Modeling(self.num_classes) ## ***Define Evaluator*** self.evaluator = Evaluator(self.num_classes) ## ***Define Optimizer*** self.optimizer = Optimizer(self.model.parameters(), optimizer_name=optimizer_name, lr=lr, weight_decay=weight_decay) ## ***Define Loss*** self.criterion = Loss() # ------------------------- # # Some settings """ You don't have to touch bellow code. Using cuda: Enable to use cuda if you want. Resuming checkpoint: You can resume training if you want. """ ## ***Using cuda*** if self.use_cuda: self.model = torch.nn.DataParallel(self.model, device_ids=gpu_ids).cuda() ## ***Resuming checkpoint*** """You can ignore bellow code.""" self.best_pred = 0.0 if resume is not None: if not os.path.isfile(resume["checkpoint_path"]): raise RuntimeError("=> no checkpoint found at '{}'" .format(resume["checkpoint_path"])) checkpoint = torch.load(resume["checkpoint_path"]) self.start_epoch = checkpoint['epoch'] if self.use_cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if resume["fine_tuning"]: # resume params of optimizer, if run fine tuning. self.optimizer.load_state_dict(checkpoint['optimizer']) self.start_epoch = 0 self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(resume["checkpoint_path"], checkpoint['epoch']))
def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval, checkpoint, eval_interval, test_img_path, submit_path): file_num = len(os.listdir(train_img_path)) trainset = custom_dataset(train_img_path, train_gt_path) train_loader = data.DataLoader(trainset, batch_size=batch_size, \ shuffle = True, num_workers=num_workers, drop_last=True) criterion = Loss() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = EAST(pretrained=False) if checkpoint: model.load_state_dict(torch.load(checkpoint)) data_parallel = False if torch.cuda.device_count() > 1: model = nn.DataParallel(model) # model = DataParallelModel(model) data_parallel = True model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9,weight_decay=0) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter // 2], gamma=0.1) whole_number = epoch_iter * (len(trainset) / batch_size) print("epoch size:%d" % (epoch_iter)) print("batch size:%d" % (batch_size)) print("data number:%d" % (len(trainset))) all_loss = [] current_i = 0 for epoch in range(epoch_iter): model.train() epoch_loss = 0 epoch_time = time.time() for i, (img, gt_score, gt_geo, ignored_map, _) in enumerate(train_loader): current_i += 1 start_time = time.time() img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to( device), gt_geo.to(device), ignored_map.to(device) pred_score, pred_geo = model(img) loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map) epoch_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() lr_now = scheduler.get_last_lr() progress_bar(40, loss.item(), current_i, whole_number, lr_now[0]) scheduler.step() print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format( epoch_loss / int(file_num / batch_size), time.time() - epoch_time)) all_loss.append(epoch_loss / int(file_num / batch_size)) print(time.asctime(time.localtime(time.time()))) plt.plot(all_loss) plt.savefig('loss_landscape.png') plt.close() print('=' * 50) if (epoch + 1) % interval == 0: state_dict = model.module.state_dict( ) if data_parallel else model.state_dict() torch.save( state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch + 1))) output = open(os.path.join(pths_path, 'loss.pkl'), 'wb') pkl.dump(all_loss, output)
from utils.loss import Loss from utils.dataset import Yolo_dataset from config import cfg from utils.test_mAP import evaluate if __name__ == "__main__": scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True anchors = np.array(cfg.anchors).reshape([-1, 2]) model = YOLOv4(cfg.cfgfile).to(cfg.device) # model.load_weights('yolov4.weights') # model.load_state_dict(torch.load("weights/Epoch50-Total_Loss2.2437.pth")) vis = visdom.Visdom(env='YOLOv4') yolo_losses = [] for i in range(3): yolo_losses.append(Loss(i)) optimizer = optim.Adam(model.parameters(), cfg.lr) if cfg.Cosine_lr: lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5) else: lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.96) train_dataset = Yolo_dataset(train=True) val_dataset = Yolo_dataset(train=False) train_loader = DataLoader(train_dataset, shuffle=True,
def main(): global args os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu __normalize = {'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)} TrainImgLoader = torch.utils.data.DataLoader(DA(args.datapath, split='train', normalize=__normalize), batch_size=args.train_bsize, shuffle=True, num_workers=1, drop_last=False) ValImgLoader = torch.utils.data.DataLoader(DA(args.datapath, split='val', normalize=__normalize), batch_size=args.test_bsize, shuffle=False, num_workers=1, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA(args.datapath, split='test', normalize=__normalize), batch_size=args.test_bsize, shuffle=False, num_workers=1, drop_last=False) if not os.path.isdir(args.save_path): os.makedirs(os.path.join(args.save_path, 'train')) os.makedirs(os.path.join(args.save_path, 'test')) os.makedirs(os.path.join(args.save_path, 'val')) log = logger.setup_logger(args.save_path + '/training.log') writer = logger.setup_tensorboard(args.save_path) for key, value in sorted(vars(args).items()): log.info(str(key) + ':' + str(value)) model = StereoNet(k=args.stages - 1, r=args.stages - 1, maxdisp=args.maxdisp) model = nn.DataParallel(model).cuda() model.apply(weights_init) criterion = Loss(args) optimizer = optim.RMSprop(model.parameters(), lr=args.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): log.info("=> loading checkpoint '{}'".format((args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> will start from scratch.") else: log.info("Not Resume") start_full_time = time.time() for epoch in range(args.start_epoch, args.epoch): log.info('This is {}-th epoch'.format(epoch)) train(TrainImgLoader, model, criterion, optimizer, log, writer, epoch) test(ValImgLoader, model, log, writer, 'val', epoch) savefilename = args.save_path + '/checkpoint.pth' torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, savefilename) scheduler.step() # will adjust learning rate test(TestImgLoader, model, log, writer, 'test', epoch) log.info('full training time = {: 2f} Hours'.format( (time.time() - start_full_time) / 3600))
schp.save_schp_checkpoint( { 'state_dict': schp_model.state_dict(), 'cycle_n': cycle_n, }, False, "checkpoints", filename= f'schp_{opts.model}_{opts.dataset}_cycle{cycle_n}_checkpoint.pth' ) # schp.save_schp_checkpoint({ # 'state_dict': schp_model.state_dict(), # 'cycle_n': cycle_n, # }, False, '/content/drive/MyDrive/', filename=f'schp_{opts.model}_{opts.dataset}_checkpoint.pth') torch.cuda.empty_cache() criterion.end_log(len(train_loader)) if __name__ == '__main__': opts = get_argparser().parse_args(args=[]) if 'ACE2P' in opts.model: opts.loss_type = 'SCP' opts.use_mixup = False os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_ids device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') opts.device = device print("Device: %s" % device) criterion = Loss(opts) main(criterion) criterion.plot_loss('/content/drive/MyDrive/', len(criterion.log))
def add_loss(self, decay=0.999): self.loss = Loss(decay)
def main(args, logger, summary): cudnn.enabled = True # Enables bencnmark mode in cudnn, to enable the inbuilt cudnn.benchmark = True # cudnn auto-tuner to find the best algorithm to use for # our hardware seed = random.randint(1, 10000) logger.info('======>random seed {}'.format(seed)) random.seed(seed) # python random seed np.random.seed(seed) # set numpy random seed torch.manual_seed(seed) # set random seed for cpu # train_set = VaiHinGen(root=args.root, split='trainl',outer_size=2*args.image_size,centre_size=args.image_size) # test_set = VaiHinGen(root=args.root, split='testl',outer_size=2*args.image_size,centre_size=args.image_size) train_set = SkmtDataSet(args, split='train') val_set = SkmtDataSet(args, split='val') kwargs = {'num_workers': args.workers, 'pin_memory': True} train_loader = DataLoader(train_set, batch_size=args.batch_size, drop_last=True, shuffle=False, **kwargs) test_loader = DataLoader(val_set, batch_size=1, drop_last=True, shuffle=False, **kwargs) logger.info('======> building network') # set model model = build_skmtnet(backbone='resnet50', auxiliary_head=args.auxiliary, trunk_head='deeplab', num_classes=args.num_classes, output_stride=16) logger.info("======> computing network parameters") total_paramters = netParams(model) logger.info("the number of parameters: " + str(total_paramters)) # setup optimizer optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) # setup savedir args.savedir = (args.savedir + '/' + args.model + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpus) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) # setup optimization criterion criterion = Loss(args) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # set random seed for all GPU os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus model = nn.DataParallel(model).cuda() criterion = criterion.cuda() start_epoch = 0 best_epoch = 0. best_overall = 0. best_mIoU = 0. best_F1 = 0. trainer = Trainer(args=args, dataloader=train_loader, model=model, optimizer=optimizer, criterion=criterion, logger=logger, summary=summary) tester = Tester(args=args, dataloader=test_loader, model=model, criterion=criterion, logger=logger, summary=summary) writer = summary.create_summary() for epoch in range(start_epoch, args.max_epochs): trainer.train_one_epoch(epoch, writer) if (epoch % args.show_val_interval == 0): score, class_iou, class_acc, class_F1 = tester.test_one_epoch( epoch, writer) logger.info('======>Now print overall info:') for k, v in score.items(): logger.info('======>{0:^18} {1:^10}'.format(k, v)) logger.info('======>Now print class acc') for k, v in class_acc.items(): print('{}: {:.5f}'.format(k, v)) logger.info('======>{0:^18} {1:^10}'.format(k, v)) logger.info('======>Now print class iou') for k, v in class_iou.items(): print('{}: {:.5f}'.format(k, v)) logger.info('======>{0:^18} {1:^10}'.format(k, v)) logger.info('======>Now print class_F1') for k, v in class_F1.items(): logger.info('======>{0:^18} {1:^10}'.format(k, v)) if score["Mean IoU(8) : \t"] > best_mIoU: best_mIoU = score["Mean IoU(8) : \t"] if score["Overall Acc : \t"] > best_overall: best_overall = score["Overall Acc : \t"] # save model in best overall Acc model_file_name = args.savedir + '/best_model.pth' torch.save(model.state_dict(), model_file_name) best_epoch = epoch if score["Mean F1 : \t"] > best_F1: best_F1 = score["Mean F1 : \t"] logger.info("======>best mean IoU:{}".format(best_mIoU)) logger.info("======>best overall : {}".format(best_overall)) logger.info("======>best F1: {}".format(best_F1)) logger.info("======>best epoch: {}".format(best_epoch)) # save the model model_file_name = args.savedir + '/model.pth' state = {"epoch": epoch + 1, "model": model.state_dict()} logger.info('======> Now begining to save model.') torch.save(state, model_file_name) logger.info('======> Save done.')