def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # Create model model = BtsModel(args) model.train() model.decoder.apply(weights_init_xavier) set_misc(model) num_params = sum([np.prod(p.size()) for p in model.parameters()]) print("Total number of parameters: {}".format(num_params)) num_params_update = sum( [np.prod(p.shape) for p in model.parameters() if p.requires_grad]) print("Total number of learning parameters: {}".format(num_params_update)) if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.cuda() if args.distributed: print("Model Initialized on GPU: {}".format(args.gpu)) else: print("Model Initialized") global_step = 0 best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3 best_eval_measures_higher_better = torch.zeros(3).cpu() best_eval_steps = np.zeros(9, dtype=np.int32) # Training parameters optimizer = torch.optim.AdamW([{ 'params': model.module.encoder.parameters(), 'weight_decay': args.weight_decay }, { 'params': model.module.decoder.parameters(), 'weight_decay': 0 }], lr=args.learning_rate, eps=args.adam_eps) model_just_loaded = False if args.checkpoint_path != '': if os.path.isfile(args.checkpoint_path): print("Loading checkpoint '{}'".format(args.checkpoint_path)) if args.gpu is None: checkpoint = torch.load(args.checkpoint_path) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.checkpoint_path, map_location=loc) global_step = checkpoint['global_step'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) try: best_eval_measures_higher_better = checkpoint[ 'best_eval_measures_higher_better'].cpu() best_eval_measures_lower_better = checkpoint[ 'best_eval_measures_lower_better'].cpu() best_eval_steps = checkpoint['best_eval_steps'] except KeyError: print("Could not load values for online evaluation") print("Loaded checkpoint '{}' (global_step {})".format( args.checkpoint_path, checkpoint['global_step'])) else: print("No checkpoint found at '{}'".format(args.checkpoint_path)) model_just_loaded = True if args.retrain: global_step = 0 cudnn.benchmark = True dataloader = BtsDataLoader(args) dataloader_eval = BtsDataLoader(args) # Logging if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer = SummaryWriter(os.path.join(args.log_directory, args.model_name, 'summaries'), flush_secs=30) if args.do_online_eval: if args.eval_summary_directory != '': eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name) else: eval_summary_path = os.path.join(args.log_directory, 'eval') eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30) silog_criterion = silog_loss(variance_focus=args.variance_focus) start_time = time.time() duration = 0 num_log_images = args.batch_size end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate var_sum = [var.sum() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) print("Initial variables' sum: {:.3f}, avg: {:.3f}".format( var_sum, var_sum / var_cnt)) steps_per_epoch = len(dataloader.data) num_total_steps = args.num_epochs * steps_per_epoch epoch = global_step // steps_per_epoch while epoch < args.num_epochs: if args.distributed: dataloader.train_sampler.set_epoch(epoch) for step, sample_batched in enumerate(dataloader.data): optimizer.zero_grad() before_op_time = time.time() mask = torch.autograd.Variable(sample_batched['mask'].cuda( args.gpu, non_blocking=True)) image = torch.autograd.Variable(sample_batched['image'].cuda( args.gpu, non_blocking=True)) focal = torch.autograd.Variable(sample_batched['focal'].cuda( args.gpu, non_blocking=True)) depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda( args.gpu, non_blocking=True)) lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) if args.dataset == 'nyu': mask = depth_gt > 0.1 elif args.dataset == 'kitti': mask = depth_gt > 1.0 loss = silog_criterion.forward(depth_est, depth_gt, mask.to('cuda:0')) loss.backward() for param_group in optimizer.param_groups: current_lr = (args.learning_rate - end_learning_rate) * ( 1 - global_step / num_total_steps)**0.9 + end_learning_rate param_group['lr'] = current_lr optimizer.step() if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print( '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}' .format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) if np.isnan(loss.cpu().item()): print('NaN in loss occurred. Aborting training.') return -1 duration += time.time() - before_op_time if global_step and global_step % args.log_freq == 0 and not model_just_loaded: var_sum = [ var.sum() for var in model.parameters() if var.requires_grad ] var_cnt = len(var_sum) var_sum = np.sum(var_sum) examples_per_sec = args.batch_size / duration * args.log_freq duration = 0 time_sofar = (time.time() - start_time) / 3600 training_time_left = (num_total_steps / global_step - 1.0) * time_sofar if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("{}".format(args.model_name)) print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' print( print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item() / var_cnt, time_sofar, training_time_left)) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.add_scalar('silog_loss', loss, global_step) writer.add_scalar('learning_rate', current_lr, global_step) writer.add_scalar('var average', var_sum.item() / var_cnt, global_step) depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt) for i in range(num_log_images): writer.add_image( 'depth_gt/image/{}'.format(i), normalize_result(1 / depth_gt[i, :, :, :].data), global_step) writer.add_image( 'depth_est/image/{}'.format(i), normalize_result(1 / depth_est[i, :, :, :].data), global_step) writer.add_image( 'reduc1x1/image/{}'.format(i), normalize_result(1 / reduc1x1[i, :, :, :].data), global_step) writer.add_image( 'lpg2x2/image/{}'.format(i), normalize_result(1 / lpg2x2[i, :, :, :].data), global_step) writer.add_image( 'lpg4x4/image/{}'.format(i), normalize_result(1 / lpg4x4[i, :, :, :].data), global_step) writer.add_image( 'lpg8x8/image/{}'.format(i), normalize_result(1 / lpg8x8[i, :, :, :].data), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.flush() if not args.do_online_eval and global_step and global_step % args.save_freq == 0: if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save( checkpoint, args.log_directory + '/' + args.model_name + '/model-{}'.format(global_step)) if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded: time.sleep(0.1) model.eval() eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node) if eval_measures is not None: for i in range(9): eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step)) measure = eval_measures[i] is_best = False if i < 6 and measure < best_eval_measures_lower_better[ i]: old_best = best_eval_measures_lower_better[i].item( ) best_eval_measures_lower_better[i] = measure.item() is_best = True elif i >= 6 and measure > best_eval_measures_higher_better[ i - 6]: old_best = best_eval_measures_higher_better[ i - 6].item() best_eval_measures_higher_better[ i - 6] = measure.item() is_best = True if is_best: old_best_step = best_eval_steps[i] old_best_name = '/model-{}-best_{}_{:.5f}'.format( old_best_step, eval_metrics[i], old_best) model_path = args.log_directory + '/' + args.model_name + old_best_name if os.path.exists(model_path): command = 'rm {}'.format(model_path) os.system(command) best_eval_steps[i] = global_step model_save_name = '/model-{}-best_{}_{:.5f}'.format( global_step, eval_metrics[i], measure) print('New best for {}. Saving model: {}'.format( eval_metrics[i], model_save_name)) checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_eval_measures_higher_better': best_eval_measures_higher_better, 'best_eval_measures_lower_better': best_eval_measures_lower_better, 'best_eval_steps': best_eval_steps } torch.save( checkpoint, args.log_directory + '/' + args.model_name + model_save_name) eval_summary_writer.flush() model.train() block_print() set_misc(model) enable_print() model_just_loaded = False global_step += 1 epoch += 1
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # Create model model = BtsModel(args) model.train() model.decoder.apply(weights_init_xavier) if args.bn_no_track_stats: print("Disabling tracking running stats in batch norm layers") model.apply(bn_init_as_tf) if args.fix_first_conv_blocks: if 'resne' in args.encoder: fixing_layers = [ 'base_model.conv1', 'base_model.layer1.0', 'base_model.layer1.1', '.bn' ] else: fixing_layers = [ 'conv0', 'denseblock1.denselayer1', 'denseblock1.denselayer2', 'norm' ] print("Fixing first two conv blocks") elif args.fix_first_conv_block: if 'resne' in args.encoder: fixing_layers = ['base_model.conv1', 'base_model.layer1.0', '.bn'] else: fixing_layers = ['conv0', 'denseblock1.denselayer1', 'norm'] print("Fixing first conv block") else: if 'resne' in args.encoder: fixing_layers = ['base_model.conv1', '.bn'] else: fixing_layers = ['conv0', 'norm'] print("Fixing first conv layer") for name, child in model.named_children(): if not 'encoder' in name: continue for name2, parameters in child.named_parameters(): # print(name, name2) if any(x in name2 for x in fixing_layers): parameters.requires_grad = False num_params = sum([np.prod(p.size()) for p in model.parameters()]) print("Total number of parameters: {}".format(num_params)) num_params_update = sum( [np.prod(p.shape) for p in model.parameters() if p.requires_grad]) print("Total number of learning parameters: {}".format(num_params_update)) if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.cuda() if args.distributed: print("Model Initialized on GPU: {}".format(args.gpu)) else: print("Model Initialized") global_step = 0 # Training parameters optimizer = torch.optim.AdamW([{ 'params': model.module.encoder.parameters(), 'weight_decay': args.weight_decay }, { 'params': model.module.decoder.parameters(), 'weight_decay': 0 }], lr=args.learning_rate, eps=1e-6) # optimizer = torch.optim.AdamW(model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate, eps=1e-3) if args.checkpoint_path != '': if os.path.isfile(args.checkpoint_path): print("Loading checkpoint '{}'".format(args.checkpoint_path)) if args.gpu is None: checkpoint = torch.load(args.checkpoint_path) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.checkpoint_path, map_location=loc) global_step = checkpoint['global_step'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (global_step {})".format( args.checkpoint_path, checkpoint['global_step'])) else: print("No checkpoint found at '{}'".format(args.checkpoint_path)) if args.retrain: global_step = 0 cudnn.benchmark = True dataloader = BtsDataLoader(args) # Logging if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30) silog_criterion = silog_loss(variance_focus=args.variance_focus) start_time = time.time() duration = 0 log_freq = 100 save_freq = 500 num_log_images = args.batch_size end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate var_sum = [var.sum() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) print("Initial variables' sum: {:.3f}, avg: {:.3f}".format( var_sum, var_sum / var_cnt)) steps_per_epoch = len(dataloader.data) num_total_steps = args.num_epochs * steps_per_epoch epoch = global_step // steps_per_epoch while epoch < args.num_epochs: if args.distributed: dataloader.train_sampler.set_epoch(epoch) for step, sample_batched in enumerate(dataloader.data): optimizer.zero_grad() before_op_time = time.time() image = torch.autograd.Variable(sample_batched['image'].cuda( args.gpu, non_blocking=True)) focal = torch.autograd.Variable(sample_batched['focal'].cuda( args.gpu, non_blocking=True)) depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda( args.gpu, non_blocking=True)) lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) if args.dataset == 'nyu': mask = depth_gt > 0.1 else: mask = depth_gt > 1.0 loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool)) loss.backward() for param_group in optimizer.param_groups: current_lr = (args.learning_rate - end_learning_rate) * ( 1 - global_step / num_total_steps)**0.9 + end_learning_rate param_group['lr'] = current_lr optimizer.step() if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print( '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}' .format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) duration += time.time() - before_op_time if global_step and global_step % log_freq == 0: var_sum = [ var.sum() for var in model.parameters() if var.requires_grad ] var_cnt = len(var_sum) var_sum = np.sum(var_sum) examples_per_sec = args.batch_size / duration * log_freq duration = 0 time_sofar = (time.time() - start_time) / 3600 training_time_left = (num_total_steps / global_step - 1.0) * time_sofar if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("{}".format(args.model_name)) print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' print( print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item() / var_cnt, time_sofar, training_time_left)) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.add_scalar('silog_loss', loss, global_step) writer.add_scalar('learning_rate', current_lr, global_step) writer.add_scalar('var average', var_sum.item() / var_cnt, global_step) depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt) for i in range(num_log_images): writer.add_image( 'depth_gt/image/{}'.format(i), normalize_result(1 / depth_gt[i, :, :, :].data), global_step) writer.add_image( 'depth_est/image/{}'.format(i), normalize_result(1 / depth_est[i, :, :, :].data), global_step) writer.add_image( 'reduc1x1/image/{}'.format(i), normalize_result(1 / reduc1x1[i, :, :, :].data), global_step) writer.add_image( 'lpg2x2/image/{}'.format(i), normalize_result(1 / lpg2x2[i, :, :, :].data), global_step) writer.add_image( 'lpg4x4/image/{}'.format(i), normalize_result(1 / lpg4x4[i, :, :, :].data), global_step) writer.add_image( 'lpg8x8/image/{}'.format(i), normalize_result(1 / lpg8x8[i, :, :, :].data), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.flush() if global_step and global_step % save_freq == 0: if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): checkpoint = { 'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save( checkpoint, args.log_directory + '/' + args.model_name + '/model-{}'.format(global_step)) global_step += 1 epoch += 1