def main(): resume = True path = 'data/NYU_DEPTH' batch_size = 16 epochs = 10000 device = torch.device('cuda:0') print_every = 5 # exp_name = 'resnet18_nodropout_new' exp_name = 'only_depth' # exp_name = 'normal_internel' # exp_name = 'sep' lr = 1e-5 weight_decay = 0.0005 log_dir = os.path.join('logs', exp_name) model_dir = os.path.join('checkpoints', exp_name) val_every = 16 save_every = 16 # tensorboard # remove old log is not to resume if not resume: if os.path.exists(log_dir): shutil.rmtree(log_dir) os.makedirs(log_dir) if not os.path.exists(model_dir): os.makedirs(model_dir) tb = SummaryWriter(log_dir) tb.add_custom_scalars({ 'metrics': { 'thres_1.25': ['Multiline', ['thres_1.25/train', 'thres_1.25/test']], 'thres_1.25_2': ['Multiline', ['thres_1.25_2/train', 'thres_1.25_2/test']], 'thres_1.25_3': ['Multiline', ['thres_1.25_3/train', 'thres_1.25_3/test']], 'ard': ['Multiline', ['ard/train', 'ard/test']], 'srd': ['Multiline', ['srd/train', 'srd/test']], 'rmse_linear': ['Multiline', ['rmse_linear/train', 'rmse_linear/test']], 'rmse_log': ['Multiline', ['rmse_log/train', 'rmse_log/test']], 'rmse_log_invariant': ['Multiline', ['rmse_log_invariant/train', 'rmse_log_invariant/test']], } }) # data loader dataset = NYUDepth(path, 'train') dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4) dataset_test = NYUDepth(path, 'test') dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True, num_workers=4) # load model model = FCRN(True) model = model.to(device) # optimizer optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) start_epoch = 0 if resume: model_path = os.path.join(model_dir, 'model.pth') if os.path.exists(model_path): print('Loading checkpoint from {}...'.format(model_path)) # load model and optimizer checkpoint = torch.load(os.path.join(model_dir, 'model.pth'), map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('Model loaded.') else: print('No checkpoint found. Train from scratch') # training metric_logger = MetricLogger() end = time.perf_counter() max_iters = epochs * len(dataloader) def normal_loss(pred, normal, conf): """ :param pred: (B, 3, H, W) :param normal: (B, 3, H, W) :param conf: 1 """ dot_prod = (pred * normal).sum(dim=1) # weighted loss, (B, ) batch_loss = ((1 - dot_prod) * conf[:, 0]).sum(1).sum(1) # normalize, to (B, ) batch_loss /= conf[:, 0].sum(1).sum(1) return batch_loss.mean() def consistency_loss(pred, cloud, normal, conf): """ :param pred: (B, 1, H, W) :param normal: (B, 3, H, W) :param cloud: (B, 3, H, W) :param conf: (B, 1, H, W) """ B, _, _, _ = normal.size() normal = normal.detach() cloud = cloud.clone() cloud[:, 2:3, :, :] = pred # algorithm: use a kernel kernel = torch.ones((1, 1, 7, 7), device=pred.device) kernel = -kernel kernel[0, 0, 3, 3] = 48 cloud_0 = cloud[:, 0:1] cloud_1 = cloud[:, 1:2] cloud_2 = cloud[:, 2:3] diff_0 = F.conv2d(cloud_0, kernel, padding=6, dilation=2) diff_1 = F.conv2d(cloud_1, kernel, padding=6, dilation=2) diff_2 = F.conv2d(cloud_2, kernel, padding=6, dilation=2) # (B, 3, H, W) diff = torch.cat((diff_0, diff_1, diff_2), dim=1) # normalize diff = F.normalize(diff, dim=1) # (B, 1, H, W) dot_prod = (diff * normal).sum(dim=1, keepdim=True) # weighted mean over image dot_prod = torch.abs(dot_prod.view(B, -1)) conf = conf.view(B, -1) loss = (dot_prod * conf).sum(1) / conf.sum(1) # mean over batch return loss.mean() def criterion(depth_pred, normal_pred, depth, normal, cloud, conf): mse_loss = F.mse_loss(depth_pred, depth) consis_loss = consistency_loss(depth_pred, cloud, normal_pred, conf) norm_loss = normal_loss(normal_pred, normal, conf) consis_loss = torch.zeros_like(norm_loss) return mse_loss, mse_loss, mse_loss # return mse_loss, consis_loss, norm_loss # return norm_loss, norm_loss, norm_loss print('Start training') for epoch in range(start_epoch, epochs): # train model.train() for i, data in enumerate(dataloader): start = end i += 1 data = [x.to(device) for x in data] image, depth, normal, conf, cloud = data depth_pred, normal_pred = model(image) mse_loss, consis_loss, norm_loss = criterion(depth_pred, normal_pred, depth, normal, cloud, conf) loss = mse_loss + consis_loss + norm_loss optimizer.zero_grad() loss.backward() optimizer.step() # bookkeeping end = time.perf_counter() metric_logger.update(loss=loss.item()) metric_logger.update(mse_loss=mse_loss.item()) metric_logger.update(norm_loss=norm_loss.item()) metric_logger.update(consis_loss=consis_loss.item()) metric_logger.update(batch_time=end-start) if i % print_every == 0: # Compute eta. global step: starting from 1 global_step = epoch * len(dataloader) + i seconds = (max_iters - global_step) * metric_logger['batch_time'].global_avg eta = datetime.timedelta(seconds=int(seconds)) # to display: eta, epoch, iteration, loss, batch_time display_dict = { 'eta': eta, 'epoch': epoch, 'iter': i, 'loss': metric_logger['loss'].median, 'batch_time': metric_logger['batch_time'].median } display_str = [ 'eta: {eta}s', 'epoch: {epoch}', 'iter: {iter}', 'loss: {loss:.4f}', 'batch_time: {batch_time:.4f}s', ] print(', '.join(display_str).format(**display_dict)) # tensorboard min_depth = depth[0].min() max_depth = depth[0].max() * 1.25 depth = (depth[0] - min_depth) / (max_depth - min_depth) depth_pred = (depth_pred[0] - min_depth) / (max_depth - min_depth) depth_pred = torch.clamp(depth_pred, min=0.0, max=1.0) normal = (normal[0] + 1) / 2 normal_pred = (normal_pred[0] + 1) / 2 conf = conf[0] tb.add_scalar('train/loss', metric_logger['loss'].median, global_step) tb.add_scalar('train/mse_loss', metric_logger['mse_loss'].median, global_step) tb.add_scalar('train/consis_loss', metric_logger['consis_loss'].median, global_step) tb.add_scalar('train/norm_loss', metric_logger['norm_loss'].median, global_step) tb.add_image('train/depth', depth, global_step) tb.add_image('train/normal', normal, global_step) tb.add_image('train/depth_pred', depth_pred, global_step) tb.add_image('train/normal_pred', normal_pred, global_step) tb.add_image('train/conf', conf, global_step) tb.add_image('train/image', image[0], global_step) if (epoch) % val_every == 0 and epoch != 0: # validate after each epoch validate(dataloader, model, device, tb, epoch, 'train') validate(dataloader_test, model, device, tb, epoch, 'test') if (epoch) % save_every == 0 and epoch != 0: to_save = { 'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 'epoch': epoch, } torch.save(to_save, os.path.join(model_dir, 'model.pth'))
def main(): batch_size = 16 data_path = './data/nyu_depth_v2_labeled.mat' learning_rate = 1.0e-4 monentum = 0.9 weight_decay = 0.0005 num_epochs = 100 # 1.Load data train_lists, val_lists, test_lists = load_split() print("Loading data...") train_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, train_lists), batch_size=batch_size, shuffle=False, drop_last=True) val_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, val_lists), batch_size=batch_size, shuffle=True, drop_last=True) test_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, test_lists), batch_size=batch_size, shuffle=True, drop_last=True) print(train_loader) # 2.Load model print("Loading model...") model = FCRN(batch_size) model.load_state_dict(load_weights(model, weights_file, dtype)) #加载官方参数,从tensorflow转过来 #加载训练模型 resume_from_file = False resume_file = './model/model_300.pth' if resume_from_file: if os.path.isfile(resume_file): checkpoint = torch.load(resume_file) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("loaded checkpoint '{}' (epoch {})".format( resume_file, checkpoint['epoch'])) else: print("can not find!") model = model.cuda() # 3.Loss # 官方MSE # loss_fn = torch.nn.MSELoss() # 自定义MSE # loss_fn = loss_mse() # 论文的loss,the reverse Huber loss_fn = loss_huber() print("loss_fn set...") # 4.Optim optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) print("optimizer set...") # 5.Train best_val_err = 1.0e-4 start_epoch = 0 for epoch in range(num_epochs): print('Starting train epoch %d / %d' % (start_epoch + epoch + 1, num_epochs + start_epoch)) model.train() running_loss = 0 count = 0 epoch_loss = 0 for input, depth in train_loader: input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) loss = loss_fn(output, depth_var) print('loss: %f' % loss.data.cpu().item()) count += 1 running_loss += loss.data.cpu().numpy() optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss = running_loss / count print('epoch loss:', epoch_loss) # validate model.eval() num_correct, num_samples = 0, 0 loss_local = 0 with torch.no_grad(): for input, depth in val_loader: input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) if num_epochs == epoch + 1: # 关于保存的测试图片可以参考 loader 的写法 # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8) input_rgb_image = input[0].data.permute(1, 2, 0) input_gt_depth_image = depth_var[0][0].data.cpu().numpy( ).astype(np.float32) pred_depth_image = output[0].data.squeeze().cpu().numpy( ).astype(np.float32) input_gt_depth_image /= np.max(input_gt_depth_image) pred_depth_image /= np.max(pred_depth_image) plot.imsave( './result/input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image) plot.imsave( './result/gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, cmap="viridis") plot.imsave( './result/pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, cmap="viridis") loss_local += loss_fn(output, depth_var) num_samples += 1 err = float(loss_local) / num_samples print('val_error: %f' % err) if err < best_val_err or epoch == num_epochs - 1: best_val_err = err torch.save( { 'epoch': start_epoch + epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, './model/model_' + str(start_epoch + epoch + 1) + '.pth') if epoch % 10 == 0: learning_rate = learning_rate * 0.8
def main(): batch_size = args.batch_size data_path = 'nyu_depth_v2_labeled.mat' learning_rate = args.lr #1.0e-4 #1.0e-5 monentum = 0.9 weight_decay = 0.0005 num_epochs = args.epochs step_size = args.step_size step_gamma = args.step_gamma resume_from_file = False isDataAug = args.data_aug max_depth = 1000 # 1.Load data train_lists, val_lists, test_lists = load_split() print("Loading data......") train_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, train_lists), batch_size=batch_size, shuffle=True, drop_last=True) val_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, val_lists), batch_size=batch_size, shuffle=True, drop_last=True) test_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, test_lists), batch_size=batch_size, shuffle=False, drop_last=True) print(train_loader) # 2.set the model print("Set the model......") model = FCRN(batch_size) resnet = torchvision.models.resnet50() # 加载训练到一半的模型 # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth')) # print("resnet50 params loaded.") # model.load_state_dict(load_weights(model, weights_file, dtype)) model = model.cuda() # 3.Loss # loss_fn = torch.nn.MSELoss().cuda() if args.loss_type == "berhu": loss_fn = criteria.berHuLoss().cuda() print("berhu loss_fn set.") elif args.loss_type == "L1": loss_fn = criteria.MaskedL1Loss().cuda() print("L1 loss_fn set.") elif args.loss_type == "mse": loss_fn = criteria.MaskedMSELoss().cuda() print("MSE loss_fn set.") elif args.loss_type == "ssim": loss_fn = criteria.SsimLoss().cuda() print("Ssim loss_fn set.") elif args.loss_type == "three": loss_fn = criteria.Ssim_grad_L1().cuda() print("SSIM+L1+Grad loss_fn set.") # 5.Train best_val_err = 1.0e3 # validate model.eval() num_correct, num_samples = 0, 0 loss_local = 0 with torch.no_grad(): for input, depth in val_loader: input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) input_rgb_image = input_var[0].data.permute( 1, 2, 0).cpu().numpy().astype(np.uint8) input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype( np.float32) pred_depth_image = output[0].data.squeeze().cpu().numpy().astype( np.float32) input_gt_depth_image /= np.max(input_gt_depth_image) pred_depth_image /= np.max(pred_depth_image) plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image) plot.imsave('./result/gt_depth_epoch_0.png', input_gt_depth_image, cmap="viridis") plot.imsave('pred_depth_epoch_0.png', pred_depth_image, cmap="viridis") # depth_var = depth_var[:, 0, :, :] # loss_fn_local = torch.nn.MSELoss() loss_local += loss_fn(output, depth_var) num_samples += 1 err = float(loss_local) / num_samples print('val_error before train:', err) start_epoch = 0 resume_file = 'checkpoint.pth.tar' if resume_from_file: if os.path.isfile(resume_file): print("=> loading checkpoint '{}'".format(resume_file)) checkpoint = torch.load(resume_file) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( resume_file, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(resume_file)) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) scheduler = StepLR(optimizer, step_size=step_size, gamma=step_gamma) # may change to other value for epoch in range(num_epochs): # 4.Optim # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum) # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay) print("optimizer set.") print('Starting train epoch %d / %d' % (start_epoch + epoch + 1, num_epochs)) model.train() running_loss = 0 count = 0 epoch_loss = 0 #for i, (input, depth) in enumerate(train_loader): for input, depth in train_loader: print("depth", depth) if isDataAug: depth = depth * 1000 depth = torch.clamp(depth, 10, 1000) depth = max_depth / depth input_var = Variable( input.type(dtype)) # variable is for derivative depth_var = Variable( depth.type(dtype)) # variable is for derivative # print("depth_var",depth_var) output = model(input_var) loss = loss_fn(output, depth_var) print('loss:', loss.data.cpu()) count += 1 running_loss += loss.data.cpu().numpy() optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss = running_loss / count print('epoch loss:', epoch_loss) # validate model.eval() num_correct, num_samples = 0, 0 loss_local = 0 with torch.no_grad(): for input, depth in val_loader: if isDataAug: depth = depth * 1000 depth = torch.clamp(depth, 10, 1000) depth = max_depth / depth input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) input_rgb_image = input_var[0].data.permute( 1, 2, 0).cpu().numpy().astype(np.uint8) input_gt_depth_image = depth_var[0][0].data.cpu().numpy( ).astype(np.float32) pred_depth_image = output[0].data.squeeze().cpu().numpy( ).astype(np.float32) # normalization input_gt_depth_image /= np.max(input_gt_depth_image) pred_depth_image /= np.max(pred_depth_image) plot.imsave( './result/input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image) plot.imsave( './result/gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, cmap="viridis") plot.imsave( './result/pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, cmap="viridis") # depth_var = depth_var[:, 0, :, :] # loss_fn_local = torch.nn.MSELoss() loss_local += loss_fn(output, depth_var) num_samples += 1 if epoch % 10 == 9: PATH = args.loss_type + '.pth' torch.save(model.state_dict(), PATH) err = float(loss_local) / num_samples print('val_error:', err) if err < best_val_err: best_val_err = err torch.save( { 'epoch': start_epoch + epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, 'checkpoint.pth.tar') scheduler.step()
def main(): batch_size = 32 data_path = 'augmented_dataset.mat' learning_rate = 1.0e-5 monentum = 0.9 weight_decay = 0.0005 num_epochs = 50 resume_from_file = False # 1.Load data train_lists, val_lists, test_lists = load_split() print("Loading data......") train_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, train_lists), batch_size=batch_size, shuffle=True, drop_last=True) val_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, val_lists), batch_size=batch_size, shuffle=True, drop_last=True) test_loader = torch.utils.data.DataLoader(NyuDepthLoader( data_path, test_lists), batch_size=batch_size, shuffle=True, drop_last=True) print(train_loader) # 2.Load model print("Loading model......") model = FCRN(batch_size) #resnet = torchvision.models.resnet50(pretrained=True) resnet = torchvision.models.resnet50() resnet.load_state_dict(torch.load('/content/model/resnet50-19c8e357.pth')) #resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth')) print("resnet50 loaded.") resnet50_pretrained_dict = resnet.state_dict() #----------------------------------------------------------- #model.load_state_dict(load_weights(model, weights_file, dtype)) chkp = torch.load('our_checkpoint.pth.tar') model.load_state_dict(chkp['state_dict']) #----------------------------------------------------------- """ print('\nresnet50 keys:\n') for key, value in resnet50_pretrained_dict.items(): print(key, value.size()) """ #model_dict = model.state_dict() """ print('\nmodel keys:\n') for key, value in model_dict.items(): print(key, value.size()) print("resnet50.dict loaded.") """ # load pretrained weights #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict} print("resnet50_pretrained_dict loaded.") """ print('\nresnet50_pretrained keys:\n') for key, value in resnet50_pretrained_dict.items(): print(key, value.size()) """ #model_dict.update(resnet50_pretrained_dict) print("model_dict updated.") """ print('\nupdated model dict keys:\n') for key, value in model_dict.items(): print(key, value.size()) """ #model.load_state_dict(model_dict) print("model_dict loaded.") model = model.cuda() # 3.Loss loss_fn = torch.nn.MSELoss().cuda() print("loss_fn set.") # 5.Train best_val_err = 1.0e3 # validate model.eval() num_correct, num_samples = 0, 0 loss_local = 0 with torch.no_grad(): for input, depth in val_loader: input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) input_rgb_image = input_var[0].data.permute( 1, 2, 0).cpu().numpy().astype(np.uint8) input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype( np.float32) pred_depth_image = output[0].data.squeeze().cpu().numpy().astype( np.float32) input_gt_depth_image /= np.max(input_gt_depth_image) pred_depth_image /= np.max(pred_depth_image) plot.imsave('input_rgb_epoch_0.png', input_rgb_image) plot.imsave('gt_depth_epoch_0.png', input_gt_depth_image, cmap="viridis") plot.imsave('pred_depth_epoch_0.png', pred_depth_image, cmap="viridis") # depth_var = depth_var[:, 0, :, :] # loss_fn_local = torch.nn.MSELoss() loss_local += loss_fn(output, depth_var) num_samples += 1 err = float(loss_local) / num_samples print('val_error before train:', err) start_epoch = 0 resume_file = 'checkpoint.pth.tar' if resume_from_file: if os.path.isfile(resume_file): print("=> loading checkpoint '{}'".format(resume_file)) checkpoint = torch.load(resume_file) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( resume_file, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(resume_file)) for epoch in range(num_epochs): # 4.Optim optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum) # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay) print("optimizer set.") print('Starting train epoch %d / %d' % (start_epoch + epoch + 1, num_epochs)) model.train() running_loss = 0 count = 0 epoch_loss = 0 #for i, (input, depth) in enumerate(train_loader): for input, depth in train_loader: # input, depth = data #input_var = input.cuda() #depth_var = depth.cuda() input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) loss = loss_fn(output, depth_var) print('loss:', loss.data.cpu()) count += 1 running_loss += loss.data.cpu().numpy() optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss = running_loss / count print('epoch loss:', epoch_loss) # validate model.eval() num_correct, num_samples = 0, 0 loss_local = 0 with torch.no_grad(): for input, depth in val_loader: input_var = Variable(input.type(dtype)) depth_var = Variable(depth.type(dtype)) output = model(input_var) input_rgb_image = input_var[0].data.permute( 1, 2, 0).cpu().numpy().astype(np.uint8) input_gt_depth_image = depth_var[0][0].data.cpu().numpy( ).astype(np.float32) pred_depth_image = output[0].data.squeeze().cpu().numpy( ).astype(np.float32) input_gt_depth_image /= np.max(input_gt_depth_image) pred_depth_image /= np.max(pred_depth_image) plot.imsave( 'input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image) plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, cmap="viridis") plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, cmap="viridis") # depth_var = depth_var[:, 0, :, :] # loss_fn_local = torch.nn.MSELoss() loss_local += loss_fn(output, depth_var) num_samples += 1 err = float(loss_local) / num_samples print('val_error:', err) if err < best_val_err: best_val_err = err torch.save( { 'epoch': start_epoch + epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, 'checkpoint.pth.tar') if epoch % 10 == 0: learning_rate = learning_rate * 0.6