def runModel(config, data_dictionary, data_statistics, train_test_folds): program_start_time = time() # assign all program arguments to local variables with open(config['model']['path']) as handle: ModelDict = json.loads(handle.read()) # check if station and grid time invariant features should be used and set the list of desired parameters if not ('grid_time_invariant' in ModelDict and ModelDict['grid_time_invariant']): config['grid_time_invariant_parameters'] =[] if not ('station_time_invariant' in ModelDict and ModelDict['station_time_invariant']): config['station_parameters'] = [] # update general static model information experiment_info = config experiment_info['model'] = ModelDict experiment_info['code_commit'] = ModelUtils.get_git_revision_short_hash() # if needed, load time invariant features with open("%s/%s/grid_size_%s/time_invariant_data_per_station.pkl" % (config['input_source'], config['preprocessing'], config['original_grid_size']), "rb") as input_file: time_invarian_data = pkl.load(input_file) # initialize feature scaling function for each feature featureScaleFunctions = DataUtils.getFeatureScaleFunctions(ModelUtils.ParamNormalizationDict, data_statistics) # get optimizer config optimizer_config = config['optimizer'] # generate output path for experiment information setting_string = '%s_grid_%s_bs_%s_tf_%s_optim_%s_lr_%s_sl_%s' % ( config['model']['name'], config['grid_size'], config['batch_size'], config['test_fraction'], optimizer_config['algorithm'], optimizer_config['learning_rate'], config['slice_size']) output_path = '%s/%s' % (config['experiment_path'], setting_string) if not os.path.exists(output_path): os.makedirs(output_path) # time for the set up until first run experiment_info['set_up_time'] = time() - program_start_time print('[Time]: Set-up %s' % strftime("%H:%M:%S", gmtime(experiment_info['set_up_time']))) sys.stdout.flush() # initialize statistics error_statistics = None run_times = None skip_statistics = None if 'per_station_rmse' in config: error_per_station_statistics = None # keep used learning rates experiment_info['scheduled_learning_rates'] = [] # cross validation for run in range(config['runs']): # logger for tensorboardX train_logger = Logger(output_path + '/logs/run_%s/train' % run) test_logger = Logger(output_path + '/logs/run_%s/test' % run) print('[Run %s] Cross-validation test fold %s' % (str(run + 1), str(run + 1))) # take the right preprocessed train/test data set for the current run train_fold, test_fold = train_test_folds[run] # initialize best epoch test error best_epoch_test_rmse = float("inf") # use different data loader if we want to train a 3nn model approach if "knn" in ModelDict: # initialize train and test dataloaders trainset = DataLoaders.CosmoData3NNData( config=config, station_data_dict=data_dictionary, files=train_fold, featureScaling=featureScaleFunctions, time_invariant_data=time_invarian_data) trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn) testset = DataLoaders.CosmoData3NNData( config=config, station_data_dict=data_dictionary, files=test_fold, featureScaling=featureScaleFunctions, time_invariant_data=time_invarian_data) testloader = DataLoader(testset, batch_size=config['batch_size'], shuffle=True, num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn) else: # initialize train and test dataloaders trainset = DataLoaders.CosmoDataGridData( config=config, station_data_dict=data_dictionary, files=train_fold, featureScaling=featureScaleFunctions, time_invariant_data=time_invarian_data) trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn) testset = DataLoaders.CosmoDataGridData( config=config, station_data_dict=data_dictionary, files=test_fold, featureScaling=featureScaleFunctions, time_invariant_data=time_invarian_data) testloader = DataLoader(testset, batch_size=config['batch_size'], shuffle=True, num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn) # initialize network, optimizer and loss function net = Baseline.model_factory(ModelDict, trainset.n_parameters, trainset.n_grid_time_invariant_parameters, config['grid_size'], config['prediction_times']) # store class name experiment_info['model_class'] = net.__class__.__name__ if torch.cuda.device_count() > 1: net = nn.DataParallel(net) if torch.cuda.is_available(): net.cuda() # load number of train and test samples n_train_samples, n_test_samples = len(train_fold), len(test_fold) optimizer, scheduler = ModelUtils.initializeOptimizer(optimizer_config, net) criterion = nn.MSELoss() # keep number of processed smaples over all epochs for tensorboard processed_train_samples_global = 0 processed_test_samples_global = 0 # start learning for epoch in range(config['epochs']): epoch_train_time = np.zeros((5,)) epoch_start_time = time() print('Epoch: ' + str(epoch + 1) + '\n------------------------------------------------------------') # adapt learning rate and store information in experiment attributes if scheduler is not None: scheduler.step() if run == 0: experiment_info['scheduled_learning_rates'] += scheduler.get_lr() print('Using learning rate %s' % str(scheduler.get_lr())) # TRAINING # initialize variables for epoch statistics LABELS, MODELoutputs, COSMOoutputs = None, None, None processed_train_samples = 0 net.train(True) train_start_time = time() # loop over complete train set for i, data in enumerate(trainloader, 0): time_start = time() try: # get training batch, e.g. label, cosmo-1 output and time inv. features for station DATA = data # DATA has only length 4 if we do not use the station time invariant features if len(DATA) == 4: Blabel, Bip2d, BTimeData, init_station_temp = DATA station_time_inv_input = None elif len(DATA) == 5: Blabel, Bip2d, BTimeData, StationTimeInv, init_station_temp = DATA station_time_inv_input = ModelUtils.getVariable(StationTimeInv).float() else: raise Exception('Unknown data format for training...') input = ModelUtils.getVariable(Bip2d).float() time_data = ModelUtils.getVariable(BTimeData).float() target = ModelUtils.getVariable(Blabel).float() except TypeError: # when the batch size is small, it could happen, that all labels have been corrupted and therefore # collate_fn would return an empty list print('Value error...') continue time_after_data_preparation = time() processed_train_samples += len(Blabel) optimizer.zero_grad() out = net(input, time_data, station_time_inv_input) time_after_forward_pass = time() loss = criterion(out, target) loss.backward() optimizer.step() time_after_backward_pass = time() if LABELS is None: LABELS = Blabel.data MODELoutputs = out.data COSMOoutputs = init_station_temp[2].data else: LABELS = np.vstack((LABELS, Blabel.data)) MODELoutputs = np.vstack((MODELoutputs, out.data)) COSMOoutputs = np.vstack((COSMOoutputs, init_station_temp[2].data)) time_after_label_stack = time() if (i + 1) % 64 == 0: print('Sample: %s \t Loss: %s' % (processed_train_samples, float(np.sqrt(loss.data)))) # ============ TensorBoard logging ============# # (1) Log the scalar values info = { setting_string: np.sqrt(loss.item()), } for tag, value in info.items(): train_logger.scalar_summary(tag, value, processed_train_samples_global + processed_train_samples) # (2) Log values and gradients of the parameters (histogram) for tag, value in net.named_parameters(): tag = tag.replace('.', '/') train_logger.histo_summary(tag, ModelUtils.to_np(value), i + 1) train_logger.histo_summary(tag + '/grad', ModelUtils.to_np(value.grad), i + 1) epoch_train_time += np.array((time_start - time_end, time_after_data_preparation - time_start, time_after_forward_pass - time_after_data_preparation, time_after_backward_pass - time_after_forward_pass, time_after_label_stack - time_after_backward_pass)) time_end = time() # calculate error statistic of current epoch diff_model = MODELoutputs - LABELS diff_cosmo = COSMOoutputs - LABELS epoch_train_rmse_model = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_model, axis=0) epoch_train_rmse_cosmo = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_cosmo, axis=0) # update global processed samples processed_train_samples_global += processed_train_samples if np.isnan(epoch_train_rmse_model).any(): print("Learning rate too large resulted in NaN-error while training. Stopped training...") return # print epoch training times print('Timing: Waiting on data=%s, Data Preparation=%s,' 'Forward Pass=%s, Backward Pass=%s, Data Stacking=%s' % tuple(list(epoch_train_time / len(epoch_train_time)))) # RMSE of epoch print('Train/test statistic for epoch: %s' % str(epoch + 1)) print('Train RMSE COSMO: ' , ", ".join(["T=%s: %s" % (idx, epoch_train_rmse_cosmo[idx]) for idx in range(len(epoch_train_rmse_cosmo))])) print('Train RMSE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_train_rmse_model[idx]) for idx in range(len(epoch_train_rmse_model))])) sys.stdout.flush() train_time = time() - train_start_time # TESTING test_start_time = time() LABELS, MODELoutputs, COSMOoutputs, STATION = None, None, None, None processed_test_samples = 0 net.eval() for i, data in enumerate(testloader, 0): try: # get training batch, e.g. label, cosmo-1 output and time inv. features for station DATA = data # DATA has only length 4 if we do not use the station time invariant features if len(DATA) == 4: Blabel, Bip2d, BTimeData, init_station_temp = DATA station_time_inv_input = None elif len(DATA) == 5: Blabel, Bip2d, BTimeData, StationTimeInv, init_station_temp = DATA station_time_inv_input = ModelUtils.getVariable(StationTimeInv).float() else: raise Exception('Unknown data format for training...') input = ModelUtils.getVariable(Bip2d).float() time_data = ModelUtils.getVariable(BTimeData).float() target = ModelUtils.getVariable(Blabel).float() except TypeError: # when the batch size is small, it could happen, that all labels have been corrupted and therefore # collate_fn would return an empty list print('Value error...') continue processed_test_samples += len(Blabel) out = net(input, time_data, station_time_inv_input) loss = criterion(out, target) if LABELS is None: LABELS = Blabel.data MODELoutputs = out.data COSMOoutputs = init_station_temp[2].data STATION = init_station_temp[1].data else: LABELS = np.vstack((LABELS, Blabel.data)) MODELoutputs = np.vstack((MODELoutputs, out.data)) COSMOoutputs = np.vstack((COSMOoutputs, init_station_temp[2].data)) STATION = np.hstack((STATION, init_station_temp[1].data)) if i % 16: # ============ TensorBoard logging ============# # (1) Log the scalar values info = { setting_string: np.sqrt(loss.item()), } for tag, value in info.items(): test_logger.scalar_summary(tag, value, processed_test_samples_global + processed_test_samples) # calculate error statistic of current epoch diff_model = MODELoutputs - LABELS diff_cosmo = COSMOoutputs - LABELS # rmse epoch_test_rmse_model = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_model, axis=0) epoch_test_rmse_cosmo = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_cosmo, axis=0) overall_test_rmse_model = ModelUtils.rmse(diff_model) overall_test_rmse_cosmo = ModelUtils.rmse(diff_cosmo) # mae epoch_test_mae_model = np.apply_along_axis(func1d=ModelUtils.mae, arr=diff_model, axis=0) epoch_test_mae_cosmo = np.apply_along_axis(func1d=ModelUtils.mae, arr=diff_cosmo, axis=0) overall_test_mae_model = ModelUtils.mae(diff_model) overall_test_mae_cosmo = ModelUtils.mae(diff_cosmo) # calculate per station rmse if desired (especially for K-fold station generalization experiment if "per_station_rmse" in config: max_station_id = 1435 squared_errors_per_epoch = np.array((np.square(diff_model), np.square(diff_cosmo))).squeeze() # the highest index of data is 1435, thus we expect at least 1435 entries, which we can access by # station id test_samples_per_station = np.bincount(STATION, minlength=max_station_id+1) model_squared_error_per_station = np.bincount(STATION, weights=squared_errors_per_epoch[0], minlength=max_station_id+1) cosmo_squared_error_per_station = np.bincount(STATION, weights=squared_errors_per_epoch[1], minlength=max_station_id+1) # set division by zero/NaN warning to 'ignore' np.seterr(divide='ignore', invalid='ignore') # calculate rmse per station rmse_per_station = np.vstack((np.sqrt(np.divide(model_squared_error_per_station, test_samples_per_station)), np.sqrt(np.divide(cosmo_squared_error_per_station, test_samples_per_station)))).T # set division by zero/NaN warning to 'warn' np.seterr(divide='warn', invalid='warn') # update global processed samples processed_test_samples_global += processed_test_samples # RMSE of epoch print('Test RMSE COSMO: ', ", ".join( ["T=%s: %s" % (idx, epoch_test_rmse_cosmo[idx]) for idx in range(len(epoch_test_rmse_cosmo))]), " (Overall: %s" % overall_test_rmse_cosmo) print('Test RMSE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_test_rmse_model[idx]) for idx in range(len(epoch_test_rmse_model))]), " (Overall: %s" % overall_test_rmse_model) # mae of epoch print('Test MAE COSMO: ', ", ".join( ["T=%s: %s" % (idx, epoch_test_mae_cosmo[idx]) for idx in range(len(epoch_test_mae_cosmo))]), " (Overall: %s" % overall_test_mae_cosmo) print('Test MAE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_test_mae_model[idx]) for idx in range(len(epoch_test_mae_model))]), " (Overall: %s" % overall_test_mae_model) sys.stdout.flush() test_time = time() - test_start_time # time for epoch epoch_time = time() - epoch_start_time # update error statistics error_statistics = ModelUtils.updateErrorStatistic(error_statistics, np.array([epoch_train_rmse_model, epoch_test_rmse_model])[None, None, ...], run, epoch, config['prediction_times']) # update run times statistic run_times = ModelUtils.updateRuntimeStatistic(run_times, np.array([epoch_time, train_time, test_time])[None, None, ...], run, epoch) # update skip statistic skip_statistics = ModelUtils.updateSkipStatistic(skip_statistics, np.array([n_train_samples, processed_train_samples, n_test_samples, processed_test_samples])[None, None, ...], run, epoch) # update per station rmse data array over runs if desired (especially for K-fold station generalization experiment if "per_station_rmse" in config: error_per_station_statistics = ModelUtils.updatePerStationErrorStatistic(error_per_station_statistics, rmse_per_station, run, epoch, np.arange(max_station_id+1)) # store model if it was the best yes is_best = overall_test_rmse_model <= best_epoch_test_rmse best_epoch_test_rmse = min(overall_test_rmse_model, best_epoch_test_rmse) ModelUtils.save_checkpoint({ 'epoch': epoch, 'run': run, 'arch': net.__class__.__name__, 'state_dict': net.state_dict(), 'overall_test_rmse': overall_test_rmse_model, 'lead_test_rmse' : overall_test_rmse_model, 'best_epoch_test_rmse': best_epoch_test_rmse, 'optimizer': optimizer.state_dict(), }, is_best, output_path + '/stored_models/run_%s' % run) # flush output to see progress sys.stdout.flush() # update statistics dict ModelUtils.get_model_details(experiment_info, net, optimizer, criterion) # complete program runtime experiment_info['program_runtime'] = time() - program_start_time # generate data set of all experiment statistics and additional information experiment_statistic = xr.Dataset({ 'error_statistic' : error_statistics, 'run_time_statistic': run_times, 'samples_statistic' : skip_statistics}).assign_attrs(experiment_info) # dump experiment statistic with open(output_path + '/experiment_statistic.pkl', 'wb') as handle: pkl.dump(experiment_statistic, handle, protocol=pkl.HIGHEST_PROTOCOL) if 'per_station_rmse' in config: # dump experiment statistic with open(output_path + '/rmse_per_station.pkl', 'wb') as handle: pkl.dump(error_per_station_statistics, handle, protocol=pkl.HIGHEST_PROTOCOL) # print program execution time m, s = divmod(experiment_info['program_runtime'], 60) h, m = divmod(m, 60) print('Experiment has successfully finished in %dh %02dmin %02ds' % (h, m, s))
def train(args): start_t = time.time() params = get_train_options() params["exp_name"] = args.exp_name params["patch_num_point"] = 1024 params["batch_size"] = args.batch_size params['use_gan'] = args.use_gan if args.debug: params["nepoch"] = 2 params["model_save_interval"] = 3 params['model_vis_interval'] = 3 log_dir = os.path.join(params["model_save_dir"], args.exp_name) if os.path.exists(log_dir) == False: os.makedirs(log_dir) tb_logger = Logger(log_dir) trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"], split_dir=params['train_split']) #print(params["dataset_dir"]) num_workers = 4 train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') G_model = Generator(params) G_model.apply(xavier_init) G_model = torch.nn.DataParallel(G_model).to(device) D_model = Discriminator(params, in_channels=3) D_model.apply(xavier_init) D_model = torch.nn.DataParallel(D_model).to(device) G_model.train() D_model.train() optimizer_D = Adam(D_model.parameters(), lr=params["lr_D"], betas=(0.9, 0.999)) optimizer_G = Adam(G_model.parameters(), lr=params["lr_G"], betas=(0.9, 0.999)) D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2) G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2) Loss_fn = Loss() print("preparation time is %fs" % (time.time() - start_t)) iter = 0 for e in range(params["nepoch"]): D_scheduler.step() G_scheduler.step() for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): optimizer_G.zero_grad() optimizer_D.zero_grad() input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() start_t_batch = time.time() output_point_cloud = G_model(input_data) repulsion_loss = Loss_fn.get_repulsion_loss( output_point_cloud.permute(0, 2, 1)) uniform_loss = Loss_fn.get_uniform_loss( output_point_cloud.permute(0, 2, 1)) #print(output_point_cloud.shape,gt_data.shape) emd_loss = Loss_fn.get_emd_loss( output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1)) if params['use_gan'] == True: fake_pred = D_model(output_point_cloud.detach()) d_loss_fake = Loss_fn.get_discriminator_loss_single( fake_pred, label=False) d_loss_fake.backward() optimizer_D.step() real_pred = D_model(gt_data.detach()) d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred, label=True) d_loss_real.backward() optimizer_D.step() d_loss = d_loss_real + d_loss_fake fake_pred = D_model(output_point_cloud) g_loss = Loss_fn.get_generator_loss(fake_pred) #print(repulsion_loss,uniform_loss,emd_loss) total_G_loss=params['uniform_w']*uniform_loss+params['emd_w']*emd_loss+ \ repulsion_loss*params['repulsion_w']+ g_loss*params['gan_w'] else: #total_G_loss = params['uniform_w'] * uniform_loss + params['emd_w'] * emd_loss + \ # repulsion_loss * params['repulsion_w'] total_G_loss=params['emd_w'] * emd_loss + \ repulsion_loss * params['repulsion_w'] #total_G_loss=emd_loss total_G_loss.backward() optimizer_G.step() current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr'] current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr'] tb_logger.scalar_summary('repulsion_loss', repulsion_loss.item(), iter) tb_logger.scalar_summary('uniform_loss', uniform_loss.item(), iter) tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) if params['use_gan'] == True: tb_logger.scalar_summary('d_loss', d_loss.item(), iter) tb_logger.scalar_summary('g_loss', g_loss.item(), iter) tb_logger.scalar_summary('lr_D', current_lr_D, iter) tb_logger.scalar_summary('lr_G', current_lr_G, iter) msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( str(datetime.timedelta(seconds=round(time.time() - start_t))), "epoch", e, batch_id + 1, len(train_data_loader), "total_G_loss", total_G_loss.item(), "iter time", (time.time() - start_t_batch)) print(msg) iter += 1 if (e + 1) % params['model_save_interval'] == 0 and e > 0: model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) if os.path.exists(model_save_dir) == False: os.makedirs(model_save_dir) D_ckpt_model_filename = "D_iter_%d.pth" % (e) G_ckpt_model_filename = "G_iter_%d.pth" % (e) D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) torch.save(D_model.module.state_dict(), D_model_save_path) torch.save(G_model.module.state_dict(), G_model_save_path)
def train(args): start_t = time.time() params = get_train_options() params["exp_name"] = args.exp_name params["patch_num_point"] = 1024 params["batch_size"] = args.batch_size params['use_gan'] = args.use_gan if args.debug: params["nepoch"] = 2 params["model_save_interval"] = 3 params['model_vis_interval'] = 3 log_dir = os.path.join(params["model_save_dir"], args.exp_name) if os.path.exists(log_dir) == False: os.makedirs(log_dir) tb_logger = Logger(log_dir) trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"]) # print(params["dataset_dir"]) num_workers = 4 train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') G_model = Generator_recon(params) G_model.apply(xavier_init) G_model = torch.nn.DataParallel(G_model).to(device) D_model = torch.nn.DataParallel(Discriminator(params, in_channels=3)).to(device) G_model.train() D_model.train() optimizer_D = Adam(D_model.parameters(), lr=params["lr_D"], betas=(0.9, 0.999)) optimizer_G = Adam(G_model.parameters(), lr=params["lr_G"], betas=(0.9, 0.999)) D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2) G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2) Loss_fn = Loss() print("preparation time is %fs" % (time.time() - start_t)) iter = 0 for e in range(params["nepoch"]): D_scheduler.step() G_scheduler.step() for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): optimizer_G.zero_grad() optimizer_D.zero_grad() input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() start_t_batch = time.time() output_point_cloud = G_model(input_data) emd_loss = Loss_fn.get_emd_loss( output_point_cloud.permute(0, 2, 1), input_data.permute(0, 2, 1)) total_G_loss = emd_loss total_G_loss.backward() optimizer_G.step() current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr'] current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr'] tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) tb_logger.scalar_summary('lr_D', current_lr_D, iter) tb_logger.scalar_summary('lr_G', current_lr_G, iter) msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( str(datetime.timedelta(seconds=round(time.time() - start_t))), "epoch", e, batch_id + 1, len(train_data_loader), "total_G_loss", total_G_loss.item(), "iter time", (time.time() - start_t_batch)) print(msg) if iter % params['model_save_interval'] == 0 and iter > 0: model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) if os.path.exists(model_save_dir) == False: os.makedirs(model_save_dir) D_ckpt_model_filename = "D_iter_%d.pth" % (iter) G_ckpt_model_filename = "G_iter_%d.pth" % (iter) D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) torch.save(D_model.module.state_dict(), D_model_save_path) torch.save(G_model.module.state_dict(), G_model_save_path) if iter % params['model_vis_interval'] == 0 and iter > 0: np_pcd = output_point_cloud.permute( 0, 2, 1)[0].detach().cpu().numpy() # print(np_pcd.shape) img = (np.array(visualize_point_cloud(np_pcd)) * 255).astype( np.uint8) tb_logger.image_summary("images", img[np.newaxis, :], iter) gt_pcd = gt_data.permute(0, 2, 1)[0].detach().cpu().numpy() # print(gt_pcd.shape) gt_img = (np.array(visualize_point_cloud(gt_pcd)) * 255).astype(np.uint8) tb_logger.image_summary("gt", gt_img[np.newaxis, :], iter) input_pcd = input_data.permute(0, 2, 1)[0].detach().cpu().numpy() input_img = (np.array(visualize_point_cloud(input_pcd)) * 255).astype(np.uint8) tb_logger.image_summary("input", input_img[np.newaxis, :], iter) iter += 1
def train(args): start_t = time.time() params = get_train_options() params["exp_name"] = args.exp_name params["patch_num_point"] = 256 params["batch_size"] = args.batch_size if args.debug: params["nepoch"] = 2 params["model_save_interval"] = 3 params['model_vis_interval'] = 3 log_dir = os.path.join(params["model_save_dir"], args.exp_name) if os.path.exists(log_dir) == False: os.makedirs(log_dir) tb_logger = Logger(log_dir) #trainloader=PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split']) trainloader = PUGAN_Dataset(h5_file_path=params["dataset_dir"], npoint=256) num_workers = 4 train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ########################################## # Initialize generator and discriminator # ########################################## G_AB = Generator(params) G_AB.apply(xavier_init) G_AB = torch.nn.DataParallel(G_AB).to(device) G_BA = Downsampler(params) G_BA.apply(xavier_init) G_BA = torch.nn.DataParallel(G_BA).to(device) D_A = Discriminator(params, in_channels=3) D_A.apply(xavier_init) D_A = torch.nn.DataParallel(D_A).to(device) D_B = Discriminator(params, in_channels=3) D_B.apply(xavier_init) D_B = torch.nn.DataParallel(D_B).to(device) ######################################## #Optimizers and Learning Rate scheduler# ######################################## optimizer_D_A = Adam(D_A.parameters(), lr=params["lr_D_A"], betas=(0.9, 0.999)) optimizer_D_B = Adam(D_B.parameters(), lr=params["lr_D_B"], betas=(0.9, 0.999)) optimizer_G_AB = Adam(G_AB.parameters(), lr=params["lr_G_AB"], betas=(0.9, 0.999)) optimizer_G_BA = Adam(G_BA.parameters(), lr=params["lr_G_BA"], betas=(0.9, 0.999)) D_A_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2) G_AB_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2) D_B_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2) G_BA_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2) Loss_fn = Loss() print("preparation time is %fs" % (time.time() - start_t)) iter = 0 for e in range(params["nepoch"]): for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): G_AB.train() G_BA.train() D_A.train() D_B.train() optimizer_G_AB.zero_grad() optimizer_D_A.zero_grad() optimizer_G_BA.zero_grad() optimizer_D_B.zero_grad() input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() start_t_batch = time.time() output_point_cloud_high = G_AB(input_data) output_point_cloud_low = G_BA(gt_data) ##################################### # Loss # ##################################### repulsion_loss_AB = Loss_fn.get_repulsion_loss( output_point_cloud_high.permute(0, 2, 1)) uniform_loss_AB = Loss_fn.get_uniform_loss( output_point_cloud_high.permute(0, 2, 1)) repulsion_loss_BA = Loss_fn.get_repulsion_loss( output_point_cloud_low.permute(0, 2, 1)) uniform_loss_BA = Loss_fn.get_uniform_loss( output_point_cloud_low.permute(0, 2, 1)) emd_loss_AB = Loss_fn.get_emd_loss( output_point_cloud_high.permute(0, 2, 1), gt_data.permute(0, 2, 1)) #emd_loss_BA = Loss_fn.get_emd_loss(output_point_cloud_low.permute(0, 2, 1), input_data.permute(0, 2, 1)) #Cycle Loss recov_A = G_BA(output_point_cloud_high) ABA_repul_loss = Loss_fn.get_repulsion_loss( recov_A.permute(0, 2, 1)) ABA_uniform_loss = Loss_fn.get_uniform_loss( recov_A.permute(0, 2, 1)) recov_B = G_AB(output_point_cloud_low) BAB_repul_loss = Loss_fn.get_repulsion_loss( recov_B.permute(0, 2, 1)) BAB_uniform_loss = Loss_fn.get_uniform_loss( recov_B.permute(0, 2, 1)) BAB_emd_loss = Loss_fn.get_emd_loss(recov_B.permute(0, 2, 1), gt_data.permute(0, 2, 1)) #G_AB loss fake_pred_B = D_A(output_point_cloud_high.detach()) g_AB_loss = Loss_fn.get_generator_loss(fake_pred_B) total_G_AB_loss=g_AB_loss*params['gan_w_AB']+ BAB_repul_loss*params['repulsion_w_AB']+ \ BAB_uniform_loss*params['uniform_w_AB']+ BAB_emd_loss*params['emd_w_AB']+ \ params['uniform_w_AB']*uniform_loss_AB+params['emd_w_AB']*emd_loss_AB+ \ repulsion_loss_AB*params['repulsion_w_AB'] total_G_AB_loss.backward() optimizer_G_AB.step() #G_BA loss fake_pred_A = D_B(output_point_cloud_low.detach()) g_BA_loss = Loss_fn.get_generator_loss(fake_pred_A) total_G_BA_loss=g_BA_loss*params['gan_w_BA']+ ABA_repul_loss*params['repulsion_w_BA']+ \ repulsion_loss_BA*params['repulsion_w_BA'] # ABA_uniform_loss*params['uniform_w_BA']+ \ # params['uniform_w_BA']*uniform_loss_BA+ \ total_G_BA_loss.backward() optimizer_G_BA.step() #Discriminator A loss fake_B_ = fake_A_buffer.push_and_pop(output_point_cloud_high) fake_pred_B = D_A(fake_B_.detach()) d_A_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_B, label=False) real_pred_B = D_A(gt_data.detach()) d_A_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_B, label=True) d_A_loss = d_A_loss_real + d_A_loss_fake d_A_loss.backward() optimizer_D_A.step() #Discriminator B loss fake_A_ = fake_B_buffer.push_and_pop(output_point_cloud_low) fake_pred_A = D_B(fake_A_.detach()) d_B_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_A, label=False) real_pred_A = D_B(input_data.detach()) d_B_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_A, label=True) d_B_loss = d_B_loss_real + d_B_loss_fake d_B_loss.backward() optimizer_D_B.step() #Learning rate scheduler# current_lr_D_A = optimizer_D_A.state_dict( )['param_groups'][0]['lr'] current_lr_G_AB = optimizer_G_AB.state_dict( )['param_groups'][0]['lr'] current_lr_D_B = optimizer_D_B.state_dict( )['param_groups'][0]['lr'] current_lr_G_BA = optimizer_G_BA.state_dict( )['param_groups'][0]['lr'] # tb_logger.scalar_summary('repulsion_loss_AB', repulsion_loss_AB.item(), iter) # tb_logger.scalar_summary('uniform_loss_AB', uniform_loss_AB.item(), iter) # tb_logger.scalar_summary('repulsion_loss_BA', repulsion_loss_BA.item(), iter) # tb_logger.scalar_summary('uniform_loss_BA', uniform_loss_BA.item(), iter) # tb_logger.scalar_summary('emd_loss_AB', emd_loss_AB.item(), iter) tb_logger.scalar_summary('d_A_loss', d_A_loss.item(), iter) tb_logger.scalar_summary('g_AB_loss', g_AB_loss.item(), iter) tb_logger.scalar_summary('Total_G_AB_loss', total_G_AB_loss.item(), iter) tb_logger.scalar_summary('lr_D_A', current_lr_D_A, iter) tb_logger.scalar_summary('lr_G_AB', current_lr_G_AB, iter) tb_logger.scalar_summary('d_B_loss', d_B_loss.item(), iter) tb_logger.scalar_summary('g_BA_loss', g_BA_loss.item(), iter) tb_logger.scalar_summary('Total_G_BA_loss', total_G_BA_loss.item(), iter) tb_logger.scalar_summary('lr_D_B', current_lr_D_B, iter) tb_logger.scalar_summary('lr_G_BA', current_lr_G_BA, iter) msg = "{:0>8},{}:{}, [{}/{}], {}: {}, {}: {}, {}:{}, {}: {},{}: {}".format( str(datetime.timedelta(seconds=round(time.time() - start_t))), "epoch", e + 1, batch_id + 1, len(train_data_loader), "total_G_AB_loss", total_G_AB_loss.item(), "total_G_BA_loss", total_G_BA_loss.item(), "iter time", (time.time() - start_t_batch), "d_A_loss", d_A_loss.item(), "d_B_loss", d_B_loss.item()) print(msg) iter += 1 D_A_scheduler.step() G_AB_scheduler.step() D_B_scheduler.step() G_BA_scheduler.step() if (e + 1) % params['model_save_interval'] == 0 and e > 0: model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) if os.path.exists(model_save_dir) == False: os.makedirs(model_save_dir) D_A_ckpt_model_filename = "D_A_iter_%d.pth" % (e + 1) G_AB_ckpt_model_filename = "G_AB_iter_%d.pth" % (e + 1) D_A_model_save_path = os.path.join(model_save_dir, D_A_ckpt_model_filename) G_AB_model_save_path = os.path.join(model_save_dir, G_AB_ckpt_model_filename) D_B_ckpt_model_filename = "D_B_iter_%d.pth" % (e + 1) G_BA_ckpt_model_filename = "G_BA_iter_%d.pth" % (e + 1) model_ckpt_model_filename = "Cyclegan_iter_%d.pth" % (e + 1) D_B_model_save_path = os.path.join(model_save_dir, D_B_ckpt_model_filename) G_BA_model_save_path = os.path.join(model_save_dir, G_BA_ckpt_model_filename) model_all_path = os.path.join(model_save_dir, model_ckpt_model_filename) torch.save( { 'G_AB_state_dict': G_AB.module.state_dict(), 'G_BA_state_dict': G_BA.module.state_dict(), 'D_A_state_dict': D_A.module.state_dict(), 'D_B_state_dict': D_B.module.state_dict(), 'optimizer_G_AB_state_dict': optimizer_G_AB.state_dict(), 'optimizer_G_BA_state_dict': optimizer_G_BA.state_dict(), 'optimizer_D_A_state_dict': optimizer_D_A.state_dict(), 'optimizer_D_B_state_dict': optimizer_D_B.state_dict() }, model_all_path) torch.save(D_A.module.state_dict(), D_A_model_save_path) torch.save(G_AB.module.state_dict(), G_AB_model_save_path) torch.save(D_B.module.state_dict(), D_B_model_save_path) torch.save(G_BA.module.state_dict(), G_BA_model_save_path)