def __init__(self): super(FeatureExtractionHRNetW48, self).__init__() parser = argparse.ArgumentParser(description='Train classification network') parser.add_argument('--cfg', help='experiment configure file name', # required=True, type=str, default='cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' ) parser.add_argument('--modelDir', help='model directory', type=str, default='') parser.add_argument('--logDir', help='log directory', type=str, default='') parser.add_argument('--dataDir', help='data directory', type=str, default='') parser.add_argument('--testModel', help='testModel', type=str, default='hrnetv2_w48_imagenet_pretrained.pth') args = parser.parse_args() update_config(config, args) self.HRNet = get_cls_net(config) self.HRNet.init_weights(pretrained = 'hrnetv2_w48_imagenet_pretrained.pth')
def parse_args(): parser = argparse.ArgumentParser(description="Face Mask Overlay") parser.add_argument( "--cfg", help="experiment configuration filename", required=True, type=str, ) parser.add_argument( "--landmark_model", help="path to model for landmarks exctraction", required=True, type=str, ) parser.add_argument( "--detector_model", help="path to detector model", type=str, default="detection/face_detector.prototxt", ) parser.add_argument( "--detector_weights", help="path to detector weights", type=str, default="detection/face_detector.caffemodel", ) parser.add_argument( "--mask_image", help="path to a .png file with a mask", required=True, type=str, ) parser.add_argument("--device", default="cpu", help="Device to inference on") args = parser.parse_args() update_config(config, args) return args
def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str) args = parser.parse_args() update_config(config, args) return args
def get_args(): parser = argparse.ArgumentParser(description='Tissue_test') parser.add_argument('--cfg', type=str, help='path of config file') parser.add_argument('opts', type=None, nargs=argparse.REMAINDER, help='modify some configs') args = parser.parse_args() update_config(config, args) return args
def get_args(): parser = argparse.ArgumentParser(description='ECDP_NCIC') parser.add_argument('--cfg', type=str, help='path of config name') parser.add_argument('opts', type=None, nargs=argparse.REMAINDER, help='modify some default cfgs') args = parser.parse_args() update_config(config, args) return parser.parse_args()
def get_args(): parser = argparse.ArgumentParser(description='MIL_TISSUE') parser.add_argument('--cfg', type=str, help='experiment configure file name') parser.add_argument('opts', type=None, nargs=argparse.REMAINDER, help='modify the options using command-line') \ #将命令行剩下所有的参数封装为list args = parser.parse_args() update_config(config, args) return args
def generate_kpts(video_name, smooth=False): human_model = yolo_model() args = get_args() update_config(cfg, args) cam = cv2.VideoCapture(video_name) video_length = int(cam.get(cv2.CAP_PROP_FRAME_COUNT)) # # ret_val, input_image = cam.read() # # Video writer # fourcc = cv2.VideoWriter_fourcc(*'mp4v') # input_fps = cam.get(cv2.CAP_PROP_FPS) pose_model = model_load(cfg) pose_model.cuda() # collect keypoints coordinate kpts_result = [] for i in tqdm(range(video_length)): ret_val, input_image = cam.read() try: bboxs, scores = yolo_det(input_image, human_model) # bbox is coordinate location inputs, origin_img, center, scale = preprocess( input_image, bboxs, scores, cfg) except Exception as e: print(e) continue with torch.no_grad(): # compute output heatmap inputs = inputs[:, [2, 1, 0]] output = pose_model(inputs.cuda()) # compute coordinate preds, maxvals = get_final_preds(cfg, output.clone().cpu().numpy(), np.asarray(center), np.asarray(scale)) # if len(preds) != 1: # print('here') if smooth: # smooth and fine-tune coordinates preds = smooth_filter(preds) # 3D video pose (only support single human) kpts_result.append(preds[0]) result = np.array(kpts_result) return result
def post(self): scan_methods = [] conf_all = common.conf for i in self.request.body.decode().split("&"): para = secure.clear(urllib.unquote(i.split("=", 1)[0])) value = secure.clear(urllib.unquote(i.split("=", 1)[1])) if para in conf_all.keys(): conf_all[para] = value elif "scan_methods" in para: scan_methods.append(para[para.rindex("_") + 1:].upper()) conf_all["scan_methods"] = ",".join(scan_methods) update_config(conf_all, CHECK_CONF_FILE) return self.write(out.alert("Success!", "/config"))
def build_model(args): print('Load config...') update_config(cfg, args) # print("Setup Log ...") log_dir = cfg.TEST.CALIB_DIR.split('/') log_dir = os.path.join( cfg.TEST.MODEL_DIR, cfg.TRAIN.EXP_NAME + '_' + cfg.TEST.CALIB_NAME[:-4] + '_resol_' + str(cfg.DATASET.OUTPUT_SIZE[0]) + '_' + log_dir[-2] + '_' + log_dir[-1].split('_')[0] + '_' + log_dir[-1].split('_')[1] + '_' + cfg.TEST.MODEL_NAME.split('-')[-1].split('.')[0]) cond_mkdir(log_dir) if len(cfg.TEST.FRAME_RANGE) > 1: save_dir_img_ext = cfg.TEST.SAVE_FOLDER + '_' + str( cfg.TEST.FRAME_RANGE[0]) + '_' + str(cfg.TEST.FRAME_RANGE[-1]) else: save_dir_img_ext = cfg.TEST.SAVE_FOLDER save_dir_img = os.path.join(log_dir, save_dir_img_ext) save_dir_uv = save_dir_img + '_uv' save_dir_nimg = save_dir_img + '_nimg' cond_mkdir(save_dir_img) cond_mkdir(save_dir_uv) cond_mkdir(save_dir_nimg) print("Build dataloader ...") view_dataset = eval(cfg.DATASET.DATASET)(cfg, isTrain=False) print("*" * 100) # print('Build Network...') # model_net = eval(cfg.MODEL.NAME)(cfg, isTrain=False) # model = model_net # model.setup(cfg) # print('Loading Model...') # checkpoint_path = cfg.TEST.MODEL_PATH # if os.path.exists(checkpoint_path): # pass # elif os.path.exists(os.path.join(cfg.TEST.MODEL_DIR, checkpoint_path)): # checkpoint_path = os.path.join(cfg.TEST.MODEL_DIR,checkpoint_path) print('Start buffering data for inference...') view_dataloader = DataLoader(view_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=cfg.WORKERS) view_dataset.buffer_all() i = 0 for view_data in view_dataloader: print(str(i) + '/' + str(view_data.__len__())) i += 1
def main(): args = parse_args() update_config(cfg, args) logger, final_output_dir, tb_log_dir = create_logger( cfg, args.cfg, 'valid') logger.info(pprint.pformat(args)) logger.info(cfg) # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')( cfg, is_train=False ) if cfg.TEST.MODEL_FILE != '': print(cfg.TEST.MODEL_FILE) logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) else: model_state_file = os.path.join( final_output_dir, 'model_best.pth' ) logger.info('=> loading model from {}'.format(model_state_file)) model.load_state_dict(torch.load(model_state_file)) model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() # define loss function (criterion) and optimizer criterion = JointsMSELoss( use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT ).cuda() # Data loading code valid_dataset = get_valid_dataset(cfg) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=False, num_workers=cfg.WORKERS, pin_memory=True ) # evaluate on validation set validate(cfg, valid_loader, valid_dataset, model, criterion, final_output_dir)
def main(): # Config args = parse_args() update_config(cfg, args) # Build logger logger, final_output_dir, tb_log_dir = create_logger( cfg, args.cfg, 'valid') logger.info(pprint.pformat(args)) logger.info(cfg) # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED # Build model model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=False) if cfg.TEST.MODEL_FILE: logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) else: model_state_file = os.path.join(final_output_dir, 'final_state.pth') logger.info('=> loading model from {}'.format(model_state_file)) model.load_state_dict(torch.load(model_state_file)) model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() # Build loss criterion = JointsMSELoss( use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda() # Build dataloader normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, transforms.Compose([transforms.ToTensor(), normalize])) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=False, num_workers=cfg.WORKERS, pin_memory=True) # Evaluation validate(cfg, valid_loader, valid_dataset, model, criterion, final_output_dir, tb_log_dir)
def __init__(self, save_path): self.yolo = None self.hrnet = None self.small_class_lable = None self.config = config cfg_path = "./config/con_path.yaml" update_config(self.config, cfg_path) self.save_path = save_path self.yolo_size = 416 self.yolo_conf_thres = 0.85 self.yolo_nms_thres= 0.4 self.big_label = ['upclothes', 'downclothes', 'dress', 'shoes', 'cap'] self.small_label = ['boots', 'baseballcap', 'dcoat', 'fshirt', 'fshoes', 'gallus', 'hat', 'highheels', 'jumpsuit', 'lcoat', 'ldress', 'lshoes', 'lskirt', 'mshirt', 'pants', 'polo', 'scoat', 'sdress', 'shirt', 'shorts', 'sshoes', 'sskirt', 'suit-coat', 'suit-pants', 'sweater', 'tshirt', 'wsweater']
def parse_args(): parser = argparse.ArgumentParser(description='Train segmentation network') parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) parser.add_argument('opts', help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) args = parser.parse_args() update_config(config, args) return args
def post(self): if "save_default" in self.request.arguments: pocs_default = dict() for p in self.request.body_arguments: if self.get_argument(p) == "true": pocs_default[p] = "True" else: pocs_default[p] = "False" update_config(pocs_default, DEFAULT_POCS_PATH) for p in common.all_pocs: if ((self.get_argument(p) == "true") and (p not in common.used_pocs)): common.used_pocs.append(p) elif ((self.get_argument(p) == "false") and (p in common.used_pocs)): common.used_pocs.remove(p) return self.write(out.jump("/scan_config"))
def getKptsFromImage(human_model, pose_model, image, smooth=None): args = get_args() update_config(cfg, args) bboxs, scores = yolo_det(image, human_model) # bbox is coordinate location inputs, origin_img, center, scale = preprocess(image, bboxs, scores, cfg) with torch.no_grad(): # compute output heatmap inputs = inputs[:, [2, 1, 0]] output = pose_model(inputs.cuda()) # compute coordinate preds, maxvals = get_final_preds( cfg, output.clone().cpu().numpy(), np.asarray(center), np.asarray(scale)) # 3D video pose (only support single human) return preds[0]
def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--cfg', help='configuration filename', default='None', type=str) parser.add_argument('--model-file', help='model parameters', default='./output/FFL3/HRNet-106-Points/onlyMask.pth', type=str) args = parser.parse_args() if os.path.exists(args.cfg): update_config(config, args) return args
def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--experiment_name', help='experiment name', default='HRNet-106-Points', type=str) parser.add_argument('--cfg', help='configuration filename', default='None', type=str) args = parser.parse_args() if os.path.exists(args.cfg): update_config(config, args) return args
def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str) parser.add_argument('--model-file', help='model parameters', required=True, type=str) parser.add_argument('--onnx-export', type=str, default='', help="convert model to onnx") args = parser.parse_args() update_config(config, args) return args
def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument( '--cfg', default='./experiments/300w/face_alignment_300w_hrnet_w18.yaml', help='experiment configuration filename', type=str) parser.add_argument('--model-file', help='model parameters', default='./models/HR18-DSM-old.pth', type=str) parser.add_argument("--aim", type=str, default="all", choices=["eye", 'mouth', 'all']) args = parser.parse_args() update_config(config, args) return args
def get(self): if "restore" in self.request.arguments: try: with open(DEFAULT_CONF_FILE, 'r') as handler: default_configuration = json.loads(handler.read()) update_config(default_configuration, CHECK_CONF_FILE) except Exception as e: logger.error("Fail to restore the default configuration.%s" % str(e)) update_config(common.conf, CHECK_CONF_FILE) else: common.conf = default_configuration logger.success("Restored default configuration.") scan_methods = {"GET": "", "POST": "", "DELETE": "", "PUT": ""} options = common.conf["scan_methods"].split(",") for m in options: if m.upper() in scan_methods: scan_methods[m] = "checked" return self.render("config.html", config=common.conf, scan_methods=scan_methods)
from dataset import FileLoader import glob, os from torch.utils.data import DataLoader import torch from model.HRNet import HighResolutionNet from model import SemanticSegmentationNet from lib.config import config, update_config import torch.backends.cudnn as cudnn import model update_config(config, "hrnet_config.yaml") def worker_init_fn(worker_id): # ! to make the seed chain reproducible, must use the torch random, not numpy # the torch rng from main thread will regenerate a base seed, which is then # copied into the dataloader each time it created (i.e start of each epoch) # then dataloader with this seed will spawn worker, now we reseed the worker worker_info = torch.utils.data.get_worker_info() # to make it more random, simply switch torch.randint to np.randint worker_seed = torch.randint(0, 2**32, (1, ))[0].cpu().item() + worker_id # print('Loader Worker %d Uses RNG Seed: %d' % (worker_id, worker_seed)) # retrieve the dataset copied into this worker process # then set the random seed for each augmentation worker_info.dataset.setup_augmentor(worker_id, worker_seed) return batch_size = {"train": 1, "valid": 1} tr_file_list = glob.glob(
from lib.utils.paf_to_pose import paf_to_pose_cpp parser = argparse.ArgumentParser() parser.add_argument('--cfg', help='experiment configure file name', default='./experiments/vgg19_368x368_sgd.yaml', type=str) parser.add_argument('--weight', type=str, default='../ckpts/openpose.pth') parser.add_argument('opts', help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) args = parser.parse_args() # update config file update_config(cfg, args) ''' MS COCO annotation order: 0: nose 1: l eye 2: r eye 3: l ear 4: r ear 5: l shoulder 6: r shoulder 7: l elbow 8: r elbow 9: l wrist 10: r wrist 11: l hip 12: r hip 13: l knee 14: r knee 15: l ankle 16: r ankle The order in this work: (0-'nose' 1-'neck' 2-'right_shoulder' 3-'right_elbow' 4-'right_wrist' 5-'left_shoulder' 6-'left_elbow' 7-'left_wrist' 8-'right_hip' 9-'right_knee' 10-'right_ankle' 11-'left_hip' 12-'left_knee' 13-'left_ankle' 14-'right_eye' 15-'left_eye' 16-'right_ear' 17-'left_ear' )
def main(args): sys.path.append( args.openpose_dir) # In case calling from an external script from lib.network.rtpose_vgg import get_model from lib.network.rtpose_vgg import use_vgg from lib.network import im_transform from evaluate.coco_eval import get_outputs, handle_paf_and_heat from lib.utils.common import Human, BodyPart, CocoPart, CocoColors, CocoPairsRender, draw_humans from lib.utils.paf_to_pose import paf_to_pose_cpp from lib.config import cfg, update_config update_config(cfg, args) model = get_model('vgg19') model = torch.nn.DataParallel(model).cuda() use_vgg(model) # model.load_state_dict(torch.load(args.weight)) checkpoint = torch.load(args.weight) epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] state_dict = checkpoint['state_dict'] # state_dict = {key.replace("module.",""):value for key, value in state_dict.items()} # Remove "module." from vgg keys model.load_state_dict(state_dict) # optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format(args.weight, epoch)) model.float() model.eval() image_folders = args.image_folders.split(',') for i, image_folder in enumerate(image_folders): print( f"\nProcessing {i} of {len(image_folders)}: {' '.join(image_folder.split('/')[-4:-2])}" ) if args.all_frames: # Split video and run inference on all frames output_dir = os.path.join(os.path.dirname(image_folder), 'predictions', 'pose2d', 'openpose_pytorch_ft_all') os.makedirs(output_dir, exist_ok=True) video_path = os.path.join( image_folder, 'scan_video.avi') # break up video and run on all frames temp_folder = image_folder.split('/')[-3] + '_openpose' image_folder = os.path.join( '/tmp', f'{temp_folder}') # Overwrite image_folder os.makedirs(image_folder, exist_ok=True) split_video(video_path, image_folder) else: # Just use GT-annotated frames output_dir = os.path.join(os.path.dirname(image_folder), 'predictions', 'pose2d', 'openpose_pytorch_ft') os.makedirs(output_dir, exist_ok=True) img_mask = os.path.join(image_folder, '??????.png') img_names = glob(img_mask) for img_name in img_names: image_file_path = img_name oriImg = cv2.imread(image_file_path) # B,G,R order shape_dst = np.min(oriImg.shape[0:2]) with torch.no_grad(): paf, heatmap, im_scale = get_outputs(oriImg, model, 'rtpose') humans = paf_to_pose_cpp(heatmap, paf, cfg) # Save joints in OpenPose format image_h, image_w = oriImg.shape[:2] people = [] for i, human in enumerate(humans): keypoints = [] for j in range(18): if j == 8: keypoints.extend([ 0, 0, 0 ]) # Add extra joint (midhip) to correspond to body_25 if j not in human.body_parts.keys(): keypoints.extend([0, 0, 0]) else: body_part = human.body_parts[j] keypoints.extend([ body_part.x * image_w, body_part.y * image_h, body_part.score ]) person = {"person_id": [i - 1], "pose_keypoints_2d": keypoints} people.append(person) people_dict = {"people": people} _, filename = os.path.split(image_file_path) name, _ = os.path.splitext(filename) frame_id = int(name) with open( os.path.join(output_dir, f"scan_video_{frame_id:012}_keypoints.json"), 'w') as outfile: json.dump(people_dict, outfile) if args.all_frames: shutil.rmtree(image_folder) # Delete image_folder
def main(): print('Load config...') args = parse_args() update_config(cfg, args) print("Setup Log ...") log_dir, iter_init, epoch_begin, checkpoint_path = create_logger( cfg, args.cfg) print(args) print(cfg) print("*" * 100) print('Set gpus...' + str(cfg.GPUS)[1:-1]) print(' Batch size: ' + str(cfg.TRAIN.BATCH_SIZE)) if not cfg.GPUS == 'None': os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.GPUS)[1:-1] # import pytorch after set cuda import torch from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from lib.models import metric from lib.models.render_net import RenderNet from lib.models.feature_net import FeatureNet from lib.models.merge_net import MergeNet from lib.models.feature_pair_net import FeaturePairNet from lib.engine.loss import MultiLoss import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler # from utils.encoding import DataParallelModel # from utils.encoding import DataParallelCriterion from lib.dataset.DomeViewDataset import DomeViewDataset from lib.dataset.DPViewDataset import DPViewDataset from lib.utils.model import save_checkpoint # device = torch.device('cuda: 2'+ str(cfg.GPUS[-1])) print("*" * 100) print("Build dataloader ...") view_dataset = eval(cfg.DATASET.DATASET)(cfg=cfg, is_train=True) if cfg.TRAIN.VAL_FREQ > 0: print("Build val dataloader ...") view_val_dataset = eval(cfg.DATASET.DATASET)(cfg=cfg, is_train=False) print("*" * 100) print('Build Network...') gpu_count = torch.cuda.device_count() dist.init_process_group(backend='nccl', init_method=cfg.DIST_URL, world_size=cfg.WORLD_SIZE, rank=cfg.RANK) model_net = eval(cfg.MODEL.NAME)(cfg) if gpu_count > 1: model_net.cuda() model = torch.nn.parallel.DistributedDataParallel( model_net, device_ids=list(cfg.GPUS) #, find_unused_parameters = True ) elif gpu_count == 1: model = model_net.cuda() optimizerG = torch.optim.Adam(model_net.parameters(), lr=cfg.TRAIN.LR) if checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location='cpu') iter_init = checkpoint['iter'] epoch_begin = checkpoint['epoch'] optimizerG.load_state_dict(checkpoint['optimizer']) model_net.load_state_dict(checkpoint['state_dict']) print(' Load checkpoint path from %s' % (checkpoint_path)) # Loss criterion = MultiLoss(cfg) criterion.cuda() # Optimizer lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizerG, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, last_epoch=epoch_begin - 1) print('Start buffering data for training...') view_dataloader = DataLoader( view_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * gpu_count, # shuffle = cfg.TRAIN.SHUFFLE, pin_memory=True, num_workers=cfg.WORKERS, sampler=DistributedSampler(view_dataset)) view_dataset.buffer_all() if cfg.TRAIN.VAL_FREQ > 0: print('Start buffering data for validation...') view_val_dataloader = DataLoader( view_val_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False, num_workers=cfg.WORKERS, sampler=DistributedSampler(view_val_dataset)) view_val_dataset.buffer_all() writer = SummaryWriter(log_dir) # Activate some model parts if cfg.DATASET.DATASET == 'realdome_cx': view_data = view_dataset.read_view(0) cur_obj_path = view_data['obj_path'] frame_idx = view_data['f_idx'] obj_data = view_dataset.objs[frame_idx] model_net.init_rasterizer(obj_data, view_dataset.global_RT) if type(model_net) == MergeNet: imgs, uv_maps = view_dataset.get_all_view() model_net.init_all_atlas(imgs, uv_maps) print('Begin training... Log in ' + log_dir) model.train() # model_net.set_mode(is_train = True) # model = DataParallelModel(model_net) # model.cuda() start = time.time() iter = iter_init for epoch in range(epoch_begin, cfg.TRAIN.END_EPOCH + 1): for view_data in view_dataloader: # #################################################################### # # all put into forward ? # # # img_gt = view_trgt['img'].cuda() # # if cfg.DATASET.DATASET == 'DomeViewDataset': # # uv_map, alpha_map, cur_obj_path = model.module.project_with_rasterizer(cur_obj_path, view_dataset.objs, view_trgt) # # ROI = view_trgt['ROI'].cuda() # # elif cfg.DATASET.DATASET == 'DomeViewDataset': # uv_map = view_trgt['uv_map'].cuda() # alpha_map = view_trgt['mask'][:,None,:,:].cuda() # ROI = None # outputs = model.forward(uv_map = uv_map, # img_gt = img_gt, # alpha_map = alpha_map, # ROI = ROI) outputs = model.forward(view_data, is_train=True) img_gt = view_data['img'].cuda() alpha_map = view_data['mask'][:, None, :, :].cuda() uv_map = view_data['uv_map'].cuda() ROI = None if alpha_map is not None: outputs = outputs * alpha_map if ROI is not None: outputs = outputs * ROI # ignore loss outside alpha_map and ROI if alpha_map is not None: img_gt = img_gt * alpha_map if ROI is not None: img_gt = img_gt * ROI # Loss loss_g = criterion(outputs, img_gt) # loss_g = criterion_parall(outputs, img_gt) # loss_rn = criterion.loss_rgb # loss_rn_hsv = criterion.loss_hsv # loss_atlas = criterion.loss_atlas optimizerG.zero_grad() loss_g.backward() optimizerG.step() # chcek gradiant if iter == iter_init: print('Checking gradiant in first iteration') for name, param in model.named_parameters(): if param.grad is None: print(name, True if param.grad is not None else False) if type(outputs) == list: for iP in range(len(outputs)): outputs[iP] = outputs[iP].cuda() outputs = torch.cat(outputs, dim=0) # get output images outputs_img = outputs[:, 0:3, :, :] neural_img = outputs[:, 3:6, :, :] aligned_uv = None # aligned_uv = outputs[:, -2:, : ,:] # aligned_uv = outputs[:, -2:, : ,:] atlas = model_net.get_atalas() # Metrics log_time = datetime.datetime.now().strftime( '%m/%d') + '_' + datetime.datetime.now().strftime('%H:%M:%S') with torch.no_grad(): err_metrics_batch_i = metric.compute_err_metrics_batch( outputs_img * 255.0, img_gt * 255.0, alpha_map, compute_ssim=False) # vis if not iter % cfg.LOG.PRINT_FREQ: img_ref = view_data['img_ref'].cuda() vis.writer_add_image(writer, iter, epoch, img_gt, outputs_img, neural_img, uv_map, aligned_uv, atlas, img_ref) # Log loss_list = criterion.loss_list() end = time.time() iter_time = end - start vis.writer_add_scalar(writer, iter, epoch, err_metrics_batch_i, loss_list, log_time, iter_time) iter += 1 start = time.time() lr_scheduler.step() if iter % cfg.LOG.CHECKPOINT_FREQ == 0 and iter != 0: final_output_dir = os.path.join( log_dir, 'model_epoch_%d_iter_%s_.pth' % (epoch, iter)) is_best_model = False save_checkpoint( { 'epoch': epoch + 1, 'iter': iter + 1, 'model': cfg.MODEL.NAME, 'state_dict': model_net.state_dict(), 'atlas': atlas, 'optimizer': optimizerG.state_dict(), # 'best_state_dict': model.module.state_dict(), # 'perf': perf_indicator, }, is_best_model, final_output_dir) # scipy.io.savemat('/data/NFS/new_disk/chenxin/relightable-nr/data/densepose_cx/logs/dnr/tmp/neural_img_epoch_%d_iter_%s_.npy'% (epoch, iter), # {"neural_tex": model_net.neural_tex.cpu().clone().detach().numpy()}) # model_net # validation if cfg.TRAIN.VAL_FREQ > 0: if not epoch % cfg.TRAIN.VAL_FREQ: print('Begin validation...') start_val = time.time() with torch.no_grad(): # error metrics metric_val = { 'mae_valid': [], 'mse_valid': [], 'psnr_valid': [], 'ssim_valid': [] } loss_list_val = { 'Loss': [], 'rgb': [], 'hsv': [], 'atlas': [] } val_iter = 0 for view_val_trgt in view_val_dataloader: img_gt = view_val_trgt['img'].cuda() # alpha_map = None # ROI = None # img_gt = view_val_trgt['img'].cuda() # # if cfg.DATASET.DATASET == 'DomeViewDataset': # # uv_map, alpha_map, cur_obj_path = model.module.project_with_rasterizer(cur_obj_path, view_dataset.objs, view_trgt) # # ROI = view_trgt['ROI'].cuda() # # elif cfg.DATASET.DATASET == 'DomeViewDataset': # uv_map = view_val_trgt['uv_map'].cuda() # alpha_map = view_val_trgt['mask'][:,None,:,:].cuda() # ROI = None outputs = model.forward(view_data, is_train=False) outputs_img = outputs[:, 0:3, :, :] neural_img = outputs[:, 3:6, :, :] aligned_uv = outputs[:, -2:, :, :] # ignore loss outside alpha_map and ROI if alpha_map is not None: img_gt = img_gt * alpha_map outputs = outputs * alpha_map if ROI is not None: img_gt = img_gt * ROI outputs = outputs * ROI # Metrics loss_val = criterion(outputs, img_gt) loss_list_val_batch = criterion.loss_list() metric_val_batch = metric.compute_err_metrics_batch( outputs_img * 255.0, img_gt * 255.0, alpha_map, compute_ssim=True) batch_size = outputs_img.shape[0] for i in range(batch_size): for key in list(metric_val.keys()): if key in metric_val_batch.keys(): metric_val[key].append( metric_val_batch[key][i]) for key, val in loss_list_val_batch.items(): loss_list_val[key].append(val) if val_iter == 0: iter_id = epoch vis.writer_add_image(writer, iter_id, epoch, img_gt, outputs_img, neural_img, uv_map, aligned_uv, atlas=None, ex_name='Val_') val_iter = val_iter + 1 # mean error for key in list(metric_val.keys()): if metric_val[key]: metric_val[key] = np.vstack( metric_val[key]) metric_val[ key + '_mean'] = metric_val[key].mean() else: metric_val[key + '_mean'] = np.nan for key in loss_list_val.keys(): loss_list_val[key] = torch.tensor( loss_list_val[key]).mean() # vis end_val = time.time() val_time = end_val - start_val log_time = datetime.datetime.now().strftime( '%m/%d') + '_' + datetime.datetime.now( ).strftime('%H:%M:%S') iter_id = epoch vis.writer_add_scalar(writer, iter_id, epoch, metric_val, loss=loss_list_val, log_time=log_time, iter_time=val_time, ex_name='Val')
def main(): print('Load config...') args = parse_args() update_config(cfg, args) print("Setup Log ...") log_dir, iter_init, epoch_begin, checkpoint_path = create_logger( cfg, args.cfg) save_img_dir = os.path.join(log_dir, 'images_trainning') os.mkdir(save_img_dir) print(args) print(cfg) print("*" * 100) print('Set gpus...' + str(cfg.GPUS)[1:-1]) print(' Batch size: ' + str(cfg.TRAIN.BATCH_SIZE)) # if not cfg.GPUS == 'None': # os.environ["CUDA_VISIBLE_DEVICES"]=str(cfg.GPUS)[1:-1] os.environ["CUDA_VISIBLE_DEVICES"] = '1' # import pytorch after set cuda import torch from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from lib.models import metric from lib.models.render_net import RenderNet from lib.models.feature_net import FeatureNet from lib.models.merge_net import MergeNet from lib.models.feature_pair_net import FeaturePairNet from lib.models.gan_net import Pix2PixModel from lib.engine.loss import MultiLoss import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler # from utils.encoding import DataParallelModel # from utils.encoding import DataParallelCriterion from lib.dataset.DomeViewDataset import DomeViewDataset from lib.dataset.DomeViewDatasetFVV import DomeViewDatasetFVV from lib.dataset.DPViewDataset import DPViewDataset from lib.utils.model import save_checkpoint # device = torch.device('cuda: 2'+ str(cfg.GPUS[-1])) print("*" * 100) torch.__version__ # Get PyTorch and CUDA version torch.cuda.is_available() # Check that CUDA works torch.cuda.device_count() # Check how many CUDA capable devices you have # Print device human readable names torch.cuda.get_device_name(0) # Add more lines with +1 like get_device_name(3), get_device_name(4) if you have more devices. print("Build dataloader ...") view_dataset = eval(cfg.DATASET.DATASET)(cfg=cfg, isTrain=True) viewFVV_dataset = eval(cfg.DATASET_FVV.DATASET)(cfg=cfg, isTrain=True) if cfg.TRAIN.VAL_FREQ > 0: print("Build val dataloader ...") view_val_dataset = eval(cfg.DATASET.DATASET)(cfg=cfg, isTrain=False) print("*" * 100) print('Build Network...') # gpu_count = torch.cuda.device_count() gpu_count = len(cfg.GPUS) dist.init_process_group(backend='nccl', init_method=cfg.DIST_URL, world_size=cfg.WORLD_SIZE, rank=cfg.RANK) model_net = eval(cfg.MODEL.NAME)(cfg) model = model_net model.setup(cfg) if checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location='cpu') # iter_init = checkpoint['iter'] # epoch_begin = checkpoint['epoch'] # to-do try directly load optimizer state_dict with same load_state_dict model_net.load_optimizer_state_dict(checkpoint['optimizer']) model_net.load_state_dict(checkpoint['state_dict']) print(' Load checkpoint path from %s' % (checkpoint_path)) print('Start buffering data for training...') view_dataloader = DataLoader(view_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * gpu_count, shuffle=cfg.TRAIN.SHUFFLE, pin_memory=True, num_workers=cfg.WORKERS) viewFVV_dataloader = DataLoader(viewFVV_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * gpu_count, shuffle=False, num_workers=0) # view_dataloader = DataLoader(view_dataset, # batch_size = cfg.TRAIN.BATCH_SIZE * gpu_count, # # shuffle = cfg.TRAIN.SHUFFLE, # pin_memory=True, # num_workers = cfg.WORKERS, # sampler=DistributedSampler(view_dataset)) view_dataset.buffer_all() if cfg.TRAIN.VAL_FREQ > 0: print('Start buffering data for validation...') view_val_dataloader = DataLoader( view_val_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False, num_workers=cfg.WORKERS, sampler=DistributedSampler(view_val_dataset)) view_val_dataset.buffer_all() writer = SummaryWriter(log_dir) print('Begin training... Log in ' + log_dir) model.train() start = time.time() iter = iter_init for epoch in range(epoch_begin, cfg.TRAIN.END_EPOCH + 1): model.update_learning_rate() viewFVV_dataset.refresh() for view_data, viewFVV_data in zip(view_dataloader, viewFVV_dataloader): model.optimize_parameters(view_data) img_gt = view_data['img'] alpha_map = view_data['mask'][:, None, :, :] uv_map = view_data['uv_map'] # chcek gradiant if iter == iter_init: print('Checking gradiant in first iteration') for name, param in model.named_parameters(): if param.grad is None: print(name, True if param.grad is not None else False) outputs = model_net.get_current_results() loss_list = model_net.get_current_losses() outputs_img = outputs['rs'][:, 0:3, :, :].clone().detach().cpu() outputs_mask = outputs['rs'][:, 3:4, :, :].clone().detach().cpu() outputs_img *= outputs_mask outputs['img_rs'] = outputs_img outputs['mask_rs'] = outputs_mask # neural_img = outputs['nimg_rs'].clone().detach().cpu() # Metrics log_time = datetime.datetime.now().strftime( '%m/%d') + '_' + datetime.datetime.now().strftime('%H:%M:%S') with torch.no_grad(): err_metrics_batch_i = metric.compute_err_metrics_batch( outputs_img * 255.0, img_gt * 255.0, alpha_map, compute_ssim=False) # sythnesis views model.optimize_parameters(viewFVV_data) loss_list['G_views'] = float(model_net.loss_G_Multi) outputs_views = model_net.get_current_results() outputs_views_img = outputs_views['rs'][:, 0:3, :, :].clone().detach( ).cpu() outputs['nimg_rs_view'] = outputs_views_img outputs['img_rs_view'] = outputs_views_img outputs['uv_map_view'] = viewFVV_data['uv_map'] # vis if not iter % cfg.LOG.PRINT_FREQ: vis.writer_add_image_gan(writer, iter, epoch, inputs=view_data, results=outputs, save_folder=save_img_dir) # Log end = time.time() iter_time = end - start # vis.writer_add_scalar(writer, iter, epoch, err_metrics_batch_i, loss_list, log_time, iter_time) vis.writer_add_scalar_gan(writer, iter, epoch, err_metrics_batch_i, loss_list, log_time, iter_time) iter += 1 start = time.time() if iter % cfg.LOG.CHECKPOINT_FREQ == 0: final_output_dir = os.path.join( log_dir, 'model_epoch_%d_iter_%s_.pth' % (epoch, iter)) is_best_model = False save_checkpoint( { 'epoch': epoch + 1, 'iter': iter + 1, 'model': cfg.MODEL.NAME, 'state_dict': model_net.state_dict(), # 'atlas': atlas, 'optimizer': model_net.optimizer_state_dict() # 'best_state_dict': model.module.state_dict(), # 'perf': perf_indicator, }, is_best_model, final_output_dir)
def main(): args = parse_args() update_config(cfg, args) if not args.camera: # handle video cam = cv2.VideoCapture(args.video_input) video_length = int(cam.get(cv2.CAP_PROP_FRAME_COUNT)) else: cam = cv2.VideoCapture(0) video_length = 30000 ret_val, input_image = cam.read() resize_W = 640 resize_H = 384 input_image = cv2.resize(input_image, (resize_W, resize_H)) # Video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') input_fps = cam.get(cv2.CAP_PROP_FPS) out = cv2.VideoWriter(args.video_output, fourcc, input_fps, (input_image.shape[1], input_image.shape[0])) #### load optical flow model flow_model = load_model() #### load pose-hrnet MODEL pose_model = model_load(cfg) pose_model.cuda() first_frame = 1 flow_boxs = 0 flow_kpts = 0 item = 0 for i in tqdm(range(video_length - 1)): x0 = ckpt_time() ret_val, input_image = cam.read() input_image = cv2.resize(input_image, (resize_W, resize_H)) if first_frame == 0: try: t0 = ckpt_time() flow_result = flow_net(pre_image, input_image, flow_model) flow_boxs, flow_kpts = flow_propagation(keypoints, flow_result) _, t1 = ckpt_time(t0, 1) except Exception as e: print(e) continue pre_image = input_image first_frame = 0 try: bboxs, scores = yolo_det(input_image, human_model) # bbox is coordinate location if type(flow_boxs) == int: inputs, origin_img, center, scale = preprocess(input_image, bboxs, scores, cfg) else: # flow_boxs = (flow_boxs + bboxs) /2 inputs, origin_img, center, scale = preprocess(input_image, flow_boxs, scores, cfg) except: out.write(input_image) cv2.namedWindow("enhanced", 0); cv2.resizeWindow("enhanced", 1080, 720); cv2.imshow('enhanced', input_image) cv2.waitKey(2) continue with torch.no_grad(): # compute output heatmap inputs = inputs[:, [2, 1, 0]] output = pose_model(inputs.cuda()) # compute coordinate preds, maxvals = get_final_preds( cfg, output.clone().cpu().numpy(), np.asarray(center), np.asarray(scale)) if type(flow_boxs) != int: preds = (preds + flow_kpts) / 2 origin_img = np.zeros(origin_img.shape, np.uint8) image = plot_keypoint(origin_img, preds, maxvals, 0.1) out.write(image) keypoints = np.concatenate((preds, maxvals), 2) if args.display: ########### 指定屏幕大小 cv2.namedWindow("enhanced", cv2.WINDOW_GUI_NORMAL); cv2.resizeWindow("enhanced", 1920, 1080); cv2.imshow('enhanced', image) cv2.waitKey(1)
def main(): print('Load config...') args = parse_args() update_config(cfg, args) # print("Setup Log ...") log_dir = cfg.TEST.CALIB_DIR.split('/') log_dir = os.path.join(cfg.TEST.MODEL_DIR, cfg.TRAIN.EXP_NAME+'_'+cfg.TEST.CALIB_NAME[:-4]+'_resol_'+str(cfg.DATASET.OUTPUT_SIZE[0])+'_'+log_dir[-2]+'_'+ log_dir[-1].split('_')[0] + '_' + log_dir[-1].split('_')[1] + '_' + cfg.TEST.MODEL_NAME.split('-')[-1].split('.')[0]) cond_mkdir(log_dir) if len(cfg.TEST.FRAME_RANGE) > 1: save_dir_img_ext = cfg.TEST.SAVE_FOLDER + '_' + str(cfg.TEST.FRAME_RANGE[0]) +'_'+ str(cfg.TEST.FRAME_RANGE[-1]) else: save_dir_img_ext = cfg.TEST.SAVE_FOLDER save_dir_img = os.path.join(log_dir, save_dir_img_ext) cond_mkdir(save_dir_img) # log_dir, iter, checkpoint_path = create_logger(cfg, args.cfg) # print(args) # print(cfg) # print("*" * 100) print('Set gpus...' + str(cfg.GPUS)[1:-1]) print(' Batch size: '+ str(cfg.TEST.BATCH_SIZE)) if not cfg.GPUS == 'None': os.environ["CUDA_VISIBLE_DEVICES"]=str(cfg.GPUS)[1:-1] # import pytorch after set cuda import torch import torchvision from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from lib.models import metric from lib.models.render_net import RenderNet from lib.models.feature_net import FeatureNet from lib.models.merge_net import MergeNet from utils.encoding import DataParallelModel from lib.dataset.DomeViewDataset import DomeViewDataset from lib.dataset.DPViewDataset import DPViewDataset # device = torch.device('cuda: 2') # device = torch.device('cuda: '+ str(cfg.GPUS[-1])) # print("*" * 100) print("Build dataloader ...") # dataset for training views view_dataset = eval(cfg.DATASET.DATASET)(cfg = cfg, isTrain=False) print("*" * 100) print('Build Network...') model_net = eval(cfg.MODEL.NAME)(cfg) print('Loading Model...') checkpoint_path = cfg.TEST.MODEL_PATH if os.path.exists(checkpoint_path): pass elif os.path.exists(os.path.join(cfg.TEST.MODEL_DIR, checkpoint_path)): checkpoint_path = os.path.join(cfg.TEST.MODEL_DIR,checkpoint_path) model_net.load_checkpoint(checkpoint_path) if checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location='cpu') iter_init = checkpoint['iter'] epoch_begin = checkpoint['epoch'] model_net.load_state_dict(checkpoint['state_dict']) print(' Load checkpoint path from %s'%(checkpoint_path)) print('Start buffering data for inference...') view_dataloader = DataLoader(view_dataset, batch_size = cfg.TEST.BATCH_SIZE, shuffle = False, num_workers = cfg.WORKERS) view_dataset.buffer_all() # Init Rasterizer if cfg.DATASET.DATASET == 'realdome_cx': view_data = view_dataset.read_view(0) cur_obj_path = view_data['obj_path'] frame_idx = view_data['f_idx'] obj_data = view_dataset.objs[frame_idx] model_net.init_rasterizer(obj_data, view_dataset.global_RT) # model_net.set_parallel(cfg.GPUS) # model_net.set_mode(is_train = True) model = model_net # model = DataParallelModel(model_net) model.cuda() model.train() print('Begin inference...') inter = 0 with torch.no_grad(): for view_trgt in view_dataloader: start = time.time() ROI = None img_gt = None # get image # if cfg.DATASET.DATASET == 'realdome_cx': # uv_map, alpha_map, cur_obj_path = model.module.project_with_rasterizer(cur_obj_path, view_dataset.objs, view_trgt) # elif cfg.DATASET.DATASET == 'densepose': uv_map = view_trgt['uv_map'].cuda() # alpha_map = view_trgt['mask'][:,None,:,:].cuda() outputs = model.forward(uv_map = uv_map, img_gt = img_gt) neural_img = outputs[:, 3:6, : ,:].clamp(min = 0., max = 1.) outputs = outputs[:, 0:3, : ,:] if type(outputs) == list: for iP in range(len(outputs)): # outputs[iP] = outputs[iP].to(device) outputs[iP] = outputs[iP].cuda() outputs = torch.cat(outputs, dim = 0) # img_max_val = 2.0 # outputs = (outputs * 0.5 + 0.5) * img_max_val # map to [0, img_max_val] # if alpha_map: # outputs = outputs * alpha_map # save for batch_idx in range(0, outputs.shape[0]): cv2.imwrite(os.path.join(save_dir_img, str(inter).zfill(5) + '.png'), outputs[batch_idx, :].permute((1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) inter = inter + 1 end = time.time() print("View %07d t_total %0.4f" % (inter, end - start)) make_gif(save_dir_img, save_dir_img+'.gif')
def main(): # set all the configurations args = parse_args() update_config(cfg, args) # set the logger, tb_log_dir means tensorboard logdir logger, final_output_dir, tb_log_dir = create_logger( cfg, args.cfg, 'train') logger.info(pprint.pformat(args)) logger.info(cfg) writer_dict = { 'writer': SummaryWriter(log_dir=tb_log_dir), 'train_global_steps': 0, 'valid_global_steps': 0, } # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED # bulid up model model = get_net(cfg) model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() # Data loading normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg, True, transforms.Compose([ transforms.ToTensor(), normalize, ])) valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg, False, transforms.Compose([ transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=cfg.TRAIN.SHUFFLE, num_workers=cfg.WORKERS, pin_memory=cfg.PIN_MEMORY) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=False, num_workers=cfg.WORKERS, pin_memory=cfg.PIN_MEMORY) # define loss function (criterion) and optimizer criterion = get_loss(cfg).cuda() optimizer = get_optimizer(cfg, model) # load checkpoint model best_perf = 0.0 best_model = False last_epoch = -1 lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, last_epoch=last_epoch) begin_epoch = cfg.TRAIN.BEGIN_EPOCH checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth') if cfg.AUTO_RESUME and os.path.exists(checkpoint_file): logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) checkpoint = torch.load(checkpoint_file) begin_epoch = checkpoint['epoch'] best_perf = checkpoint['perf'] last_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file, checkpoint['epoch'])) # training for epoch in range(begin_epoch + 1, cfg.TRAIN.END_EPOCH + 1): # train for one epoch train(cfg, train_loader, model, criterion, optimizer, epoch, writer_dict) lr_scheduler.step() # evaluate on validation set if epoch % cfg.TRAIN.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH + 1: perf_indicator = validate(cfg, valid_loader, valid_dataset, model, criterion, final_output_dir, tb_log_dir, writer_dict) if perf_indicator >= best_perf: best_perf = perf_indicator best_model = True else: best_model = False # save checkpoint model and best model logger.info('=> saving checkpoint to {}'.format(final_output_dir)) save_checkpoint( { 'epoch': epoch, 'model': cfg.MODEL.NAME, 'state_dict': model.state_dict(), 'best_state_dict': model.module.state_dict(), 'perf': perf_indicator, 'optimizer': optimizer.state_dict(), }, best_model, final_output_dir) # save final model final_model_state_file = os.path.join(final_output_dir, 'final_state.pth') logger.info( '=> saving final model state to {}'.format(final_model_state_file)) torch.save(model.module.state_dict(), final_model_state_file) writer_dict['writer'].close()
def main(): print('Load config...') args = parse_args() update_config(cfg, args) # cfg.defrost() # cfg.RANK = args.ranka # cfg.freeze() # device allocation print('Set device...') #print(cfg.GPUS) #os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPUS #device = torch.device('cuda') torch.cuda.set_device(cfg.GPUS[0]) device = torch.device('cuda:' + str(cfg.GPUS[0])) print("Build dataloader ...") # load texture if cfg.DATASET.TEX_PATH: texture_init = cv2.cvtColor(cv2.imread(cfg.DATASET.TEX_PATH), cv2.COLOR_BGR2RGB) texture_init_resize = cv2.resize( texture_init, (cfg.MODEL.TEX_MAPPER.NUM_SIZE, cfg.MODEL.TEX_MAPPER.NUM_SIZE), interpolation=cv2.INTER_AREA).astype(np.float32) / 255.0 texture_init_use = torch.from_numpy(texture_init_resize).to(device) # dataset for training views view_dataset = dataio.ViewDataset( cfg=cfg, root_dir=cfg.DATASET.ROOT, calib_path=cfg.DATASET.CALIB_PATH, calib_format=cfg.DATASET.CALIB_FORMAT, sampling_pattern=cfg.TRAIN.SAMPLING_PATTERN, precomp_high_dir=cfg.DATASET.PRECOMP_DIR, precomp_low_dir=cfg.DATASET.PRECOMP_DIR, preset_uv_path=cfg.DATASET.UV_PATH, ) # dataset for validation views view_val_dataset = dataio.ViewDataset( cfg=cfg, root_dir=cfg.DATASET.ROOT, calib_path=cfg.DATASET.CALIB_PATH, calib_format=cfg.DATASET.CALIB_FORMAT, sampling_pattern=cfg.TRAIN.SAMPLING_PATTERN_VAL, precomp_high_dir=cfg.DATASET.PRECOMP_DIR, precomp_low_dir=cfg.DATASET.PRECOMP_DIR, ) num_view_val = len(view_val_dataset) print('Build Network...') # Rasterizer cur_obj_path = '' if not cfg.DATASET.LOAD_PRECOMPUTE: view_data = view_dataset.read_view(0) cur_obj_path = view_data['obj_path'] frame_idx = view_data['f_idx'] obj_data = view_dataset.objs[frame_idx] rasterizer = network.Rasterizer( cfg, obj_fp=cur_obj_path, img_size=cfg.DATASET.OUTPUT_SIZE[0], camera_mode=cfg.DATASET.CAM_MODE, obj_data=obj_data, # preset_uv_path = cfg.DATASET.UV_PATH, global_RT=view_dataset.global_RT) # texture mapper texture_mapper = network.TextureMapper( texture_size=cfg.MODEL.TEX_MAPPER.NUM_SIZE, texture_num_ch=cfg.MODEL.TEX_MAPPER.NUM_CHANNELS, mipmap_level=cfg.MODEL.TEX_MAPPER.MIPMAP_LEVEL, apply_sh=cfg.MODEL.TEX_MAPPER.SH_BASIS) # render net render_net = network.RenderingNet( nf0=cfg.MODEL.RENDER_NET.NF0, in_channels=cfg.MODEL.TEX_MAPPER.NUM_CHANNELS, out_channels=3, num_down_unet=5, use_gcn=False) # interpolater interpolater = network.Interpolater() # L1 loss criterionL1 = nn.L1Loss(reduction='mean').to(device) # Optimizer optimizerG = torch.optim.Adam(list(texture_mapper.parameters()) + list(render_net.parameters()), lr=cfg.TRAIN.LR) print('Loading Model...') iter = 0 dir_name = os.path.join(datetime.datetime.now().strftime('%m-%d') + '_' + datetime.datetime.now().strftime('%H-%M-%S') + '_' + cfg.TRAIN.SAMPLING_PATTERN + '_' + cfg.DATASET.ROOT.strip('/').split('/')[-1]) if cfg.TRAIN.EXP_NAME is not '': dir_name += '_' + cfg.TRAIN.EXP_NAME if cfg.AUTO_RESUME: checkpoint_path = '' if cfg.TRAIN.RESUME and cfg.TRAIN.CHECKPOINT: checkpoint_path = cfg.TRAIN.CHECKPOINT dir_name = cfg.TRAIN.CHECKPOINT_DIR nums = [ int(s) for s in cfg.TRAIN.CHECKPOINT_NAME.split('_') if s.isdigit() ] cfg.defrost() cfg.TRAIN.BEGIN_EPOCH = nums[0] + 1 cfg.freeze() iter = nums[1] + 1 elif cfg.MODEL.PRETRAINED: checkpoint_path = cfg.MODEL.PRETRAIN if checkpoint_path: print(' Checkpoint_path : %s' % (checkpoint_path)) util.custom_load([texture_mapper, render_net], ['texture_mapper', 'render_net'], checkpoint_path) else: print(' Not load params. ') texture_mapper.to(device) render_net.to(device) interpolater.to(device) rasterizer.to(device) texture_mapper_module = texture_mapper render_net_module = render_net # use multi-GPU if len(cfg.GPUS) > 1: texture_mapper = nn.DataParallel(texture_mapper, device_ids=cfg.GPUS) render_net = nn.DataParallel(render_net, device_ids=cfg.GPUS) interpolater = nn.DataParallel(interpolater, device_ids=cfg.GPUS) rasterizer = nn.DataParallel(rasterizer, device_ids=cfg.GPUS) rasterizer = rasterizer.module # set to training mode texture_mapper.train() render_net.train() interpolater.train() rasterizer.eval() # not train now part_list = [texture_mapper_module, render_net_module] # collect all networks part_name_list = ['texture_mapper', 'render_net'] print("*" * 100) print("Number of generator parameters:") cfg.defrost() cfg.MODEL.TEX_MAPPER.NUM_PARAMS = util.print_network(texture_mapper).item() cfg.MODEL.RENDER_NET.NUM_PARAMS = util.print_network(render_net).item() cfg.freeze() print("*" * 100) print("Setup Log ...") log_dir = os.path.join(cfg.LOG.LOGGING_ROOT, dir_name) data_util.cond_mkdir(log_dir) val_out_dir = os.path.join(log_dir, 'val_out') val_gt_dir = os.path.join(log_dir, 'val_gt') val_err_dir = os.path.join(log_dir, 'val_err') data_util.cond_mkdir(val_out_dir) data_util.cond_mkdir(val_gt_dir) data_util.cond_mkdir(val_err_dir) util.custom_copy(args.cfg, os.path.join(log_dir, cfg.LOG.CFG_NAME)) print('Start buffering data for training and validation...') view_dataloader = DataLoader(view_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=cfg.TRAIN.SHUFFLE, num_workers=8) view_val_dataloader = DataLoader(view_val_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False, num_workers=8) #view_dataset.buffer_all() #view_val_dataset.buffer_all() # Save all command line arguments into a txt file in the logging directory for later referene. writer = SummaryWriter(log_dir) # iter = cfg.TRAIN.BEGIN_EPOCH * len(view_dataset) # pre model is batch-1 print('Begin training...') # init value val_log_batch_id = 0 first_val = True img_h, img_w = cfg.DATASET.OUTPUT_SIZE for epoch in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH): for view_trgt in view_dataloader: start = time.time() # get image img_gt = [] img_gt.append(view_trgt[0]['img_gt'].to(device)) ROI = view_trgt[0]['ROI'].to(device) # get uvmap alpha uv_map = [] alpha_map = [] if not cfg.DATASET.LOAD_PRECOMPUTE: # raster module frame_idxs = view_trgt[0]['f_idx'].numpy() for batch_idx, frame_idx in enumerate(frame_idxs): obj_path = view_trgt[0]['obj_path'][batch_idx] if cur_obj_path != obj_path: cur_obj_path = obj_path obj_data = view_dataset.objs[frame_idx] rasterizer.update_vs(obj_data['v_attr']) proj = view_trgt[0]['proj'].to(device)[batch_idx, ...] pose = view_trgt[0]['pose'].to(device)[batch_idx, ...] dist_coeffs = view_trgt[0]['dist_coeffs'].to(device)[ batch_idx, ...] uv_map_single, alpha_map_single, _, _, _, _, _, _, _, _, _, _, _, _ = \ rasterizer(proj = proj[None, ...], pose = pose[None, ...], dist_coeffs = dist_coeffs[None, ...], offset = None, scale = None, ) uv_map.append(uv_map_single[0, ...].clone().detach()) alpha_map.append(alpha_map_single[0, ...].clone().detach()) # fix alpha map uv_map = torch.stack(uv_map, dim=0) alpha_map = torch.stack(alpha_map, dim=0)[:, None, :, :] # alpha_map = alpha_map * torch.tensor(img_gt[0][:,0,:,:][:,None,:,:] <= (2.0 * 255)).permute(0,2,1,3).to(alpha_map.dtype).to(alpha_map.device) # check per iter image for batch_idx, frame_idx in enumerate(frame_idxs): if cfg.DEBUG.SAVE_TRANSFORMED_IMG: save_dir_img_gt = './Debug/image_mask' save_path_img_gt = os.path.join( save_dir_img_gt, '%06d_%03d.png' % (iter, frame_idx)) cv2.imwrite( save_path_img_gt, cv2.cvtColor( img_gt[0][batch_idx, ...].cpu().detach().numpy( ).transpose(1, 2, 0) * 255.0, cv2.COLOR_RGB2BGR)) #cv2.imwrite(os.path.join(save_dir_img_gt, '%03d_'%frame_idx + img_fn), cv2.cvtColor(img_gt*255.0, cv2.COLOR_BGR2RGB)) print(' Save img: ' + save_path_img_gt) if cfg.DEBUG.SAVE_TRANSFORMED_MASK: save_alpha_map = alpha_map.permute( 0, 2, 3, 1).cpu().detach().numpy() save_dir_mask = './Debug/image_mask' save_path_mask = os.path.join( save_dir_mask, '%06d_%03d_mask.png' % (iter, frame_idx)) cv2.imwrite(save_path_mask, save_alpha_map[batch_idx, ...] * 255.0) print(' Save mask: ' + save_path_mask) else: # get view data uv_map = view_trgt[0]['uv_map'].to(device) # [N, H, W, 2] # sh_basis_map = view_trgt[0]['sh_basis_map'].to(device) # [N, H, W, 9] alpha_map = view_trgt[0]['alpha_map'][:, None, :, :].to( device) # [N, 1, H, W] # sample texture # neural_img = texture_mapper(uv_map, sh_basis_map) neural_img = texture_mapper(uv_map) # rendering net outputs = render_net(neural_img, None) img_max_val = 2.0 outputs = (outputs * 0.5 + 0.5) * img_max_val # map to [0, img_max_val] if type(outputs) is not list: outputs = [outputs] # # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors, also weight loss by alpha # alpha_map_central = alpha_map[:, :, 5:-5, 5:-5] # for i in range(len(view_trgt)): # outputs[i] = outputs[i][:, :, 5:-5, 5:-5] * alpha_map_central # img_gt[i] = img_gt[i][:, :, 5:-5, 5:-5] * alpha_map_central # ignore loss outside ROI for i in range(len(view_trgt)): outputs[i] = outputs[i] * ROI * alpha_map img_gt[i] = img_gt[i] * ROI * alpha_map # loss on final image loss_rn = list() for idx in range(len(view_trgt)): loss_rn.append( criterionL1(outputs[idx].contiguous().view(-1).float(), img_gt[idx].contiguous().view(-1).float())) loss_rn = torch.stack(loss_rn, dim=0).mean() # total loss for generator loss_g = loss_rn optimizerG.zero_grad() loss_g.backward() optimizerG.step() # error metrics with torch.no_grad(): err_metrics_batch_i = metric.compute_err_metrics_batch( outputs[0] * 255.0, img_gt[0] * 255.0, alpha_map, compute_ssim=False) # err_metrics_batch_i = metric.compute_err_metrics_batch(outputs[0] * 255.0, img_gt[0] * 255.0, alpha_map_central, compute_ssim = False) # tensorboard scalar logs of training data writer.add_scalar("loss_g", loss_g, iter) writer.add_scalar("loss_rn", loss_rn, iter) writer.add_scalar("final_mae_valid", err_metrics_batch_i['mae_valid_mean'], iter) writer.add_scalar("final_psnr_valid", err_metrics_batch_i['psnr_valid_mean'], iter) end = time.time() print( "Iter %07d Epoch %03d loss_g %0.4f mae_valid %0.4f psnr_valid %0.4f t_total %0.4f" % (iter, epoch, loss_g, err_metrics_batch_i['mae_valid_mean'], err_metrics_batch_i['psnr_valid_mean'], end - start)) # tensorboard figure logs of training data if not iter % cfg.LOG.PRINT_FREQ: output_final_vs_gt = [] for i in range(len(view_trgt)): output_final_vs_gt.append(outputs[i].clamp(min=0., max=1.)) output_final_vs_gt.append(img_gt[i].clamp(min=0., max=1.)) output_final_vs_gt.append( (outputs[i] - img_gt[i]).abs().clamp(min=0., max=1.)) output_final_vs_gt = torch.cat(output_final_vs_gt, dim=0) raster_uv_maps = torch.cat( ( uv_map.permute(0, 3, 1, 2), # N H W 2 -> N 2 H W torch.zeros(uv_map.shape[0], 1, img_h, img_w, dtype=uv_map.dtype, device=uv_map.device)), dim=1) writer.add_image( "raster_uv_vis", torchvision.utils.make_grid( raster_uv_maps, nrow=raster_uv_maps[0].shape[0], range=(0, 1), scale_each=False, normalize=False).cpu().detach().numpy() [::-1, :, :], # uv0 -> 0vu (rgb) iter) writer.add_image( "output_final_vs_gt", torchvision.utils.make_grid( output_final_vs_gt, nrow=outputs[0].shape[0], # 3 range=(0, 1), scale_each=False, normalize=False).cpu().detach().numpy(), iter) # validation if not iter % cfg.TRAIN.VAL_FREQ: start_val = time.time() with torch.no_grad(): # error metrics err_metrics_val = {} err_metrics_val['mae_valid'] = [] err_metrics_val['mse_valid'] = [] err_metrics_val['psnr_valid'] = [] err_metrics_val['ssim_valid'] = [] # loop over batches batch_id = 0 for view_val_trgt in view_val_dataloader: start_val_i = time.time() # get image img_gt = [] img_gt.append(view_val_trgt[0]['img_gt'].to(device)) ROI = view_val_trgt[0]['ROI'].to(device) # get uvmap alpha uv_map = [] alpha_map = [] if not cfg.DATASET.LOAD_PRECOMPUTE: # build raster module frame_idxs = view_val_trgt[0]['f_idx'].numpy() for batch_idx, frame_idx in enumerate(frame_idxs): obj_path = view_val_trgt[0]['obj_path'][ batch_idx] if cur_obj_path != obj_path: cur_obj_path = obj_path obj_data = view_val_dataset.objs[frame_idx] rasterizer.update_vs(obj_data['v_attr']) proj = view_val_trgt[0]['proj'].to(device)[ batch_idx, ...] pose = view_val_trgt[0]['pose'].to(device)[ batch_idx, ...] dist_coeffs = view_val_trgt[0][ 'dist_coeffs'].to(device)[batch_idx, ...] uv_map_single, alpha_map_single, _, _, _, _, _, _, _, _, _, _, _, _ = \ rasterizer(proj = proj[None, ...], pose = pose[None, ...], dist_coeffs = dist_coeffs[None, ...], offset = None, scale = None, ) uv_map.append( uv_map_single[0, ...].clone().detach()) alpha_map.append( alpha_map_single[0, ...].clone().detach()) # fix alpha map uv_map = torch.stack(uv_map, dim=0) alpha_map = torch.stack(alpha_map, dim=0)[:, None, :, :] # alpha_map = alpha_map * torch.tensor(img_gt[0][:,0,:,:][:,None,:,:] <= (2.0 * 255)).permute(0,2,1,3).to(alpha_map.dtype).to(alpha_map.device) else: uv_map = view_val_trgt[0]['uv_map'].to( device) # [N, H, W, 2] # sh_basis_map = view_val_trgt[0]['sh_basis_map'].to(device) # [N, H, W, 9] alpha_map = view_val_trgt[0][ 'alpha_map'][:, None, :, :].to( device) # [N, 1, H, W] view_idx = view_val_trgt[0]['idx'] num_view = len(view_val_trgt) img_gt = [] for i in range(num_view): img_gt.append( view_val_trgt[i]['img_gt'].to(device)) # sample texture # neural_img = texture_mapper(uv_map, sh_basis_map) neural_img = texture_mapper(uv_map) # rendering net outputs = render_net(neural_img, None) img_max_val = 2.0 outputs = (outputs * 0.5 + 0.5 ) * img_max_val # map to [0, img_max_val] if type(outputs) is not list: outputs = [outputs] # apply alpha and ROI for i in range(num_view): outputs[i] = outputs[i] * alpha_map * ROI img_gt[i] = img_gt[i] * alpha_map * ROI # tensorboard figure logs of validation data if batch_id == val_log_batch_id: output_final_vs_gt = [] for i in range(num_view): output_final_vs_gt.append(outputs[i].clamp( min=0., max=1.)) output_final_vs_gt.append(img_gt[i].clamp( min=0., max=1.)) output_final_vs_gt.append( (outputs[i] - img_gt[i]).abs().clamp( min=0., max=1.)) output_final_vs_gt = torch.cat(output_final_vs_gt, dim=0) writer.add_image( "output_final_vs_gt_val", torchvision.utils.make_grid( output_final_vs_gt, nrow=outputs[0].shape[0], # 3 range=(0, 1), scale_each=False, normalize=False).cpu().detach().numpy(), iter) # error metrics err_metrics_batch_i_final = metric.compute_err_metrics_batch( outputs[0] * 255.0, img_gt[0] * 255.0, alpha_map, compute_ssim=True) batch_size = view_idx.shape[0] for i in range(batch_size): for key in list(err_metrics_val.keys()): if key in err_metrics_batch_i_final.keys(): err_metrics_val[key].append( err_metrics_batch_i_final[key][i]) # save images for i in range(batch_size): cv2.imwrite( os.path.join( val_out_dir, str(iter).zfill(8) + '_' + str(view_idx[i].cpu().detach().numpy( )).zfill(5) + '.png'), outputs[0][i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) cv2.imwrite( os.path.join( val_err_dir, str(iter).zfill(8) + '_' + str(view_idx[i].cpu().detach().numpy( )).zfill(5) + '.png'), (outputs[0] - img_gt[0]).abs().clamp( min=0., max=1.)[i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) if first_val: cv2.imwrite( os.path.join( val_gt_dir, str(view_idx[i].cpu().detach().numpy() ).zfill(5) + '.png'), img_gt[0][i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) end_val_i = time.time() print( "Val batch %03d mae_valid %0.4f psnr_valid %0.4f ssim_valid %0.4f t_total %0.4f" % (batch_id, err_metrics_batch_i_final['mae_valid_mean'], err_metrics_batch_i_final['psnr_valid_mean'], err_metrics_batch_i_final['ssim_valid_mean'], end_val_i - start_val_i)) batch_id += 1 for key in list(err_metrics_val.keys()): if err_metrics_val[key]: err_metrics_val[key] = np.vstack( err_metrics_val[key]) err_metrics_val[ key + '_mean'] = err_metrics_val[key].mean() else: err_metrics_val[key + '_mean'] = np.nan # tensorboard scalar logs of validation data writer.add_scalar("final_mae_valid_val", err_metrics_val['mae_valid_mean'], iter) writer.add_scalar("final_psnr_valid_val", err_metrics_val['psnr_valid_mean'], iter) writer.add_scalar("final_ssim_valid_val", err_metrics_val['ssim_valid_mean'], iter) first_val = False val_log_batch_id = (val_log_batch_id + 1) % batch_id end_val = time.time() print( "Val mae_valid %0.4f psnr_valid %0.4f ssim_valid %0.4f t_total %0.4f" % (err_metrics_val['mae_valid_mean'], err_metrics_val['psnr_valid_mean'], err_metrics_val['ssim_valid_mean'], end_val - start_val)) iter += 1 if iter % cfg.LOG.CHECKPOINT_FREQ == 0: util.custom_save( os.path.join(log_dir, 'model_epoch_%d_iter_%s_.pth' % (epoch, iter)), part_list, part_name_list) util.custom_save( os.path.join(log_dir, 'model_epoch_%d_iter_%s_.pth' % (epoch, iter)), part_list, part_name_list)
def main(): args = parse_args() update_config(cfg, args) logger, final_output_dir, tb_log_dir = create_logger( cfg, args.cfg, 'train') logger.info(pprint.pformat(args)) logger.info(cfg) # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=True) # copy model file this_dir = os.path.dirname(__file__) shutil.copy2( os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'), final_output_dir) # logger.info(pprint.pformat(model)) writer_dict = { 'writer': SummaryWriter(log_dir=tb_log_dir), 'train_global_steps': 0, 'valid_global_steps': 0, } dump_input = torch.rand( (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])) # writer_dict['writer'].add_graph(model, (dump_input)) logger.info(get_model_summary(model, dump_input)) model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() # define loss function (criterion) and optimizer criterion = JointsMSELoss( use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda() # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, transforms.Compose([transforms.ToTensor(), normalize])) valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, transforms.Compose([transforms.ToTensor(), normalize])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=cfg.TRAIN.SHUFFLE, num_workers=cfg.WORKERS, pin_memory=cfg.PIN_MEMORY) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=False, num_workers=cfg.WORKERS, pin_memory=cfg.PIN_MEMORY) best_perf = 0.0 best_model = False last_epoch = -1 optimizer = get_optimizer(cfg, model) begin_epoch = cfg.TRAIN.BEGIN_EPOCH checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth') if cfg.AUTO_RESUME and os.path.exists(checkpoint_file): logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) checkpoint = torch.load(checkpoint_file) begin_epoch = checkpoint['epoch'] best_perf = checkpoint['perf'] last_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file, checkpoint['epoch'])) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, last_epoch=last_epoch) for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH): # train for one epoch train(cfg, train_loader, model, criterion, optimizer, epoch, final_output_dir, tb_log_dir, writer_dict) lr_scheduler.step() # evaluate on validation set perf_indicator = validate(cfg, valid_loader, valid_dataset, model, criterion, final_output_dir, tb_log_dir, writer_dict) if perf_indicator >= best_perf: best_perf = perf_indicator best_model = True else: best_model = False logger.info('=> saving checkpoint to {}'.format(final_output_dir)) save_checkpoint( { 'epoch': epoch + 1, 'model': cfg.MODEL.NAME, 'state_dict': model.state_dict(), 'best_state_dict': model.module.state_dict(), 'perf': perf_indicator, 'optimizer': optimizer.state_dict(), }, best_model, final_output_dir) final_model_state_file = os.path.join(final_output_dir, 'final_state.pth') logger.info( '=> saving final model state to {}'.format(final_model_state_file)) torch.save(model.module.state_dict(), final_model_state_file) writer_dict['writer'].close()