Exemplo n.º 1
0
def runModel(config, data_dictionary, data_statistics, train_test_folds):
    program_start_time = time()

    # assign all program arguments to local variables
    with open(config['model']['path']) as handle:
        ModelDict = json.loads(handle.read())

    # check if station and grid time invariant features should be used and set the list of desired parameters
    if not ('grid_time_invariant' in ModelDict and ModelDict['grid_time_invariant']): config['grid_time_invariant_parameters'] =[]
    if not ('station_time_invariant' in ModelDict and ModelDict['station_time_invariant']): config['station_parameters'] = []

    # update general static model information
    experiment_info = config
    experiment_info['model'] = ModelDict
    experiment_info['code_commit'] = ModelUtils.get_git_revision_short_hash()


    # if needed, load time invariant features
    with open("%s/%s/grid_size_%s/time_invariant_data_per_station.pkl" % (config['input_source'], config['preprocessing'], config['original_grid_size']), "rb") as input_file:
        time_invarian_data = pkl.load(input_file)


    # initialize feature scaling function for each feature
    featureScaleFunctions = DataUtils.getFeatureScaleFunctions(ModelUtils.ParamNormalizationDict, data_statistics)

    # get optimizer config
    optimizer_config = config['optimizer']

    # generate output path for experiment information
    setting_string = '%s_grid_%s_bs_%s_tf_%s_optim_%s_lr_%s_sl_%s' % (
        config['model']['name'], config['grid_size'], config['batch_size'], config['test_fraction'], optimizer_config['algorithm'], optimizer_config['learning_rate'], config['slice_size'])
    output_path = '%s/%s' % (config['experiment_path'], setting_string)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    # time for the set up until first run
    experiment_info['set_up_time'] = time() - program_start_time
    print('[Time]: Set-up %s' % strftime("%H:%M:%S", gmtime(experiment_info['set_up_time'])))
    sys.stdout.flush()

    # initialize statistics
    error_statistics = None
    run_times = None
    skip_statistics = None
    if 'per_station_rmse' in config:
        error_per_station_statistics = None

    # keep used learning rates
    experiment_info['scheduled_learning_rates'] = []

    # cross validation
    for run in range(config['runs']):
        # logger  for tensorboardX
        train_logger = Logger(output_path + '/logs/run_%s/train' % run)
        test_logger = Logger(output_path + '/logs/run_%s/test' % run)

        print('[Run %s] Cross-validation test fold %s' % (str(run + 1), str(run + 1)))

        # take the right preprocessed train/test data set for the current run
        train_fold, test_fold = train_test_folds[run]

        # initialize best epoch test error
        best_epoch_test_rmse = float("inf")

        # use different data loader if we want to train a 3nn model approach
        if "knn" in ModelDict:
            # initialize train and test dataloaders
            trainset = DataLoaders.CosmoData3NNData(
                config=config,
                station_data_dict=data_dictionary,
                files=train_fold,
                featureScaling=featureScaleFunctions,
                time_invariant_data=time_invarian_data)
            trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True,
                                     num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn)

            testset = DataLoaders.CosmoData3NNData(
                config=config,
                station_data_dict=data_dictionary,
                files=test_fold,
                featureScaling=featureScaleFunctions,
                time_invariant_data=time_invarian_data)
            testloader = DataLoader(testset, batch_size=config['batch_size'], shuffle=True,
                                    num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn)
        else:
            # initialize train and test dataloaders
            trainset = DataLoaders.CosmoDataGridData(
                config=config,
                station_data_dict=data_dictionary,
                files=train_fold,
                featureScaling=featureScaleFunctions,
                time_invariant_data=time_invarian_data)
            trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True,
                                     num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn)

            testset = DataLoaders.CosmoDataGridData(
                config=config,
                station_data_dict=data_dictionary,
                files=test_fold,
                featureScaling=featureScaleFunctions,
                time_invariant_data=time_invarian_data)
            testloader = DataLoader(testset, batch_size=config['batch_size'], shuffle=True,
                                    num_workers=config['n_loaders'], collate_fn=DataLoaders.collate_fn)

        # initialize network, optimizer and loss function
        net = Baseline.model_factory(ModelDict, trainset.n_parameters, trainset.n_grid_time_invariant_parameters,
                                     config['grid_size'], config['prediction_times'])
        # store class name
        experiment_info['model_class'] = net.__class__.__name__

        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)

        if torch.cuda.is_available():
            net.cuda()

        # load number of train and test samples
        n_train_samples, n_test_samples = len(train_fold), len(test_fold)

        optimizer, scheduler = ModelUtils.initializeOptimizer(optimizer_config, net)
        criterion = nn.MSELoss()

        # keep number of processed smaples over all epochs for tensorboard
        processed_train_samples_global = 0
        processed_test_samples_global = 0

        # start learning
        for epoch in range(config['epochs']):
            epoch_train_time = np.zeros((5,))
            epoch_start_time = time()
            print('Epoch: ' + str(epoch + 1) + '\n------------------------------------------------------------')

            # adapt learning rate and store information in experiment attributes
            if scheduler is not None:
                scheduler.step()
                if run == 0: experiment_info['scheduled_learning_rates'] += scheduler.get_lr()
                print('Using learning rate %s' % str(scheduler.get_lr()))

            # TRAINING
            # initialize variables for epoch statistics
            LABELS, MODELoutputs, COSMOoutputs = None, None, None
            processed_train_samples = 0
            net.train(True)

            train_start_time = time()
            # loop over complete train set
            for i, data in enumerate(trainloader, 0):
                time_start = time()
                try:
                    # get training batch, e.g. label, cosmo-1 output and time inv. features for station
                    DATA = data
                    # DATA has only length 4 if we do not use the station time invariant features
                    if len(DATA) == 4:
                        Blabel, Bip2d, BTimeData, init_station_temp = DATA
                        station_time_inv_input = None
                    elif len(DATA) == 5:
                        Blabel, Bip2d, BTimeData, StationTimeInv, init_station_temp = DATA
                        station_time_inv_input = ModelUtils.getVariable(StationTimeInv).float()
                    else:
                        raise Exception('Unknown data format for training...')
                    input = ModelUtils.getVariable(Bip2d).float()
                    time_data = ModelUtils.getVariable(BTimeData).float()
                    target = ModelUtils.getVariable(Blabel).float()

                except TypeError:
                    # when the batch size is small, it could happen, that all labels have been corrupted and therefore
                    # collate_fn would return an empty list
                    print('Value error...')
                    continue
                time_after_data_preparation = time()

                processed_train_samples += len(Blabel)

                optimizer.zero_grad()
                out = net(input, time_data, station_time_inv_input)
                time_after_forward_pass = time()
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
                time_after_backward_pass = time()

                if LABELS is None:
                    LABELS = Blabel.data
                    MODELoutputs = out.data
                    COSMOoutputs = init_station_temp[2].data
                else:
                    LABELS = np.vstack((LABELS, Blabel.data))
                    MODELoutputs = np.vstack((MODELoutputs, out.data))
                    COSMOoutputs = np.vstack((COSMOoutputs, init_station_temp[2].data))

                time_after_label_stack = time()

                if (i + 1) % 64 == 0:

                    print('Sample: %s \t Loss: %s' % (processed_train_samples, float(np.sqrt(loss.data))))

                    # ============ TensorBoard logging ============#
                    # (1) Log the scalar values
                    info = {
                        setting_string: np.sqrt(loss.item()),
                    }

                    for tag, value in info.items():
                        train_logger.scalar_summary(tag, value, processed_train_samples_global + processed_train_samples)

                    # (2) Log values and gradients of the parameters (histogram)
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        train_logger.histo_summary(tag, ModelUtils.to_np(value), i + 1)
                        train_logger.histo_summary(tag + '/grad', ModelUtils.to_np(value.grad), i + 1)

                    epoch_train_time += np.array((time_start - time_end,
                                                  time_after_data_preparation - time_start,
                                                  time_after_forward_pass - time_after_data_preparation,
                                                  time_after_backward_pass - time_after_forward_pass,
                                                  time_after_label_stack - time_after_backward_pass))

                time_end = time()

            # calculate error statistic of current epoch
            diff_model = MODELoutputs - LABELS
            diff_cosmo = COSMOoutputs - LABELS
            epoch_train_rmse_model = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_model, axis=0)
            epoch_train_rmse_cosmo = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_cosmo, axis=0)


            # update global processed samples
            processed_train_samples_global += processed_train_samples

            if np.isnan(epoch_train_rmse_model).any():
                print("Learning rate too large resulted in NaN-error while training. Stopped training...")
                return
            # print epoch training times
            print('Timing: Waiting on data=%s, Data Preparation=%s,'
                  'Forward Pass=%s, Backward Pass=%s, Data Stacking=%s' % tuple(list(epoch_train_time / len(epoch_train_time))))

            # RMSE of epoch
            print('Train/test statistic for epoch: %s' % str(epoch + 1))
            print('Train RMSE COSMO: ' , ", ".join(["T=%s: %s" % (idx, epoch_train_rmse_cosmo[idx]) for idx in range(len(epoch_train_rmse_cosmo))]))
            print('Train RMSE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_train_rmse_model[idx]) for idx in range(len(epoch_train_rmse_model))]))
            sys.stdout.flush()

            train_time = time() - train_start_time

            # TESTING
            test_start_time = time()

            LABELS, MODELoutputs, COSMOoutputs, STATION = None, None, None, None
            processed_test_samples = 0
            net.eval()
            for i, data in enumerate(testloader, 0):
                try:
                    # get training batch, e.g. label, cosmo-1 output and time inv. features for station
                    DATA = data
                    # DATA has only length 4 if we do not use the station time invariant features
                    if len(DATA) == 4:
                        Blabel, Bip2d, BTimeData, init_station_temp = DATA
                        station_time_inv_input = None
                    elif len(DATA) == 5:
                        Blabel, Bip2d, BTimeData, StationTimeInv, init_station_temp = DATA
                        station_time_inv_input = ModelUtils.getVariable(StationTimeInv).float()
                    else:
                        raise Exception('Unknown data format for training...')
                    input = ModelUtils.getVariable(Bip2d).float()
                    time_data = ModelUtils.getVariable(BTimeData).float()
                    target = ModelUtils.getVariable(Blabel).float()

                except TypeError:
                    # when the batch size is small, it could happen, that all labels have been corrupted and therefore
                    # collate_fn would return an empty list
                    print('Value error...')
                    continue

                processed_test_samples += len(Blabel)

                out = net(input, time_data, station_time_inv_input)
                loss = criterion(out, target)

                if LABELS is None:
                    LABELS = Blabel.data
                    MODELoutputs = out.data
                    COSMOoutputs = init_station_temp[2].data
                    STATION = init_station_temp[1].data
                else:
                    LABELS = np.vstack((LABELS, Blabel.data))
                    MODELoutputs = np.vstack((MODELoutputs, out.data))
                    COSMOoutputs = np.vstack((COSMOoutputs, init_station_temp[2].data))
                    STATION = np.hstack((STATION, init_station_temp[1].data))

                if i % 16:
                    # ============ TensorBoard logging ============#
                    # (1) Log the scalar values
                    info = {
                        setting_string: np.sqrt(loss.item()),
                    }

                    for tag, value in info.items():
                        test_logger.scalar_summary(tag, value, processed_test_samples_global + processed_test_samples)

            # calculate error statistic of current epoch
            diff_model = MODELoutputs - LABELS
            diff_cosmo = COSMOoutputs - LABELS

            # rmse
            epoch_test_rmse_model = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_model, axis=0)
            epoch_test_rmse_cosmo = np.apply_along_axis(func1d=ModelUtils.rmse, arr=diff_cosmo, axis=0)
            overall_test_rmse_model = ModelUtils.rmse(diff_model)
            overall_test_rmse_cosmo = ModelUtils.rmse(diff_cosmo)

            # mae
            epoch_test_mae_model = np.apply_along_axis(func1d=ModelUtils.mae, arr=diff_model, axis=0)
            epoch_test_mae_cosmo = np.apply_along_axis(func1d=ModelUtils.mae, arr=diff_cosmo, axis=0)
            overall_test_mae_model = ModelUtils.mae(diff_model)
            overall_test_mae_cosmo = ModelUtils.mae(diff_cosmo)

            # calculate per station rmse if desired (especially for K-fold station generalization experiment
            if "per_station_rmse" in config:
                max_station_id = 1435

                squared_errors_per_epoch = np.array((np.square(diff_model), np.square(diff_cosmo))).squeeze()

                # the highest index of data is 1435, thus we expect at least 1435 entries, which we can access by
                # station id
                test_samples_per_station = np.bincount(STATION, minlength=max_station_id+1)
                model_squared_error_per_station = np.bincount(STATION, weights=squared_errors_per_epoch[0], minlength=max_station_id+1)
                cosmo_squared_error_per_station = np.bincount(STATION, weights=squared_errors_per_epoch[1], minlength=max_station_id+1)

                # set division by zero/NaN warning to 'ignore'
                np.seterr(divide='ignore', invalid='ignore')

                # calculate rmse per station
                rmse_per_station = np.vstack((np.sqrt(np.divide(model_squared_error_per_station, test_samples_per_station)),
                                              np.sqrt(np.divide(cosmo_squared_error_per_station, test_samples_per_station)))).T

                # set division by zero/NaN warning to 'warn'
                np.seterr(divide='warn', invalid='warn')






            # update global processed samples
            processed_test_samples_global += processed_test_samples

            # RMSE of epoch
            print('Test RMSE COSMO: ', ", ".join(
                ["T=%s: %s" % (idx, epoch_test_rmse_cosmo[idx]) for idx in range(len(epoch_test_rmse_cosmo))]),
                  " (Overall: %s" % overall_test_rmse_cosmo)
            print('Test RMSE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_test_rmse_model[idx]) for idx in range(len(epoch_test_rmse_model))]),
                  " (Overall: %s" % overall_test_rmse_model)
            # mae of epoch
            print('Test MAE COSMO: ', ", ".join(
                ["T=%s: %s" % (idx, epoch_test_mae_cosmo[idx]) for idx in range(len(epoch_test_mae_cosmo))]),
                  " (Overall: %s" % overall_test_mae_cosmo)
            print('Test MAE Model: ' , ", ".join(["T=%s: %s" % (idx, epoch_test_mae_model[idx]) for idx in range(len(epoch_test_mae_model))]),
                  " (Overall: %s" % overall_test_mae_model)

            sys.stdout.flush()

            test_time = time() - test_start_time

            # time for epoch
            epoch_time = time() - epoch_start_time

            # update error statistics
            error_statistics = ModelUtils.updateErrorStatistic(error_statistics,
                                                               np.array([epoch_train_rmse_model, epoch_test_rmse_model])[None, None, ...],
                                                               run, epoch, config['prediction_times'])
            # update run times statistic
            run_times = ModelUtils.updateRuntimeStatistic(run_times, np.array([epoch_time, train_time, test_time])[None, None, ...],
                                                          run, epoch)
            # update skip statistic
            skip_statistics = ModelUtils.updateSkipStatistic(skip_statistics,
                                                             np.array([n_train_samples, processed_train_samples,
                                                                       n_test_samples, processed_test_samples])[None, None, ...],
                                                             run, epoch)

            # update per station rmse data array over runs if desired (especially for K-fold station generalization experiment
            if "per_station_rmse" in config:
                error_per_station_statistics = ModelUtils.updatePerStationErrorStatistic(error_per_station_statistics, rmse_per_station, run, epoch, np.arange(max_station_id+1))

            # store model if it was the best yes
            is_best = overall_test_rmse_model <= best_epoch_test_rmse
            best_epoch_test_rmse = min(overall_test_rmse_model, best_epoch_test_rmse)
            ModelUtils.save_checkpoint({
                'epoch': epoch,
                'run': run,
                'arch': net.__class__.__name__,
                'state_dict': net.state_dict(),
                'overall_test_rmse': overall_test_rmse_model,
                'lead_test_rmse' : overall_test_rmse_model,
                'best_epoch_test_rmse': best_epoch_test_rmse,
                'optimizer': optimizer.state_dict(),
            }, is_best, output_path + '/stored_models/run_%s' % run)

            # flush output to see progress
            sys.stdout.flush()

    # update statistics dict
    ModelUtils.get_model_details(experiment_info, net, optimizer, criterion)

    # complete program runtime
    experiment_info['program_runtime'] = time() - program_start_time

    # generate data set of all experiment statistics and additional information
    experiment_statistic = xr.Dataset({
        'error_statistic' : error_statistics,
        'run_time_statistic': run_times,
        'samples_statistic' : skip_statistics}).assign_attrs(experiment_info)

    # dump experiment statistic
    with open(output_path + '/experiment_statistic.pkl', 'wb') as handle:
        pkl.dump(experiment_statistic, handle, protocol=pkl.HIGHEST_PROTOCOL)

    if 'per_station_rmse' in config:
        # dump experiment statistic
        with open(output_path + '/rmse_per_station.pkl', 'wb') as handle:
            pkl.dump(error_per_station_statistics, handle, protocol=pkl.HIGHEST_PROTOCOL)

    # print program execution time
    m, s = divmod(experiment_info['program_runtime'], 60)
    h, m = divmod(m, 60)
    print('Experiment has successfully finished in %dh %02dmin %02ds' % (h, m, s))
Exemplo n.º 2
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 1024
    params["batch_size"] = args.batch_size
    params['use_gan'] = args.use_gan

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"],
                                split_dir=params['train_split'])
    #print(params["dataset_dir"])
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    G_model = Generator(params)
    G_model.apply(xavier_init)
    G_model = torch.nn.DataParallel(G_model).to(device)
    D_model = Discriminator(params, in_channels=3)
    D_model.apply(xavier_init)
    D_model = torch.nn.DataParallel(D_model).to(device)

    G_model.train()
    D_model.train()

    optimizer_D = Adam(D_model.parameters(),
                       lr=params["lr_D"],
                       betas=(0.9, 0.999))
    optimizer_G = Adam(G_model.parameters(),
                       lr=params["lr_G"],
                       betas=(0.9, 0.999))

    D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2)
    G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):
        D_scheduler.step()
        G_scheduler.step()
        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):
            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()

            start_t_batch = time.time()
            output_point_cloud = G_model(input_data)

            repulsion_loss = Loss_fn.get_repulsion_loss(
                output_point_cloud.permute(0, 2, 1))
            uniform_loss = Loss_fn.get_uniform_loss(
                output_point_cloud.permute(0, 2, 1))
            #print(output_point_cloud.shape,gt_data.shape)
            emd_loss = Loss_fn.get_emd_loss(
                output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1))

            if params['use_gan'] == True:
                fake_pred = D_model(output_point_cloud.detach())
                d_loss_fake = Loss_fn.get_discriminator_loss_single(
                    fake_pred, label=False)
                d_loss_fake.backward()
                optimizer_D.step()

                real_pred = D_model(gt_data.detach())
                d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred,
                                                                    label=True)
                d_loss_real.backward()
                optimizer_D.step()

                d_loss = d_loss_real + d_loss_fake

                fake_pred = D_model(output_point_cloud)
                g_loss = Loss_fn.get_generator_loss(fake_pred)

                #print(repulsion_loss,uniform_loss,emd_loss)
                total_G_loss=params['uniform_w']*uniform_loss+params['emd_w']*emd_loss+ \
                repulsion_loss*params['repulsion_w']+ g_loss*params['gan_w']
            else:
                #total_G_loss = params['uniform_w'] * uniform_loss + params['emd_w'] * emd_loss + \
                #               repulsion_loss * params['repulsion_w']
                total_G_loss=params['emd_w'] * emd_loss + \
                               repulsion_loss * params['repulsion_w']

            #total_G_loss=emd_loss
            total_G_loss.backward()
            optimizer_G.step()

            current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr']
            current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr']

            tb_logger.scalar_summary('repulsion_loss', repulsion_loss.item(),
                                     iter)
            tb_logger.scalar_summary('uniform_loss', uniform_loss.item(), iter)
            tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter)
            if params['use_gan'] == True:
                tb_logger.scalar_summary('d_loss', d_loss.item(), iter)
                tb_logger.scalar_summary('g_loss', g_loss.item(), iter)
            tb_logger.scalar_summary('lr_D', current_lr_D, iter)
            tb_logger.scalar_summary('lr_G', current_lr_G, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e, batch_id + 1, len(train_data_loader),
                "total_G_loss", total_G_loss.item(), "iter time",
                (time.time() - start_t_batch))
            print(msg)

            iter += 1
        if (e + 1) % params['model_save_interval'] == 0 and e > 0:
            model_save_dir = os.path.join(params['model_save_dir'],
                                          params['exp_name'])
            if os.path.exists(model_save_dir) == False:
                os.makedirs(model_save_dir)
            D_ckpt_model_filename = "D_iter_%d.pth" % (e)
            G_ckpt_model_filename = "G_iter_%d.pth" % (e)
            D_model_save_path = os.path.join(model_save_dir,
                                             D_ckpt_model_filename)
            G_model_save_path = os.path.join(model_save_dir,
                                             G_ckpt_model_filename)
            torch.save(D_model.module.state_dict(), D_model_save_path)
            torch.save(G_model.module.state_dict(), G_model_save_path)
Exemplo n.º 3
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 1024
    params["batch_size"] = args.batch_size
    params['use_gan'] = args.use_gan

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"])
    # print(params["dataset_dir"])
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    G_model = Generator_recon(params)
    G_model.apply(xavier_init)
    G_model = torch.nn.DataParallel(G_model).to(device)
    D_model = torch.nn.DataParallel(Discriminator(params,
                                                  in_channels=3)).to(device)

    G_model.train()
    D_model.train()

    optimizer_D = Adam(D_model.parameters(),
                       lr=params["lr_D"],
                       betas=(0.9, 0.999))
    optimizer_G = Adam(G_model.parameters(),
                       lr=params["lr_G"],
                       betas=(0.9, 0.999))

    D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2)
    G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):
        D_scheduler.step()
        G_scheduler.step()
        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):
            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()

            start_t_batch = time.time()
            output_point_cloud = G_model(input_data)

            emd_loss = Loss_fn.get_emd_loss(
                output_point_cloud.permute(0, 2, 1),
                input_data.permute(0, 2, 1))

            total_G_loss = emd_loss
            total_G_loss.backward()
            optimizer_G.step()

            current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr']
            current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr']

            tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter)
            tb_logger.scalar_summary('lr_D', current_lr_D, iter)
            tb_logger.scalar_summary('lr_G', current_lr_G, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e, batch_id + 1, len(train_data_loader),
                "total_G_loss", total_G_loss.item(), "iter time",
                (time.time() - start_t_batch))
            print(msg)

            if iter % params['model_save_interval'] == 0 and iter > 0:
                model_save_dir = os.path.join(params['model_save_dir'],
                                              params['exp_name'])
                if os.path.exists(model_save_dir) == False:
                    os.makedirs(model_save_dir)
                D_ckpt_model_filename = "D_iter_%d.pth" % (iter)
                G_ckpt_model_filename = "G_iter_%d.pth" % (iter)
                D_model_save_path = os.path.join(model_save_dir,
                                                 D_ckpt_model_filename)
                G_model_save_path = os.path.join(model_save_dir,
                                                 G_ckpt_model_filename)
                torch.save(D_model.module.state_dict(), D_model_save_path)
                torch.save(G_model.module.state_dict(), G_model_save_path)

            if iter % params['model_vis_interval'] == 0 and iter > 0:
                np_pcd = output_point_cloud.permute(
                    0, 2, 1)[0].detach().cpu().numpy()
                # print(np_pcd.shape)
                img = (np.array(visualize_point_cloud(np_pcd)) * 255).astype(
                    np.uint8)
                tb_logger.image_summary("images", img[np.newaxis, :], iter)

                gt_pcd = gt_data.permute(0, 2, 1)[0].detach().cpu().numpy()
                # print(gt_pcd.shape)
                gt_img = (np.array(visualize_point_cloud(gt_pcd)) *
                          255).astype(np.uint8)
                tb_logger.image_summary("gt", gt_img[np.newaxis, :], iter)

                input_pcd = input_data.permute(0, 2,
                                               1)[0].detach().cpu().numpy()
                input_img = (np.array(visualize_point_cloud(input_pcd)) *
                             255).astype(np.uint8)
                tb_logger.image_summary("input", input_img[np.newaxis, :],
                                        iter)
            iter += 1
Exemplo n.º 4
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 256
    params["batch_size"] = args.batch_size

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    #trainloader=PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split'])
    trainloader = PUGAN_Dataset(h5_file_path=params["dataset_dir"], npoint=256)
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ##########################################
    # Initialize generator and discriminator #
    ##########################################
    G_AB = Generator(params)
    G_AB.apply(xavier_init)
    G_AB = torch.nn.DataParallel(G_AB).to(device)

    G_BA = Downsampler(params)
    G_BA.apply(xavier_init)
    G_BA = torch.nn.DataParallel(G_BA).to(device)

    D_A = Discriminator(params, in_channels=3)
    D_A.apply(xavier_init)
    D_A = torch.nn.DataParallel(D_A).to(device)

    D_B = Discriminator(params, in_channels=3)
    D_B.apply(xavier_init)
    D_B = torch.nn.DataParallel(D_B).to(device)

    ########################################
    #Optimizers and Learning Rate scheduler#
    ########################################

    optimizer_D_A = Adam(D_A.parameters(),
                         lr=params["lr_D_A"],
                         betas=(0.9, 0.999))
    optimizer_D_B = Adam(D_B.parameters(),
                         lr=params["lr_D_B"],
                         betas=(0.9, 0.999))

    optimizer_G_AB = Adam(G_AB.parameters(),
                          lr=params["lr_G_AB"],
                          betas=(0.9, 0.999))
    optimizer_G_BA = Adam(G_BA.parameters(),
                          lr=params["lr_G_BA"],
                          betas=(0.9, 0.999))

    D_A_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2)
    G_AB_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2)
    D_B_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2)
    G_BA_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):

        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):

            G_AB.train()
            G_BA.train()
            D_A.train()
            D_B.train()

            optimizer_G_AB.zero_grad()
            optimizer_D_A.zero_grad()
            optimizer_G_BA.zero_grad()
            optimizer_D_B.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            start_t_batch = time.time()

            output_point_cloud_high = G_AB(input_data)
            output_point_cloud_low = G_BA(gt_data)

            #####################################
            #               Loss                #
            #####################################
            repulsion_loss_AB = Loss_fn.get_repulsion_loss(
                output_point_cloud_high.permute(0, 2, 1))
            uniform_loss_AB = Loss_fn.get_uniform_loss(
                output_point_cloud_high.permute(0, 2, 1))
            repulsion_loss_BA = Loss_fn.get_repulsion_loss(
                output_point_cloud_low.permute(0, 2, 1))
            uniform_loss_BA = Loss_fn.get_uniform_loss(
                output_point_cloud_low.permute(0, 2, 1))
            emd_loss_AB = Loss_fn.get_emd_loss(
                output_point_cloud_high.permute(0, 2, 1),
                gt_data.permute(0, 2, 1))
            #emd_loss_BA = Loss_fn.get_emd_loss(output_point_cloud_low.permute(0, 2, 1), input_data.permute(0, 2, 1))

            #Cycle Loss
            recov_A = G_BA(output_point_cloud_high)
            ABA_repul_loss = Loss_fn.get_repulsion_loss(
                recov_A.permute(0, 2, 1))
            ABA_uniform_loss = Loss_fn.get_uniform_loss(
                recov_A.permute(0, 2, 1))

            recov_B = G_AB(output_point_cloud_low)
            BAB_repul_loss = Loss_fn.get_repulsion_loss(
                recov_B.permute(0, 2, 1))
            BAB_uniform_loss = Loss_fn.get_uniform_loss(
                recov_B.permute(0, 2, 1))
            BAB_emd_loss = Loss_fn.get_emd_loss(recov_B.permute(0, 2, 1),
                                                gt_data.permute(0, 2, 1))

            #G_AB loss
            fake_pred_B = D_A(output_point_cloud_high.detach())
            g_AB_loss = Loss_fn.get_generator_loss(fake_pred_B)
            total_G_AB_loss=g_AB_loss*params['gan_w_AB']+ BAB_repul_loss*params['repulsion_w_AB']+ \
            BAB_uniform_loss*params['uniform_w_AB']+ BAB_emd_loss*params['emd_w_AB']+ \
            params['uniform_w_AB']*uniform_loss_AB+params['emd_w_AB']*emd_loss_AB+ \
            repulsion_loss_AB*params['repulsion_w_AB']

            total_G_AB_loss.backward()
            optimizer_G_AB.step()

            #G_BA loss
            fake_pred_A = D_B(output_point_cloud_low.detach())
            g_BA_loss = Loss_fn.get_generator_loss(fake_pred_A)
            total_G_BA_loss=g_BA_loss*params['gan_w_BA']+ ABA_repul_loss*params['repulsion_w_BA']+ \
            repulsion_loss_BA*params['repulsion_w_BA']
            # ABA_uniform_loss*params['uniform_w_BA']+ \
            # params['uniform_w_BA']*uniform_loss_BA+ \

            total_G_BA_loss.backward()
            optimizer_G_BA.step()

            #Discriminator A loss
            fake_B_ = fake_A_buffer.push_and_pop(output_point_cloud_high)
            fake_pred_B = D_A(fake_B_.detach())
            d_A_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_B,
                                                                  label=False)

            real_pred_B = D_A(gt_data.detach())
            d_A_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_B,
                                                                  label=True)

            d_A_loss = d_A_loss_real + d_A_loss_fake
            d_A_loss.backward()
            optimizer_D_A.step()

            #Discriminator B loss
            fake_A_ = fake_B_buffer.push_and_pop(output_point_cloud_low)
            fake_pred_A = D_B(fake_A_.detach())
            d_B_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_A,
                                                                  label=False)

            real_pred_A = D_B(input_data.detach())
            d_B_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_A,
                                                                  label=True)
            d_B_loss = d_B_loss_real + d_B_loss_fake
            d_B_loss.backward()
            optimizer_D_B.step()

            #Learning rate scheduler#
            current_lr_D_A = optimizer_D_A.state_dict(
            )['param_groups'][0]['lr']
            current_lr_G_AB = optimizer_G_AB.state_dict(
            )['param_groups'][0]['lr']
            current_lr_D_B = optimizer_D_B.state_dict(
            )['param_groups'][0]['lr']
            current_lr_G_BA = optimizer_G_BA.state_dict(
            )['param_groups'][0]['lr']

            # tb_logger.scalar_summary('repulsion_loss_AB', repulsion_loss_AB.item(), iter)
            # tb_logger.scalar_summary('uniform_loss_AB', uniform_loss_AB.item(), iter)
            # tb_logger.scalar_summary('repulsion_loss_BA', repulsion_loss_BA.item(), iter)
            # tb_logger.scalar_summary('uniform_loss_BA', uniform_loss_BA.item(), iter)
            # tb_logger.scalar_summary('emd_loss_AB', emd_loss_AB.item(), iter)

            tb_logger.scalar_summary('d_A_loss', d_A_loss.item(), iter)
            tb_logger.scalar_summary('g_AB_loss', g_AB_loss.item(), iter)
            tb_logger.scalar_summary('Total_G_AB_loss', total_G_AB_loss.item(),
                                     iter)
            tb_logger.scalar_summary('lr_D_A', current_lr_D_A, iter)
            tb_logger.scalar_summary('lr_G_AB', current_lr_G_AB, iter)
            tb_logger.scalar_summary('d_B_loss', d_B_loss.item(), iter)
            tb_logger.scalar_summary('g_BA_loss', g_BA_loss.item(), iter)
            tb_logger.scalar_summary('Total_G_BA_loss', total_G_BA_loss.item(),
                                     iter)
            tb_logger.scalar_summary('lr_D_B', current_lr_D_B, iter)
            tb_logger.scalar_summary('lr_G_BA', current_lr_G_BA, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {}, {}: {}, {}:{}, {}: {},{}: {}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e + 1, batch_id + 1, len(train_data_loader),
                "total_G_AB_loss", total_G_AB_loss.item(), "total_G_BA_loss",
                total_G_BA_loss.item(),
                "iter time", (time.time() - start_t_batch), "d_A_loss",
                d_A_loss.item(), "d_B_loss", d_B_loss.item())
            print(msg)

            iter += 1

        D_A_scheduler.step()
        G_AB_scheduler.step()
        D_B_scheduler.step()
        G_BA_scheduler.step()

        if (e + 1) % params['model_save_interval'] == 0 and e > 0:
            model_save_dir = os.path.join(params['model_save_dir'],
                                          params['exp_name'])
            if os.path.exists(model_save_dir) == False:
                os.makedirs(model_save_dir)
            D_A_ckpt_model_filename = "D_A_iter_%d.pth" % (e + 1)
            G_AB_ckpt_model_filename = "G_AB_iter_%d.pth" % (e + 1)
            D_A_model_save_path = os.path.join(model_save_dir,
                                               D_A_ckpt_model_filename)
            G_AB_model_save_path = os.path.join(model_save_dir,
                                                G_AB_ckpt_model_filename)
            D_B_ckpt_model_filename = "D_B_iter_%d.pth" % (e + 1)
            G_BA_ckpt_model_filename = "G_BA_iter_%d.pth" % (e + 1)
            model_ckpt_model_filename = "Cyclegan_iter_%d.pth" % (e + 1)
            D_B_model_save_path = os.path.join(model_save_dir,
                                               D_B_ckpt_model_filename)
            G_BA_model_save_path = os.path.join(model_save_dir,
                                                G_BA_ckpt_model_filename)
            model_all_path = os.path.join(model_save_dir,
                                          model_ckpt_model_filename)
            torch.save(
                {
                    'G_AB_state_dict': G_AB.module.state_dict(),
                    'G_BA_state_dict': G_BA.module.state_dict(),
                    'D_A_state_dict': D_A.module.state_dict(),
                    'D_B_state_dict': D_B.module.state_dict(),
                    'optimizer_G_AB_state_dict': optimizer_G_AB.state_dict(),
                    'optimizer_G_BA_state_dict': optimizer_G_BA.state_dict(),
                    'optimizer_D_A_state_dict': optimizer_D_A.state_dict(),
                    'optimizer_D_B_state_dict': optimizer_D_B.state_dict()
                }, model_all_path)
            torch.save(D_A.module.state_dict(), D_A_model_save_path)
            torch.save(G_AB.module.state_dict(), G_AB_model_save_path)
            torch.save(D_B.module.state_dict(), D_B_model_save_path)
            torch.save(G_BA.module.state_dict(), G_BA_model_save_path)