コード例 #1
0
ファイル: home.py プロジェクト: GAIMJKP/CSDGAN
def continue_training():
    if 'index' in request.form.keys():  # Entering page for the first time
        runs = db.query_all_runs(session['user_id'])
        session['run_id'] = int(runs[int(request.form['index']) - 1]['id'])
        session['title'] = runs[int(request.form['index']) - 1]['title']
        return render_template('home/continue_training.html',
                               title=session['title'])

    if 'cancel' in request.form.keys():
        return redirect(url_for('index'))

    if 'train' in request.form.keys():
        db.query_clear_prior_retraining(
            run_id=session['run_id']
        )  # run_id/status_id combination is primary key in status table
        db.query_incr_retrains(run_id=session['run_id'])
        db.query_set_status(run_id=session['run_id'],
                            status_id=cs.STATUS_DICT['Retraining kicked off'])
        retrain = current_app.task_queue.enqueue(
            'CSDGAN.pipeline.train.retrain.retrain',
            args=(session['run_id'], g.user['username'], session['title'],
                  int(request.form['num_epochs'])),
            job_timeout=-1)
        db.query_update_train_id(run_id=session['run_id'],
                                 train_id=retrain.get_id())
        logger.info('User #{} ({}) continued training Run #{} ({})'.format(
            g.user['id'], g.user['username'], session['run_id'],
            session['title']))
        return redirect(url_for('index'))
コード例 #2
0
ファイル: generate_image_data.py プロジェクト: GAIMJKP/CSDGAN
def generate_image_data(run_id, username, title, aug=None):
    """
    Loads an Image CGAN created by train_image_model.py. Generates data based on user specifications in pre-built gen_dict.pkl.
    :param aug: Whether this is part of the standard run or generating additional data
    """
    if aug is None:
        run_id = str(run_id)
        db.query_verify_live_run(run_id=run_id)

        cu.setup_run_logger(name='gen_func', username=username, title=title)
        logger = logging.getLogger('gen_func')

    try:
        if aug is None:
            db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Generating data'])

        # Check for objects created by train_image_model.py
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)
        assert os.path.exists(os.path.join(run_dir, 'CGAN.pkl')), "CGAN object not found"
        if aug:
            gen_dict_path = os.path.join(run_dir, cs.GEN_DICT_NAME + ' Additional Data ' + str(aug) + '.pkl')
        else:
            gen_dict_path = os.path.join(run_dir, cs.GEN_DICT_NAME + '.pkl')

        assert os.path.exists(gen_dict_path), "gen_dict object not found"

        # Load in CGAN and gen_dict
        CGAN = cu.get_CGAN(username=username, title=title)

        with open(gen_dict_path, 'rb') as f:
            gen_dict = pkl.load(f)

        if aug is None:
            logger.info('Successfully loaded in CGAN. Generating data...')

        # Generate and output data
        folder_name = title + ('' if aug is None else ' Additional Data ' + str(aug))
        output_path = os.path.join(cs.OUTPUT_FOLDER, username, title, folder_name)
        os.makedirs(output_path, exist_ok=True)

        for i, (dep_class, size) in enumerate(gen_dict.items()):
            if size > 0:
                class_path = os.path.join(output_path, dep_class)
                os.makedirs(class_path, exist_ok=True)
                stratify = np.eye(CGAN.nc)[i]
                CGAN.gen_data(size=size, path=class_path, stratify=stratify, label=dep_class)

        _ = shutil.make_archive(output_path, 'zip', output_path)
        shutil.rmtree(output_path)

        if aug is None:
            db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Complete'])
            logger.info('Successfully completed generate_tabular_data function. Run complete.')

    except Exception as e:
        if aug is None:
            db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
            logger.exception('Error: %s', e)
        raise Exception("Intentionally failing process after broadly catching an exception. "
                        "Logs describing this error can be found in the run's specific logs file.")
コード例 #3
0
ファイル: retrain.py プロジェクト: GAIMJKP/CSDGAN
def retrain(run_id, username, title, num_epochs):
    """
    Continues training a tabular model for a specified number of epochs
    """
    run_id = str(run_id)
    db.query_verify_live_run(run_id=run_id)

    cu.setup_run_logger(name='train_func', username=username, title=title)
    cu.setup_run_logger(name='train_info', username=username, title=title, filename='train_log')
    logger = logging.getLogger('train_func')

    try:
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)

        # Load in prior trained GAN
        CGAN = cu.get_CGAN(username=username, title=title)

        # Train
        logger.info('Beginning retraining...')
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Retrain 0/4'])

        if type(CGAN).__name__ == 'TabularCGAN':
            CGAN.train_gan(num_epochs=num_epochs,
                           cadence=cs.TABULAR_DEFAULT_CADENCE,
                           print_freq=cs.TABULAR_DEFAULT_PRINT_FREQ,
                           eval_freq=cs.TABULAR_DEFAULT_EVAL_FREQ,
                           run_id=run_id,
                           logger=logging.getLogger('train_info'),
                           retrain=True)
        elif type(CGAN).__name__ == 'ImageCGAN':
            CGAN.train_gan(num_epochs=num_epochs,
                           print_freq=cs.IMAGE_DEFAULT_PRINT_FREQ,
                           eval_freq=cs.IMAGE_DEFAULT_EVAL_FREQ,
                           run_id=run_id,
                           logger=logging.getLogger('train_info'),
                           retrain=True)
        else:
            raise Exception('Invalid CGAN class object loaded')

        logger = logging.getLogger('train_func')
        logger.info('Successfully retrained CGAN. Loading and saving best model...')

        # Load best-performing GAN and pickle CGAN to main directory
        CGAN.load_netG(best=True)

        with open(os.path.join(run_dir, 'CGAN.pkl'), 'wb') as f:
            pkl.dump(CGAN, f)

        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Retraining Complete'])
        logger.info('Successfully completed retrain_tabular_model function.')

    except Exception as e:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
        logger.exception('Error: %s', e)
        raise Exception("Intentionally failing process after broadly catching an exception. "
                        "Logs describing this error can be found in the run's specific logs file.")
コード例 #4
0
def make_tabular_dataset(run_id, username, title, dep_var, cont_inputs, int_inputs, test_size):
    """
    Requirements of data set is that it is contained in a flat file and the continuous vs. categorical vs. integer vs. dependent
    variables are specified. It should also be specified how to deal with missing data (stretch goal).
    """
    run_id = str(run_id)
    db.query_verify_live_run(run_id=run_id)

    cu.setup_run_logger(name='dataset_func', username=username, title=title)
    logger = logging.getLogger('dataset_func')

    try:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Preprocessing data'])

        # Check existence of run directory
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)
        assert os.path.exists(run_dir), "Run directory does not exist"

        # Perform various checks and load in data
        path = os.path.join(cs.UPLOAD_FOLDER, run_id)
        file = os.listdir(path)[0]
        assert os.path.splitext(file)[1] in {'.txt', '.csv', '.zip'}, "Path is not zip or flat file"
        if os.path.splitext(file)[1] == '.zip':
            logger.info('Tabular file contained in zip. Unzipping...')
            zip_ref = ZipFile(os.path.join(path, file), 'r')
            zip_ref.extractall(run_dir)
            zip_ref.close()

            unzipped_path = os.path.join(run_dir, os.path.splitext(file)[0])

            if os.path.isdir(unzipped_path):
                assert os.path.exists(unzipped_path), \
                    "Flat file in zip not named the same as zip file"
                unzipped_file = os.listdir(unzipped_path)[0]
                assert os.path.splitext(unzipped_file)[1] in {'.txt', '.csv'}, \
                    "Flat file in zip should be .txt or .csv"
                data = pd.read_csv(os.path.join(unzipped_path, unzipped_file), header=0)
            else:
                unzipped_file = [file for file in os.listdir(run_dir) if file not in ['gen_dict.pkl', 'run_log.log']][0]  # Expected entries
                assert os.path.splitext(unzipped_file)[1] in {'.txt', '.csv'}, \
                    "Flat file in zip should be .txt or .csv"
                data = pd.read_csv(os.path.join(run_dir, unzipped_file), header=0)
        else:
            logger.info('Tabular file not contained in zip.')
            data = pd.read_csv(os.path.join(path, file), header=0)

        # Initialize data set object
        dataset = TabularDataset(df=data,
                                 dep_var=dep_var,
                                 cont_inputs=cont_inputs,
                                 int_inputs=int_inputs,
                                 test_size=test_size)
        logger.info('TabularDataset successfully created. Pickling and exiting.')

        # Pickle relevant objects
        with open(os.path.join(run_dir, "dataset.pkl"), "wb") as f:
            pkl.dump(dataset, f)

    except Exception as e:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
        logger.exception('Error: %s', e)
        raise Exception("Intentionally failing process after broadly catching an exception. "
                        "Logs describing this error can be found in the run's specific logs file.")
コード例 #5
0
    def train_gan(self,
                  num_epochs,
                  print_freq,
                  eval_freq=None,
                  run_id=None,
                  logger=None,
                  retrain=False):
        """
        Primary method for training
        :param num_epochs: Desired number of epochs to train for
        :param print_freq: How frequently to print out training statistics (i.e., freq of 5 will result in information being printed every 5 epochs)
        :param eval_freq: How frequently to evaluate with netE. If None, no evaluation will occur. Evaluation takes a significant amount of time.
        :param run_id: If not None, will update database as it progresses through training in quarter increments.
        :param logger: Logger to be used for logging training progress. Must exist if run_id is not None.
        :param retrain: Whether model is being retrained
        """
        assert logger if run_id else True, "Must pass a logger if run_id is passed"

        total_epochs = self.epoch + num_epochs

        if run_id:
            checkpoints = [int(num_epochs * i / 4) for i in range(1, 4)]

        if self.label_noise_linear_anneal:
            self.ln_rate = self.label_noise / num_epochs

        if self.discrim_noise_linear_anneal:
            self.dn_rate = self.discrim_noise / num_epochs

        uu.train_log_print(run_id=run_id,
                           logger=logger,
                           statement="Beginning training")
        og_start_time = time.time()
        start_time = time.time()

        for epoch in range(num_epochs):
            for x, y in self.train_gen:
                y = torch.eye(self.nc)[y] if len(y.shape) == 1 else y
                x, y = x.to(self.device), y.to(self.device)
                self.train_one_step(x, y)

            self.next_epoch()

            if self.epoch % print_freq == 0 or (self.epoch == num_epochs):
                uu.train_log_print(run_id=run_id,
                                   logger=logger,
                                   statement="Time: %ds" %
                                   (time.time() - start_time))
                start_time = time.time()

                self.print_progress(total_epochs=total_epochs,
                                    run_id=run_id,
                                    logger=logger)

            if eval_freq is not None:
                if self.epoch % eval_freq == 0 or (self.epoch == num_epochs):
                    self.init_fake_gen()
                    self.test_model(train_gen=self.fake_train_gen,
                                    val_gen=self.fake_val_gen)
                    uu.train_log_print(
                        run_id=run_id,
                        logger=logger,
                        statement="Epoch: %d\tEvaluator Score: %.4f" %
                        (self.epoch, np.max(self.stored_acc[-1])))

            if run_id:
                if self.epoch in checkpoints:
                    db.query_verify_live_run(run_id=run_id)
                    logger.info('Checkpoint reached.')
                    status_id = 'Train ' + str(
                        checkpoints.index(self.epoch) + 1) + '/4'
                    status_id = status_id.replace(
                        'Train', 'Retrain') if retrain else status_id
                    db.query_set_status(run_id=run_id,
                                        status_id=cs.STATUS_DICT[status_id])

        uu.train_log_print(run_id=run_id,
                           logger=logger,
                           statement="Total training time: %ds" %
                           (time.time() - og_start_time))
        uu.train_log_print(run_id=run_id,
                           logger=logger,
                           statement="Training complete")
コード例 #6
0
ファイル: create.py プロジェクト: GAIMJKP/CSDGAN
def success():
    if request.method == 'POST':
        if 'cancel' in request.form:
            db.clean_run(run_id=session['run_id'])
            return redirect(url_for('index'))

        cmd = 'redis-cli ' + (
            '-h redis-server ' if cs.DOCKERIZED else
            '') + 'ping'  # Check to make sure redis server is up
        if os.system(cmd) != 0:
            db.query_set_status(run_id=session['run_id'],
                                status_id=cs.STATUS_DICT['Error'])
            e = 'Redis server is not set up to handle requests.'
            logger.exception('Error: %s', e)
            raise NameError('Error: ' + e)

        db.query_set_status(run_id=session['run_id'],
                            status_id=cs.STATUS_DICT['Training kicked off'])
        if session['format'] == 'Tabular':
            # Load advanced settings (or defaults)
            bs = session['tabular_batch_size'] if session[
                'advanced_options'] else cs.TABULAR_DEFAULT_BATCH_SIZE
            tabular_init_params = session['tabular_init_params'] if session[
                'advanced_options'] else cs.TABULAR_CGAN_INIT_PARAMS
            tabular_eval_freq = session['tabular_eval_freq'] if session[
                'advanced_options'] else cs.TABULAR_DEFAULT_EVAL_FREQ
            tabular_eval_params = session['tabular_eval_params'] if session[
                'advanced_options'] else cs.TABULAR_EVAL_PARAM_GRID
            tabular_eval_folds = session['tabular_eval_folds'] if session[
                'advanced_options'] else cs.TABULAR_EVAL_FOLDS
            tabular_test_size = session['tabular_test_size'] if session[
                'advanced_options'] else cs.TABULAR_DEFAULT_TEST_SIZE

            # Commence tabular run
            make_dataset = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.data.make_tabular_dataset.make_tabular_dataset',
                args=(session['run_id'], g.user['username'], session['title'],
                      session['dep_var'], session['cont_inputs'],
                      session['int_inputs'], tabular_test_size))
            train_model = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.train.train_tabular_model.train_tabular_model',
                args=(session['run_id'], g.user['username'], session['title'],
                      session['num_epochs'], bs, tabular_init_params,
                      tabular_eval_freq, tabular_eval_params,
                      tabular_eval_folds),
                depends_on=make_dataset,
                job_timeout=-1)
            generate_data = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.generate.generate_tabular_data.generate_tabular_data',
                args=(session['run_id'], g.user['username'], session['title']),
                depends_on=train_model)
        else:  # Image
            # Load advanced settings (or defaults)
            image_init_params = session['image_init_params'] if session[
                'advanced_options'] else cs.IMAGE_CGAN_INIT_PARAMS
            image_eval_freq = session['image_eval_freq'] if session[
                'advanced_options'] else cs.IMAGE_DEFAULT_EVAL_FREQ

            # Commence image run
            make_dataset = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.data.make_image_dataset.make_image_dataset',
                args=(session['run_id'], g.user['username'], session['title'],
                      session['folder'], session['bs'], session['x_dim'],
                      session['splits']))
            train_model = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.train.train_image_model.train_image_model',
                args=(session['run_id'], g.user['username'], session['title'],
                      session['num_epochs'], session['bs'], session['nc'],
                      session['num_channels'], image_init_params,
                      image_eval_freq),
                depends_on=make_dataset,
                job_timeout=-1)
            generate_data = current_app.task_queue.enqueue(
                'CSDGAN.pipeline.generate.generate_image_data.generate_image_data',
                args=(session['run_id'], g.user['username'], session['title']),
                depends_on=train_model)
        db.query_add_job_ids(run_id=session['run_id'],
                             data_id=make_dataset.get_id(),
                             train_id=train_model.get_id(),
                             generate_id=generate_data.get_id())
        logger.info('User #{} ({}) kicked off a {} Run #{} ({})'.format(
            g.user['id'], g.user['username'], session['format'],
            session['run_id'], session['title']))
        return redirect(url_for('index'))

    return render_template('create/success.html', title=session['title'])
コード例 #7
0
def train_image_model(run_id, username, title, num_epochs, bs, nc,
                      num_channels, image_init_params, image_eval_freq):
    """
    Trains an Image CGAN on the data preprocessed by make_image_dataset.py. Loads best generator and pickles CGAN for predictions.
    """
    run_id = str(run_id)
    db.query_verify_live_run(run_id=run_id)

    cu.setup_run_logger(name='train_func', username=username, title=title)
    cu.setup_run_logger(name='train_info',
                        username=username,
                        title=title,
                        filename='train_log')
    logger = logging.getLogger('train_func')

    try:
        # Check for objects created by make_image_dataset.py
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)
        le, ohe, train_gen, val_gen, test_gen = cu.get_image_dataset(
            username=username, title=title)

        device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")

        CGAN = ImageCGAN(train_gen=train_gen,
                         val_gen=val_gen,
                         test_gen=test_gen,
                         device=device,
                         nc=nc,
                         num_channels=num_channels,
                         path=run_dir,
                         le=le,
                         ohe=ohe,
                         fake_bs=bs,
                         **image_init_params)

        # Benchmark and store
        logger.info(
            'Successfully instantiated CGAN object. Beginning benchmarking...')
        db.query_set_status(run_id=run_id,
                            status_id=cs.STATUS_DICT['Benchmarking'])

        benchmark, real_netE = CGAN.eval_on_real_data(
            num_epochs=image_init_params['eval_num_epochs'],
            es=image_init_params['early_stopping_patience'])

        db.query_update_benchmark(run_id=run_id, benchmark=benchmark)

        with open(os.path.join(run_dir, 'real_netE.pkl'), 'wb') as f:
            pkl.dump(real_netE, f)

        # Train
        logger.info('Successfully completed benchmark. Beginning training...')
        db.query_set_status(run_id=run_id,
                            status_id=cs.STATUS_DICT['Train 0/4'])
        CGAN.train_gan(num_epochs=num_epochs,
                       print_freq=cs.IMAGE_DEFAULT_PRINT_FREQ,
                       eval_freq=image_eval_freq,
                       run_id=run_id,
                       logger=logging.getLogger('train_info'))

        logger = logging.getLogger('train_func')
        logger.info(
            'Successfully trained CGAN. Loading and saving best model...')

        # Load best-performing GAN and pickle CGAN to main directory
        CGAN.load_netG(best=True)

        with open(os.path.join(run_dir, 'CGAN.pkl'), 'wb') as f:
            pkl.dump(CGAN, f)

        logger.info('Successfully completed train_tabular_model function.')

    except Exception as e:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
        logger.exception('Error: %s', e)
        raise Exception(
            "Intentionally failing process after broadly catching an exception. "
            "Logs describing this error can be found in the run's specific logs file."
        )
コード例 #8
0
ファイル: make_image_dataset.py プロジェクト: GAIMJKP/CSDGAN
def make_image_dataset(run_id,
                       username,
                       title,
                       folder,
                       bs,
                       x_dim=None,
                       splits=None):
    """
    Requirements of image data set is that it should be a single zip with all images with same label in a folder named with the label name
    Images should either be the same size, or a specified image size should be provided (all images will be cropped to the same size)
    Assumes that file has been pre-unzipped and checked by the create.py functions/related util functions
    This file accomplishes the following:
        1. Accepts a desired image size (optional, else first image dim will be used), batch size, and train/val/test splits
        2. Splits data into train/val/test splits via stratified sampling and moves into corresponding folders
        3. Deletes original unzipped images
        4. Pickles label encoder, one hot encoder, resulting image size, and all three generators
    """
    run_id = str(run_id)
    db.query_verify_live_run(run_id=run_id)

    cu.setup_run_logger(name='dataset_func', username=username, title=title)
    logger = logging.getLogger('dataset_func')

    try:
        db.query_set_status(run_id=run_id,
                            status_id=cs.STATUS_DICT['Preprocessing data'])

        # Check existence of run directory
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)
        assert os.path.exists(run_dir), "Run directory does not exist"

        unzipped_path = os.path.join(run_dir, folder)
        assert os.path.exists(unzipped_path), "Unzipped path does not exist"

        # Load and preprocess data
        import_gen = cuidl.import_dataset(path=unzipped_path,
                                          bs=bs,
                                          shuffle=False,
                                          incl_paths=True)

        splits = [float(num) for num in splits]
        le, ohe, x_dim = cuidl.preprocess_imported_dataset(
            path=unzipped_path,
            import_gen=import_gen,
            splits=splits,
            x_dim=x_dim)

        logger.info(
            'Data successfully imported and preprocessed. Splitting into train/val/test...'
        )

        # Create data loader for each component of data set
        train_gen = cuidl.import_dataset(os.path.join(unzipped_path, 'train'),
                                         bs=bs,
                                         shuffle=True,
                                         incl_paths=False)
        val_gen = cuidl.import_dataset(os.path.join(unzipped_path, 'val'),
                                       bs=bs,
                                       shuffle=True,
                                       incl_paths=False)
        test_gen = cuidl.import_dataset(os.path.join(unzipped_path, 'test'),
                                        bs=bs,
                                        shuffle=True,
                                        incl_paths=False)

        logger.info(
            'Data successfully split into train/val/test. Pickling and exiting.'
        )

        # Pickle relevant objects
        with open(os.path.join(run_dir, "le.pkl"), "wb") as f:
            pkl.dump(le, f)

        with open(os.path.join(run_dir, "ohe.pkl"), "wb") as f:
            pkl.dump(ohe, f)

        with open(os.path.join(run_dir, "train_gen.pkl"), "wb") as f:
            pkl.dump(train_gen, f)

        with open(os.path.join(run_dir, "val_gen.pkl"), "wb") as f:
            pkl.dump(val_gen, f)

        with open(os.path.join(run_dir, "test_gen.pkl"), "wb") as f:
            pkl.dump(test_gen, f)

    except Exception as e:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
        logger.exception('Error: %s', e)
        raise Exception(
            "Intentionally failing process after broadly catching an exception. "
            "Logs describing this error can be found in the run's specific logs file."
        )
コード例 #9
0
def train_tabular_model(run_id, username, title, num_epochs, bs,
                        tabular_init_params, tabular_eval_freq,
                        tabular_eval_params, tabular_eval_folds):
    """
    Trains a Tabular CGAN on the data preprocessed by make_tabular_dataset.py. Loads best generator and pickles CGAN for predictions.
    """
    run_id = str(run_id)
    db.query_verify_live_run(run_id=run_id)

    cu.setup_run_logger(name='train_func', username=username, title=title)
    cu.setup_run_logger(name='train_info',
                        username=username,
                        title=title,
                        filename='train_log')
    logger = logging.getLogger('train_func')

    try:
        # Check for objects created by make_tabular_dataset.py
        run_dir = os.path.join(cs.RUN_FOLDER, username, title)
        dataset = cu.get_tabular_dataset(username=username, title=title)

        device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")

        if len(pkl.dumps(dataset, -1)) < cs.TABULAR_MEM_THRESHOLD:
            dataset.to_dev(device)

        data_gen = data.DataLoader(dataset,
                                   batch_size=bs,
                                   shuffle=True,
                                   num_workers=0)

        CGAN = TabularCGAN(
            data_gen=data_gen,
            device=device,
            path=run_dir,
            seed=None,
            eval_param_grid=tabular_eval_params,
            eval_folds=tabular_eval_folds,
            test_ranges=[dataset.x_train.shape[0] * 2**x for x in range(5)],
            eval_stratify=dataset.eval_stratify,
            nc=len(dataset.labels_list),
            **tabular_init_params)

        # Benchmark and store
        logger.info(
            'Successfully instantiated CGAN object. Beginning benchmarking...')
        db.query_set_status(run_id=run_id,
                            status_id=cs.STATUS_DICT['Benchmarking'])
        benchmark = uu.train_test_logistic_reg(
            x_train=CGAN.data_gen.dataset.x_train.cpu().detach().numpy(),
            y_train=CGAN.data_gen.dataset.y_train.cpu().detach().numpy(),
            x_test=CGAN.data_gen.dataset.x_test.cpu().detach().numpy(),
            y_test=CGAN.data_gen.dataset.y_test.cpu().detach().numpy(),
            param_grid=tabular_eval_params,
            cv=tabular_eval_folds,
            labels_list=dataset.labels_list,
            verbose=False)
        db.query_update_benchmark(run_id=run_id, benchmark=benchmark)

        # Train
        logger.info('Successfully completed benchmark. Beginning training...')
        db.query_set_status(run_id=run_id,
                            status_id=cs.STATUS_DICT['Train 0/4'])
        CGAN.train_gan(num_epochs=num_epochs,
                       cadence=cs.TABULAR_DEFAULT_CADENCE,
                       print_freq=cs.TABULAR_DEFAULT_PRINT_FREQ,
                       eval_freq=tabular_eval_freq,
                       run_id=run_id,
                       logger=logging.getLogger('train_info'))

        logger = logging.getLogger('train_func')
        logger.info(
            'Successfully trained CGAN. Loading and saving best model...')

        # Load best-performing GAN and pickle CGAN to main directory
        CGAN.load_netG(best=True)

        with open(os.path.join(run_dir, 'CGAN.pkl'), 'wb') as f:
            pkl.dump(CGAN, f)

        logger.info('Successfully completed train_tabular_model function.')

    except Exception as e:
        db.query_set_status(run_id=run_id, status_id=cs.STATUS_DICT['Error'])
        logger.exception('Error: %s', e)
        raise Exception(
            "Intentionally failing process after broadly catching an exception. "
            "Logs describing this error can be found in the run's specific logs file."
        )