def main(): # Build model. model = model_builder.build_model(cfg=cfg) # Read checkpoint. ckpt = torch.load( cfg.MODEL.PATH2CKPT, map_location=torch.device("cpu")) if cfg.GENERAL.RESUME else {} if cfg.GENERAL.RESUME: with utils.log_info(msg="Load pre-trained model.", level="INFO", state=True): model.load_state_dict(ckpt["model"]) # Set device. model, device = utils.set_device(model, cfg.GENERAL.GPU) try: test_data_loader = data_loader.build_data_loader( cfg, cfg.DATA.DATASET, "test") generate(cfg=cfg, model=model, data_loader=test_data_loader, device=device) except: utils.notify("Can not build data loader for test set.", level="ERROR") raise ValueError("")
def build_model(cfg, logger=None): with utils.log_info(msg="Build model from configurations.", level="INFO", state=True, logger=logger): model = Model(cfg) return model
def generate( cfg, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, logger=None, *args, **kwargs, ): model.eval() total_loss = [] with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger): pbar = tqdm(total=len(data_loader), dynamic_ncols=True) for idx, data in enumerate(data_loader): start_time = time.time() output, *_ = utils.inference(model=model, data=data, device=device) for i in range(output.shape[0]): save_dir = os.path.join(cfg.SAVE.DIR, "results") if not os.path.exists(save_dir): os.makedirs(save_dir) path2file = os.path.join(save_dir, data["img_idx"][i]+".png") succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file) if not succeed: utils.notify("Cannot save image to {}".format(path2file)) pbar.update() pbar.close()
def generate( cfg, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, phase, logger=None, *args, **kwargs, ): model.eval() # Prepare to log info. log_info = print if logger is None else logger.log_info total_loss = [] inference_time = [] # Read data and evaluate and record info. with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger): pbar = tqdm(total=len(data_loader), dynamic_ncols=True) for idx, data in enumerate(data_loader): start_time = time.time() output = utils.inference(model=model, data=data, device=device) inference_time.append(time.time()-start_time) for i in range(output.shape[0]): save_dir = os.path.join(cfg.SAVE.DIR, phase) if not os.path.exists(save_dir): os.makedirs(save_dir) path2file = os.path.join(save_dir, data["img_idx"][i]+"_g.png") succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file) if not succeed: log_info("Cannot save image to {}".format(path2file)) pbar.update() pbar.close() log_info("Runtime per image: {:<5} seconds.".format(round(sum(inference_time)/len(inference_time), 4)))
def train_one_epoch( epoch: int, cfg, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, loss_fn, optimizer: torch.optim.Optimizer, lr_scheduler, metrics_logger, logger=None, *args, **kwargs, ): model.train() # TODO Prepare to log info. log_info = print if logger is None else logger.log_info total_loss = [] # TODO Read data and train and record info. with utils.log_info(msg="TRAIN at epoch: {}, lr: {:<5}".format( str(epoch).zfill(3), optimizer.param_groups[0]["lr"]), level="INFO", state=True, logger=logger): pbar = tqdm(total=len(data_loader), dynamic_ncols=True) for idx, data in enumerate(data_loader): optimizer.zero_grad() out, loss = utils.inference_and_cal_loss(model=model, data=data, loss_fn=loss_fn, device=device) loss.backward() optimizer.step() total_loss.append(loss.detach().cpu().item()) metrics_logger.record("train", epoch, "loss", loss.detach().cpu().item()) output = out.detach().cpu() target = data["target"] utils.cal_and_record_metrics("train", epoch, output, target, metrics_logger, logger=logger) pbar.set_description( "Epoch: {:<3}, avg loss: {:<5}, cur loss: {:<5}".format( epoch, round(sum(total_loss) / len(total_loss), 5), round(total_loss[-1], 5))) pbar.update() lr_scheduler.step() pbar.close() mean_metrics = metrics_logger.mean("train", epoch) log_info("SSIM: {:<5}, PSNR: {:<5}, MAE: {:<5}, Loss: {:<5}".format( mean_metrics["SSIM"], mean_metrics["PSNR"], mean_metrics["MAE"], mean_metrics["loss"], ))
def evaluate( epoch: int, cfg, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, loss_fn, metrics_logger, phase="valid", logger=None, save=False, *args, **kwargs, ): model.eval() # Prepare to log info. log_info = print if logger is None else logger.log_info total_loss = [] inference_time = [] # Read data and evaluate and record info. with utils.log_info(msg="{} at epoch: {}".format(phase.upper(), str(epoch).zfill(3)), level="INFO", state=True, logger=logger): # log_info("Will{}save results to {}".format(" " if save else " not ", cfg.SAVE.DIR)) pbar = tqdm(total=len(data_loader), dynamic_ncols=True) for idx, data in enumerate(data_loader): start_time = time.time() out, loss = utils.inference_and_cal_loss(model=model, data=data, loss_fn=loss_fn, device=device) inference_time.append(time.time()-start_time) total_loss.append(loss.detach().cpu().item()) if save: # Save results to directory. for i in range(out.shape[0]): save_dir = os.path.join(cfg.SAVE.DIR, phase) if not os.path.exists(save_dir): os.makedirs(save_dir) path2file = os.path.join(save_dir, data["img_idx"][i]+"_g.png") succeed = utils.save_image(out[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file) if not succeed: log_info("Cannot save image to {}".format(path2file)) metrics_logger.record(phase, epoch, "loss", loss.detach().cpu().item()) output = out.detach().cpu() target = data["target"] utils.cal_and_record_metrics(phase, epoch, output, target, metrics_logger, logger=logger) pbar.set_description("Epoch: {:<3}, avg loss: {:<5}, cur loss: {:<5}".format(epoch, round(sum(total_loss)/len(total_loss), 5), round(total_loss[-1], 5))) pbar.update() pbar.close() log_info("Runtime per image: {:<5} seconds.".format(round(sum(inference_time)/len(inference_time), 4))) mean_metrics = metrics_logger.mean(phase, epoch) log_info("SSIM: {:<5}, PSNR: {:<5}, MAE: {:<5}, Loss: {:<5}".format( mean_metrics["SSIM"], mean_metrics["PSNR"], mean_metrics["MAE"], mean_metrics["loss"], ))
r""" Author: Yiqun Chen Docs: Test modules not model. """ import os, sys sys.path.append(os.path.join(sys.path[0], "..")) sys.path.append(os.path.join(os.getcwd(), "src")) from utils import utils @utils.log_info_wrapper(msg="Another start info", logger=None) def test(): print("Hello World!") if __name__ == "__main__": with utils.log_info(msg="Start test", level="INFO", state=True, logger=None): test()