def train(cfg, writer, logger): # Setup seeds for reproducing torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset'], cfg['task']) data_path = cfg['data']['path'] t_loader = data_loader( data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), img_norm=cfg['data']['img_norm'], # version = cfg['data']['version'], augmentations=data_aug) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], img_norm=cfg['data']['img_norm'], # version=cfg['data']['version'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), ) trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics if cfg['task'] == "seg": n_classes = t_loader.n_classes running_metrics_val = runningScoreSeg(n_classes) elif cfg['task'] == "depth": n_classes = 0 running_metrics_val = runningScoreDepth() else: raise NotImplementedError('Task {} not implemented'.format( cfg['task'])) # Setup Model model = get_model(cfg['model'], cfg['task'], n_classes).to(device) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) # checkpoint = torch.load(cfg['training']['resume'], map_location=lambda storage, loc: storage) # load model trained on gpu on cpu model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 best_rel = 100.0 # i = start_iter i = 0 flag = True while i <= cfg['training']['train_iters'] and flag: print(len(trainloader)) for (images, labels, img_path) in trainloader: start_ts = time.time() # return current time stamp scheduler.step() model.train() # set model to training mode images = images.to(device) labels = labels.to(device) optimizer.zero_grad() #clear earlier gradients outputs = model(images) if cfg['model']['arch'] == "dispnet" and cfg['task'] == "depth": outputs = 1 / outputs loss = loss_fn(input=outputs, target=labels) # compute loss loss.backward() # backpropagation loss optimizer.step() # optimizer parameter update time_meter.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg['training']['train_iters'], loss.item(), time_meter.val / cfg['training']['batch_size']) print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg['training']['val_interval'] == 0 or ( i + 1) == cfg['training']['train_iters']: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, img_path_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model( images_val ) # [batch_size, n_classes, height, width] if cfg['model']['arch'] == "dispnet" and cfg[ 'task'] == "depth": outputs = 1 / outputs val_loss = loss_fn(input=outputs, target=labels_val ) # mean pixelwise loss in a batch if cfg['task'] == "seg": pred = outputs.data.max(1)[1].cpu().numpy( ) # [batch_size, height, width] gt = labels_val.data.cpu().numpy( ) # [batch_size, height, width] elif cfg['task'] == "depth": pred = outputs.squeeze(1).data.cpu().numpy() gt = labels_val.data.squeeze(1).cpu().numpy() else: raise NotImplementedError( 'Task {} not implemented'.format(cfg['task'])) running_metrics_val.update(gt=gt, pred=pred) val_loss_meter.update(val_loss.item()) writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1) logger.info("Iter %d val_loss: %.4f" % (i + 1, val_loss_meter.avg)) print("Iter %d val_loss: %.4f" % (i + 1, val_loss_meter.avg)) # output scores if cfg['task'] == "seg": score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) sys.stdout.flush() logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i + 1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1) elif cfg['task'] == "depth": val_result = running_metrics_val.get_scores() for k, v in val_result.items(): print(k, v) logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i + 1) else: raise NotImplementedError('Task {} not implemented'.format( cfg['task'])) val_loss_meter.reset() running_metrics_val.reset() save_model = False if cfg['task'] == "seg": if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] save_model = True state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } if cfg['task'] == "depth": if val_result["abs rel : \t"] <= best_rel: best_rel = val_result["abs rel : \t"] save_model = True state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_rel": best_rel, } if save_model: save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) if (i + 1) == cfg['training']['train_iters']: flag = False break i += 1
def train(cfg, writer, logger, run_id): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] logger.info("Using dataset: {}".format(data_path)) t_loader = data_loader(data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), ) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model # model = get_model(cfg['model'], n_classes).to(device) model = get_model(cfg['model'], n_classes) logger.info("Using Model: {}".format(cfg['model']['arch'])) # model=apex.parallel.convert_syncbn_model(model) model = model.to(device) # a=range(torch.cuda.device_count()) # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # model = torch.nn.DataParallel(model, device_ids=[0,1]) # model = encoding.parallel.DataParallelModel(model, device_ids=[0, 1]) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0) loss_fn = get_loss_function(cfg) # loss_fn== encoding.parallel.DataParallelCriterion(loss_fn, device_ids=[0, 1]) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() time_meter_val = averageMeter() best_iou = -100.0 i = start_iter flag = True train_data_len = t_loader.__len__() batch_size = cfg['training']['batch_size'] epoch = cfg['training']['train_epoch'] train_iter = int(np.ceil(train_data_len / batch_size) * epoch) val_rlt_f1 = [] val_rlt_IoU = [] best_f1_till_now = 0 best_IoU_till_now = 0 while i <= train_iter and flag: for (images, labels) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = loss_fn(input=outputs, target=labels) loss.backward() # optimizer.backward(loss) optimizer.step() time_meter.update(time.time() - start_ts) ### add by Sprit time_meter_val.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, train_iter, loss.item(), time_meter.avg / cfg['training']['batch_size']) print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg['training']['val_interval'] == 0 or \ (i + 1) == train_iter: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model(images_val) # val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) # val_loss_meter.update(val_loss.item()) # writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1) # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info('{}: {}'.format(k, v)) # writer.add_scalar('val_metrics/{}'.format(k), v, i+1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) # writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1) # val_loss_meter.reset() running_metrics_val.reset() ### add by Sprit avg_f1 = score["Mean F1 : \t"] avg_IoU = score["Mean IoU : \t"] val_rlt_f1.append(avg_f1) val_rlt_IoU.append(score["Mean IoU : \t"]) if avg_f1 >= best_f1_till_now: best_f1_till_now = avg_f1 correspond_iou = score["Mean IoU : \t"] best_epoch_till_now = i + 1 print("\nBest F1 till now = ", best_f1_till_now) print("Correspond IoU= ", correspond_iou) print("Best F1 Iter till now= ", best_epoch_till_now) if avg_IoU >= best_IoU_till_now: best_IoU_till_now = avg_IoU correspond_f1 = score["Mean F1 : \t"] correspond_acc = score["Overall Acc: \t"] best_epoch_till_now = i + 1 print("Best IoU till now = ", best_IoU_till_now) print("Correspond F1= ", correspond_f1) print("Correspond OA= ", correspond_acc) print("Best IoU Iter till now= ", best_epoch_till_now) ### add by Sprit iter_time = time_meter_val.avg time_meter_val.reset() remain_time = iter_time * (train_iter - i) m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain training time = %d hours %d minutes %d seconds \n" % ( h, m, s) else: train_time = "Remain training time : Training completed.\n" print(train_time) if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) if (i + 1) == train_iter: flag = False break my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_f1, cfg['training']['val_interval']) my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_IoU, cfg['training']['val_interval'])
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader # data_loader = get_loader(cfg['data']['dataset']) # data_path = cfg['data']['path'] # # t_loader = data_loader( # data_path, # is_transform=True, # split=cfg['data']['train_split'], # img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), # augmentations=data_aug) # # v_loader = data_loader( # data_path, # is_transform=True, # split=cfg['data']['val_split'], # img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),) # # n_classes = t_loader.n_classes # trainloader = data.DataLoader(t_loader, # batch_size=cfg['training']['batch_size'], # num_workers=cfg['training']['n_workers'], # shuffle=True) # # valloader = data.DataLoader(v_loader, # batch_size=cfg['training']['batch_size'], # num_workers=cfg['training']['n_workers']) paths = { 'masks': './satellitedata/patchohio_train/gt/', 'images': './satellitedata/patchohio_train/rgb', 'nirs': './satellitedata/patchohio_train/nir', 'swirs': './satellitedata/patchohio_train/swir', 'vhs': './satellitedata/patchohio_train/vh', 'vvs': './satellitedata/patchohio_train/vv', 'redes': './satellitedata/patchohio_train/rede', 'ndvis': './satellitedata/patchohio_train/ndvi', } valpaths = { 'masks': './satellitedata/patchohio_val/gt/', 'images': './satellitedata/patchohio_val/rgb', 'nirs': './satellitedata/patchohio_val/nir', 'swirs': './satellitedata/patchohio_val/swir', 'vhs': './satellitedata/patchohio_val/vh', 'vvs': './satellitedata/patchohio_val/vv', 'redes': './satellitedata/patchohio_val/rede', 'ndvis': './satellitedata/patchohio_val/ndvi', } n_classes = 3 train_img_paths = [pth for pth in os.listdir(paths['images']) if ('_01_' not in pth) and ('_25_' not in pth)] val_img_paths = [pth for pth in os.listdir(valpaths['images']) if ('_01_' not in pth) and ('_25_' not in pth)] ntrain = len(train_img_paths) nval = len(val_img_paths) train_idx = [i for i in range(ntrain)] val_idx = [i for i in range(nval)] trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png') valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png') config_path = 'crop_pspnet_config.json' with open(config_path, 'r') as f: mycfg = json.load(f) train_data_path = './satellitedata/' print('train_data_path: {}'.format(train_data_path)) dataset_path, train_dir = os.path.split(train_data_path) print('dataset_path: {}'.format(dataset_path) + ', train_dir: {}'.format(train_dir)) mycfg['dataset_path'] = dataset_path config = Config(**mycfg) config = update_config(config, num_channels=12, nb_epoch=50) #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color) dataset_train = TrainDataset(trainds, train_idx, config, 1) dataset_val = TrainDataset(valds, val_idx, config, 1) trainloader = data.DataLoader(dataset_train, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(dataset_val, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) # Setup Metrics running_metrics_train = runningScore(n_classes) running_metrics_val = runningScore(n_classes) k = 0 nbackground = 0 ncorn = 0 #ncotton = 0 #nrice = 0 nsoybean = 0 for indata in trainloader: k += 1 gt = indata['seg_label'].data.cpu().numpy() nbackground += (gt == 0).sum() ncorn += (gt == 1).sum() #ncotton += (gt == 2).sum() #nrice += (gt == 3).sum() nsoybean += (gt == 2).sum() print('k = {}'.format(k)) print('nbackgraound: {}'.format(nbackground)) print('ncorn: {}'.format(ncorn)) #print('ncotton: {}'.format(ncotton)) #print('nrice: {}'.format(nrice)) print('nsoybean: {}'.format(nsoybean)) wgts = [1.0, 1.0*nbackground/ncorn, 1.0*nbackground/nsoybean] total_wgts = sum(wgts) wgt_background = wgts[0]/total_wgts wgt_corn = wgts[1]/total_wgts #wgt_cotton = wgts[2]/total_wgts #wgt_rice = wgts[3]/total_wgts wgt_soybean = wgts[2]/total_wgts weights = torch.autograd.Variable(torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean])) #weights = torch.autograd.Variable(torch.cuda.FloatTensor([1.0, 1.0, 1.0])) # Setup Model model = get_model(cfg['model'], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'} optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume']) ) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"] ) ) else: logger.info("No checkpoint found at '{}'".format(cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg['training']['train_iters'] and flag: for inputdata in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = inputdata['img_data'] labels = inputdata['seg_label'] #print('images.size: {}'.format(images.size())) #print('labels.size: {}'.format(labels.size())) images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) #print('outputs.size: {}'.format(outputs[1].size())) #print('labels.size: {}'.format(labels.size())) loss = loss_fn(input=outputs[1], target=labels, weight=weights) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format(i + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i+1) time_meter.reset() if (i + 1) % cfg['training']['val_interval'] == 0 or \ (i + 1) == cfg['training']['train_iters']: model.eval() with torch.no_grad(): for inputdata in valloader: images_val = inputdata['img_data'] labels_val = inputdata['seg_label'] images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model(images_val) val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i+1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format( cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) if (i + 1) == cfg['training']['train_iters']: flag = False break
def train(cfg, writer, logger, args): # cfg # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["val_split"], img_size=(1024, 2048), ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = FASSDNet(n_classes=19, alpha=args.alpha).to(device) total_params = sum(p.numel() for p in model.parameters()) print('Parameters:', total_params) model.apply(weights_init) # Non-strict ImageNet pre-train pretrained_path = 'weights/imagenet_weights.pth' checkpoint = torch.load(pretrained_path) q = 1 model_dict = {} state_dict = model.state_dict() # print('================== Weights orig: ', model.base[1].conv.weight[0][0][0]) for k, v in checkpoint.items(): if q == 1: # print("===> Key of checkpoint: ", k) # print("===> Value of checkpoint: ", v[0][0][0]) if ('base.' + k in state_dict): # print("============> CONTAINS KEY...") # print("===> Value of the key: ", state_dict['base.'+k][0][0][0]) pass else: # print("============> DOES NOT CONTAIN KEY...") pass q = 0 if ('base.' + k in state_dict) and (state_dict['base.' + k].shape == checkpoint[k].shape): model_dict['base.' + k] = v state_dict.update(model_dict) # Updated weights with ImageNet pretraining model.load_state_dict(state_dict) # print('================== Weights loaded: ', model.base[0].conv.weight[0][0][0]) # Multi-gpu model model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) print("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) print("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): print_str = "Finetuning model from '{}'".format( cfg["training"]["finetune"]) if logger is not None: logger.info(print_str) print(print_str) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] print_str = "Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"]) print(print_str) if logger is not None: logger.info(print_str) else: print_str = "No checkpoint found at '{}'".format( cfg["training"]["resume"]) print(print_str) if logger is not None: logger.info(print_str) if cfg["training"]["finetune"] is not None: if os.path.isfile(cfg["training"]["finetune"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["finetune"])) checkpoint = torch.load(cfg["training"]["finetune"]) model.load_state_dict(checkpoint["model_state"]) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True loss_all = 0 loss_n = 0 sys.stdout.flush() while i <= cfg["training"]["train_iters"] and flag: for (images, labels, _) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() c_lr = scheduler.get_lr() time_meter.update(time.time() - start_ts) loss_all += loss.item() loss_n += 1 if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f} lr={:.6f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss_all / loss_n, time_meter.avg / cfg["training"]["batch_size"], c_lr[0], ) print(print_str) if logger is not None: logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: torch.cuda.empty_cache() model.eval() loss_all = 0 loss_n = 0 with torch.no_grad(): # for i_val, (images_val, labels_val, _) in tqdm(enumerate(valloader)): for i_val, (images_val, labels_val, _) in enumerate(valloader): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model(images_val) val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) print_str = "Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg) if logger is not None: logger.info(print_str) print(print_str) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print_str = "{}: {}".format(k, v) if logger is not None: logger.info(print_str) print(print_str) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): print_str = "{}: {}".format(k, v) if logger is not None: logger.info(print_str) print(print_str) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) if score["Mean IoU : \t"] >= best_iou: # Save best model (mIoU) best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) torch.cuda.empty_cache() if (i + 1) == cfg["training"]["train_iters"]: flag = False break sys.stdout.flush() # Added
def train(cfg, logger, logdir): # Setup seeds init_seed(11733, en_cudnn=False) # Setup Augmentations train_augmentations = cfg["training"].get("train_augmentations", None) t_data_aug = get_composed_augmentations(train_augmentations) val_augmentations = cfg["validating"].get("val_augmentations", None) v_data_aug = get_composed_augmentations(val_augmentations) # Setup Dataloader path_n = cfg["model"]["path_num"] data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader(data_path,split=cfg["data"]["train_split"],augmentations=t_data_aug,path_num=path_n) v_loader = data_loader(data_path,split=cfg["data"]["val_split"],augmentations=v_data_aug,path_num=path_n) trainloader = data.DataLoader(t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, drop_last=True ) valloader = data.DataLoader(v_loader, batch_size=cfg["validating"]["batch_size"], num_workers=cfg["validating"]["n_workers"] ) logger.info("Using training seting {}".format(cfg["training"])) # Setup Metrics running_metrics_val = runningScore(t_loader.n_classes) # Setup Model and Loss loss_fn = get_loss_function(cfg["training"]) teacher = get_model(cfg["teacher"], t_loader.n_classes) model = get_model(cfg["model"],t_loader.n_classes, loss_fn, cfg["training"]["resume"],teacher) logger.info("Using loss {}".format(loss_fn)) # Setup optimizer optimizer = get_optimizer(cfg["training"], model) # Setup Multi-GPU model = DataParallelModel(model).cuda() #Initialize training param cnt_iter = 0 best_iou = 0.0 time_meter = averageMeter() while cnt_iter <= cfg["training"]["train_iters"]: for (f_img, labels) in trainloader: cnt_iter += 1 model.train() optimizer.zero_grad() start_ts = time.time() outputs = model(f_img,labels,pos_id=cnt_iter%path_n) seg_loss = gather(outputs, 0) seg_loss = torch.mean(seg_loss) seg_loss.backward() time_meter.update(time.time() - start_ts) optimizer.step() if (cnt_iter + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( cnt_iter + 1, cfg["training"]["train_iters"], seg_loss.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) time_meter.reset() if (cnt_iter + 1) % cfg["training"]["val_interval"] == 0 or (cnt_iter + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (f_img_val, labels_val) in tqdm(enumerate(valloader)): outputs = model(f_img_val,pos_id=i_val%path_n) outputs = gather(outputs, 0, dim=0) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": cnt_iter + 1, "model_state": clean_state_dict(model.module.state_dict(),'teacher'), "best_iou": best_iou, } save_path = os.path.join(logdir, "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path)
def train(cfg, writer, logger): # Setup random seeds torch.manual_seed(cfg.get('seed', 1860)) torch.cuda.manual_seed(cfg.get('seed', 1860)) np.random.seed(cfg.get('seed', 1860)) random.seed(cfg.get('seed', 1860)) # Setup device if cfg["device"]["use_gpu"]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): logger.warning("CUDA not available, using CPU instead!") else: device = torch.device("cpu") # Setup augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) if "rcrop" in augmentations.keys(): data_aug_val = get_composed_augmentations( {"rcrop": augmentations["rcrop"]}) # Setup dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] if 'depth_scaling' not in cfg['data'].keys(): cfg['data']['depth_scaling'] = None if 'max_depth' not in cfg['data'].keys(): logger.warning( "Key d_max not found in configuration file! Using default value") cfg['data']['max_depth'] = 256 if 'min_depth' not in cfg['data'].keys(): logger.warning( "Key d_min not found in configuration file! Using default value") cfg['data']['min_depth'] = 1 t_loader = data_loader(data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug, depth_scaling=cfg['data']['depth_scaling'], n_bins=cfg['data']['depth_bins'], max_depth=cfg['data']['max_depth'], min_depth=cfg['data']['min_depth']) v_loader = data_loader(data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug_val, depth_scaling=cfg['data']['depth_scaling'], n_bins=cfg['data']['depth_bins'], max_depth=cfg['data']['max_depth'], min_depth=cfg['data']['min_depth']) trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=cfg['validation']['batch_size'], num_workers=cfg['validation']['n_workers'], shuffle=True, drop_last=True) # Check selected tasks if sum(cfg["data"]["tasks"].values()) > 1: logger.info("Running multi-task training with config: {}".format( cfg["data"]["tasks"])) # Get output dimension of the network's final layer n_classes_d_cls = None if cfg["data"]["tasks"]["d_cls"]: n_classes_d_cls = t_loader.n_classes_d_cls # Setup metrics for validation if cfg["data"]["tasks"]["d_cls"]: running_metrics_val_d_cls = runningScore(n_classes_d_cls) if cfg["data"]["tasks"]["d_reg"]: running_metrics_val_d_reg = running_side_score() # Setup model model = get_model(cfg['model'], cfg["data"]["tasks"], n_classes_d_cls=n_classes_d_cls).to(device) # model = d_regResNet().to(device) # Setup multi-GPU support n_gpus = torch.cuda.device_count() if n_gpus > 1: logger.info("Running multi-gpu training on {} GPUs".format(n_gpus)) model = torch.nn.DataParallel(model, device_ids=range(n_gpus)) # Setup multi-task loss task_weights = {} update_weights = True if \ cfg["training"]["task_weight_policy"] == 'update' else False for task, weight in cfg["training"]["task_weight_init"].items(): task_weights[task] = torch.tensor(weight).float() task_weights[task] = task_weights[task].to(device) task_weights[task] = task_weights[task].requires_grad_(update_weights) logger.info("Task weights were initialized with {}".format( cfg["training"]["task_weight_init"])) # Setup optimizer and lr_scheduler optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } objective_params = list(model.parameters()) + list(task_weights.values()) optimizer = optimizer_cls(objective_params, **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) logger.info("Using learning-rate scheduler {}".format(scheduler)) # Setup task-specific loss functions # logger.debug("setting loss functions") loss_fns = {} for task, selected in cfg["data"]["tasks"].items(): if selected: logger.info("Task " + task + " was selected for training.") loss_fn = get_loss_function(cfg, task) logger.info("Using loss function {} for task {}".format( loss_fn, task)) loss_fns[task] = loss_fn # Load weights from old checkpoint if set # logger.debug("checking for resume checkpoint") start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) logger.info("Loading file...") checkpoint = torch.load(cfg['training']['resume'], map_location="cpu") logger.info("Loading model...") model.load_state_dict(checkpoint["model_state"]) model.to("cpu") model.to(device) logger.info("Restoring task weights...") task_weights = checkpoint["task_weights"] for task, state in task_weights.items(): # task_weights[task] = state.to(device) task_weights[task] = torch.tensor(state.data).float() task_weights[task] = task_weights[task].to(device) task_weights[task] = task_weights[task].requires_grad_( update_weights) logger.info("Loading scheduler...") scheduler.load_state_dict(checkpoint["scheduler_state"]) # scheduler.to("cpu") start_iter = checkpoint["iteration"] # Add loaded parameters to optimizer # NOTE task_weights will not update otherwise! logger.info("Loading optimizer...") optimizer_cls = get_optimizer(cfg) objective_params = list(model.parameters()) + \ list(task_weights.values()) optimizer = optimizer_cls(objective_params, **optimizer_params) optimizer.load_state_dict(checkpoint["optimizer_state"]) # for state in optimizer.state.values(): # for k, v in state.items(): # if torch.is_tensor(v): # state[k] = v.to(device) logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["iteration"])) else: logger.error( "No checkpoint found at '{}'. Re-initializing params!".format( cfg['training']['resume'])) # Initialize meters for various metrics # logger.debug("initializing metrics") val_loss_meter = averageMeter() time_meter = averageMeter() # Setup other utility variables i = start_iter flag = True timer_training_start = time.time() logger.info("Starting training phase...") logger.debug("model device cuda?") logger.debug(next(model.parameters()).is_cuda) logger.debug("d_reg weight device:") logger.debug(task_weights["d_reg"].device) logger.debug("cls weight device:") logger.debug(task_weights["d_cls"].device) while i <= cfg['training']['train_iters'] and flag: for (images, labels) in trainloader: start_ts = time.time() scheduler.step() model.train() # Forward pass # logger.debug("sending images to device") images = images.to(device) optimizer.zero_grad() # logger.debug("forward pass") outputs = model(images) # Clip predicted depth to min/max # logger.debug("clamping outputs") if cfg["data"]["tasks"]["d_reg"]: if cfg["data"]["depth_scaling"] is not None: if cfg["data"]["depth_scaling"] == "clip": logger.warning("Using deprecated clip function!") outputs["d_reg"] = torch.clamp( outputs["d_reg"], 0, cfg["data"]["max_depth"]) # Calculate single-task losses # logger.debug("calculate loss") st_loss = {} for task, loss_fn in loss_fns.items(): labels[task] = labels[task].to(device) st_loss[task] = loss_fn(input=outputs[task], target=labels[task]) # Calculate multi-task loss # logger.debug("calculate mt loss") mt_loss = 0 if len(st_loss) > 1: for task, loss in st_loss.items(): s = task_weights[task] # s := log(sigma^2) r = s * 0.5 # regularization term if task in ["d_cls"]: w = torch.exp(-s) # weighting (class.) elif task in ["d_reg"]: w = 0.5 * torch.exp(-s) # weighting (regr.) else: raise ValueError("Weighting not implemented!") mt_loss += loss * w + r else: mt_loss = list(st_loss.values())[0] # Backward pass # logger.debug("backward pass") mt_loss.backward() # logger.debug("update weights") optimizer.step() time_meter.update(time.time() - start_ts) # Output current training status # logger.debug("write log") if i == 0 or (i + 1) % cfg['training']['print_interval'] == 0: pad = str(len(str(cfg['training']['train_iters']))) print_str = ("Training Iteration: [{:>" + pad + "d}/{:d}]" + " Loss: {:>14.4f}" + " Time/Image: {:>7.4f}").format( i + 1, cfg['training']['train_iters'], mt_loss.item(), time_meter.avg / cfg['training']['batch_size']) logger.info(print_str) # Add training status to summaries writer.add_scalar('learning_rate', scheduler.get_lr()[0], i + 1) writer.add_scalar('batch_size', cfg['training']['batch_size'], i + 1) writer.add_scalar('loss/train_loss', mt_loss.item(), i + 1) for task, loss in st_loss.items(): writer.add_scalar("loss/single_task/" + task, loss, i + 1) for task, weight in task_weights.items(): writer.add_scalar("task_weights/" + task, weight, i + 1) time_meter.reset() # Add latest input image to summaries train_input = images[0].cpu().numpy()[::-1, :, :] writer.add_image("training/input", train_input, i + 1) # Add d_cls predictions and gt for latest sample to summaries if cfg["data"]["tasks"]["d_cls"]: train_pred = outputs["d_cls"].detach().cpu().numpy().max( 0)[1].astype(np.uint8) # train_pred = np.array(outputs["d_cls"][0].data.max(0)[1], # dtype=np.uint8) train_pred = t_loader.decode_segmap(train_pred) train_pred = torch.tensor(np.rollaxis(train_pred, 2, 0)) writer.add_image("training/d_cls/prediction", train_pred, i + 1) train_gt = t_loader.decode_segmap( labels["d_cls"][0].data.cpu().numpy()) train_gt = torch.tensor(np.rollaxis(train_gt, 2, 0)) writer.add_image("training/d_cls/label", train_gt, i + 1) # Add d_reg predictions and gt for latest sample to summaries if cfg["data"]["tasks"]["d_reg"]: train_pred = outputs["d_reg"][0] train_pred = np.array(train_pred.data.cpu().numpy()) train_pred = t_loader.visualize_depths( t_loader.restore_metric_depths(train_pred)) writer.add_image("training/d_reg/prediction", train_pred, i + 1) train_gt = labels["d_reg"][0].data.cpu().numpy() train_gt = t_loader.visualize_depths( t_loader.restore_metric_depths(train_gt)) if len(train_gt.shape) < 3: train_gt = np.expand_dims(train_gt, axis=0) writer.add_image("training/d_reg/label", train_gt, i + 1) # Run mid-training validation if (i + 1) % cfg['training']['val_interval'] == 0: # or (i + 1) == cfg['training']['train_iters']: # Output current status # logger.debug("Training phase took " + str(timedelta(seconds=time.time() - timer_training_start))) timer_validation_start = time.time() logger.info("Validating model at training iteration" + " {}...".format(i + 1)) # Evaluate validation set model.eval() with torch.no_grad(): i_val = 0 pbar = tqdm(total=len(valloader), unit="batch") for (images_val, labels_val) in valloader: # Forward pass images_val = images_val.to(device) outputs_val = model(images_val) # Clip predicted depth to min/max if cfg["data"]["tasks"]["d_reg"]: if cfg["data"]["depth_scaling"] is None: logger.warning( "Using deprecated clip function!") outputs_val["d_reg"] = torch.clamp( outputs_val["d_reg"], 0, cfg["data"]["max_depth"]) else: outputs_val["d_reg"] = torch.clamp( outputs_val["d_reg"], 0, 1) # Calculate single-task losses st_loss_val = {} for task, loss_fn in loss_fns.items(): labels_val[task] = labels_val[task].to(device) st_loss_val[task] = loss_fn( input=outputs_val[task], target=labels_val[task]) # Calculate multi-task loss mt_loss_val = 0 if len(st_loss) > 1: for task, loss_val in st_loss_val.items(): s = task_weights[task] r = s * 0.5 if task in ["d_cls"]: w = torch.exp(-s) elif task in ["d_reg"]: w = 0.5 * torch.exp(-s) else: raise ValueError( "Weighting not implemented!") mt_loss_val += loss_val * w + r else: mt_loss_val = list(st_loss.values())[0] # Accumulate metrics for summaries val_loss_meter.update(mt_loss_val.item()) if cfg["data"]["tasks"]["d_cls"]: running_metrics_val_d_cls.update( labels_val["d_cls"].data.cpu().numpy(), outputs_val["d_cls"].data.cpu().numpy().argmax( 1)) if cfg["data"]["tasks"]["d_reg"]: running_metrics_val_d_reg.update( v_loader.restore_metric_depths( outputs_val["d_reg"].data.cpu().numpy()), v_loader.restore_metric_depths( labels_val["d_reg"].data.cpu().numpy())) # Update progressbar i_val += 1 pbar.update() # Stop validation early if max_iter key is set if "max_iter" in cfg["validation"].keys() and \ i_val >= cfg["validation"]["max_iter"]: logger.warning("Stopped validation early " + "because max_iter was reached") break # Add sample input images from latest batch to summaries num_img_samples_val = min(len(images_val), NUM_IMG_SAMPLES) for cur_s in range(0, num_img_samples_val): val_input = images_val[cur_s].cpu().numpy()[::-1, :, :] writer.add_image( "validation_sample_" + str(cur_s + 1) + "/input", val_input, i + 1) # Add predictions/ground-truth for d_cls to summaries if cfg["data"]["tasks"]["d_cls"]: val_pred = outputs_val["d_cls"][cur_s].data.max(0)[1] val_pred = np.array(val_pred, dtype=np.uint8) val_pred = t_loader.decode_segmap(val_pred) val_pred = torch.tensor(np.rollaxis(val_pred, 2, 0)) writer.add_image( "validation_sample_" + str(cur_s + 1) + "/prediction_d_cls", val_pred, i + 1) val_gt = t_loader.decode_segmap( labels_val["d_cls"][cur_s].data.cpu().numpy()) val_gt = torch.tensor(np.rollaxis(val_gt, 2, 0)) writer.add_image( "validation_sample_" + str(cur_s + 1) + "/label_d_cls", val_gt, i + 1) # Add predictions/ground-truth for d_reg to summaries if cfg["data"]["tasks"]["d_reg"]: val_pred = outputs_val["d_reg"][cur_s].cpu().numpy() val_pred = v_loader.visualize_depths( v_loader.restore_metric_depths(val_pred)) writer.add_image( "validation_sample_" + str(cur_s + 1) + "/prediction_d_reg", val_pred, i + 1) val_gt = labels_val["d_reg"][cur_s].data.cpu().numpy() val_gt = v_loader.visualize_depths( v_loader.restore_metric_depths(val_gt)) if len(val_gt.shape) < 3: val_gt = np.expand_dims(val_gt, axis=0) writer.add_image( "validation_sample_" + str(cur_s + 1) + "/label_d_reg", val_gt, i + 1) # Add evaluation metrics for d_cls predictions to summaries if cfg["data"]["tasks"]["d_cls"]: score, class_iou = running_metrics_val_d_cls.get_scores() for k, v in score.items(): writer.add_scalar( 'validation/d_cls_metrics/{}'.format(k[:-3]), v, i + 1) for k, v in class_iou.items(): writer.add_scalar( 'validation/d_cls_metrics/class_{}'.format(k), v, i + 1) running_metrics_val_d_cls.reset() # Add evaluation metrics for d_reg predictions to summaries if cfg["data"]["tasks"]["d_reg"]: writer.add_scalar('validation/d_reg_metrics/rel', running_metrics_val_d_reg.rel, i + 1) running_metrics_val_d_reg.reset() # Add validation loss to summaries writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1) # Output current status logger.info( ("Validation Loss at Iteration {}: " + "{:>14.4f}").format( i + 1, val_loss_meter.avg)) val_loss_meter.reset() # logger.debug("Validation phase took {}".format(timedelta(seconds=time.time() - timer_validation_start))) timer_training_start = time.time() # Close progressbar pbar.close() # Save checkpoint if (i + 1) % cfg['training']['checkpoint_interval'] == 0 or \ (i + 1) == cfg['training']['train_iters'] or \ i == 0: state = { "iteration": i + 1, "model_state": model.state_dict(), "task_weights": task_weights, "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict() } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_checkpoint_iter_".format(cfg['model']['arch'], cfg['data']['dataset']) + str(i + 1) + ".pkl") torch.save(state, save_path) logger.info("Saved checkpoint at iteration {} to: {}".format( i + 1, save_path)) # Stop training if current iteration == max iterations if (i + 1) == cfg['training']['train_iters']: flag = False break i += 1
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] if not 'fold' in cfg['data'].keys(): cfg['data']['fold'] = None t_loader = data_loader( data_path, is_transform=True, split=cfg['data']['train_split'], img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']], augmentations=data_aug, fold=cfg['data']['fold'], n_classes=cfg['data']['n_classes']) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']], fold=cfg['data']['fold'], n_classes=cfg['data']['n_classes']) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=1, num_workers=cfg['training']['n_workers']) logger.info("Training on fold {}".format(cfg['data']['fold'])) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg['model'], n_classes).to(device) if args.model_path != "fcn8s_pascal_1_26.pkl": # Default Value state = convert_state_dict(torch.load(args.model_path)["model_state"]) if cfg['model']['use_scale']: model = load_my_state_dict(model, state) model.freeze_weights_extractor() else: model.load_state_dict(state) model.freeze_weights_extractor() model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'} optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume']) ) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"] ) ) else: logger.info("No checkpoint found at '{}'".format(cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg['training']['train_iters'] and flag: for (images, labels) in trainloader: # import matplotlib.pyplot as plt # plt.figure(1);plt.imshow(np.transpose(images[0], (1,2,0)));plt.figure(2); plt.imshow(labels[0]); plt.show() i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format(i + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i+1) time_meter.reset() if (i + 1) == cfg['training']['train_iters']: flag = False state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format( cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) break
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataloader_type"]) data_root = cfg["data"]["data_root"] presentation_root = cfg["data"]["presentation_root"] t_loader = data_loader( data_root=data_root, presentation_root=presentation_root, is_transform=True, img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader(data_root=data_root, presentation_root=presentation_root, is_transform=True, img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, test_mode=True) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=False, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=False) # Setup Metrics # running_metrics_train = runningScore(n_classes) running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg["model"], n_classes, defaultParams).to(device) #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) model.load_pretrained_weights(cfg["training"]["saved_model_path"]) # train_loss_meter = averageMeter() val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter while i <= cfg["training"]["num_presentations"]: # # # TRAINING PHASE # # # i += 1 start_ts = time.time() trainloader.dataset.random_select() hebb = model.initialZeroHebb().to(device) for idx, (images, labels) in enumerate( trainloader, 1): # get a single training presentation images = images.to(device) labels = labels.to(device) if idx <= 5: model.eval() with torch.no_grad(): outputs, hebb = model(images, labels, hebb, device, test_mode=False) else: scheduler.step() model.train() optimizer.zero_grad() outputs, hebb = model(images, labels, hebb, device, test_mode=True) loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) # -> time taken per presentation if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Pres [{:d}/{:d}] Loss: {:.4f} Time/Pres: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["num_presentations"], loss.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) writer.add_scalar("loss/test_loss", loss.item(), i + 1) time_meter.reset() # # # TEST PHASE # # # if ((i + 1) % cfg["training"]["test_interval"] == 0 or (i + 1) == cfg["training"]["num_presentations"]): training_state_dict = model.state_dict( ) # saving the training state of the model valloader.dataset.random_select() hebb = model.initialZeroHebb().to(device) for idx, (images_val, labels_val) in enumerate( valloader, 1): # get a single test presentation images_val = images_val.to(device) labels_val = labels_val.to(device) if idx <= 5: model.eval() with torch.no_grad(): outputs, hebb = model(images_val, labels_val, hebb, device, test_mode=False) else: model.train() optimizer.zero_grad() outputs, hebb = model(images_val, labels_val, hebb, device, test_mode=True) loss = loss_fn(input=outputs, target=labels_val) loss.backward() optimizer.step() pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(loss.item()) model.load_state_dict( training_state_dict) # revert back to training parameters writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format( cfg["model"]["arch"], cfg["data"]["dataloader_type"]), ) torch.save(state, save_path) if (i + 1) == cfg["training"]["num_presentations"]: break
def train(cfg, writer, logger): # Setup random seeds to a determinated value for reproduction # seed = 1337 # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # np.random.seed(seed) # random.seed(seed) # np.random.default_rng(seed) # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("Using dataset: {}".format(data_path)) t_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.train_split, norm = cfg.data.norm, augments=data_aug ) v_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.val_split, ) train_data_len = len(t_loader) logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}') batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=10, # persis num_workers=cfg.train.n_workers,) # Setup Model device = f'cuda:{cfg.gpu[0]}' model = get_model(cfg.model, 2).to(device) input_size = (cfg.model.input_nbr, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap')} optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) loss_fn = get_loss_function(cfg) logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume) ) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"] ) ) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for (file_a, file_b, label, mask) in trainloader: it += 1 file_a = file_a.to(device) file_b = file_b.to(device) label = label.to(device) mask = mask.to(device) optimizer.zero_grad() # print(f'dtype: {file_a.dtype}') outputs = model(file_a, file_b) loss = loss_fn(input=outputs, target=label, mask=mask) loss.backward() # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape) # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape) # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape) # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the acc of the minibatch pred = outputs.max(1)[1].cpu().numpy() runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy()) train_time_meter.update(time.time() - train_start_time) if it % cfg.train.print_interval == 0: # acc of the samples between print_interval score, _ = runing_metrics_train.get_scores() train_cls_0_acc, train_cls_1_acc = score['Acc'] fmt_str = "Iter [{:d}/{:d}] train Loss: {:.4f} Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}" print_str = fmt_str.format(it, train_iter, loss.item(), #extracts the loss’s value as a Python float. train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc) runing_metrics_train.reset() train_time_meter.reset() logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), it) writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it) # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it) # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it) if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() # change behavior like drop out with torch.no_grad(): # disable autograd, save memory usage for (file_a_val, file_b_val, label_val, mask_val) in valloader: file_a_val = file_a_val.to(device) file_b_val = file_b_val.to(device) outputs = model(file_a_val, file_b_val) # tensor.max() returns the maximum value and its indices pred = outputs.max(1)[1].cpu().numpy() running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy()) label_val = label_val.to(device) mask_val = mask_val.to(device) val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val) val_loss_meter.update(val_loss.item()) score, _ = running_metrics_val.get_scores() val_cls_0_acc, val_cls_1_acc = score['Acc'] writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}") # lr_now = optimizer.param_groups[0]['lr'] # logger.info(f'lr: {lr_now}') # writer.add_scalar('lr', lr_now, it+1) logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc)) writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it) # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it) # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it) val_loss_meter.reset() running_metrics_val.reset() # OA=score["Overall_Acc"] val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2 if val_macro_OA >= best_macro_OA_now and it>200: best_macro_OA_now = val_macro_OA best_macro_OA_iter_now = it state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader)) torch.save(state, save_path) logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter-it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() train_start_time = time.time() logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader)) torch.save(state, save_path)
def train(cfg, writer, logger): # Setup seeds init_seed(11733, en_cudnn=False) # Setup Augmentations train_augmentations = cfg["training"].get("train_augmentations", None) t_data_aug = get_composed_augmentations(train_augmentations) val_augmentations = cfg["validating"].get("val_augmentations", None) v_data_aug = get_composed_augmentations(val_augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) t_loader = data_loader(cfg=cfg["data"], mode='train', augmentations=t_data_aug) v_loader = data_loader(cfg=cfg["data"], mode='val', augmentations=v_data_aug) trainloader = data.DataLoader(t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=cfg["validating"]["batch_size"], num_workers=cfg["validating"]["n_workers"]) logger.info("Using training seting {}".format(cfg["training"])) # Setup Metrics running_metrics_val = runningScore(t_loader.n_classes, t_loader.unseen_classes) model_state = torch.load( './runs/deeplabv3p_ade_25unseen/84253/deeplabv3p_ade20k_best_model.pkl' ) running_metrics_val.confusion_matrix = model_state['results'] score, a_iou = running_metrics_val.get_scores() pdb.set_trace() # Setup Model and Loss loss_fn = get_loss_function(cfg["training"]) logger.info("Using loss {}".format(loss_fn)) model = get_model(cfg["model"], t_loader.n_classes, loss_fn=loss_fn) # Setup optimizer optimizer = get_optimizer(cfg["training"], model) # Initialize training param start_iter = 0 best_iou = -100.0 # Resume from checkpoint if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info("Resuming training from checkpoint '{}'".format( cfg["training"]["resume"])) model_state = torch.load(cfg["training"]["resume"])["model_state"] model.load_state_dict(model_state) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) # Setup Multi-GPU if torch.cuda.is_available(): model = model.cuda() # DataParallelModel(model).cuda() logger.info("Model initialized on GPUs.") time_meter = averageMeter() i = start_iter embd = t_loader.embeddings ignr_idx = t_loader.ignore_index embds = embd.cuda() while i <= cfg["training"]["train_iters"]: for (images, labels) in trainloader: images = images.cuda() labels = labels.cuda() i += 1 model.train() optimizer.zero_grad() start_ts = time.time() loss_sum = model(images, labels, embds, ignr_idx) if loss_sum == 0: # Ignore samples contain unseen cat continue # To enable non-transductive learning, set transductive=0 in the config loss_sum.backward() time_meter.update(time.time() - start_ts) optimizer.step() if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss_sum.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss_sum.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): images_val = images_val.cuda() labels_val = labels_val.cuda() outputs = model(images_val, labels_val, embds, ignr_idx) # outputs = gather(outputs, 0, dim=0) running_metrics_val.update(outputs) score, a_iou = running_metrics_val.get_scores() for k, v in score.items(): print("{}: {}".format(k, v)) logger.info("{}: {}".format(k, v)) #writer.add_scalar("val_metrics/{}".format(k), v, i + 1) #for k, v in class_iou.items(): # logger.info("{}: {}".format(k, v)) # writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) if a_iou >= best_iou: best_iou = a_iou state = { "epoch": i + 1, "model_state": model.state_dict(), "best_iou": best_iou, "results": running_metrics_val.confusion_matrix } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) running_metrics_val.reset()
def train(cfg, writer, logger): # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("data path: {}".format(data_path)) t_loader = data_loader( data_path, data_format=cfg.data.format, norm=cfg.data.norm, split='train', split_root=cfg.data.split, augments=data_aug, logger=logger, log=cfg.data.log, ENL=cfg.data.ENL, ) v_loader = data_loader( data_path, data_format=cfg.data.format, split='val', log=cfg.data.log, split_root=cfg.data.split, logger=logger, ENL=cfg.data.ENL, ) train_data_len = len(t_loader) logger.info( f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}' ) batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader( v_loader, batch_size=cfg.test.batch_size, # persis num_workers=cfg.train.n_workers, ) # Setup Model device = f'cuda:{cfg.train.gpu[0]}' model = get_model(cfg.model).to(device) input_size = (cfg.model.in_channels, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap') } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'wrap') and cfg.train.optimizer.wrap == 'lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) # loss_fn = get_loss_function(cfg) # logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg.train.resume)) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"])) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy( osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) data_range = 255 if cfg.data.log: data_range = np.log(data_range) # data_range /= 350 # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() train_loss_meter = averageMeter() val_psnr_meter = averageMeter() val_ssim_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for clean, noisy, _ in trainloader: it += 1 noisy = noisy.to(device, dtype=torch.float32) # noisy /= 350 mask1, mask2 = rand_pool.generate_mask_pair(noisy) noisy_sub1 = rand_pool.generate_subimages(noisy, mask1) noisy_sub2 = rand_pool.generate_subimages(noisy, mask2) # preparing for the regularization term with torch.no_grad(): noisy_denoised = model(noisy) noisy_sub1_denoised = rand_pool.generate_subimages( noisy_denoised, mask1) noisy_sub2_denoised = rand_pool.generate_subimages( noisy_denoised, mask2) # print(rand_pool.operation_seed_counter) # for ii, param in enumerate(model.parameters()): # if torch.sum(torch.isnan(param.data)): # print(f'{ii}: nan parameters') # calculating the loss noisy_output = model(noisy_sub1) noisy_target = noisy_sub2 if cfg.train.loss.gamma.const: gamma = cfg.train.loss.gamma.base else: gamma = it / train_iter * cfg.train.loss.gamma.base diff = noisy_output - noisy_target exp_diff = noisy_sub1_denoised - noisy_sub2_denoised loss1 = torch.mean(diff**2) loss2 = gamma * torch.mean((diff - exp_diff)**2) loss_all = loss1 + loss2 # loss1 = noisy_output - noisy_target # loss2 = torch.exp(noisy_target - noisy_output) # loss_all = torch.mean(loss1 + loss2) loss_all.backward() # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the loss of the minibatch train_loss_meter.update(loss_all) train_time_meter.update(time.time() - train_start_time) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it) if it % 1000 == 0: writer.add_histogram('hist/pred', noisy_denoised, it) writer.add_histogram('hist/noisy', noisy, it) if cfg.data.simulate: writer.add_histogram('hist/clean', clean, it) if cfg.data.simulate: pass # print interval if it % cfg.train.print_interval == 0: terminal_info = f"Iter [{it:d}/{train_iter:d}] \ train Loss: {train_loss_meter.avg:.4f} \ Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}" logger.info(terminal_info) writer.add_scalar('loss/train_loss', train_loss_meter.avg, it) if cfg.data.simulate: pass runing_metrics_train.reset() train_time_meter.reset() train_loss_meter.reset() # val interval if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() with torch.no_grad(): for clean, noisy, _ in valloader: # noisy /= 350 # clean /= 350 noisy = noisy.to(device, dtype=torch.float32) noisy_denoised = model(noisy) if cfg.data.simulate: clean = clean.to(device, dtype=torch.float32) psnr = piq.psnr(clean, noisy_denoised, data_range=data_range) ssim = piq.ssim(clean, noisy_denoised, data_range=data_range) val_psnr_meter.update(psnr) val_ssim_meter.update(ssim) val_loss = torch.mean((noisy_denoised - noisy)**2) val_loss_meter.update(val_loss) writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info( f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}" ) val_loss_meter.reset() running_metrics_val.reset() if cfg.data.simulate: writer.add_scalars('metrics/val', { 'psnr': val_psnr_meter.avg, 'ssim': val_ssim_meter.avg }, it) logger.info( f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}' ) val_psnr_meter.reset() val_ssim_meter.reset() train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter - it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % ( h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() # save model if it % (train_iter / cfg.train.epoch * 10) == 0: ep = int(it / (train_iter / cfg.train.epoch)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), } save_path = osp.join(writer.file_writer.get_logdir(), f"{ep}.pkl") torch.save(state, save_path) logger.info(f'saved model state dict at {save_path}') train_start_time = time.time()
def test(cfg, logger, run_id): # Setup Augmentations augmentations = cfg.test.augments logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path data_loader = data_loader( data_path, data_format=cfg.data.format, norm = cfg.data.norm, split=cfg.test.dataset, split_root = cfg.data.split, log = cfg.data.log, augments=data_aug, logger=logger, ENL = cfg.data.ENL, ) run_id = osp.join(run_id, cfg.test.dataset) # os.mkdir(run_id) logger.info("data path: {}".format(data_path)) logger.info(f'num of {cfg.test.dataset} set samples: {len(data_loader)}') loader = data.DataLoader(data_loader, batch_size=cfg.test.batch_size, num_workers=cfg.test.n_workers, shuffle=False, persistent_workers=True, drop_last=False, ) # Setup Model device = f'cuda:{cfg.gpu[0]}' model = get_model(cfg.model).to(device) input_size = (cfg.model.in_channels, 512, 512) logger.info(f'using model: {cfg.model.arch}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) # load model params if osp.isfile(cfg.test.pth): logger.info("Loading model from checkpoint '{}'".format(cfg.test.pth)) # load model state checkpoint = torch.load(cfg.test.pth) model.load_state_dict(checkpoint["model_state"]) else: raise FileNotFoundError(f'{cfg.test.pth} file not found') # Setup Metrics running_metrics_val = runningScore(2) running_metrics_train = runningScore(2) metrics = runningScore(2) test_psnr_meter = averageMeter() test_ssim_meter = averageMeter() img_cnt = 0 data_range = 255 if cfg.data.log: data_range = np.log(data_range) # test model.eval() with torch.no_grad(): for clean, noisy, files_path in loader: noisy = noisy.to(device, dtype=torch.float32) noisy_denoised = model(noisy) psnr = [] ssim = [] if cfg.data.simulate: clean = clean.to(device, dtype=torch.float32) for ii in range(9): psnr.append(piq.psnr(noisy_denoised[:, ii, :, :], clean[:, ii, :, :], data_range=data_range).cpu()) ssim.append(piq.ssim(noisy_denoised[:, ii, :, :], clean[:, ii, :, :], data_range=data_range).cpu()) print(f'{ii}: PSNR: {psnr[ii]}\n\tSSIM: {ssim[ii]}') print('\n') test_psnr_meter.update(np.array(psnr).mean(), n=clean.shape[0]) test_ssim_meter.update(np.array(ssim).mean(), n=clean.shape[0]) if cfg.data.simulate: logger.info(f'overall psnr: {test_psnr_meter.avg}, ssim: {test_ssim_meter.avg}') logger.info(f'\ndone')
def validate(cfg, model_nontree, model_tree, loss_fn, device, root): val_loss_meter_nontree = averageMeter() if cfg['training']['use_hierarchy']: val_loss_meter_level0_nontree = averageMeter() val_loss_meter_level1_nontree = averageMeter() val_loss_meter_level2_nontree = averageMeter() val_loss_meter_level3_nontree = averageMeter() val_loss_meter_tree = averageMeter() if cfg['training']['use_hierarchy']: val_loss_meter_level0_tree = averageMeter() val_loss_meter_level1_tree = averageMeter() val_loss_meter_level2_tree = averageMeter() val_loss_meter_level3_tree = averageMeter() if torch.cuda.is_available(): data_path = cfg['data']['server_path'] else: data_path = cfg['data']['path'] data_loader = get_loader(cfg['data']['dataset']) augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) v_loader = data_loader(data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug) n_classes = v_loader.n_classes valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val_nontree = runningScore(n_classes) running_metrics_val_tree = runningScore(n_classes) model_nontree.eval() model_tree.eval() with torch.no_grad(): print("validation loop") for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs_nontree = model_nontree(images_val) outputs_tree = model_tree(images_val) if cfg['training']['use_tree_loss']: val_loss_nontree = loss_fn( input=outputs_nontree, target=labels_val, root=root, use_hierarchy=cfg['training']['use_hierarchy']) else: val_loss_nontree = loss_fn(input=outputs_nontree, target=labels_val) if cfg['training']['use_tree_loss']: val_loss_tree = loss_fn( input=outputs_tree, target=labels_val, root=root, use_hierarchy=cfg['training']['use_hierarchy']) else: val_loss_tree = loss_fn(input=outputs_tree, target=labels_val) # Using standard max prob based classification pred_nontree = outputs_nontree.data.max(1)[1].cpu().numpy() pred_tree = outputs_tree.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val_nontree.update( gt, pred_nontree) # updates confusion matrix running_metrics_val_tree.update(gt, pred_tree) if cfg['training']['use_tree_loss']: val_loss_meter_nontree.update( val_loss_nontree[1][0]) # take the 1st level else: val_loss_meter_nontree.update(val_loss_nontree.item()) if cfg['training']['use_tree_loss']: val_loss_meter_tree.update(val_loss_tree[0].item()) else: val_loss_meter_tree.update(val_loss_tree.item()) if cfg['training']['use_hierarchy']: val_loss_meter_level0_nontree.update(val_loss_nontree[1][0]) val_loss_meter_level1_nontree.update(val_loss_nontree[1][1]) val_loss_meter_level2_nontree.update(val_loss_nontree[1][2]) val_loss_meter_level3_nontree.update(val_loss_nontree[1][3]) if cfg['training']['use_hierarchy']: val_loss_meter_level0_tree.update(val_loss_tree[1][0]) val_loss_meter_level1_tree.update(val_loss_tree[1][1]) val_loss_meter_level2_tree.update(val_loss_tree[1][2]) val_loss_meter_level3_tree.update(val_loss_tree[1][3]) if i_val == 1: break score_nontree, class_iou_nontree = running_metrics_val_nontree.get_scores( ) score_tree, class_iou_tree = running_metrics_val_tree.get_scores() ### VISUALISE METRICS AND LOSSES HERE val_loss_meter_nontree.reset() running_metrics_val_nontree.reset() val_loss_meter_tree.reset() running_metrics_val_tree.reset() if cfg['training']['use_hierarchy']: val_loss_meter_level0_nontree.reset() val_loss_meter_level1_nontree.reset() val_loss_meter_level2_nontree.reset() val_loss_meter_level3_nontree.reset() if cfg['training']['use_hierarchy']: val_loss_meter_level0_tree.reset() val_loss_meter_level1_tree.reset() val_loss_meter_level2_tree.reset() val_loss_meter_level3_tree.reset()
def train(cfg, writer, logger, args): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device('cuda') # Setup Augmentations # augmentations = cfg['training'].get('augmentations', None) if cfg['data']['dataset'] in ['cityscapes']: augmentations = cfg['training'].get( 'augmentations', { 'brightness': 63. / 255., 'saturation': 0.5, 'contrast': 0.8, 'hflip': 0.5, 'rotate': 10, 'rscalecropsquare': 713, }) # augmentations = cfg['training'].get('augmentations', # {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5}) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] t_loader = data_loader(data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), ) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg['model'], n_classes, args).to(device) model.apply(weights_init) print('sleep for 5 seconds') time.sleep(5) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # model = torch.nn.DataParallel(model, device_ids=(0, 1)) print(model.device_ids) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) if 'multi_step' in cfg['training']['loss']['name']: my_loss_fn = loss_fn( scale_weight=cfg['training']['loss']['scale_weight'], n_inp=2, weight=None, reduction='sum', bkargs=args) else: my_loss_fn = loss_fn(weight=None, reduction='sum', bkargs=args) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg['training']['train_iters'] and flag: for (images, labels) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = my_loss_fn(myinput=outputs, target=labels) loss.backward() optimizer.step() # gpu_profile(frame=sys._getframe(), event='line', arg=None) time_meter.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg['training']['val_interval'] == 0 or \ (i + 1) == cfg['training']['train_iters']: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model(images_val) val_loss = my_loss_fn(myinput=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i + 1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) if (i + 1) == cfg['training']['train_iters']: flag = False break
params['features.'+str(layer_idx)+'.weight'] = w net.train() augmentations = Compose([RandomRotate(5), RandomHorizontallyFlip()]) train_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="train", augmentations=augmentations) trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH, shuffle=True, num_workers=TRAIN_BATCH, pin_memory=True) val_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="val", augmentations=None) valloader = torch.utils.data.DataLoader(val_dataset, batch_size=VAL_BATCH, shuffle=True, num_workers=VAL_BATCH, pin_memory=True) running_metrics_val = runningScore(2) best_val_loss = math.inf val_loss = 0 ctr = 0 best_iou=-100 val_loss_meter = averageMeter() time_meter = averageMeter() for EPOCHS in range(50): # Training net.train() running_loss = 0 for i, data in enumerate(trainloader): start_ts = time.time() imgs, labels = data imgs, labels = imgs.to(device), labels.to(device) out = net(imgs)
def train(cfg, writer, logger, start_iter=0, model_only=False, gpu=-1, save_dir=None): # Setup seeds and config torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device if gpu == -1: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cuda:%d" %gpu if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) if cfg["data"]["dataset"] == "softmax_cityscapes_convention": data_aug = get_composed_augmentations_softmax(augmentations) else: data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader( data_path, config = cfg["data"], is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader( data_path, config = cfg["data"], is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), ) sampler = None if "sampling" in cfg["data"]: sampler = data.WeightedRandomSampler( weights = get_sampling_weights(t_loader, cfg["data"]["sampling"]), num_samples = len(t_loader), replacement = True ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], sampler=sampler, shuffle=sampler==None, ) valloader = data.DataLoader( v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"] ) # Setup Metrics running_metrics_val = {"seg": runningScoreSeg(n_classes)} if "classifiers" in cfg["data"]: for name, classes in cfg["data"]["classifiers"].items(): running_metrics_val[name] = runningScoreClassifier( len(classes) ) if "bin_classifiers" in cfg["data"]: for name, classes in cfg["data"]["bin_classifiers"].items(): running_metrics_val[name] = runningScoreClassifier(2) # Setup Model model = get_model(cfg["model"], n_classes).to(device) total_params = sum(p.numel() for p in model.parameters()) print( 'Parameters:',total_params ) if gpu == -1: model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) else: model = torch.nn.DataParallel(model, device_ids=[gpu]) model.apply(weights_init) pretrained_path='weights/hardnet_petite_base.pth' weights = torch.load(pretrained_path) model.module.base.load_state_dict(weights) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"} optimizer = optimizer_cls(model.parameters(), **optimizer_params) print("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_dict = get_loss_function(cfg, device) if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"]) ) checkpoint = torch.load(cfg["training"]["resume"], map_location=device) model.load_state_dict(checkpoint["model_state"], strict=False) if not model_only: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"] ) ) else: logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"])) if cfg["training"]["finetune"] is not None: if os.path.isfile(cfg["training"]["finetune"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["finetune"]) ) checkpoint = torch.load(cfg["training"]["finetune"]) model.load_state_dict(checkpoint["model_state"]) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True loss_all = 0 loss_n = 0 while i <= cfg["training"]["train_iters"] and flag: for (images, label_dict, _) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) optimizer.zero_grad() output_dict = model(images) loss = compute_loss( # considers key names in loss_dict and output_dict loss_dict, images, label_dict, output_dict, device, t_loader ) loss.backward() # backprops sum of loss tensors, frozen components will have no grad_fn optimizer.step() c_lr = scheduler.get_lr() if i%1000 == 0: # log images, seg ground truths, predictions pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy() gt_array = label_dict["seg"].data.cpu().numpy() softmax_gt_array = None if "softmax" in label_dict: softmax_gt_array = label_dict["softmax"].data.max(1)[1].cpu().numpy() write_images_to_board(t_loader, images, gt_array, pred_array, i, name = 'train', softmax_gt = softmax_gt_array) if save_dir is not None: image_array = images.data.cpu().numpy().transpose(0, 2, 3, 1) write_images_to_dir(t_loader, image_array, gt_array, pred_array, i, save_dir, name = 'train', softmax_gt = softmax_gt_array) time_meter.update(time.time() - start_ts) loss_all += loss.item() loss_n += 1 if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f} lr={:.6f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss_all / loss_n, time_meter.avg / cfg["training"]["batch_size"], c_lr[0], ) print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][ "train_iters" ]: torch.cuda.empty_cache() model.eval() # set batchnorm and dropouts to work in eval mode loss_all = 0 loss_n = 0 with torch.no_grad(): # Deactivate torch autograd engine, less memusage for i_val, (images_val, label_dict_val, _) in tqdm(enumerate(valloader)): images_val = images_val.to(device) output_dict = model(images_val) val_loss = compute_loss( loss_dict, images_val, label_dict_val, output_dict, device, v_loader ) val_loss_meter.update(val_loss.item()) for name, metrics in running_metrics_val.items(): gt_array = label_dict_val[name].data.cpu().numpy() if name+'_loss' in cfg['training'] and cfg['training'][name+'_loss']['name'] == 'l1': # for binary classification pred_array = output_dict[name].data.cpu().numpy() pred_array = np.sign(pred_array) pred_array[pred_array == -1] = 0 gt_array[gt_array == -1] = 0 else: pred_array = output_dict[name].data.max(1)[1].cpu().numpy() metrics.update(gt_array, pred_array) softmax_gt_array = None # log validation images pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy() gt_array = label_dict_val["seg"].data.cpu().numpy() if "softmax" in label_dict_val: softmax_gt_array = label_dict_val["softmax"].data.max(1)[1].cpu().numpy() write_images_to_board(v_loader, images_val, gt_array, pred_array, i, 'validation', softmax_gt = softmax_gt_array) if save_dir is not None: images_val = images_val.cpu().numpy().transpose(0, 2, 3, 1) write_images_to_dir(v_loader, images_val, gt_array, pred_array, i, save_dir, name='validation', softmax_gt = softmax_gt_array) logger.info("Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg)) writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) for name, metrics in running_metrics_val.items(): overall, classwise = metrics.get_scores() for k, v in overall.items(): logger.info("{}_{}: {}".format(name, k, v)) writer.add_scalar("val_metrics/{}_{}".format(name, k), v, i + 1) if k == cfg["training"]["save_metric"]: curr_performance = v for metric_name, metric in classwise.items(): for k, v in metric.items(): logger.info("{}_{}_{}: {}".format(name, metric_name, k, v)) writer.add_scalar("val_metrics/{}_{}_{}".format(name, metric_name, k), v, i + 1) metrics.reset() state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) if curr_performance >= best_iou: best_iou = curr_performance state = { "epoch": i + 1, "model_state": model.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) torch.cuda.empty_cache() if (i + 1) == cfg["training"]["train_iters"]: flag = False break
def test(cfg, areaname): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader # data_loader = get_loader(cfg['data']['dataset']) # data_path = cfg['data']['path'] # # t_loader = data_loader( # data_path, # is_transform=True, # split=cfg['data']['train_split'], # img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), # augmentations=data_aug) # # v_loader = data_loader( # data_path, # is_transform=True, # split=cfg['data']['val_split'], # img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),) # # n_classes = t_loader.n_classes # trainloader = data.DataLoader(t_loader, # batch_size=cfg['training']['batch_size'], # num_workers=cfg['training']['n_workers'], # shuffle=True) # # valloader = data.DataLoader(v_loader, # batch_size=cfg['training']['batch_size'], # num_workers=cfg['training']['n_workers']) datapath = '/home/chengjjang/Projects/deepres/SatelliteData/{}/'.format( areaname) paths = { 'masks': '{}/patch{}_train/gt'.format(datapath, areaname), 'images': '{}/patch{}_train/rgb'.format(datapath, areaname), 'nirs': '{}/patch{}_train/nir'.format(datapath, areaname), 'swirs': '{}/patch{}_train/swir'.format(datapath, areaname), 'vhs': '{}/patch{}_train/vh'.format(datapath, areaname), 'vvs': '{}/patch{}_train/vv'.format(datapath, areaname), 'redes': '{}/patch{}_train/rede'.format(datapath, areaname), 'ndvis': '{}/patch{}_train/ndvi'.format(datapath, areaname), } valpaths = { 'masks': '{}/patch{}_val/gt'.format(datapath, areaname), 'images': '{}/patch{}_val/rgb'.format(datapath, areaname), 'nirs': '{}/patch{}_val/nir'.format(datapath, areaname), 'swirs': '{}/patch{}_val/swir'.format(datapath, areaname), 'vhs': '{}/patch{}_val/vh'.format(datapath, areaname), 'vvs': '{}/patch{}_val/vv'.format(datapath, areaname), 'redes': '{}/patch{}_val/rede'.format(datapath, areaname), 'ndvis': '{}/patch{}_val/ndvi'.format(datapath, areaname), } n_classes = 3 train_img_paths = [ pth for pth in os.listdir(paths['images']) if ('_01_' not in pth) and ('_25_' not in pth) ] val_img_paths = [ pth for pth in os.listdir(valpaths['images']) if ('_01_' not in pth) and ('_25_' not in pth) ] ntrain = len(train_img_paths) nval = len(val_img_paths) train_idx = [i for i in range(ntrain)] val_idx = [i for i in range(nval)] train_idx = [i for i in range(ntrain)] val_idx = [i for i in range(nval)] trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png') valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png') print('valds.im_names: {}'.format(valds.im_names)) config_path = 'crop_pspnet_config.json' with open(config_path, 'r') as f: mycfg = json.load(f) train_data_path = '{}/patch{}_train'.format(datapath, areaname) dataset_path, train_dir = os.path.split(train_data_path) mycfg['dataset_path'] = dataset_path config = Config(**mycfg) config = update_config(config, num_channels=12, nb_epoch=50) #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color) dataset_train = TrainDataset(trainds, train_idx, config, 1) dataset_val = ValDataset(valds, val_idx, config, 1) trainloader = data.DataLoader(dataset_train, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(dataset_val, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) # Setup Metrics running_metrics_train = runningScore(n_classes) running_metrics_val = runningScore(n_classes) nbackground = 1116403140 ncorn = 44080178 nsoybean = 316698122 print('nbackgraound: {}'.format(nbackground)) print('ncorn: {}'.format(ncorn)) print('nsoybean: {}'.format(nsoybean)) wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean] total_wgts = sum(wgts) wgt_background = wgts[0] / total_wgts wgt_corn = wgts[1] / total_wgts wgt_soybean = wgts[2] / total_wgts weights = torch.autograd.Variable( torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean])) # Setup Model model = get_model(cfg['model'], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) start_iter = 0 runpath = '/home/chengjjang/arisia/CropPSPNet/runs/pspnet_crop_{}'.format( areaname) modelpath = glob.glob('{}/*/*_best_model.pkl'.format(runpath))[0] print('modelpath: {}'.format(modelpath)) checkpoint = torch.load(modelpath) model.load_state_dict(checkpoint["model_state"]) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 respath = '{}_results_val'.format(areaname) os.makedirs(respath, exist_ok=True) model.eval() with torch.no_grad(): for inputdata in valloader: imname_val = inputdata['img_name'] images_val = inputdata['img_data'] labels_val = inputdata['seg_label'] images_val = images_val.to(device) labels_val = labels_val.to(device) print('imname_val: {}'.format(imname_val)) outputs = model(images_val) val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() dname = imname_val[0].split('.png')[0] np.save('{}/pred'.format(respath) + dname + '.npy', pred) np.save('{}/gt'.format(respath) + dname + '.npy', gt) np.save('{}/output'.format(respath) + dname + '.npy', outputs.data.cpu().numpy()) running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) #writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1) #logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) print('Test loss: {}'.format(val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print('val_metrics, {}: {}'.format(k, v)) for k, v in class_iou.items(): print('val_metrics, {}: {}'.format(k, v)) val_loss_meter.reset() running_metrics_val.reset()
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = cityscapesLoader data_path = cfg["data"]["path"] t_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg["model"], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) val_loss_meter = averageMeter() # get loss_seg meter and also loss_dep meter val_loss_meter = averageMeter() # loss_seg_meter = averageMeter() # loss_dep_meter = averageMeter() time_meter = averageMeter() acc_result_total = averageMeter() acc_result_correct = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg["training"]["train_iters"] and flag: for (images, masks, depths) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) depths = depths.to(device) # print(images.shape) optimizer.zero_grad() outputs = model(images).squeeze(1) # ----------------------------------------------------------------- # add depth loss # ----------------------------------------------------------------- # MSE loss # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean') # ----------------------------------------------------------------- # Berhu loss; loss_dep = loss loss = berhu_loss_function(prediction=outputs, target=depths) masks = masks.type(torch.cuda.ByteTensor) loss = torch.sum(loss[masks]) / torch.sum(masks) # ----------------------------------------------------------------- loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] loss_dep: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss.item(), time_meter.avg / cfg["training"]["batch_size"]) print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (images_val, masks_val, depths_val) in enumerate(valloader): images_val = images_val.to(device) # add depth to device depths_val = depths_val.to(device) outputs = model(images_val).squeeze(1) # depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3)) # ----------------------------------------------------------------- # berhu loss function val_loss = berhu_loss_function(prediction=outputs, target=depths_val) masks_val = masks_val.type(torch.cuda.ByteTensor) val_loss = val_loss.type(torch.cuda.ByteTensor) print('val_loss1 is', val_loss) val_loss = torch.sum( val_loss[masks_val]) / torch.sum(masks_val) print('val_loss2 is', val_loss) # ----------------------------------------------------------------- # Update val_loss_meter.update(val_loss.item()) outputs = outputs.cpu().numpy() depths_val = depths_val.cpu().numpy() masks_val = masks_val.cpu().numpy() # depths_val = depths_val.type(torch.cuda.FloatTensor) # outputs = outputs.type(torch.cuda.FloatTensor) # ----------------------------------------------------------------- # Try the following against error: # RuntimeWarning: invalid value encountered in double_scalars: acc = np.diag(hist).sum() / hist.sum() # Similar error: https://github.com/meetshah1995/pytorch-semseg/issues/118 acc_1 = outputs / depths_val acc_2 = 1 / acc_1 acc_threshold = np.maximum(acc_1, acc_2) acc_result_total.update(np.sum(masks_val)) acc_result_correct.update( np.sum( np.logical_and(acc_threshold < 1.25, masks_val))) print("Iter {:d}, val_loss {:.4f}".format( i + 1, val_loss_meter.avg)) writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) acc_result = float(acc_result_correct.sum) / float( acc_result_total.sum) print("Iter {:d}, acc_1.25 {:.4f}".format(i + 1, acc_result)) logger.info("Iter %d acc_1.25: %.4f" % (i + 1, acc_result)) # ----------------------------------------------------------------- score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() acc_result_total.reset() acc_result_correct.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) # insert print function to see if the losses are correct if (i + 1) == cfg["training"]["train_iters"]: flag = False break
def train(cfg, writer, logger): triplet_mode = True if args.triplet_mode == 'yes' else False # Setup dataset split before setting up the seed for random if cfg['data']['dataset'] == 'thigh': # data_split_info = init_data_split(cfg['data']['path'], cfg['data'].get('split_ratio', 0), cfg['data'].get('compound', False)) # fly jenelia dataset' subject_names = [ f"MSTHIGH_{i:02d}" for i in range(3, 16) if i != 8 and i != 13 ] elif cfg['data']['dataset'] == 'femur': subject_names = [ f"MSTHIGH_{i:02d}" for i in range(3, 16) if i != 8 and i != 13 ] # femur_data_split(cfg['data']['path'], subject_names, ratio=cfg['data']['split_ratio']) # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log('Using loss : {}'.format(cfg['training']['loss']['name'])) # Setup Augmentations augmentations = cfg['training'].get( 'augmentations', None) # if no augmentation => default None data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] triplet_save_path = cfg['data']['triplet_save_path'] t_loader = data_loader(data_path, triplet_save_path, split=cfg['data']['train_split'], augmentations=data_aug, n_classes=cfg['training'].get('n_classes', 2), patch_size=cfg['data']['patch_size'], triplet_mode=triplet_mode) # # If using validation, uncomment this block # v_loader = data_loader( # data_path, # split=cfg['data']['val_split'], # data_split_info=data_split_info, # n_classe=cfg['training'].get('n_classes', 1)) n_classes = t_loader.n_classes log('n_classes is: {}'.format(n_classes)) trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) print('trainloader len: ', len(trainloader)) # Setup Metrics running_metrics_val = runningScore( n_classes) # a confusion matrix is created # Setup Model model = get_model(cfg['model'], n_classes) if triplet_mode: model_triplet = get_model(cfg['model_triplet'], n_classes) model = model.to(device) if triplet_mode: model_triplet = model_triplet.to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) count_parameters(model, verbose=False) if triplet_mode: model_triplet = torch.nn.DataParallel(model_triplet, device_ids=range( torch.cuda.device_count())) # print(range(torch.cuda.device_count())) count_parameters(model_triplet, verbose=False) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) min_loss = None start_iter = 0 val_loss_meter = averageMeter() time_meter = averageMeter() i_train_iter = start_iter display('Training from {}th iteration\n'.format(i_train_iter)) while i_train_iter < cfg['training']['train_iters']: i_batch_idx = 0 train_iter_start_time = time.time() averageLoss = 0 # training for (img, lbl, A, P, N, A_lbl, P_lbl, N_lbl) in trainloader: start_ts = time.time() model.train() if triplet_mode: model_triplet.train() optimizer.zero_grad() img = img.to(device) lbl = lbl.to(device) pred = model(img) if triplet_mode: invalid_in_batch_triplet = False for batch_i in range(len(A[0])): if torch.eq(torch.count_nonzero(A[batch_i]), torch.tensor(0)): invalid_in_batch_triplet = True print("found invalid batch, index", i_batch_idx) break if triplet_mode: loss_triplet = 0.0 if not invalid_in_batch_triplet: A = A.to(device) P = P.to(device) N = N.to(device) A_lbl = A_lbl.to(device) P_lbl = P_lbl.to(device) N_lbl = N_lbl.to(device) (A_p, P_p, N_p), (A_embed, P_embed, N_embed) = model_triplet(A, P, N) dist_ap = F.pairwise_distance(A_embed, P_embed, 2) dist_an = F.pairwise_distance(A_embed, N_embed, 2) # -1 means, dist_ap should be less than dist_an target = torch.FloatTensor(dist_ap.size()).fill_(-1) target = target.to(device) loss_triplet = nn.MarginRankingLoss(margin=1.0)(dist_ap, dist_an, target) loss_mse_added = nn.MSELoss()(A_p, A_lbl) + nn.MSELoss()( P_p, P_lbl) + nn.MSELoss()(N_p, N_lbl) loss_dice = dice_loss()(pred, lbl) if triplet_mode: loss = float(args.dice_weight) * loss_dice + float( args.triplet_weight) * loss_triplet else: loss = float(args.dice_weight) * loss_dice # print('loss_dice match: ', loss_dice.item()) # print('loss_triplet match: ', loss_triplet.item()) # print('loss match: ', loss.item()) averageLoss += loss.item() loss.backward() # print('{} optim: {}'.format(i, optimizer.param_groups[0]['lr'])) optimizer.step() # print('{} scheduler: {}'.format(i, scheduler.get_lr()[0])) scheduler.step() time_meter.update(time.time() - start_ts) print_per_batch_check = True if cfg['training'][ 'print_interval_per_batch'] else i_batch_idx + 1 == len( trainloader) if (i_train_iter + 1) % cfg['training'][ 'print_interval'] == 0 and print_per_batch_check: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i_train_iter + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) display(print_str) writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1) time_meter.reset() i_batch_idx += 1 time_for_one_iteration = time.time() - train_iter_start_time display( 'EntireTime for {}th training iteration: {} EntireTime/Image: {}'. format( i_train_iter + 1, time_converter(time_for_one_iteration), time_converter( time_for_one_iteration / (len(trainloader) * cfg['training']['batch_size'])))) averageLoss /= (len(trainloader) * cfg['training']['batch_size']) # validation validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \ (i_train_iter + 1) == cfg['training']['train_iters'] if not validation_check: print('no validation check') else: ''' This IF-CHECK is used to update the best model ''' log('Validation: average loss for current iteration is: {}'.format( averageLoss)) if min_loss is None: min_loss = averageLoss if averageLoss <= min_loss: min_loss = averageLoss state = { "epoch": i_train_iter + 1, "model_state": model.state_dict(), # "model_triplet_state": model_triplet.state_dict() if triplet_mode else None, "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "min_loss": min_loss } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_model_best.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) print('save_path is: ' + save_path) torch.save(state, save_path) # model_count += 1 i_train_iter += 1
def train(cfg, writer, logger_old, args): # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] if isinstance(cfg['training']['loss']['superpixels'], int): use_superpixels = True cfg['data']['train_split'] = 'train_super' cfg['data']['val_split'] = 'val_super' setup_superpixels(cfg['training']['loss']['superpixels']) elif cfg['training']['loss']['superpixels'] is not None: raise Exception( "cfg['training']['loss']['superpixels'] is of the wrong type") else: use_superpixels = False t_loader = data_loader(data_path, is_transform=True, split=cfg['data']['train_split'], superpixels=cfg['training']['loss']['superpixels'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], superpixels=cfg['training']['loss']['superpixels'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), ) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore(n_classes) running_metrics_train = runningScore(n_classes) # Setup Model model = get_model(cfg['model'], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger_old.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger_old.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger_old.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger_old.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: logger_old.info("No checkpoint found at '{}'".format( cfg['training']['resume'])) val_loss_meter = averageMeter() train_loss_meter = averageMeter() time_meter = averageMeter() train_len = t_loader.train_len val_static = 0 best_iou = -100.0 i = start_iter j = 0 flag = True # Prepare logging xp_name = cfg['model']['arch'] + '_' + \ cfg['training']['loss']['name'] + '_' + args.name xp = logger.Experiment(xp_name, use_visdom=True, visdom_opts={ 'server': 'http://localhost', 'port': 8098 }, time_indexing=False, xlabel='Epoch') # log the hyperparameters of the experiment xp.log_config(flatten(cfg)) # create parent metric for training metrics (easier interface) xp.ParentWrapper(tag='train', name='parent', children=(xp.AvgMetric(name="loss"), xp.AvgMetric(name='acc'), xp.AvgMetric(name='acccls'), xp.AvgMetric(name='fwavacc'), xp.AvgMetric(name='meaniu'))) xp.ParentWrapper(tag='val', name='parent', children=(xp.AvgMetric(name="loss"), xp.AvgMetric(name='acc'), xp.AvgMetric(name='acccls'), xp.AvgMetric(name='fwavacc'), xp.AvgMetric(name='meaniu'))) best_loss = xp.BestMetric(tag='val-best', name='loss', mode='min') best_acc = xp.BestMetric(tag='val-best', name='acc') best_acccls = xp.BestMetric(tag='val-best', name='acccls') best_fwavacc = xp.BestMetric(tag='val-best', name='fwavacc') best_meaniu = xp.BestMetric(tag='val-best', name='meaniu') xp.plotter.set_win_opts(name="loss", opts={'title': 'Loss'}) xp.plotter.set_win_opts(name="acc", opts={'title': 'Micro-Average'}) xp.plotter.set_win_opts(name="acccls", opts={'title': 'Macro-Average'}) xp.plotter.set_win_opts(name="fwavacc", opts={'title': 'FreqW Accuracy'}) xp.plotter.set_win_opts(name="meaniu", opts={'title': 'Mean IoU'}) it_per_step = cfg['training']['acc_batch_size'] eff_batch_size = cfg['training']['batch_size'] * it_per_step while i <= train_len * (cfg['training']['epochs']) and flag: for (images, labels, labels_s, masks) in trainloader: i += 1 j += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) labels_s = labels_s.to(device) masks = masks.to(device) outputs = model(images) if use_superpixels: outputs_s, labels_s, sizes = convert_to_superpixels( outputs, labels_s, masks) loss = loss_fn(input=outputs_s, target=labels_s, size=sizes) outputs = convert_to_pixels(outputs_s, outputs, masks) else: loss = loss_fn(input=outputs, target=labels) # accumulate train metrics during train pred = outputs.data.max(1)[1].cpu().numpy() gt = labels.data.cpu().numpy() running_metrics_train.update(gt, pred) train_loss_meter.update(loss.item()) if args.evaluate: decoded = t_loader.decode_segmap(np.squeeze(pred, axis=0)) misc.imsave("./{}.png".format(i), decoded) image_save = np.transpose( np.squeeze(images.data.cpu().numpy(), axis=0), (1, 2, 0)) misc.imsave("./{}.jpg".format(i), image_save) # accumulate gradients based on the accumulation batch size if i % it_per_step == 1 or it_per_step == 1: optimizer.zero_grad() grad_rescaling = torch.tensor(1. / it_per_step).type_as(loss) loss.backward(grad_rescaling) if (i + 1) % it_per_step == 1 or it_per_step == 1: optimizer.step() optimizer.zero_grad() time_meter.update(time.time() - start_ts) # training logs if (j + 1) % (cfg['training']['print_interval'] * it_per_step) == 0: fmt_str = "Epoch [{}/{}] Iter [{}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" total_iter = int(train_len / eff_batch_size) total_epoch = int(cfg['training']['epochs']) current_epoch = ceil((i + 1) / train_len) current_iter = int((j + 1) / it_per_step) print_str = fmt_str.format( current_epoch, total_epoch, current_iter, total_iter, loss.item(), time_meter.avg / cfg['training']['batch_size']) print(print_str) logger_old.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i + 1) time_meter.reset() # end of epoch evaluation if (i + 1) % train_len == 0 or \ (i + 1) == train_len * (cfg['training']['epochs']): optimizer.step() optimizer.zero_grad() model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, labels_val_s, masks_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) labels_val_s = labels_val_s.to(device) masks_val = masks_val.to(device) outputs = model(images_val) if use_superpixels: outputs_s, labels_val_s, sizes_val = convert_to_superpixels( outputs, labels_val_s, masks_val) val_loss = loss_fn(input=outputs_s, target=labels_val_s, size=sizes_val) outputs = convert_to_pixels( outputs_s, outputs, masks_val) else: val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1) writer.add_scalar('loss/train_loss', train_loss_meter.avg, i + 1) logger_old.info("Epoch %d Val Loss: %.4f" % (int( (i + 1) / train_len), val_loss_meter.avg)) logger_old.info("Epoch %d Train Loss: %.4f" % (int( (i + 1) / train_len), train_loss_meter.avg)) score, class_iou = running_metrics_train.get_scores() print("Training metrics:") for k, v in score.items(): print(k, v) logger_old.info('{}: {}'.format(k, v)) writer.add_scalar('train_metrics/{}'.format(k), v, i + 1) for k, v in class_iou.items(): logger_old.info('{}: {}'.format(k, v)) writer.add_scalar('train_metrics/cls_{}'.format(k), v, i + 1) xp.Parent_Train.update(loss=train_loss_meter.avg, acc=score['Overall Acc: \t'], acccls=score['Mean Acc : \t'], fwavacc=score['FreqW Acc : \t'], meaniu=score['Mean IoU : \t']) score, class_iou = running_metrics_val.get_scores() print("Validation metrics:") for k, v in score.items(): print(k, v) logger_old.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i + 1) for k, v in class_iou.items(): logger_old.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1) xp.Parent_Val.update(loss=val_loss_meter.avg, acc=score['Overall Acc: \t'], acccls=score['Mean Acc : \t'], fwavacc=score['FreqW Acc : \t'], meaniu=score['Mean IoU : \t']) xp.Parent_Val.log_and_reset() xp.Parent_Train.log_and_reset() best_loss.update(xp.loss_val).log() best_acc.update(xp.acc_val).log() best_acccls.update(xp.acccls_val).log() best_fwavacc.update(xp.fwavacc_val).log() best_meaniu.update(xp.meaniu_val).log() visdir = os.path.join('runs', cfg['training']['loss']['name'], args.name, 'plots.json') xp.to_json(visdir) val_loss_meter.reset() train_loss_meter.reset() running_metrics_val.reset() running_metrics_train.reset() j = 0 if score["Mean IoU : \t"] >= best_iou: val_static = 0 best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) else: val_static += 1 if (i + 1) == train_len * ( cfg['training']['epochs']) or val_static == 10: flag = False break return best_iou
def eval(cfg, writer, logger, logdir): # Setup seeds #torch.manual_seed(cfg.get("seed", 1337)) #torch.cuda.manual_seed(cfg.get("seed", 1337)) #np.random.seed(cfg.get("seed", 1337)) #random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataloader_type"]) data_root = cfg["data"]["data_root"] presentation_root = cfg["data"]["presentation_root"] v_loader = data_loader(data_root=data_root, presentation_root=presentation_root, is_transform=True, img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, test_mode=True) n_classes = v_loader.n_classes valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=False) # Setup Metrics # running_metrics_train = runningScore(n_classes) running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg["model"], n_classes, defaultParams).to(device) #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) # train_loss_meter = averageMeter() val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = 0 pres_results = [ ] # a final list of all <image, label, output> of all presentations img_list = [] while i < cfg["training"]["num_presentations"]: # # # TESTING PHASE # # # i += 1 training_state_dict = model.state_dict() hebb = model.initialZeroHebb().to(device) valloader.dataset.random_select() start_ts = time.time() for idx, (images_val, labels_val) in enumerate( valloader, 1): # get a single test presentation img = torchvision.utils.make_grid(images_val).numpy() img = np.transpose(img, (1, 2, 0)) img = img[:, :, ::-1] img_list.append(img) pres_results.append(decode_segmap(labels_val.numpy())) images_val = images_val.to(device) labels_val = labels_val.to(device) if idx <= 5: model.eval() with torch.no_grad(): outputs, hebb = model(images_val, labels_val, hebb, device, test_mode=False) else: model.train() optimizer.zero_grad() outputs, hebb = model(images_val, labels_val, hebb, device, test_mode=True) loss = loss_fn(input=outputs, target=labels_val) loss.backward() optimizer.step() pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(loss.item()) # Turning the image, label, and output into plottable formats '''img = torchvision.utils.make_grid(images_val.cpu()).numpy() img = np.transpose(img, (1, 2, 0)) img = img[:, :, ::-1] print("img.shape",img.shape) print("gt.shape and type",gt.shape, gt.dtype) print("pred.shape and type",pred.shape, pred.dtype)''' cla, cnt = np.unique(pred, return_counts=True) print("Unique classes predicted = {}, counts = {}".format( cla, cnt)) #pres_results.append(img) #pres_results.append(decode_segmap(gt)) pres_results.append(decode_segmap(pred)) time_meter.update(time.time() - start_ts) # -> time taken per presentation model.load_state_dict( training_state_dict) # revert back to training parameters # Display presentations stats fmt_str = "Pres [{:d}/{:d}] Loss: {:.4f} Time/Pres: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["num_presentations"], loss.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) writer.add_scalar("loss/test_loss", loss.item(), i + 1) time_meter.reset() writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) # Display presentation metrics score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) #for k, v in class_iou.items(): # logger.info("{}: {}".format(k, v)) # writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() # save presentations to a png image file save_presentations(pres_results=pres_results, num_pres=cfg["training"]["num_presentations"], num_col=7, logdir=logdir, name="pre_results.png") save_presentations(pres_results=img_list, num_pres=cfg["training"]["num_presentations"], num_col=6, logdir=logdir, name="img_list.png")
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, n_classes=20, ) v_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]) # ----------------------------------------------------------------- # Setup Metrics (substract one class) running_metrics_val = runningScore(n_classes - 1) # Setup Model model = get_model(cfg["model"], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) val_loss_meter = averageMeter() # get loss_seg meter and also loss_dep meter loss_seg_meter = averageMeter() loss_dep_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg["training"]["train_iters"] and flag: for (images, labels, masks, depths) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) depths = depths.to(device) #print(images.shape) optimizer.zero_grad() outputs = model(images) #print('depths size: ', depths.size()) #print('output shape: ', outputs.shape) loss_seg = loss_fn(input=outputs[:, :-1, :, :], target=labels) # ----------------------------------------------------------------- # add depth loss # ----------------------------------------------------------------- # MSE loss # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean') # ----------------------------------------------------------------- # Berhu loss loss_dep = berhu_loss_function(prediction=outputs[:, -1, :, :], target=depths) #loss_dep = loss_dep.type(torch.cuda.ByteTensor) masks = masks.type(torch.cuda.ByteTensor) loss_dep = torch.sum(loss_dep[masks]) / torch.sum(masks) print('loss depth', loss_dep) loss = loss_dep + loss_seg # ----------------------------------------------------------------- loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] loss_seg: {:.4f} loss_dep: {:.4f} overall loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss_seg.item(), loss_dep.item(), loss.item(), time_meter.avg / cfg["training"]["batch_size"]) print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, masks_val, depths_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) print('images_val shape', images_val.size()) # add depth to device depths_val = depths_val.to(device) outputs = model(images_val) #depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3)) # ----------------------------------------------------------------- # loss function for segmentation print('output shape', outputs.size()) val_loss_seg = loss_fn(input=outputs[:, :-1, :, :], target=labels_val) # ----------------------------------------------------------------- # MSE loss # val_loss_dep = F.mse_loss(input=outputs[:, -1, :, :], target=depths_val, reduction='mean') # ----------------------------------------------------------------- # berhu loss function val_loss_dep = berhu_loss_function( prediction=outputs[:, -1, :, :], target=depths_val) val_loss_dep = val_loss_dep.type(torch.cuda.ByteTensor) masks_val = masks_val.type(torch.cuda.ByteTensor) val_loss_dep = torch.sum( val_loss_dep[masks_val]) / torch.sum(masks_val) val_loss = loss_dep + loss_seg # ----------------------------------------------------------------- prediction = outputs[:, :-1, :, :] prediction = prediction.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() # adapt metrics to seg and dep running_metrics_val.update(gt, prediction) loss_seg_meter.update(val_loss_seg.item()) loss_dep_meter.update(val_loss_dep.item()) # ----------------------------------------------------------------- # get rid of val_loss_meter # val_loss_meter.update(val_loss.item()) # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) # ----------------------------------------------------------------- score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) print("Segmentation loss is {}".format(loss_seg_meter.avg)) logger.info("Segmentation loss is {}".format( loss_seg_meter.avg)) #writer.add_scalar("Segmentation loss is {}".format(loss_seg_meter.avg), i + 1) print("Depth loss is {}".format(loss_dep_meter.avg)) logger.info("Depth loss is {}".format(loss_dep_meter.avg)) #writer.add_scalar("Depth loss is {}".format(loss_dep_meter.avg), i + 1) val_loss_meter.reset() loss_seg_meter.reset() loss_dep_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) # insert print function to see if the losses are correct if (i + 1) == cfg["training"]["train_iters"]: flag = False break
def train(cfg, writer, logger): # Setup dataset split before setting up the seed for random if cfg['data']['dataset'] == 'miccai2008': split_info = init_data_split_miccai2008( cfg['data']['path']) # miccai2008 dataset elif cfg['data']['dataset'] == 'sasha': split_info = init_data_split_sasha( cfg['data']['path']) # miccai2008 dataset # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Cross Entropy Weight weight = prep_class_val_weights(cfg['training']['cross_entropy_ratio']) # Setup Augmentations augmentations = cfg['training'].get('augmentations', None) print(('augmentations_cfg:', augmentations)) data_aug = get_composed_augmentations3d(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] t_loader = data_loader(data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug, split_info=split_info, patch_size=cfg['training']['patch_size'], mods=cfg['data']['mods']) v_loader = data_loader(data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), split_info=split_info, patch_size=cfg['training']['patch_size'], mods=cfg['data']['mods']) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg['model'], n_classes).to(device) model.apply(weights_init) params = sum([ np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters()) ]) / 1e6 print('NumOfParams:{}M'.format(params)) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) softmax_function = nn.Softmax(dim=1) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i_train_iter = start_iter display('Training from {}th iteration\n'.format(i_train_iter)) while i_train_iter < cfg['training']['train_iters']: i_batch_idx = 0 train_iter_start_time = time.time() for (images, labels, case_index_list) in trainloader: start_ts_network = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs_FM = model(images) #print('Unique on labels:{}'.format(np.unique(labels.data.cpu().numpy()))) #[0, 1] #print('Unique on outputs:{}'.format(np.unique(outputs_FM.data.cpu().numpy()))) #[-1.15, +0.39] log('TrainIter=> images.size():{} labels.size():{} | outputs.size():{}' .format(images.size(), labels.size(), outputs_FM.size())) loss = cfg['training']['loss_balance_ratio'] * loss_fn( input=outputs_FM, target=labels, weight=weight, size_average=cfg['training']['loss']['size_average'] ) #Input:FM, Softmax is built with crossentropy loss fucntion loss.backward() optimizer.step() time_meter.update(time.time() - start_ts_network) print_per_batch_check = True if cfg['training'][ 'print_interval_per_batch'] else i_batch_idx + 1 == len( trainloader) if (i_train_iter + 1) % cfg['training'][ 'print_interval'] == 0 and print_per_batch_check: fmt_str = "Iter [{:d}/{:d}::{:d}/{:d}] [Loss: {:.4f}] NetworkTime/Image: {:.4f}" print_str = fmt_str.format( i_train_iter + 1, cfg['training']['train_iters'], i_batch_idx + 1, len(trainloader), loss.item(), time_meter.avg / cfg['training']['batch_size']) display(print_str) writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1) time_meter.reset() i_batch_idx += 1 entire_time_all_cases = time.time() - train_iter_start_time display( 'EntireTime for {}th training iteration: {:.4f} EntireTime/Image: {:.4f}' .format( i_train_iter + 1, entire_time_all_cases, entire_time_all_cases / (len(trainloader) * cfg['training']['batch_size']))) validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \ (i_train_iter + 1) == cfg['training']['train_iters'] if not validation_check: print('') else: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, case_index_list_val) in enumerate(valloader): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs_FM_val = model(images_val) log( 'ValIter=> images_val.size():{} labels_val.size():{} | outputs.size():{}' .format(images_val.size(), labels_val.size(), outputs_FM_val.size()) ) #Input:FM, Softmax is built with crossentropy loss fucntion val_loss = cfg['training']['loss_balance_ratio'] * loss_fn( input=outputs_FM_val, target=labels_val, weight=weight, size_average=cfg['training']['loss']['size_average']) outputs_CLASS_val = outputs_FM_val.data.max(1)[1] outputs_PROB_val = softmax_function(outputs_FM_val.data) outputs_lesionPROB_val = outputs_PROB_val[:, 1, :, :, :] running_metrics_val.update(labels_val.data.cpu().numpy(), outputs_CLASS_val.cpu().numpy()) val_loss_meter.update(val_loss.item()) ''' This FOR-LOOP is used to visualize validation data via tensorboard It would take 3s roughly. ''' for batch_identifier_index, case_index in enumerate( case_index_list_val): tensor_grid = [] image_val = images_val[ batch_identifier_index, :, :, :, :].float( ) #torch.Size([3, 160, 160, 160]) label_val = labels_val[ batch_identifier_index, :, :, :].float( ) #torch.Size([160, 160, 160]) output_lesionFM_val = outputs_FM_val[ batch_identifier_index, 1, :, :, :].float() #torch.Size([160, 160, 160]) output_nonlesFM_val = outputs_FM_val[ batch_identifier_index, 0, :, :, :].float() #torch.Size([160, 160, 160]) output_CLASS_val = outputs_CLASS_val[ batch_identifier_index, :, :, :].float( ) #torch.Size([160, 160, 160]) output_lesionPROB_val = outputs_lesionPROB_val[ batch_identifier_index, :, :, :].float( ) #torch.Size([160, 160, 160]) for z_index in range(images_val.size()[-1]): label_slice = label_val[:, :, z_index] output_CLASS_slice = output_CLASS_val[:, :, z_index] if label_slice.sum( ) == 0 and output_CLASS_slice.sum() == 0: continue image_slice = image_val[:, :, :, z_index] output_nonlesFM_slice = output_nonlesFM_val[:, :, z_index] output_lesionFM_slice = output_lesionFM_val[:, :, z_index] output_lesionPROB_slice = output_lesionPROB_val[:, :, z_index] label_slice = F.pad(label_slice.unsqueeze_(0), (0, 0, 0, 0, 1, 1)) output_CLASS_slice = F.pad( output_CLASS_slice.unsqueeze_(0), (0, 0, 0, 0, 2, 0)) output_nonlesFM_slice = output_nonlesFM_slice.unsqueeze_( 0).repeat(3, 1, 1) output_lesionFM_slice = output_lesionFM_slice.unsqueeze_( 0).repeat(3, 1, 1) output_lesionPROB_slice = output_lesionPROB_slice.unsqueeze_( 0).repeat(3, 1, 1) slice_list = [ image_slice, output_nonlesFM_slice, output_lesionFM_slice, output_lesionPROB_slice, output_CLASS_slice, label_slice ] #slice_list = [image_slice, output_lesionFM_slice, output_lesionPROB_slice, output_CLASS_slice, label_slice] slice_grid = make_grid(slice_list, padding=20) tensor_grid.append(slice_grid) if len(tensor_grid) == 0: continue tensorboard_image_tensor = make_grid( tensor_grid, nrow=int(math.sqrt(len(tensor_grid) / 6)) + 1, padding=0).permute(1, 2, 0).cpu().numpy() writer.add_image(case_index, tensorboard_image_tensor, i_train_iter + 1) writer.add_scalar('loss/val_loss', val_loss_meter.avg, i_train_iter + 1) logger.info("Iter %d Loss_total: %.4f" % (i_train_iter + 1, val_loss_meter.avg)) ''' This CODE-BLOCK is used to calculate and update the evaluation matrcs ''' score, class_iou = running_metrics_val.get_scores() print( '\x1b[1;32;44mValidationDataLoaded-EXPINDEX={}'.format(run_id)) for k, v in score.items(): print(k, v) logger.info('{}: {}'.format(k, v)) if isinstance(v, list): continue writer.add_scalar('val_metrics/{}'.format(k), v, i_train_iter + 1) for k, v in class_iou.items(): print('IOU:cls_{}:{}'.format(k, v)) logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i_train_iter + 1) print('\x1b[0m\n') val_loss_meter.reset() running_metrics_val.reset() ''' This IF-CHECK is used to update the best model ''' if score["Mean IoU : \t"] >= best_iou: #if score["Patch DICE AVER: \t"] >= best_iou: #best_iou = score["Patch DICE AVER: \t"] best_iou = score["Mean IoU : \t"] state = { "epoch": i_train_iter + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) i_train_iter += 1
def train(cfg, writer, logger): # Setup seeds # torch.manual_seed(cfg.get("seed", 1337)) # torch.cuda.manual_seed(cfg.get("seed", 1337)) # np.random.seed(cfg.get("seed", 1337)) # random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader( data_path, is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg["model"], n_classes).to(device) # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) if not args.load_weight_only: model = DataParallel_withLoss(model, loss_fn) model.load_state_dict(checkpoint["model_state"]) if not args.not_load_optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) # !!! # checkpoint["scheduler_state"]['last_epoch'] = -1 # scheduler.load_state_dict(checkpoint["scheduler_state"]) # start_iter = checkpoint["epoch"] start_iter = 0 # import ipdb # ipdb.set_trace() logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: pretrained_dict = convert_state_dict(checkpoint["model_state"]) model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) model = DataParallel_withLoss(model, loss_fn) # import ipdb # ipdb.set_trace() # start_iter = -1 logger.info( "Loaded checkpoint '{}' (iter unknown, from pretrained icnet model)" .format(cfg["training"]["resume"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg["training"]["train_iters"] and flag: for (images, labels, inst_labels) in trainloader: start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) inst_labels = inst_labels.to(device) optimizer.zero_grad() loss, _, aux_info = model(labels, inst_labels, images, return_aux_info=True) loss = loss.sum() loss_sem = aux_info[0].sum() loss_inst = aux_info[1].sum() # loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} (Sem:{:.4f}/Inst:{:.4f}) LR:{:.5f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss.item(), loss_sem.item(), loss_inst.item(), scheduler.get_lr()[0], time_meter.avg / cfg["training"]["batch_size"], ) # print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, inst_labels_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) inst_labels_val = inst_labels_val.to(device) # outputs = model(images_val) # val_loss = loss_fn(input=outputs, target=labels_val) val_loss, (outputs, outputs_inst) = model( labels_val, inst_labels_val, images_val) val_loss = val_loss.sum() pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) if (i + 1) % cfg["training"]["save_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_{:05d}_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"], i + 1), ) torch.save(state, save_path) if (i + 1) == cfg["training"]["train_iters"]: flag = False break i += 1
def train(cfg, writer, logger): # Setup dataset split before setting up the seed for random if cfg['data']['dataset'] == 'thigh': # data_split_info = init_data_split(cfg['data']['path'], cfg['data'].get('split_ratio', 0), cfg['data'].get('compound', False)) # fly jenelia dataset' subject_names = [ f"MSTHIGH_{i:02d}" for i in range(3, 16) if i != 8 and i != 13 ] elif cfg['data']['dataset'] == 'femur': subject_names = [ f"MSTHIGH_{i:02d}" for i in range(3, 16) if i != 8 and i != 13 ] # femur_data_split(cfg['data']['path'], subject_names, ratio=cfg['data']['split_ratio']) # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Cross Entropy Weight weight = None if cfg['training']['loss'].get('name', None) != 'regression_l1': if 'cross_entropy_ratio' in cfg['training']: weight = prep_class_val_weights( cfg['training']['cross_entropy_ratio']) log('Using loss : {}'.format(cfg['training']['loss']['name'])) # Setup Augmentations augmentations = cfg['training'].get( 'augmentations', None) # if no augmentation => default None data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] t_loader = data_loader(data_path, split=cfg['data']['train_split'], augmentations=data_aug, n_classes=cfg['training'].get('n_classes', 2)) # # If using validation, uncomment this block # v_loader = data_loader( # data_path, # split=cfg['data']['val_split'], # data_split_info=data_split_info, # n_classe=cfg['training'].get('n_classes', 1)) n_classes = t_loader.n_classes log('n_classes is: {}'.format(n_classes)) trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) print('trainloader len: ', len(trainloader)) # Setup Metrics running_metrics_val = runningScore( n_classes) # a confusion matrix is created # Setup Model model = get_model(cfg['model'], n_classes) model = model.to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # print(range(torch.cuda.device_count())) count_parameters(model, verbose=False) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) softmax_function = nn.Softmax(dim=1) # model_count = 0 min_loss = None start_iter = 0 if cfg['training']['resume'] is not None: log('resume saved model') if os.path.isfile(cfg['training']['resume']): display("Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] min_loss = checkpoint["min_loss"] display("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: display("No checkpoint found at '{}'".format( cfg['training']['resume'])) log('no saved model found') val_loss_meter = averageMeter() time_meter = averageMeter() # if cfg['training']['loss']['name'] == 'dice': # loss_fn = dice_loss() i_train_iter = start_iter display('Training from {}th iteration\n'.format(i_train_iter)) while i_train_iter < cfg['training']['train_iters']: i_batch_idx = 0 train_iter_start_time = time.time() averageLoss = 0 # training for (images, labels) in trainloader: start_ts = time.time() model.train() # images = images.cuda() # labels = labels.cuda() optimizer.zero_grad() # anchor_imgs, pos_imgs, neg_imgs = random_crop_triplet(images, labels) images = images.to(device) labels = labels.to(device) # outputs = model(images) # print(outputs.shape, labels.shape) if cfg['training']['loss']['name'] in ['dice']: outputs = model(images) # print(outputs.unique) loss = loss_fn(outputs, labels) # print('loss match: ', loss, loss.item()) averageLoss += loss.item() else: hard_loss = loss_fn( input=outputs, target=labels, weight=weight, size_average=cfg['training']['loss']['size_average']) loss = hard_loss averageLoss += loss loss.backward() # print('{} optim: {}'.format(i, optimizer.param_groups[0]['lr'])) optimizer.step() # print('{} scheduler: {}'.format(i, scheduler.get_lr()[0])) scheduler.step() time_meter.update(time.time() - start_ts) print_per_batch_check = True if cfg['training'][ 'print_interval_per_batch'] else i_batch_idx + 1 == len( trainloader) if (i_train_iter + 1) % cfg['training'][ 'print_interval'] == 0 and print_per_batch_check: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i_train_iter + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) display(print_str) writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1) time_meter.reset() i_batch_idx += 1 time_for_one_iteration = time.time() - train_iter_start_time display( 'EntireTime for {}th training iteration: {} EntireTime/Image: {}'. format( i_train_iter + 1, time_converter(time_for_one_iteration), time_converter( time_for_one_iteration / (len(trainloader) * cfg['training']['batch_size'])))) averageLoss /= (len(trainloader) * cfg['training']['batch_size']) print(averageLoss) # validation validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \ (i_train_iter + 1) == cfg['training']['train_iters'] if not validation_check: print('no validation check') else: ''' This IF-CHECK is used to update the best model ''' log('Validation: average loss for current iteration is: {}'.format( averageLoss)) if min_loss is None: min_loss = averageLoss if averageLoss <= min_loss: min_loss = averageLoss state = { "epoch": i_train_iter + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "min_loss": min_loss } # if cfg['training']['cp_save_path'] is None: save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_model_best.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) # else: # save_path = os.path.join(cfg['training']['cp_save_path'], writer.file_writer.get_logdir(), # "{}_{}_model_best.pkl".format( # cfg['model']['arch'], # cfg['data']['dataset'])) print('save_path is: ' + save_path) torch.save(state, save_path) # model_count += 1 i_train_iter += 1
def train(cfg, writer, logger, run_id): # Setup random seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) torch.backends.cudnn.benchmark = True # Setup Augmentations augmentations = cfg['train'].get('augmentations', None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataloader']) data_path = cfg['data']['path'] logger.info("Using dataset: {}".format(data_path)) t_loader = data_loader(data_path, transform=None, split=cfg['data']['train_split'], augmentations=data_aug) v_loader = data_loader( data_path, transform=None, split=cfg['data']['val_split'], ) logger.info( f'num of train samples: {len(t_loader)} \nnum of val samples: {len(v_loader)}' ) train_data_len = len(t_loader) batch_size = cfg['train']['batch_size'] epoch = cfg['train']['train_epoch'] train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['train']['batch_size'], num_workers=cfg['train']['n_workers'], shuffle=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=cfg['train']['batch_size'], num_workers=cfg['train']['n_workers']) # Setup Model model = get_model(cfg['model'], n_classes) logger.info("Using Model: {}".format(cfg['model']['arch'])) device = f'cuda:{cuda_idx[0]}' model = model.to(device) model = torch.nn.DataParallel(model, device_ids=cuda_idx) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['train']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['train']['lr_schedule']) loss_fn = get_loss_function(cfg) # logger.info("Using loss {}".format(loss_fn)) # set checkpoints start_iter = 0 if cfg['train']['resume'] is not None: if os.path.isfile(cfg['train']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg['train']['resume'])) checkpoint = torch.load(cfg['train']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg['train']['resume'], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg['train']['resume'])) # Setup Metrics running_metrics_val = runningScore(n_classes) val_loss_meter = averageMeter() train_time_meter = averageMeter() time_meter_val = averageMeter() best_iou = 0 flag = True val_rlt_f1 = [] val_rlt_OA = [] best_f1_till_now = 0 best_OA_till_now = 0 best_fwIoU_now = 0 best_fwIoU_iter_till_now = 0 # train it = start_iter model.train() while it <= train_iter and flag: for (file_a, file_b, label, mask) in trainloader: it += 1 start_ts = time.time() file_a = file_a.to(device) file_b = file_b.to(device) label = label.to(device) mask = mask.to(device) optimizer.zero_grad() outputs = model(file_a, file_b) loss = loss_fn(input=outputs, target=label, mask=mask) loss.backward() # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape) # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape) # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape) # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() train_time_meter.update(time.time() - start_ts) time_meter_val.update(time.time() - start_ts) if (it + 1) % cfg['train']['print_interval'] == 0: fmt_str = "train:\nIter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( it + 1, train_iter, loss.item(), #extracts the loss’s value as a Python float. train_time_meter.avg / cfg['train']['batch_size']) train_time_meter.reset() logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), it + 1) if (it + 1) % cfg['train']['val_interval'] == 0 or \ (it + 1) == train_iter: model.eval() # change behavior like drop out with torch.no_grad(): # disable autograd, save memory usage for (file_a_val, file_b_val, label_val, mask_val) in valloader: file_a_val = file_a_val.to(device) file_b_val = file_b_val.to(device) outputs = model(file_a_val, file_b_val) # tensor.max with return the maximum value and its indices pred = outputs.max(1)[1].cpu().numpy() gt = label_val.numpy() running_metrics_val.update(gt, pred, mask_val) label_val = label_val.to(device) mask_val = mask_val.to(device) val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val) val_loss_meter.update(val_loss.item()) lr_now = optimizer.param_groups[0]['lr'] logger.info(f'lr: {lr_now}') # writer.add_scalar('lr', lr_now, it+1) writer.add_scalar('loss/val_loss', val_loss_meter.avg, it + 1) logger.info("Iter %d, val Loss: %.4f" % (it + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() # for k, v in score.items(): # logger.info('{}: {}'.format(k, v)) # writer.add_scalar('val_metrics/{}'.format(k), v, it+1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, it + 1) val_loss_meter.reset() running_metrics_val.reset() avg_f1 = score["Mean_F1"] OA = score["Overall_Acc"] fw_IoU = score["FreqW_IoU"] val_rlt_f1.append(avg_f1) val_rlt_OA.append(OA) if fw_IoU >= best_fwIoU_now and it > 200: best_fwIoU_now = fw_IoU correspond_meanIou = score["Mean_IoU"] best_fwIoU_iter_till_now = it + 1 state = { "epoch": it + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_fwIoU": best_fwIoU_now, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format( cfg['model']['arch'], cfg['data']['dataloader'])) torch.save(state, save_path) logger.info("best_fwIoU_now = %.8f" % (best_fwIoU_now)) logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now)) iter_time = time_meter_val.avg time_meter_val.reset() remain_time = iter_time * (train_iter - it) m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % ( h, m, s) else: train_time = "Remain train time : train completed.\n" print(train_time) model.train() if (it + 1) == train_iter: flag = False logger.info("Use the Sar_seg_band3,val_interval: 30") break logger.info("best_fwIoU_now = %.8f" % (best_fwIoU_now)) logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now)) state = { "epoch": it + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_fwIoU": best_fwIoU_now, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg['model']['arch'], cfg['data']['dataloader'])) torch.save(state, save_path)
def train(cfg, writer, logger): # Setup dataset split before setting up the seed for random data_split_info = init_data_split(cfg['data']['path'], cfg['data'].get( 'split_ratio', 0), cfg['data'].get('compound', False)) # fly jenelia dataset # Setup seeds torch.manual_seed(cfg.get('seed', 1337)) torch.cuda.manual_seed(cfg.get('seed', 1337)) np.random.seed(cfg.get('seed', 1337)) random.seed(cfg.get('seed', 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Cross Entropy Weight if cfg['training']['loss']['name'] != 'regression_l1': weight = prep_class_val_weights(cfg['training']['cross_entropy_ratio']) else: weight = None log('Using loss : {}'.format(cfg['training']['loss']['name'])) # Setup Augmentations augmentations = cfg['training'].get( 'augmentations', None) # if no augmentation => default None data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] patch_size = [para for axis, para in cfg['training']['patch_size'].items()] t_loader = data_loader(data_path, split=cfg['data']['train_split'], augmentations=data_aug, data_split_info=data_split_info, patch_size=patch_size, allow_empty_patch=cfg['training'].get( 'allow_empty_patch', True), n_classes=cfg['training'].get('n_classes', 1)) # v_loader = data_loader( # data_path, # split=cfg['data']['val_split'], # data_split_info=data_split_info, # patch_size=patch_size, # n_classe=cfg['training'].get('n_classes', 1)) n_classes = t_loader.n_classes log('n_classes is: {}'.format(n_classes)) trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=False) # valloader = data.DataLoader(v_loader, # batch_size=cfg['training']['batch_size'], # num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore( n_classes) # a confusion matrix is created # Setup Model model = get_model(cfg['model'], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # if cfg['training'].get('pretrained_model', None) is not None: # log('Load pretrained model: {}'.format(cfg['training'].get('pretrained_model', None))) # pretrainedModel = torch.load(cfg['training'].get('pretrained_model', None)) # my_dict = model.state_dict() # x = my_dict.copy() # pretrained_dict = pretrainedModel['model_state'] # # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in my_dict} # my_dict.update(pretrained_dict) # y = my_dict.copy() # shared_items = {k: x[k] for k in x if k in y and torch.equal(x[k], y[k])} # if len(shared_items) == len(my_dict): # exit(1) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg['training']['optimizer'].items() if k != 'name' } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) softmax_function = nn.Softmax(dim=1) # model_count = 0 min_loss = None start_iter = 0 if cfg['training']['resume'] is not None: log('resume saved model') if os.path.isfile(cfg['training']['resume']): display("Loading model and optimizer from checkpoint '{}'".format( cfg['training']['resume'])) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] min_loss = checkpoint["min_loss"] display("Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"])) else: display("No checkpoint found at '{}'".format( cfg['training']['resume'])) log('no saved model found') val_loss_meter = averageMeter() time_meter = averageMeter() i_train_iter = start_iter display('Training from {}th iteration\n'.format(i_train_iter)) while i_train_iter < cfg['training']['train_iters']: i_batch_idx = 0 train_iter_start_time = time.time() averageLoss = 0 # training for (images, labels) in trainloader: start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) # mean = images[0] soft_loss = -1 mediate_average_loss = -1 optimizer.zero_grad() if cfg['model']['arch'] == 'unet3dreg' or cfg['model'][ 'arch'] == 'unet3d': outputs = model(images) else: outputs, myconv1_copy, myconv3_copy, myup2_copy, myup1_copy = model( images) if cfg['training'].get('task', 'regression') == 'regression': loss = nn.L1Loss() hard_loss = loss(outputs, labels) else: hard_loss = loss_fn( input=outputs, target=labels, weight=weight, size_average=cfg['training']['loss']['size_average']) if cfg['training'].get('fed_by_teacher', False): # Setup Teacher Model model_file_name = cfg['training'].get('pretrained_model', None) model_name = { 'arch': model_file_name.split('/')[-1].split('_')[0] } teacher_model = get_model(model_name, n_classes) pretrainedModel = torch.load(cfg['training'].get( 'pretrained_model', None)) teacher_state = convert_state_dict( pretrainedModel["model_state"] ) # maybe in this way it can take multiple images??? teacher_model.load_state_dict(teacher_state) teacher_model.eval() teacher_model.to(device) outputs_teacher, conv1_copy, conv3_copy, up2_copy, up1_copy = teacher_model( images) outputs_teacher = autograd.Variable(outputs_teacher, requires_grad=False) conv1_copy = autograd.Variable(conv1_copy, requires_grad=False) conv3_copy = autograd.Variable(conv3_copy, requires_grad=False) up2_copy = autograd.Variable(up2_copy, requires_grad=False) up1_copy = autograd.Variable(up1_copy, requires_grad=False) soft_loss = loss(outputs, outputs_teacher) # loss_hard_soft = 0.8 * hard_loss + 0.1 * soft_loss loss_hard_soft = hard_loss + 0.1 * soft_loss if cfg['training'].get('fed_by_intermediate', False): mediate1_loss = loss(myconv1_copy, conv1_copy) mediate2_loss = loss(myconv3_copy, conv3_copy) mediate3_loss = loss(myup2_copy, up2_copy) mediate4_loss = loss(myup1_copy, up1_copy) mediate_average_loss = (mediate1_loss + mediate2_loss + mediate3_loss + mediate4_loss) / 4 log('mediate1_loss: {}, mediate2_loss: {}, mediate3_loss: {}, mediate4_loss: {}' .format(mediate1_loss, mediate2_loss, mediate3_loss, mediate4_loss)) loss = loss_hard_soft + 0.1 * mediate_average_loss else: loss = 0.9 * hard_loss + 0.1 * soft_loss elif cfg['training'].get('fed_by_intermediate', False): # Setup Teacher Model model_file_name = cfg['training'].get('pretrained_model', None) model_name = { 'arch': model_file_name.split('/')[-1].split('_')[0] } teacher_model = get_model(model_name, n_classes) pretrainedModel = torch.load(cfg['training'].get( 'pretrained_model', None)) teacher_state = convert_state_dict( pretrainedModel["model_state"] ) # maybe in this way it can take multiple images??? teacher_model.load_state_dict(teacher_state) teacher_model.eval() teacher_model.to(device) outputs_teacher, conv1_copy, conv3_copy, up2_copy, up1_copy = teacher_model( images) outputs_teacher = autograd.Variable(outputs_teacher, requires_grad=False) conv1_copy = autograd.Variable(conv1_copy, requires_grad=False) conv3_copy = autograd.Variable(conv3_copy, requires_grad=False) up2_copy = autograd.Variable(up2_copy, requires_grad=False) up1_copy = autograd.Variable(up1_copy, requires_grad=False) mediate1_loss = loss(myconv1_copy, conv1_copy) mediate2_loss = loss(myconv3_copy, conv3_copy) mediate3_loss = loss(myup2_copy, up2_copy) mediate4_loss = loss(myup1_copy, up1_copy) mediate_average_loss = (mediate1_loss + mediate2_loss + mediate3_loss + mediate4_loss) / 4 log('mediate1_loss: {}, mediate2_loss: {}, mediate3_loss: {}, mediate4_loss: {}' .format(mediate1_loss, mediate2_loss, mediate3_loss, mediate4_loss)) loss = 0.9 * hard_loss + 0.1 * mediate_average_loss else: loss = hard_loss log('==> hard loss: {} soft loss: {} mediate loss: {}'.format( hard_loss, soft_loss, mediate_average_loss)) averageLoss += loss loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) print_per_batch_check = True if cfg['training'][ 'print_interval_per_batch'] else i_batch_idx + 1 == len( trainloader) if (i_train_iter + 1) % cfg['training'][ 'print_interval'] == 0 and print_per_batch_check: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i_train_iter + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) display(print_str) writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1) time_meter.reset() i_batch_idx += 1 time_for_one_iteration = time.time() - train_iter_start_time display( 'EntireTime for {}th training iteration: {} EntireTime/Image: {}'. format( i_train_iter + 1, time_converter(time_for_one_iteration), time_converter( time_for_one_iteration / (len(trainloader) * cfg['training']['batch_size'])))) averageLoss /= (len(trainloader) * cfg['training']['batch_size']) # validation validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \ (i_train_iter + 1) == cfg['training']['train_iters'] if not validation_check: print('no validation check') else: ''' This IF-CHECK is used to update the best model ''' log('Validation: average loss for current iteration is: {}'.format( averageLoss)) if min_loss is None: min_loss = averageLoss if averageLoss <= min_loss: min_loss = averageLoss state = { "epoch": i_train_iter + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "min_loss": min_loss } save_path = os.path.join( os.getcwd(), writer.file_writer.get_logdir(), "{}_{}_model_best.pkl".format(cfg['model']['arch'], cfg['data']['dataset'])) print('save_path is: ' + save_path) # with open('/home/heng/Research/isbi/log_final_experiment.txt', 'a') as f: # to change!!!!! # id = cfg['id'] # f.write(str(id) + ':' + save_path + '\n') torch.save(state, save_path) # if score["Mean IoU : \t"] >= best_iou: # best_iou = score["Mean IoU : \t"] # state = { # "epoch": i_train_iter + 1, # "model_state": model.state_dict(), # "optimizer_state": optimizer.state_dict(), # "scheduler_state": scheduler.state_dict(), # "best_iou": best_iou, # } # save_path = os.path.join(writer.file_writer.get_logdir(), # "{}_{}_best_model.pkl".format( # cfg['model']['arch'], # cfg['data']['dataset'])) # torch.save(state, save_path) # model_count += 1 i_train_iter += 1 with open('/home/heng/Research/isbi/log_final_experiment_flyJanelia.txt', 'a') as f: # to change!!!!! id = cfg['id'] f.write(str(id) + ':' + save_path + '\n')
def test(cfg, logger, run_id): # Setup Augmentations augmentations = cfg.test.augments logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path data_loader = data_loader( data_path, data_format=cfg.data.format, norm = cfg.data.norm, split=cfg.test.dataset, split_root = cfg.data.split, log = cfg.data.log, augments=data_aug, logger=logger, ENL = cfg.data.ENL, ) run_id = osp.join(run_id, cfg.test.dataset) os.mkdir(run_id) logger.info("data path: {}".format(data_path)) logger.info(f'num of {cfg.test.dataset} set samples: {len(data_loader)}') loader = data.DataLoader(data_loader, batch_size=cfg.test.batch_size, num_workers=cfg.test.n_workers, shuffle=False, persistent_workers=True, drop_last=False, ) # Setup Model device = f'cuda:{cfg.gpu[0]}' model = get_model(cfg.model).to(device) input_size = (cfg.model.in_channels, 512, 512) logger.info(f'using model: {cfg.model.arch}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) # load model params if osp.isfile(cfg.test.pth): logger.info("Loading model from checkpoint '{}'".format(cfg.test.pth)) # load model state checkpoint = torch.load(cfg.test.pth) model.load_state_dict(checkpoint["model_state"]) else: raise FileNotFoundError(f'{cfg.test.pth} file not found') # Setup Metrics running_metrics_val = runningScore(2) running_metrics_train = runningScore(2) metrics = runningScore(2) test_psnr_meter = averageMeter() test_ssim_meter = averageMeter() img_cnt = 0 data_range = 255 if cfg.data.log: data_range = np.log(data_range) # test model.eval() with torch.no_grad(): for clean, noisy, files_path in loader: noisy = noisy.to(device, dtype=torch.float32) noisy_denoised = model(noisy) psnr = [] ssim = [] if cfg.data.simulate: clean = clean.to(device, dtype=torch.float32) for ii in range(clean.shape[0]): psnr.append(piq.psnr(noisy_denoised[ii, ...], clean[ii, ...], data_range=data_range).cpu()) ssim.append(piq.ssim(noisy_denoised[ii, ...], clean[ii, ...], data_range=data_range).cpu()) test_psnr_meter.update(np.array(psnr).mean(), n=clean.shape[0]) test_ssim_meter.update(np.array(ssim).mean(), n=clean.shape[0]) noisy = data_loader.Hoekman_recover_to_C3(noisy) clean = data_loader.Hoekman_recover_to_C3(clean) noisy_denoised = data_loader.Hoekman_recover_to_C3(noisy_denoised) # save images for ii in range(clean.shape[0]): file_path = files_path[ii][29:] file_path = file_path.replace(r'/', '_') file_ori = noisy[ii, ...] file_clean = clean[ii, ...] file_denoise = noisy_denoised[ii, ...] print('clean') pauli_clean = (psr.rgb_by_c3(file_clean, 'sinclair', is_print=True)*255).astype(np.uint8) print('noisy') pauli_ori = (psr.rgb_by_c3(file_ori, 'sinclair', is_print=True)*255).astype(np.uint8) print('denoise') pauli_denoise = (psr.rgb_by_c3(file_denoise, 'sinclair', is_print=True)*255).astype(np.uint8) path_ori = osp.join(run_id, file_path) path_denoise = osp.join(run_id, file_path) path_clean = osp.join(run_id, file_path) if cfg.data.simulate: metric_str = f'_{psnr[ii].item():.3f}_{ssim[ii].item():.3f}' path_ori += metric_str path_denoise += metric_str path_clean += metric_str path_ori += '-ori.png' path_denoise += '-denoise.png' path_clean += '-clean.png' cv2.imwrite(path_ori, pauli_ori) cv2.imwrite(path_denoise, pauli_denoise) cv2.imwrite(path_clean, pauli_clean) if cfg.data.simulate: logger.info(f'overall psnr: {test_psnr_meter.avg}, ssim: {test_ssim_meter.avg}') logger.info(f'\ndone')
def train(cfg, writer, logger): # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup Augmentations augmentations = cfg["training"].get("augmentations", None) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader( data_path, sbd_path=cfg["data"]["sbd_path"], is_transform=True, split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, ) v_loader = data_loader( data_path, sbd_path=cfg["data"]["sbd_path"], is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), ) n_classes = t_loader.n_classes trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, ) valloader = data.DataLoader(v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]) # Setup Metrics running_metrics_val = runningScore(n_classes) # Setup Model model = get_model(cfg["model"], n_classes).to(device) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"]) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True while i <= cfg["training"]["train_iters"] and flag: for (images, labels) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, cfg["training"]["train_iters"], loss.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) writer.add_scalar("loss/train_loss", loss.item(), i + 1) time_meter.reset() if (i + 1) % cfg["training"]["val_interval"] == 0 or ( i + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): images_val = images_val.to(device) labels_val = labels_val.to(device) outputs = model(images_val) val_loss = loss_fn(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/{}".format(k), v, i + 1) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path) if (i + 1) == cfg["training"]["train_iters"]: flag = False break
def train(cfg, writer, logger, args): # Setup seeds torch.manual_seed(cfg.get('seed', RNG_SEED)) torch.cuda.manual_seed(cfg.get('seed', RNG_SEED)) np.random.seed(cfg.get('seed', RNG_SEED)) random.seed(cfg.get('seed', RNG_SEED)) # Setup device # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device(args.device) # Setup Augmentations # augmentations = cfg['training'].get('augmentations', None) if cfg['data']['dataset'] in ['cityscapes']: augmentations = cfg['training'].get('augmentations', {'brightness': 63. / 255., 'saturation': 0.5, 'contrast': 0.8, 'hflip': 0.5, 'rotate': 10, 'rscalecropsquare': 704, # 640, # 672, # 704, }) elif cfg['data']['dataset'] in ['drive']: augmentations = cfg['training'].get('augmentations', {'brightness': 63. / 255., 'saturation': 0.5, 'contrast': 0.8, 'hflip': 0.5, 'rotate': 180, 'rscalecropsquare': 576, }) # augmentations = cfg['training'].get('augmentations', # {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5}) else: augmentations = cfg['training'].get('augmentations', {'rotate': 10, 'hflip': 0.5}) data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg['data']['dataset']) data_path = cfg['data']['path'] t_loader = data_loader( data_path, is_transform=True, split=cfg['data']['train_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']), augmentations=data_aug) v_loader = data_loader( data_path, is_transform=True, split=cfg['data']['val_split'], img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers'], shuffle=True) valloader = data.DataLoader(v_loader, batch_size=cfg['training']['batch_size'], num_workers=cfg['training']['n_workers']) # Setup Metrics running_metrics_val = runningScore(n_classes, cfg['data']['void_class'] > 0) # Setup Model print('trying device {}'.format(device)) model = get_model(cfg['model'], n_classes, args) # .to(device) if cfg['model']['arch'] not in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'unetresnet50', 'unetresnet50bn', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: model.apply(weights_init) else: init_model(model) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # if cfg['model']['arch'] in ['druresnet50syncedbn']: # print('using synchronized batch normalization') # time.sleep(5) # patch_replication_callback(model) model = model.cuda() # model = torch.nn.DataParallel(model, device_ids=(3, 2)) # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'} if cfg['model']['arch'] in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: optimizer = optimizer_cls([ {'params': model.module.paramGroup1.parameters(), 'lr': optimizer_params['lr'] / 10}, {'params': model.module.paramGroup2.parameters()} ], **optimizer_params) else: optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.warning(f"Model parameters in total: {sum([p.numel() for p in model.parameters()])}") logger.warning(f"Trainable parameters in total: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") logger.info("Using optimizer {}".format(optimizer)) scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule']) loss_fn = get_loss_function(cfg) logger.info("Using loss {}".format(loss_fn)) start_iter = 0 if cfg['training']['resume'] is not None: if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume']) ) checkpoint = torch.load(cfg['training']['resume']) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], checkpoint["epoch"] ) ) else: logger.info("No checkpoint found at '{}'".format(cfg['training']['resume'])) val_loss_meter = averageMeter() time_meter = averageMeter() best_iou = -100.0 i = start_iter flag = True weight = torch.ones(n_classes) if cfg['data'].get('void_class'): if cfg['data'].get('void_class') >= 0: weight[cfg['data'].get('void_class')] = 0. weight = weight.to(device) logger.info("Set the prediction weights as {}".format(weight)) while i <= cfg['training']['train_iters'] and flag: for (images, labels) in trainloader: i += 1 start_ts = time.time() scheduler.step() model.train() # for param_group in optimizer.param_groups: # print(param_group['lr']) images = images.to(device) labels = labels.to(device) optimizer.zero_grad() if cfg['model']['arch'] in ['reclast']: h0 = torch.ones([images.shape[0], args.hidden_size, images.shape[2], images.shape[3]], dtype=torch.float32) h0.to(device) outputs = model(images, h0) elif cfg['model']['arch'] in ['recmid']: W, H = images.shape[2], images.shape[3] w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2) h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2) h0 = torch.ones([images.shape[0], args.hidden_size, w, h], dtype=torch.float32) h0.to(device) outputs = model(images, h0) elif cfg['model']['arch'] in ['dru', 'sru']: W, H = images.shape[2], images.shape[3] w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2) h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2) h0 = torch.ones([images.shape[0], args.hidden_size, w, h], dtype=torch.float32) h0.to(device) s0 = torch.ones([images.shape[0], n_classes, W, H], dtype=torch.float32) s0.to(device) outputs = model(images, h0, s0) elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: W, H = images.shape[2], images.shape[3] w, h = int(W / 2 ** 4), int(H / 2 ** 4) if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: w, h = int(W / 2 ** 5), int(H / 2 ** 5) h0 = torch.ones([images.shape[0], args.hidden_size, w, h], dtype=torch.float32, device=device) s0 = torch.zeros([images.shape[0], n_classes, W, H], dtype=torch.float32, device=device) outputs = model(images, h0, s0) else: outputs = model(images) loss = loss_fn(input=outputs, target=labels, weight=weight, bkargs=args) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. # if use_grad_clip(cfg['model']['arch']): # # if cfg['model']['arch'] in ['rcnn', 'rcnn2', 'rcnn3']: # if use_grad_clip(cfg['model']['arch']): nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() time_meter.update(time.time() - start_ts) if (i + 1) % cfg['training']['print_interval'] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format(i + 1, cfg['training']['train_iters'], loss.item(), time_meter.avg / cfg['training']['batch_size']) # print(print_str) logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), i+1) time_meter.reset() if (i + 1) % cfg['training']['val_interval'] == 0 or \ (i + 1) == cfg['training']['train_iters']: torch.backends.cudnn.benchmark = False model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): if args.benchmark: if i_val > 10: break images_val = images_val.to(device) labels_val = labels_val.to(device) if cfg['model']['arch'] in ['reclast']: h0 = torch.ones([images_val.shape[0], args.hidden_size, images_val.shape[2], images_val.shape[3]], dtype=torch.float32) h0.to(device) outputs = model(images_val, h0) elif cfg['model']['arch'] in ['recmid']: W, H = images_val.shape[2], images_val.shape[3] w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2) h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2) h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h], dtype=torch.float32) h0.to(device) outputs = model(images_val, h0) elif cfg['model']['arch'] in ['dru', 'sru']: W, H = images_val.shape[2], images_val.shape[3] w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2) h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2) h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h], dtype=torch.float32) h0.to(device) s0 = torch.ones([images_val.shape[0], n_classes, W, H], dtype=torch.float32) s0.to(device) outputs = model(images_val, h0, s0) elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: W, H = images_val.shape[2], images_val.shape[3] w, h = int(W / 2**4), int(H / 2**4) if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']: w, h = int(W / 2 ** 5), int(H / 2 ** 5) h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h], dtype=torch.float32) h0.to(device) s0 = torch.zeros([images_val.shape[0], n_classes, W, H], dtype=torch.float32) s0.to(device) outputs = model(images_val, h0, s0) else: outputs = model(images_val) val_loss = loss_fn(input=outputs, target=labels_val, bkargs=args) if cfg['training']['loss']['name'] in ['multi_step_cross_entropy']: pred = outputs[-1].data.max(1)[1].cpu().numpy() else: pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() logger.debug('pred shape: ', pred.shape, '\t ground-truth shape:',gt.shape) # IPython.embed() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) # assert i_val > 0, "Validation dataset is empty for no reason." torch.backends.cudnn.benchmark = True writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1) logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) # IPython.embed() score, class_iou, _ = running_metrics_val.get_scores() for k, v in score.items(): # print(k, v) logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/{}'.format(k), v, i+1) for k, v in class_iou.items(): logger.info('{}: {}'.format(k, v)) writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1) val_loss_meter.reset() running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join(writer.file_writer.get_logdir(), best_model_path(cfg)) torch.save(state, save_path) if (i + 1) == cfg['training']['train_iters']: flag = False save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_final_model.pkl".format( cfg['model']['arch'], cfg['data']['dataset'])) torch.save(state, save_path) break