def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # Create model model = BtsModel(args) model.train() model.decoder.apply(weights_init_xavier) set_misc(model) num_params = sum([np.prod(p.size()) for p in model.parameters()]) print("Total number of parameters: {}".format(num_params)) num_params_update = sum( [np.prod(p.shape) for p in model.parameters() if p.requires_grad]) print("Total number of learning parameters: {}".format(num_params_update)) if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.cuda() if args.distributed: print("Model Initialized on GPU: {}".format(args.gpu)) else: print("Model Initialized") global_step = 0 best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3 best_eval_measures_higher_better = torch.zeros(3).cpu() best_eval_steps = np.zeros(9, dtype=np.int32) # Training parameters optimizer = torch.optim.AdamW([{ 'params': model.module.encoder.parameters(), 'weight_decay': args.weight_decay }, { 'params': model.module.decoder.parameters(), 'weight_decay': 0 }], lr=args.learning_rate, eps=args.adam_eps) model_just_loaded = False if args.checkpoint_path != '': if os.path.isfile(args.checkpoint_path): print("Loading checkpoint '{}'".format(args.checkpoint_path)) if args.gpu is None: checkpoint = torch.load(args.checkpoint_path) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.checkpoint_path, map_location=loc) global_step = checkpoint['global_step'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) try: best_eval_measures_higher_better = checkpoint[ 'best_eval_measures_higher_better'].cpu() best_eval_measures_lower_better = checkpoint[ 'best_eval_measures_lower_better'].cpu() best_eval_steps = checkpoint['best_eval_steps'] except KeyError: print("Could not load values for online evaluation") print("Loaded checkpoint '{}' (global_step {})".format( args.checkpoint_path, checkpoint['global_step'])) else: print("No checkpoint found at '{}'".format(args.checkpoint_path)) model_just_loaded = True if args.retrain: global_step = 0 cudnn.benchmark = True dataloader = BtsDataLoader(args) dataloader_eval = BtsDataLoader(args) # Logging if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer = SummaryWriter(os.path.join(args.log_directory, args.model_name, 'summaries'), flush_secs=30) if args.do_online_eval: if args.eval_summary_directory != '': eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name) else: eval_summary_path = os.path.join(args.log_directory, 'eval') eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30) silog_criterion = silog_loss(variance_focus=args.variance_focus) start_time = time.time() duration = 0 num_log_images = args.batch_size end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate var_sum = [var.sum() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) print("Initial variables' sum: {:.3f}, avg: {:.3f}".format( var_sum, var_sum / var_cnt)) steps_per_epoch = len(dataloader.data) num_total_steps = args.num_epochs * steps_per_epoch epoch = global_step // steps_per_epoch while epoch < args.num_epochs: if args.distributed: dataloader.train_sampler.set_epoch(epoch) for step, sample_batched in enumerate(dataloader.data): optimizer.zero_grad() before_op_time = time.time() mask = torch.autograd.Variable(sample_batched['mask'].cuda( args.gpu, non_blocking=True)) image = torch.autograd.Variable(sample_batched['image'].cuda( args.gpu, non_blocking=True)) focal = torch.autograd.Variable(sample_batched['focal'].cuda( args.gpu, non_blocking=True)) depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda( args.gpu, non_blocking=True)) lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) if args.dataset == 'nyu': mask = depth_gt > 0.1 elif args.dataset == 'kitti': mask = depth_gt > 1.0 loss = silog_criterion.forward(depth_est, depth_gt, mask.to('cuda:0')) loss.backward() for param_group in optimizer.param_groups: current_lr = (args.learning_rate - end_learning_rate) * ( 1 - global_step / num_total_steps)**0.9 + end_learning_rate param_group['lr'] = current_lr optimizer.step() if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print( '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}' .format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) if np.isnan(loss.cpu().item()): print('NaN in loss occurred. Aborting training.') return -1 duration += time.time() - before_op_time if global_step and global_step % args.log_freq == 0 and not model_just_loaded: var_sum = [ var.sum() for var in model.parameters() if var.requires_grad ] var_cnt = len(var_sum) var_sum = np.sum(var_sum) examples_per_sec = args.batch_size / duration * args.log_freq duration = 0 time_sofar = (time.time() - start_time) / 3600 training_time_left = (num_total_steps / global_step - 1.0) * time_sofar if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("{}".format(args.model_name)) print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' print( print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item() / var_cnt, time_sofar, training_time_left)) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.add_scalar('silog_loss', loss, global_step) writer.add_scalar('learning_rate', current_lr, global_step) writer.add_scalar('var average', var_sum.item() / var_cnt, global_step) depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt) for i in range(num_log_images): writer.add_image( 'depth_gt/image/{}'.format(i), normalize_result(1 / depth_gt[i, :, :, :].data), global_step) writer.add_image( 'depth_est/image/{}'.format(i), normalize_result(1 / depth_est[i, :, :, :].data), global_step) writer.add_image( 'reduc1x1/image/{}'.format(i), normalize_result(1 / reduc1x1[i, :, :, :].data), global_step) writer.add_image( 'lpg2x2/image/{}'.format(i), normalize_result(1 / lpg2x2[i, :, :, :].data), global_step) writer.add_image( 'lpg4x4/image/{}'.format(i), normalize_result(1 / lpg4x4[i, :, :, :].data), global_step) writer.add_image( 'lpg8x8/image/{}'.format(i), normalize_result(1 / lpg8x8[i, :, :, :].data), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.flush() if not args.do_online_eval and global_step and global_step % args.save_freq == 0: if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save( checkpoint, args.log_directory + '/' + args.model_name + '/model-{}'.format(global_step)) if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded: time.sleep(0.1) model.eval() eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node) if eval_measures is not None: for i in range(9): eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step)) measure = eval_measures[i] is_best = False if i < 6 and measure < best_eval_measures_lower_better[ i]: old_best = best_eval_measures_lower_better[i].item( ) best_eval_measures_lower_better[i] = measure.item() is_best = True elif i >= 6 and measure > best_eval_measures_higher_better[ i - 6]: old_best = best_eval_measures_higher_better[ i - 6].item() best_eval_measures_higher_better[ i - 6] = measure.item() is_best = True if is_best: old_best_step = best_eval_steps[i] old_best_name = '/model-{}-best_{}_{:.5f}'.format( old_best_step, eval_metrics[i], old_best) model_path = args.log_directory + '/' + args.model_name + old_best_name if os.path.exists(model_path): command = 'rm {}'.format(model_path) os.system(command) best_eval_steps[i] = global_step model_save_name = '/model-{}-best_{}_{:.5f}'.format( global_step, eval_metrics[i], measure) print('New best for {}. Saving model: {}'.format( eval_metrics[i], model_save_name)) checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_eval_measures_higher_better': best_eval_measures_higher_better, 'best_eval_measures_lower_better': best_eval_measures_lower_better, 'best_eval_steps': best_eval_steps } torch.save( checkpoint, args.log_directory + '/' + args.model_name + model_save_name) eval_summary_writer.flush() model.train() block_print() set_misc(model) enable_print() model_just_loaded = False global_step += 1 epoch += 1
def test_images(params): """Test function.""" model = BtsModel(params=args) model = torch.nn.DataParallel(model) checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['model']) model.eval() model.cuda() # apply transformations loader_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) pred_depths = [] save_name = args.save_name if not os.path.exists(save_name + 'raw'): os.mkdir(save_name + 'raw') with torch.no_grad(): gt_file = open(args.media_path, "r") num_test_samples = len(gt_file.readlines()) print(num_test_samples) max_distance_list = [10, 25, 50, 100, 150, 200] for max_dist in max_distance_list: globals()['silog_%s' % max_dist] = np.zeros( num_test_samples, np.float32) globals()['log10_%s' % max_dist] = np.zeros( num_test_samples, np.float32) globals()['rms_%s' % max_dist] = np.zeros(num_test_samples, np.float32) globals()['log_rms_%s' % max_dist] = np.zeros( num_test_samples, np.float32) globals()['abs_rel_%s' % max_dist] = np.zeros( num_test_samples, np.float32) globals()['sq_rel_%s' % max_dist] = np.zeros( num_test_samples, np.float32) globals()['d1_%s' % max_dist] = np.zeros(num_test_samples, np.float32) globals()['d2_%s' % max_dist] = np.zeros(num_test_samples, np.float32) globals()['d3_%s' % max_dist] = np.zeros(num_test_samples, np.float32) i = 0 gt_file = open(args.media_path, "r") for x in gt_file: output = x.split() file_sub_path = output[0] focal = output[2] print(file_sub_path, focal) media_path = os.path.join("../train_val_split/val/rgb/", file_sub_path) save_dir = os.path.join(args.save_name, "raw", os.path.dirname(file_sub_path)) if not os.path.exists(save_dir): os.makedirs(save_dir) image = cv2.imread(media_path) height, width, _ = image.shape #check divisibility be 32 adjusted_height = lambda height: 32 * (math.ceil( height / 32)) if height % 32 != 0 else height adjusted_width = lambda height: 32 * (math.ceil( width / 32)) if width % 32 != 0 else width image_original = cv2.resize( image, (adjusted_width(width), adjusted_height(height)), interpolation=cv2.INTER_LANCZOS4) image = loader_transforms(image_original).float().cuda() image = image.unsqueeze(0) _, _, _, _, depth_est = model(image, focal) depth_est = depth_est.cpu().numpy().squeeze() depth_est_scaled = cv2.resize(depth_est, (1920, 1200), interpolation=cv2.INTER_LANCZOS4) filename_pred_png = os.path.join( save_dir, os.path.splitext(os.path.basename(media_path))[0] + ".png") pred_depth_scaled = depth_est_scaled * 256.0 pred_depth_scaled = pred_depth_scaled.astype(np.uint16) cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0]) gt_depth_path = os.path.join("../train_val_split/val/depth/", file_sub_path) depth = cv2.imread(gt_depth_path, -1) if depth is None: print('Missing: %s ' % gt_depth_path) missing_ids.add(i) continue gt_depth = depth.astype(np.float32) / 256.0 gt_depth_copy = np.copy(gt_depth) gt_depth[gt_depth > 0] = 1 #mask pred_depth_maskd = np.where(gt_depth, depth_est_scaled, 0) for max_dist in max_distance_list: pred_depth_maskd[pred_depth_maskd < args.min_depth_eval] = args.min_depth_eval pred_depth_maskd[pred_depth_maskd > max_dist] = max_dist pred_depth_maskd[np.isinf(pred_depth_maskd)] = max_dist gt_depth_copy[np.isinf(gt_depth_copy)] = 0 gt_depth_copy[np.isnan(gt_depth_copy)] = 0 valid_mask = np.logical_and( gt_depth_copy > args.min_depth_eval, gt_depth_copy < max_dist) if max_dist == 10: silog_10[i], log10_10[i], abs_rel_10[i], sq_rel_10[ i], rms_10[i], log_rms_10[i], d1_10[i], d2_10[ i], d3_10[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) elif max_dist == 25: silog_25[i], log10_25[i], abs_rel_25[i], sq_rel_25[ i], rms_25[i], log_rms_25[i], d1_25[i], d2_25[ i], d3_25[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) elif max_dist == 50: silog_50[i], log10_50[i], abs_rel_50[i], sq_rel_50[ i], rms_50[i], log_rms_50[i], d1_50[i], d2_50[ i], d3_50[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) elif max_dist == 100: silog_100[i], log10_100[i], abs_rel_100[i], sq_rel_100[ i], rms_100[i], log_rms_100[i], d1_100[i], d2_100[ i], d3_100[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) elif max_dist == 150: silog_150[i], log10_150[i], abs_rel_150[i], sq_rel_150[ i], rms_150[i], log_rms_150[i], d1_150[i], d2_150[ i], d3_150[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) elif max_dist == 200: silog_200[i], log10_200[i], abs_rel_200[i], sq_rel_200[ i], rms_200[i], log_rms_200[i], d1_200[i], d2_200[ i], d3_200[i] = compute_errors( gt_depth_copy[valid_mask], pred_depth_maskd[valid_mask]) print(silog_200[i], log10_200[i], abs_rel_200[i], sq_rel_200[i], rms_200[i], log_rms_200[i], d1_200[i], d2_200[i], d3_200[i]) i += 1 eval_dump_dist = {} eval_dump_dist[max_distance_list[0]] = { 'd1': float(d1_10.mean()), 'd2': float(d2_10.mean()), 'd3': float(d3_10.mean()), 'abs_rel': float(abs_rel_10.mean()), 'sq_rel': float(sq_rel_10.mean()), 'rms': float(rms_10.mean()), 'log_rms': float(log_rms_10.mean()), 'silog': float(silog_10.mean()), 'log10': float(log10_10.mean()) } eval_dump_dist[max_distance_list[1]] = { 'd1': float(d1_25.mean()), 'd2': float(d2_25.mean()), 'd3': float(d3_25.mean()), 'abs_rel': float(abs_rel_25.mean()), 'sq_rel': float(sq_rel_25.mean()), 'rms': float(rms_25.mean()), 'log_rms': float(log_rms_25.mean()), 'silog': float(silog_25.mean()), 'log10': float(log10_25.mean()) } eval_dump_dist[max_distance_list[2]] = { 'd1': float(d1_50.mean()), 'd2': float(d2_50.mean()), 'd3': float(d3_50.mean()), 'abs_rel': float(abs_rel_50.mean()), 'sq_rel': float(sq_rel_50.mean()), 'rms': float(rms_50.mean()), 'log_rms': float(log_rms_50.mean()), 'silog': float(silog_50.mean()), 'log10': float(log10_50.mean()) } eval_dump_dist[max_distance_list[3]] = { 'd1': float(d1_100.mean()), 'd2': float(d2_100.mean()), 'd3': float(d3_100.mean()), 'abs_rel': float(abs_rel_100.mean()), 'sq_rel': float(sq_rel_100.mean()), 'rms': float(rms_100.mean()), 'log_rms': float(log_rms_100.mean()), 'silog': float(silog_100.mean()), 'log10': float(log10_100.mean()) } eval_dump_dist[max_distance_list[4]] = { 'd1': float(d1_150.mean()), 'd2': float(d2_150.mean()), 'd3': float(d3_150.mean()), 'abs_rel': float(abs_rel_150.mean()), 'sq_rel': float(sq_rel_150.mean()), 'rms': float(rms_150.mean()), 'log_rms': float(log_rms_150.mean()), 'silog': float(silog_150.mean()), 'log10': float(log10_150.mean()) } eval_dump_dist[max_distance_list[5]] = { 'd1': float(d1_200.mean()), 'd2': float(d2_200.mean()), 'd3': float(d3_200.mean()), 'abs_rel': float(abs_rel_200.mean()), 'sq_rel': float(sq_rel_200.mean()), 'rms': float(rms_200.mean()), 'log_rms': float(log_rms_200.mean()), 'silog': float(silog_200.mean()), 'log10': float(log10_200.mean()) } return eval_dump_dist
def test(params): """Test function.""" args.mode = 'test' dataloader = BtsDataLoader(args, 'test') model = BtsModel(params=args) model = torch.nn.DataParallel(model) checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['model']) model.eval() model.cuda() num_test_samples = get_num_lines(args.filenames_file) with open(args.filenames_file) as f: lines = f.readlines() import glob lines = glob.glob("/home/pebert/dataset/wireframe/train/*_label.npz") #lines.sort() print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path)) pred_depths = [] pred_8x8s = [] pred_4x4s = [] pred_2x2s = [] pred_1x1s = [] start_time = time.time() with torch.no_grad(): for _, sample in enumerate(tqdm(dataloader.data)): #if _ > 300: # break image = Variable(sample['image'].cuda()) focal = [] #Variable(sample['focal'].cuda()) # Predict lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) pred_depths.append(depth_est.cpu().numpy().squeeze()) pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze()) pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze()) pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze()) pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze()) elapsed_time = time.time() - start_time print('Elapesed time: %s' % str(elapsed_time)) print('Done.') save_name = 'result_' + args.model_name print('Saving result pngs..') if not os.path.exists(os.path.dirname(save_name)): try: os.mkdir(save_name) os.mkdir(save_name + '/raw') os.mkdir(save_name + '/cmap') os.mkdir(save_name + '/rgb') os.mkdir(save_name + '/gt') except OSError as e: if e.errno != errno.EEXIST: raise for s in tqdm(range(len(lines))): if args.dataset == 'kitti': date_drive = lines[s].split('/')[1] filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[ s].split()[0].split('/')[-1].replace('.jpg', '.png') filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[ s].split()[0].split('/')[-1].replace('.jpg', '.png') filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[ s].split()[0].split('/')[-1] elif args.dataset == 'kitti_benchmark': filename_pred_png = save_name + '/raw/' + lines[s].split( )[0].split('/')[-1].replace('.jpg', '.png') filename_cmap_png = save_name + '/cmap/' + lines[s].split( )[0].split('/')[-1].replace('.jpg', '.png') filename_image_png = save_name + '/rgb/' + lines[s].split( )[0].split('/')[-1] else: scene_name = lines[s].split('/')[-1].split('_label')[ 0] #lines[s].split()[0].split('/')[0] filename_pred_png = save_name + '/raw/' + scene_name + '_' + 'depth.png' filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + 'depth.png' filename_gt_png = save_name + '/gt/' + scene_name + '_' + 'depth.png' filename_image_png = save_name + '/rgb/' + scene_name + '_' + 'depth' rgb_path = os.path.join( lines[s][:-10].replace("_a0", "").replace("_a1", "") + ".png") #os.path.join(args.data_path, './' + lines[s].split()[0]) image = cv2.imread(rgb_path) #if args.dataset == 'nyu': # gt_path = os.path.join(args.data_path, './' + lines[s].split()[1]) # gt = cv2.imread(gt_path, -1).astype(np.float32) #/ 1000.0 # Visualization purpose only # gt[gt == 0] = np.amax(gt) pred_depth = pred_depths[s] pred_8x8 = pred_8x8s[s] pred_4x4 = pred_4x4s[s] pred_2x2 = pred_2x2s[s] pred_1x1 = pred_1x1s[s] #if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark': # pred_depth_scaled = pred_depth * 256.0 #else: pred_depth_scaled = pred_depth / pred_depth.max() * 300.0 pred_depth_scaled[pred_depth_scaled < 1e-3] = 1e-3 pred_depth_scaled[pred_depth_scaled > 300] = 300 pred_depth_scaled[np.isinf(pred_depth_scaled)] = 300 pred_depth_scaled[np.isnan(pred_depth_scaled)] = 1e-3 pred_depth_scaled = pred_depth_scaled.astype(np.float32) cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0]) #'/home/pebert/bts/tmp2.png' if args.save_lpg: cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :]) if args.dataset == 'nyu': plt.imsave(filename_gt_png, np.log10(gt[10:-1 - 9, 10:-1 - 9]), cmap='Greys') pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9] plt.imsave(filename_cmap_png, np.log10(pred_depth_cropped), cmap='Greys') pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9] filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_8x8.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8_cropped), cmap='Greys') pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9] filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_4x4.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys') pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9] filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_2x2.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2_cropped), cmap='Greys') pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9] filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_1x1.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1_cropped), cmap='Greys') else: plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys') filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_8x8.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8), cmap='Greys') filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_4x4.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4), cmap='Greys') filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_2x2.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2), cmap='Greys') filename_lpg_cmap_png = filename_cmap_png.replace( '.png', '_1x1.png') plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1), cmap='Greys') return
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # Create model model = BtsModel(args) model.train() model.decoder.apply(weights_init_xavier) if args.bn_no_track_stats: print("Disabling tracking running stats in batch norm layers") model.apply(bn_init_as_tf) if args.fix_first_conv_blocks: if 'resne' in args.encoder: fixing_layers = [ 'base_model.conv1', 'base_model.layer1.0', 'base_model.layer1.1', '.bn' ] else: fixing_layers = [ 'conv0', 'denseblock1.denselayer1', 'denseblock1.denselayer2', 'norm' ] print("Fixing first two conv blocks") elif args.fix_first_conv_block: if 'resne' in args.encoder: fixing_layers = ['base_model.conv1', 'base_model.layer1.0', '.bn'] else: fixing_layers = ['conv0', 'denseblock1.denselayer1', 'norm'] print("Fixing first conv block") else: if 'resne' in args.encoder: fixing_layers = ['base_model.conv1', '.bn'] else: fixing_layers = ['conv0', 'norm'] print("Fixing first conv layer") for name, child in model.named_children(): if not 'encoder' in name: continue for name2, parameters in child.named_parameters(): # print(name, name2) if any(x in name2 for x in fixing_layers): parameters.requires_grad = False num_params = sum([np.prod(p.size()) for p in model.parameters()]) print("Total number of parameters: {}".format(num_params)) num_params_update = sum( [np.prod(p.shape) for p in model.parameters() if p.requires_grad]) print("Total number of learning parameters: {}".format(num_params_update)) if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.cuda() if args.distributed: print("Model Initialized on GPU: {}".format(args.gpu)) else: print("Model Initialized") global_step = 0 # Training parameters optimizer = torch.optim.AdamW([{ 'params': model.module.encoder.parameters(), 'weight_decay': args.weight_decay }, { 'params': model.module.decoder.parameters(), 'weight_decay': 0 }], lr=args.learning_rate, eps=1e-6) # optimizer = torch.optim.AdamW(model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate, eps=1e-3) if args.checkpoint_path != '': if os.path.isfile(args.checkpoint_path): print("Loading checkpoint '{}'".format(args.checkpoint_path)) if args.gpu is None: checkpoint = torch.load(args.checkpoint_path) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.checkpoint_path, map_location=loc) global_step = checkpoint['global_step'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (global_step {})".format( args.checkpoint_path, checkpoint['global_step'])) else: print("No checkpoint found at '{}'".format(args.checkpoint_path)) if args.retrain: global_step = 0 cudnn.benchmark = True dataloader = BtsDataLoader(args) # Logging if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30) silog_criterion = silog_loss(variance_focus=args.variance_focus) start_time = time.time() duration = 0 log_freq = 100 save_freq = 500 num_log_images = args.batch_size end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate var_sum = [var.sum() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) print("Initial variables' sum: {:.3f}, avg: {:.3f}".format( var_sum, var_sum / var_cnt)) steps_per_epoch = len(dataloader.data) num_total_steps = args.num_epochs * steps_per_epoch epoch = global_step // steps_per_epoch while epoch < args.num_epochs: if args.distributed: dataloader.train_sampler.set_epoch(epoch) for step, sample_batched in enumerate(dataloader.data): optimizer.zero_grad() before_op_time = time.time() image = torch.autograd.Variable(sample_batched['image'].cuda( args.gpu, non_blocking=True)) focal = torch.autograd.Variable(sample_batched['focal'].cuda( args.gpu, non_blocking=True)) depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda( args.gpu, non_blocking=True)) lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) if args.dataset == 'nyu': mask = depth_gt > 0.1 else: mask = depth_gt > 1.0 loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool)) loss.backward() for param_group in optimizer.param_groups: current_lr = (args.learning_rate - end_learning_rate) * ( 1 - global_step / num_total_steps)**0.9 + end_learning_rate param_group['lr'] = current_lr optimizer.step() if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print( '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}' .format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) duration += time.time() - before_op_time if global_step and global_step % log_freq == 0: var_sum = [ var.sum() for var in model.parameters() if var.requires_grad ] var_cnt = len(var_sum) var_sum = np.sum(var_sum) examples_per_sec = args.batch_size / duration * log_freq duration = 0 time_sofar = (time.time() - start_time) / 3600 training_time_left = (num_total_steps / global_step - 1.0) * time_sofar if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("{}".format(args.model_name)) print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' print( print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item() / var_cnt, time_sofar, training_time_left)) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.add_scalar('silog_loss', loss, global_step) writer.add_scalar('learning_rate', current_lr, global_step) writer.add_scalar('var average', var_sum.item() / var_cnt, global_step) depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt) for i in range(num_log_images): writer.add_image( 'depth_gt/image/{}'.format(i), normalize_result(1 / depth_gt[i, :, :, :].data), global_step) writer.add_image( 'depth_est/image/{}'.format(i), normalize_result(1 / depth_est[i, :, :, :].data), global_step) writer.add_image( 'reduc1x1/image/{}'.format(i), normalize_result(1 / reduc1x1[i, :, :, :].data), global_step) writer.add_image( 'lpg2x2/image/{}'.format(i), normalize_result(1 / lpg2x2[i, :, :, :].data), global_step) writer.add_image( 'lpg4x4/image/{}'.format(i), normalize_result(1 / lpg4x4[i, :, :, :].data), global_step) writer.add_image( 'lpg8x8/image/{}'.format(i), normalize_result(1 / lpg8x8[i, :, :, :].data), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.flush() if global_step and global_step % save_freq == 0: if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save( checkpoint, args.log_directory + '/' + args.model_name + '/model-{}'.format(global_step)) global_step += 1 epoch += 1