def main(): global args checkpoint = None #is_eval = False is_eval = True # 我加的,用来测试,2020/02/26 if args.evaluate: args_new = args if os.path.isfile(args.evaluate): print("=> loading checkpoint '{}' ... ".format(args.evaluate), end='') checkpoint = torch.load(args.evaluate, map_location=device) args = checkpoint['args'] args.data_folder = args_new.data_folder args.val = args_new.val is_eval = True print("Completed.") else: print("No model found at '{}'".format(args.evaluate)) return print("=> creating model and optimizer ... ", end='') model = DepthCompletionNet(args).to(device) model_named_params = [ p for _, p in model.named_parameters() if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay) print("completed.") if checkpoint is not None: model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") model = torch.nn.DataParallel(model) # Data loading code print("=> creating data loaders ... ") val_dataset = KittiDepth('test_completion', args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation print("\t==> val_loader size:{}".format(len(val_loader))) # create backups and results folder logger = helper.logger(args) if checkpoint is not None: logger.best_result = checkpoint['best_result'] print("=> logger created.") if is_eval: print("=> starting model test ...") result, is_best = iterate("test_completion", args, val_loader, model, None, logger, checkpoint['epoch']) return
def get_kitti_dataloader(mode, dataset_name, setname, args): """ Get kitti dataset and dataloader according mode and setname :param mode: use this dataset for train or eval, possible value: train or eval :param dataset_name: kitti, ours, vkitti, by default, it use kitti :param setname: train, val, selval, test :param args: related arguments :return: dataset, dataloader """ dataset_dir = get_dataset_dir(dataset_name) if dataset_name == 'ours': dataset = OurDataset(base_dir=dataset_dir, mode=mode, setname="f_c_1216_352", args=args) elif dataset_name == 'ours_20190318': dataset = OurDataset(base_dir=dataset_dir, mode=mode, setname="f_c_1216_352_20190318", args=args) elif dataset_name == 'vkitti': dataset = VKittiDataset(base_dir=dataset_dir, mode=mode, setname=setname, args=args) elif dataset_name == 'nuscenes': dataset = NuScenesDataset(base_dir=dataset_dir, mode=mode, setname="f_c_1216_352", args=args) elif dataset_name == 'kitti': dataset = KittiDataset(base_dir=dataset_dir, mode=mode, setname=setname, args=args) else: dataset = KittiDepth(setname, args) if mode == 'train': dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) elif mode == 'eval': dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation else: raise ValueError("Unrecognized mode " + str(mode)) return dataset, dataloader
def main(): global args if args.partial_train == 'yes': # train on a part of the whole train set print( "Can't use partial train here. It is used only for test check. Exit..." ) return if args.test != "yes": print( "This main should use only for testing, but test=yes wat not given. Exit..." ) return print("Evaluating test set with main_test:") whole_ts = time.time() checkpoint = None is_eval = False if args.evaluate: # test a finished model args_new = args # copies if os.path.isfile(args.evaluate): # path is an existing regular file print("=> loading finished model from '{}' ... ".format( args.evaluate), end='') # "end=''" disables the newline checkpoint = torch.load(args.evaluate, map_location=device) args = checkpoint['args'] args.data_folder = args_new.data_folder args.val = args_new.val args.save_images = args_new.save_images args.result = args_new.result is_eval = True print("Completed.") else: print("No model found at '{}'".format(args.evaluate)) return elif args.resume: # resume from a checkpoint args_new = args if os.path.isfile(args.resume): print("=> loading checkpoint from '{}' ... ".format(args.resume), end='') checkpoint = torch.load(args.resume, map_location=device) args.start_epoch = checkpoint['epoch'] + 1 args.data_folder = args_new.data_folder args.val = args_new.val print("Completed. Resuming from epoch {}.".format( checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) return print("=> creating model and optimizer ... ", end='') model = DepthCompletionNet(args).to(device) model_named_params = [ p for _, p in model.named_parameters( ) # "_, p" is a direct analogy to an assignment statement k, _ = (0, 1). Unpack a tuple object if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay) print("completed.") [f'{k:<20}: {v}' for k, v in model.__dict__.items()] if checkpoint is not None: model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") model = torch.nn.DataParallel( model ) # make the model run parallelly: splits your data automatically and sends job orders to multiple models on several GPUs. # After each model finishes their job, DataParallel collects and merges the results before returning it to you # data loading code print("=> creating data loaders ... ") if not is_eval: # we're not evaluating train_dataset = KittiDepth('train', args) # get the paths for the files train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) # load them print("\t==> train_loader size:{}".format(len(train_loader))) if args_new.test == "yes": # will take the data from the "test" folders val_dataset = KittiDepth('test', args) is_test = 'yes' else: val_dataset = KittiDepth('val', args) is_test = 'no' val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation print("\t==> val_loader size:{}".format(len(val_loader))) # create backups and results folder logger = helper.logger(args, is_test) if checkpoint is not None: logger.best_result = checkpoint['best_result'] print("=> logger created.") # logger records sequential data to a log file # main code - run the NN if is_eval: print("=> starting model evaluation ...") result, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch']) return print("=> starting model training ...") for epoch in range(args.start_epoch, args.epochs): print("=> start training epoch {}".format(epoch) + "/{}..".format(args.epochs)) train_ts = time.time() iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set helper.save_checkpoint({ # save checkpoint 'epoch': epoch, 'model': model.module.state_dict(), 'best_result': logger.best_result, 'optimizer': optimizer.state_dict(), 'args': args, }, is_best, epoch, logger.output_directory) print("finish training epoch {}, time elapsed {:.2f} hours, \n".format( epoch, (time.time() - train_ts) / 3600)) last_checkpoint = os.path.join( logger.output_directory, 'checkpoint-' + str(epoch) + '.pth.tar' ) # delete last checkpoint because we have the best_model and we dont need it os.remove(last_checkpoint) print("finished model training, time elapsed {0:.2f} hours, \n".format( (time.time() - whole_ts) / 3600))
set_num, i + 1)) input_t = input_type data_in = '../data_new/phase_' + str( phase) + '/mini_set_' + str(set_num) pred_dir = '../data_new/phase_' + str( phase + 1) + '/mini_set_' + str( set_num) + '/predictions_tmp/NN' + str(i + 1) NN_arguments[i].data_folder = data_in NN_arguments[i].pred_dir = pred_dir NN_arguments[i].val = 'full' NN_arguments[i].use_d = 'd' in input_t NN_arguments[i].batch_size = predict_batch_size train_dataset = KittiDepth( 'val', NN_arguments[i] ) # we adjusted 'val-full' option for predicting on the train data train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=NN_arguments[i].batch_size, shuffle=False, num_workers=2, pin_memory=True) print("\t==> train_loader size:{}".format( len(train_loader))) print("=> starting predictions with args:\n {}".format( NN_arguments[i])) predict(NN_arguments[i], train_loader, models[i]) print("finished predictions\n") print(
def create_data_loaders(data_path, data_type='visim', loader_type='val', arch='', sparsifier_type='uar', num_samples=500, modality='rgb-fd', depth_divisor=1, max_depth=-1, max_gt_depth=-1, batch_size=8, workers=8): # Data loading code print("=> creating data loaders ...") #legacy compatibility with sparse-to-dense data folder subfolder = os.path.join(data_path, loader_type) # if os.path.exists(subfolder): # data_path = subfolder if not os.path.exists(data_path): raise RuntimeError('Data source does not exit:{}'.format(data_path)) loader = None dataset = None max_depth = max_depth if max_depth >= 0.0 else np.inf max_gt_depth = max_gt_depth if max_gt_depth >= 0.0 else np.inf # sparsifier is a class for generating random sparse depth input from the ground truth sparsifier = None if sparsifier_type == UniformSampling.name: #uar sparsifier = UniformSampling(num_samples=num_samples, max_depth=max_depth) elif sparsifier_type == SimulatedStereo.name: #sim_stereo sparsifier = SimulatedStereo(num_samples=num_samples, max_depth=max_depth) if data_type == 'kitti': from dataloaders.kitti_loader import KittiDepth dataset = KittiDepth(data_path, split=loader_type, depth_divisor=depth_divisor) elif data_type == 'visim': from dataloaders.visim_dataloader import VISIMDataset dataset = VISIMDataset(data_path, type=loader_type, modality=modality, sparsifier=sparsifier, depth_divider=depth_divisor, is_resnet=('resnet' in arch), max_gt_depth=max_gt_depth) elif data_type == 'visim_seq': from dataloaders.visim_dataloader import VISIMSeqDataset dataset = VISIMSeqDataset(data_path, type=loader_type, modality=modality, sparsifier=sparsifier, depth_divider=depth_divisor, is_resnet=('resnet' in arch), max_gt_depth=max_gt_depth) else: raise RuntimeError( 'data type not found.' + 'The dataset must be either of kitti, visim or visim_seq.') if loader_type == 'val': # set batch size to be 1 for validation loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=workers, pin_memory=True) print("=> Val loader:{}".format(len(dataset))) elif loader_type == 'train': loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, sampler=None, worker_init_fn=lambda work_id: np.random.seed(work_id)) print("=> Train loader:{}".format(len(dataset))) # worker_init_fn ensures different sampling patterns for each data loading thread print("=> data loaders created.") return loader, dataset
def main(): global args checkpoint = None is_eval = False if args.evaluate: args_new = args if os.path.isfile(args.evaluate): print("=> loading checkpoint '{}' ... ".format(args.evaluate), end='') checkpoint = torch.load(args.evaluate, map_location=device) args = checkpoint['args'] args.data_folder = args_new.data_folder args.val = args_new.val args.result = args_new.result is_eval = True print("Completed.") else: print("No model found at '{}'".format(args.evaluate)) return elif args.resume: # optionally resume from a checkpoint args_new = args if os.path.isfile(args.resume): print("=> loading checkpoint '{}' ... ".format(args.resume), end='') checkpoint = torch.load(args.resume, map_location=device) args.start_epoch = checkpoint['epoch'] + 1 args.data_folder = args_new.data_folder args.val = args_new.val args.result = args_new.result print("Completed. Resuming from epoch {}.".format( checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) return print("=> creating model and optimizer ... ", end='') model = DepthCompletionNet(args).to(device) model_named_params = [ p for _, p in model.named_parameters() if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay) print("completed.") if checkpoint is not None: model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") model = torch.nn.DataParallel(model) # Data loading code print("=> creating data loaders ... ") if not is_eval: train_dataset = KittiDepth('train', args) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) print("\t==> train_loader size:{}".format(len(train_loader))) val_dataset = KittiDepth('val', args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation print("\t==> val_loader size:{}".format(len(val_loader))) # create backups and results folder logger = helper.logger(args) if checkpoint is not None: logger.best_result = checkpoint['best_result'] print("=> logger created.") if is_eval: print("=> starting model evaluation ...") result, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch']) return # main loop print("=> starting main loop ...") for epoch in range(args.start_epoch, args.epochs): print("=> starting training epoch {} ..".format(epoch)) iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set helper.save_checkpoint({ # save checkpoint 'epoch': epoch, 'model': model.module.state_dict(), 'best_result': logger.best_result, 'optimizer' : optimizer.state_dict(), 'args' : args, }, is_best, epoch, logger.output_directory)
def main(): global args checkpoint = None is_eval = False if args.evaluate: args_new = args if os.path.isfile(args.evaluate): print("=> loading checkpoint '{}' ... ".format(args.evaluate), end='') checkpoint = torch.load(args.evaluate, map_location=device) args = checkpoint['args'] args.data_folder = args_new.data_folder args.val = args_new.val is_eval = True print("Completed.") else: print("No model found at '{}'".format(args.evaluate)) return elif args.resume: # optionally resume from a checkpoint args_new = args if os.path.isfile(args.resume): print("=> loading checkpoint '{}' ... ".format(args.resume), end='') checkpoint = torch.load(args.resume, map_location=device) args.start_epoch = checkpoint['epoch'] + 1 args.data_folder = args_new.data_folder args.val = args_new.val print("Completed. Resuming from epoch {}.".format( checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) return ################# model print("=> creating model and optimizer ... ", end='') parameters_to_train = [] encoder = networks.ResnetEncoder(num_layers=18) encoder.to(device) parameters_to_train += list(encoder.parameters()) decoder = networks.DepthDecoder(encoder.num_ch_enc) decoder.to(device) parameters_to_train += list(decoder.parameters()) # encoder_named_params = [ # p for _, p in encoder.named_parameters() if p.requires_grad # ] optimizer = torch.optim.Adam(parameters_to_train, lr=args.lr, weight_decay=args.weight_decay) encoder = torch.nn.DataParallel(encoder) decoder = torch.nn.DataParallel(decoder) model = [encoder, decoder] print("completed.") # if checkpoint is not None: # model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) # print("=> checkpoint state loaded.") # Data loading code print("=> creating data loaders ... ") if not is_eval: train_dataset = KittiDepth('train', args) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) print("\t==> train_loader size:{}".format(len(train_loader))) val_dataset = KittiDepth('val', args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=12, #1 shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation print("\t==> val_loader size:{}".format(len(val_loader))) ############################################################## # create backups and results folder logger = helper.logger(args) # if checkpoint is not None: # logger.best_result = checkpoint['best_result'] print("=> logger created.") if is_eval: print("=> starting model evaluation ...") result, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch']) return # main loop print("=> starting main loop ...") for epoch in range(args.start_epoch, args.epochs): print("=> starting training epoch {} ..".format(epoch)) iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set
def main(): global args checkpoint = None is_eval = False if args.evaluate: args_new = args if os.path.isfile(args.evaluate): print("=> loading checkpoint '{}' ... ".format(args.evaluate), end='') checkpoint = torch.load(args.evaluate, map_location=device) #args = checkpoint['args'] args.start_epoch = checkpoint['epoch'] + 1 args.data_folder = args_new.data_folder args.val = args_new.val is_eval = True print("Completed.") else: is_eval = True print("No model found at '{}'".format(args.evaluate)) #return elif args.resume: # optionally resume from a checkpoint args_new = args if os.path.isfile(args.resume): print("=> loading checkpoint '{}' ... ".format(args.resume), end='') checkpoint = torch.load(args.resume, map_location=device) args.start_epoch = checkpoint['epoch'] + 1 args.data_folder = args_new.data_folder args.val = args_new.val print("Completed. Resuming from epoch {}.".format( checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) return print("=> creating model and optimizer ... ", end='') model = None penet_accelerated = False if (args.network_model == 'e'): model = ENet(args).to(device) elif (is_eval == False): if (args.dilation_rate == 1): model = PENet_C1_train(args).to(device) elif (args.dilation_rate == 2): model = PENet_C2_train(args).to(device) elif (args.dilation_rate == 4): model = PENet_C4(args).to(device) penet_accelerated = True else: if (args.dilation_rate == 1): model = PENet_C1(args).to(device) penet_accelerated = True elif (args.dilation_rate == 2): model = PENet_C2(args).to(device) penet_accelerated = True elif (args.dilation_rate == 4): model = PENet_C4(args).to(device) penet_accelerated = True if (penet_accelerated == True): model.encoder3.requires_grad = False model.encoder5.requires_grad = False model.encoder7.requires_grad = False model_named_params = None model_bone_params = None model_new_params = None optimizer = None if checkpoint is not None: #print(checkpoint.keys()) if (args.freeze_backbone == True): model.backbone.load_state_dict(checkpoint['model']) else: model.load_state_dict(checkpoint['model'], strict=False) #optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") logger = helper.logger(args) if checkpoint is not None: logger.best_result = checkpoint['best_result'] del checkpoint print("=> logger created.") test_dataset = None test_loader = None if (args.test): test_dataset = KittiDepth('test_completion', args) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) iterate("test_completion", args, test_loader, model, None, logger, 0) return val_dataset = KittiDepth('val', args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) # set batch size to be 1 for validation print("\t==> val_loader size:{}".format(len(val_loader))) if is_eval == True: for p in model.parameters(): p.requires_grad = False result, is_best = iterate("val", args, val_loader, model, None, logger, args.start_epoch - 1) return if (args.freeze_backbone == True): for p in model.backbone.parameters(): p.requires_grad = False model_named_params = [ p for _, p in model.named_parameters() if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99)) elif (args.network_model == 'pe'): model_bone_params = [ p for _, p in model.backbone.named_parameters() if p.requires_grad ] model_new_params = [ p for _, p in model.named_parameters() if p.requires_grad ] model_new_params = list(set(model_new_params) - set(model_bone_params)) optimizer = torch.optim.Adam([{ 'params': model_bone_params, 'lr': args.lr / 10 }, { 'params': model_new_params }], lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99)) else: model_named_params = [ p for _, p in model.named_parameters() if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99)) print("completed.") model = torch.nn.DataParallel(model) # Data loading code print("=> creating data loaders ... ") if not is_eval: train_dataset = KittiDepth('train', args) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) print("\t==> train_loader size:{}".format(len(train_loader))) print("=> starting main loop ...") for epoch in range(args.start_epoch, args.epochs): print("=> starting training epoch {} ..".format(epoch)) iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch # validation memory reset for p in model.parameters(): p.requires_grad = False result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set for p in model.parameters(): p.requires_grad = True if (args.freeze_backbone == True): for p in model.module.backbone.parameters(): p.requires_grad = False if (penet_accelerated == True): model.module.encoder3.requires_grad = False model.module.encoder5.requires_grad = False model.module.encoder7.requires_grad = False helper.save_checkpoint({ # save checkpoint 'epoch': epoch, 'model': model.module.state_dict(), 'best_result': logger.best_result, 'optimizer' : optimizer.state_dict(), 'args' : args, }, is_best, epoch, logger.output_directory)
def main(): global args checkpoint = None is_eval = False if args.evaluate: if os.path.isfile(args.evaluate): print("=> loading checkpoint '{}'".format(args.evaluate)) checkpoint = torch.load(args.evaluate) args = checkpoint['args'] is_eval = True print("=> checkpoint loaded.") else: print("=> no model found at '{}'".format(args.evaluate)) return elif args.resume: # optionally resume from a checkpoint if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) return print("=> creating model and optimizer...") model = DepthCompletionNet(args).cuda() model_named_params = [ p for _, p in model.named_parameters() if p.requires_grad ] optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay) print("=> model and optimizer created.") if checkpoint is not None: model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") model = torch.nn.DataParallel(model) print("=> model transferred to multi-GPU.") # Data loading code print("=> creating data loaders ...") if not is_eval: train_dataset = KittiDepth('train', args) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) val_dataset = KittiDepth('val', args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=False) # set batch size to be 1 for validation print("=> data loaders created.") # create backups and results folde logger = helper.logger(args) if checkpoint is not None: logger.best_result = checkpoint['best_result'] print("=> logger created.") if is_eval: result, result_intensity, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch']) return # main loop for epoch in range(args.start_epoch, args.epochs): print("=> starting training epoch {} ..".format(epoch)) iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch result, result_intensity, is_best = iterate( "val", args, val_loader, model, None, logger, epoch) # evaluate on validation set helper.save_checkpoint({ # save checkpoint 'epoch': epoch, 'model': model.module.state_dict(), 'best_result': logger.best_result, 'optimizer' : optimizer.state_dict(), 'args' : args, }, is_best, epoch, logger.output_directory) logger.writer.add_scalar('eval/rmse_depth', result.rmse, epoch) logger.writer.add_scalar('eval/rmse_intensity', result_intensity.rmse, epoch) logger.writer.add_scalar('eval/mae_depth', result.mae, epoch) logger.writer.add_scalar('eval/mae_intensity', result_intensity.mae, epoch) # logger.writer.add_scalar('eval/irmse_depth', result.irmse, epoch) # logger.writer.add_scalar('eval/irmse_intensity', result_intensity.irmse, epoch) logger.writer.add_scalar('eval/rmse_total', result.rmse + args.wi * result_intensity.rmse, epoch)