def cal_mae(img_root, model_param_path): ''' Calculate the MAE of the test data. img_root: the root of test image data. gt_dmap_root: the root of test ground truth density-map data. model_param_path: the path of specific mcnn parameters. ''' device = torch.device("cuda") model = CSRNet() model.load_state_dict(torch.load(model_param_path)) model.to(device) dataset = create_test_dataloader(img_root) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) model.eval() mae = 0 with torch.no_grad(): for i, data in enumerate(tqdm(dataloader)): image = data['image'].cuda() gt_densitymap = data['densitymap'].cuda() # forward propagation et_dmap = model(image) mae += abs(et_dmap.data.sum() - gt_densitymap.data.sum()).item() del image, gt_densitymap, et_dmap print("model_param_path:" + model_param_path + " mae:" + str(mae / len(dataloader)))
def count2(img_root, model_param_path): # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cpu") model = CSRNet() model.load_state_dict(torch.load(model_param_path)) model.to(device) test_dataloader = create_test_dataloader(img_root) # dataloader for i, data in enumerate(tqdm(test_dataloader, ncols=50)): image = data['image'].to(device) et_dmp = model(image).detach() # count = et_densitymap.data.sum() #count = str('%.2f' % (et_densitymap[0].cpu().sum())) # et_dmp = et_densitymap[0] / torch.max(et_densitymap[0]) et_dmp = et_dmp.numpy() et_dmp = et_dmp[0][0] count = np.sum(et_dmp) plt.figure(i) plt.axis("off") plt.imshow(et_dmp, cmap=CM.jet) # 去除坐标轴 plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) # 输出图片边框设置 plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) plt.margins(0, 0) plt.savefig(img_root + "/test_data/result/" + str(i + 1) + "_dmp" + ".jpg") print(str(i + 1) + "_" + "renshu:", count)
def cal_mae(img_root, gt_dmap_root, model_param_path): ''' Calculate the MAE of the test data. img_root: the root of test image data. gt_dmap_root: the root of test ground truth density-map data. model_param_path: the path of specific mcnn parameters. ''' model = CSRNet() model.load_state_dict(torch.load(model_param_path, map_location=cfg.device)) model.to(cfg.device) test_dataloader = create_test_dataloader(cfg.dataset_root) # dataloader model.eval() sum_mae = 0 with torch.no_grad(): for i, data in enumerate(tqdm(test_dataloader)): image = data['image'].to(cfg.device) gt_densitymap = data['densitymap'].to(cfg.device) # forward propagation et_densitymap = model(image).detach() mae = abs(et_densitymap.data.sum() - gt_densitymap.data.sum()) sum_mae += mae.item() # clear mem del i, data, image, gt_densitymap, et_densitymap torch.cuda.empty_cache() print("model_param_path:" + model_param_path + " mae:" + str(sum_mae / len(test_dataloader)))
def one_count(img_path, model_param_path): filename = img_path.split('/')[-1] filenum = filename.split('.')[0] save_path = './data_test/test_data/result/' + filenum if not os.path.exists(save_path): os.makedirs(save_path) img_src_save_path = save_path + '/' + filenum + '_src' + '.jpg' img_et_save_path = img_src_save_path.replace('src', 'et') img_overlap_save_path = img_src_save_path.replace('src', 'overlap') device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = CSRNet() model.load_state_dict(torch.load(model_param_path)) model.to(device) img_src = open_img(img_path) img = open_img(img_path) img_trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])]) img = img_trans(img) img = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False) # print('img_src size:', img.shape) et_dmap = model(img) et_dmap = et_dmap.detach().numpy() et_dmap = et_dmap[0][0] people_num = np.sum(et_dmap) # print(et_dmap.shape) print(filenum + '_num:', '\t', people_num) # img_src plt.figure(0) plt.imshow(img_src) plt.axis('off') # 去除坐标轴 plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) # 输出图片边框设置 plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) plt.margins(0, 0) plt.savefig(img_src_save_path, bbox_inches='tight', dpi=100, pad_inches=-0.04) # # # img_et plt.figure(1) plt.imshow(et_dmap, cmap=CM.jet) plt.axis('off') # 去除坐标轴 plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) # 输出图片边框设置 plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) plt.margins(0, 0) plt.savefig(img_et_save_path, bbox_inches='tight', dpi=100, pad_inches=-0.04) # img_src = cv2.imread(img_src_save_path) img_et = cv2.imread(img_et_save_path) img_et = cv2.resize(img_et, (img_src.shape[1], img_src.shape[0])) img_overlap = cv2.addWeighted(img_src, 0.2, img_et, 0.8, 0) cv2.imwrite(img_overlap_save_path, img_overlap) return people_num
def estimate_density_map_no_gt(img_root, gt_dmap_root, model_param_path, index): ''' Show one estimated density-map. img_root: the root of test image data. gt_dmap_root: the root of test ground truth density-map data. model_param_path: the path of specific mcnn parameters. index: the order of the test image in test dataset. ''' image_export_folder = 'export_images_extra' model = CSRNet() model.load_state_dict(torch.load(model_param_path, map_location=cfg.device)) model.to(cfg.device) test_dataloader = create_test_extra_dataloader( cfg.dataset_root) # dataloader model.eval() with torch.no_grad(): for i, data in enumerate(tqdm(test_dataloader)): image = data['image'].to(cfg.device) # gt_densitymap = data['densitymap'].to(cfg.device) # forward propagation et_densitymap = model(image).detach() pred_count = et_densitymap.data.sum().cpu() # actual_count = gt_densitymap.data.sum().cpu() actual_count = 999 et_densitymap = et_densitymap.squeeze(0).squeeze(0).cpu().numpy() # gt_densitymap = gt_densitymap.squeeze(0).squeeze(0).cpu().numpy() image = image[0].cpu() # denormalize(image[0].cpu()) print(et_densitymap.shape) # et is the estimated density plt.imshow(et_densitymap, cmap=CM.jet) plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder, str(i).zfill(3), str(int(pred_count)), str(int(actual_count)), 'etdm.png')) # # gt is the ground truth density # plt.imshow(gt_densitymap, cmap=CM.jet) # plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder, # str(i).zfill(3), # str(int(pred_count)), # str(int(actual_count)), 'gtdm.png')) # image plt.imshow(image.permute(1, 2, 0)) plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder, str(i).zfill(3), str(int(pred_count)), str(int(actual_count)), 'image.png')) # clear mem del i, data, image, et_densitymap, pred_count, actual_count torch.cuda.empty_cache()
def count1(img_root, model_param_path): device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") model = CSRNet() model.load_state_dict(torch.load(model_param_path)) model.to(device) test_dataloader = create_test_dataloader(img_root) # dataloader # 添加进度条 for i, data in enumerate(tqdm(test_dataloader, ncols=50)): image = data['image'].to(device) et_densitymap = model(image).detach() # count = et_densitymap.data.sum() count = str('%.2f' % (et_densitymap[0].cpu().sum()))
def main(): transform = ST.Compose( [ ST.ToNumpyForVal(), ST.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) global args args = parser.parse_args() model = CSRNet() model = model.to("cuda") # checkpoint = flow.load('checkpoint/Shanghai_BestModelA/shanghaiA_bestmodel') checkpoint = flow.load(args.modelPath) model.load_state_dict(checkpoint) img = transform(Image.open(args.picPath).convert("RGB")) img = flow.Tensor(img) img = img.to("cuda") output = model(img.unsqueeze(0)) print("Predicted Count : ", int(output.detach().to("cpu").sum().numpy())) temp = output.view(output.shape[2], output.shape[3]) temp = temp.numpy() plt.title("Predicted Count") plt.imshow(temp, cmap=c.jet) plt.show() temp = h5py.File(args.picDensity, "r") temp_1 = np.asarray(temp["density"]) plt.title("Original Count") plt.imshow(temp_1, cmap=c.jet) print("Original Count : ", int(np.sum(temp_1)) + 1) plt.show() print("Original Image") plt.title("Original Image") plt.imshow(plt.imread(args.picPath)) plt.show()
def __init__(self): self.best_pred = 1e6 # Define Saver self.saver = Saver(opt) self.saver.save_experiment_config() # visualize if opt.visualize: # vis_legend = ["Loss", "MAE"] # batch_plot = create_vis_plot(vis, 'Batch', 'Loss', 'batch loss', vis_legend[0:1]) # val_plot = create_vis_plot(vis, 'Epoch', 'result', 'val result', vis_legend[1:2]) # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Dataset dataloader self.train_dataset = SHTDataset(opt.train_dir, train=True) self.train_loader = DataLoader(self.train_dataset, num_workers=opt.workers, shuffle=True, batch_size=opt.batch_size) # must be 1 self.test_dataset = SHTDataset(opt.test_dir, train=False) self.test_loader = torch.utils.data.DataLoader( self.test_dataset, shuffle=False, batch_size=opt.batch_size ) # must be 1, because per image size is different torch.cuda.manual_seed(opt.seed) model = CSRNet() self.model = model.to(opt.device) if opt.resume: if os.path.isfile(opt.pre): print("=> loading checkpoint '{}'".format(opt.pre)) checkpoint = torch.load(opt.pre) opt.start_epoch = checkpoint['epoch'] self.best_pred = checkpoint['best_pred'] self.model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( opt.pre, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(opt.pre)) if opt.use_mulgpu: self.model = torch.nn.DataParallel(self.model, device_ids=opt.gpu_id) self.criterion = nn.MSELoss(reduction='mean').to(opt.device) self.optimizer = torch.optim.SGD(self.model.parameters(), opt.lr, momentum=opt.momentum, weight_decay=opt.decay) # Define lr scheduler self.scheduler = lr_scheduler.MultiStepLR( self.optimizer, milestones=[round(opt.epochs * x) for x in opt.steps], gamma=opt.scales) self.scheduler.last_epoch = opt.start_epoch - 1
def count3(img_root, model_param_path): writer = SummaryWriter() device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") model = CSRNet() model.load_state_dict(torch.load(model_param_path)) model.to(device) test_dataloader = create_test_dataloader(img_root) # dataloader for i, data in enumerate(tqdm(test_dataloader, ncols=50)): image = data['image'].to(device) et_densitymap = model(image).detach() # count = et_densitymap.data.sum() count = str('%.2f' % (et_densitymap[0].cpu().sum())) writer.add_image(str(i) + '/img:', denormalize(image[0].cpu())) writer.add_image( str(i) + "/dmp_count:" + count, et_densitymap[0] / torch.max(et_densitymap[0])) print(str(i + 1) + "_img count success")
def cal_mae(img_root, gt_dmap_root, model_param_path): ''' Calculate the MAE of the test data. img_root: the root of test image data. gt_dmap_root: the root of test ground truth density-map data. model_param_path: the path of specific mcnn parameters. ''' cfg = Config() device = cfg.device model = CSRNet() model.load_state_dict(torch.load(model_param_path)) # GPU #torch.load(model_param_path, map_location=lambda storage, loc: storage) # CPU model.to(device) """ @Mushy Changed data loader to give path From config device """ dataloader = create_test_dataloader(cfg.dataset_root) #dataloader=torch.utils.data.DataLoader(cfg.dataset_root,batch_size=1,shuffle=False) model.eval() mae = 0 with torch.no_grad(): for i, data in enumerate(tqdm(dataloader)): """ @Mushy Changed how to access the data . """ img = data['image'].to(device) #gt_dmap=gt_dmap.to(device) gt_dmap = data['densitymap'].to(device) # forward propagation et_dmap = model(img) mae += abs(et_dmap.data.sum() - gt_dmap.data.sum()).item() del img, gt_dmap, et_dmap print("model_param_path:" + model_param_path + " mae:" + str(mae / len(dataloader)))
def main(): transform = ST.Compose([ ST.ToNumpyForVal(), ST.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) global args args = parser.parse_args() root = "./dataset/" # now generate the ShanghaiA's ground truth part_A_train = os.path.join(root, "part_A_final/train_data", "images") part_A_test = os.path.join(root, "part_A_final/test_data", "images") part_B_train = os.path.join(root, "part_B_final/train_data", "images") part_B_test = os.path.join(root, "part_B_final/test_data", "images") path_sets = [] if args.picSrc == "part_A_test": path_sets = [part_A_test] elif args.picSrc == "part_B_test": path_sets = [part_B_test] img_paths = [] for path in path_sets: for img_path in glob.glob(os.path.join(path, "*.jpg")): img_paths.append(img_path) model = CSRNet() model = model.to("cuda") checkpoint = flow.load(args.modelPath) model.load_state_dict(checkpoint) MAE = [] for i in range(len(img_paths)): img = transform(Image.open(img_paths[i]).convert("RGB")) img = np.asarray(img).astype(np.float32) img = flow.Tensor(img, dtype=flow.float32, device="cuda") img = img.to("cuda") gt_file = h5py.File( img_paths[i].replace(".jpg", ".h5").replace("images", "ground_truth"), "r") groundtruth = np.asarray(gt_file["density"]) with flow.no_grad(): output = model(img.unsqueeze(0)) mae = abs(output.sum().numpy() - np.sum(groundtruth)) MAE.append(mae) avg_MAE = sum(MAE) / len(MAE) print("test result: MAE:{:2f}".format(avg_MAE))
def main(): global args, best_prec1 best_prec1 = 1e6 args = parser.parse_args() args.original_lr = 1e-7 args.lr = 1e-7 args.batch_size = 1 args.momentum = 0.95 args.decay = 5 * 1e-4 args.start_epoch = 0 args.epochs = 400 args.steps = [-1, 1, 100, 150] args.scales = [1, 1, 1, 1] args.workers = 4 args.seed = time.time() args.print_freq = 30 # with open(args.train_json, 'r') as outfile: # train_list = json.load(outfile) train_list = [ os.path.join(args.train_path, i) for i in os.listdir(args.train_path) ] # with open(args.test_json, 'r') as outfile: # val_list = json.load(outfile) print(train_list) val_list = [ os.path.join(args.train_path, j) for j in os.listdir(args.test_path) ] os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu torch.cuda.manual_seed(args.seed) model = CSRNet() model = model.to(device) criterion = nn.MSELoss(size_average=False).to(device) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.decay) if args.pre: if os.path.isfile(args.pre): print("=> loading checkpoint '{}'".format(args.pre)) checkpoint = torch.load(args.pre) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.pre, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.pre)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) train(train_list, model, criterion, optimizer, epoch) prec1 = validate(val_list, model, criterion) is_best = prec1 < best_prec1 best_prec1 = min(prec1, best_prec1) print(' * best MAE {mae:.3f} '.format(mae=best_prec1)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.pre, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, args.task)
def main(): global args, best_prec1 best_prec1 = 1e6 args = parser.parse_args() args.original_lr = 1e-7 args.lr = 1e-7 args.batch_size = 1 args.momentum = 0.95 args.decay = 5 * 1e-4 args.start_epoch = 0 args.epochs = 400 args.steps = [-1, 1, 100, 150] args.scales = [1, 1, 1, 1] args.workers = 0 args.seed = time.time() args.print_freq = 30 with open(args.train_json, "r") as outfile: train_list = json.load(outfile) with open(args.test_json, "r") as outfile: val_list = json.load(outfile) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu model = CSRNet() model = model.to("cuda") criterion = nn.MSELoss(reduction="sum").to("cuda") optimizer = flow.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.decay) if args.pre: if os.path.isfile(args.pre): print("=> loading checkpoint '{}'".format(args.pre)) checkpoint = flow.load(args.pre) args.start_epoch = checkpoint["epoch"] best_prec1 = checkpoint["best_prec1"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) print("=> loaded checkpoint '{}' (epoch {})".format( args.pre, checkpoint["epoch"])) else: print("=> no checkpoint found at '{}'".format(args.pre)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) train(train_list, model, criterion, optimizer, epoch) prec1 = validate(val_list, model, criterion) is_best = prec1 < best_prec1 best_prec1 = min(prec1, best_prec1) print(" * best MAE {mae:.3f} ".format(mae=best_prec1)) save_checkpoint( { "epoch": epoch + 1, "arch": args.pre, "state_dict": model.state_dict(), "best_prec1": best_prec1, }, is_best, str(epoch + 1), args.modelPath, )
def count(path): """ evaluates the number of larva present in input. input is either an image of a video. if input is an image, the evaluation is done once over the image, if input is a video, the evaluation is done over every caption in the video seperately and then averaged over all captions to produce the result :param path: a path to an image or a video :return: count """ # Define the device(processor) type if torch.cuda.is_available(): device = torch.device('cuda') print('Current procesor is GPU') else: device = torch.device('cpu') print('Current procesor is CPU') # Define the model to use for calculations model = CSRNet() model.load_state_dict(torch.load('model_wgts.pth')) model.to(device) model.eval() # Load the image or video im_list = [] try: im_list.append(Image.open(path)) except OSError: if 'http' in path: wget.download(path, out='videos') cap = cv2.VideoCapture(os.listdir('videos')[0]) frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fc = 0 ret = True im_list = [] while (fc < frameCount and ret): ret, im = cap.read() if fc % 10 == 0: new_im = np.zeros_like(im) new_im[:, :, 0] = im[:, :, 2] new_im[:, :, 1] = im[:, :, 1] new_im[:, :, 2] = im[:, :, 0] im_list.append(Image.fromarray(new_im.astype('uint8'), 'RGB')) fc += 1 # Disable gradients with torch.no_grad(): # Prepare data for model mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transform_eval = T.Compose([ T.Resize(255, interpolation=Image.BICUBIC), T.ToTensor(), T.Normalize(mean, std) ]) model_input = torch.stack([transform_eval(im) for im in im_list]) model_input.to(device) results, densities = model(model_input) if len(results) > 1: results = results.mean() return results
def main(): global args, best_prec1 best_prec1 = 1e6 args = parser.parse_args() args.original_lr = 1e-5 args.lr = 1e-5 args.batch_size = 1 args.momentum = 0.95 args.decay = 5 * 1e-4 args.start_epoch = 0 args.epochs = 100 args.steps = [-1, 20, 40, 60] args.scales = [1, 0.1, 0.1, 0.1] args.workers = 4 args.seed = time.time() args.print_freq = 30 # with open(args.train_json, 'r') as outfile: # train_list = json.load(outfile) # with open(args.test_json, 'r') as outfile: # val_list = json.load(outfile) csv_train_path = args.train_csv csv_test_path = args.test_csv # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # torch.cuda.manual_seed(args.seed) device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') model = CSRNet() #summary(model, (3, 256, 256)) model = model.to(device) criterion = nn.MSELoss(size_average=False).to(device) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.decay) if args.pre: if os.path.isfile(args.pre): print("=> loading checkpoint '{}'".format(args.pre)) checkpoint = torch.load(args.pre) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.pre, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.pre)) precs = [] for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) train(csv_train_path, model, criterion, optimizer, epoch) prec1 = validate(csv_test_path, model, criterion) precs.append(prec1) is_best = prec1 < best_prec1 best_prec1 = min(prec1, best_prec1) print(' * best MAE {mae:.3f} '.format(mae=best_prec1)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.pre, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'MAE_history': precs }, is_best, args.task)
def main(): global args,best_prec1 best_prec1 = 1e6 args = parser.parse_args() print(args) args.original_lr = 1e-7 args.lr = 1e-7 # args.batch_size = 9 args.momentum = 0.95 args.decay = 5*1e-4 args.start_epoch = 0 args.epochs = 400 args.steps = [-1,1,100,150] args.scales = [1,1,1,1] args.workers = 4 args.seed = time.time() args.print_freq = 30 train_list, test_list = getTrainAndTestListFromPath(args.train_path, args.test_path) splitRatio = 0.8 print('batch size is ', args.batch_size) print('cuda available? {}'.format(torch.cuda.is_available())) device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # torch.cuda.manual_seed(args.seed) model = CSRNet() model = model.to(device) criterion = nn.MSELoss(size_average=False).to(device) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.decay) if args.pre: if os.path.isfile(args.pre): print("=> loading checkpoint '{}'".format(args.pre)) checkpoint = torch.load(args.pre) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.pre, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.pre)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) subsetTrain, subsetValid = getTrainAndValidateList(train_list, splitRatio) train(subsetTrain, model, criterion, optimizer, epoch, device) prec1 = validate(subsetValid, model, criterion, device) is_best = prec1 < best_prec1 best_prec1 = min(prec1, best_prec1) print(' * best MAE {mae:.3f} ' .format(mae=best_prec1)) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.pre, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best,args.task)