def _setup(self, config): self.config = config print('NeuroCard config:') pprint.pprint(config) os.chdir(config['cwd']) for k, v in config.items(): setattr(self, k, v) if config['__gpu'] == 0: torch.set_num_threads(config['__cpu']) # W&B. # Do wandb.init() after the os.chdir() above makes sure that the Git # diff file (diff.patch) is w.r.t. the directory where this file is in, # rather than w.r.t. Ray's package dir. wandb_project = config['__run'] wandb.init(name=os.path.basename( self.logdir if self.logdir[-1] != '/' else self.logdir[:-1]), sync_tensorboard=True, config=config, project=wandb_project) self.epoch = 0 if isinstance(self.join_tables, int): # Hack to support training single-model tables. sorted_table_names = sorted( list(datasets.TPC_DS.GetTDSLightJoinKeys().keys())) self.join_tables = [sorted_table_names[self.join_tables]] # Try to make all the runs the same, except for input orderings. torch.manual_seed(0) np.random.seed(0) # Common attributes. self.loader = None self.join_spec = None join_iter_dataset = None table_primary_index = None # New datasets should be loaded here. assert self.dataset in ['tpcds'] if self.dataset == 'tpcds': print('Training on Join({})'.format(self.join_tables)) loaded_tables = [] for t in self.join_tables: print('Loading', t) table = datasets.LoadTds(t, use_cols=self.use_cols) table.data.info() loaded_tables.append(table) if len(self.join_tables) > 1: join_spec, join_iter_dataset, loader, table = self.MakeSamplerDatasetLoader( loaded_tables) self.join_spec = join_spec self.train_data = join_iter_dataset self.loader = loader table_primary_index = [t.name for t in loaded_tables].index('title') table.cardinality = datasets.TPC_DS.GetFullOuterCardinalityOrFail( self.join_tables) self.train_data.cardinality = table.cardinality print('rows in full join', table.cardinality, 'cols in full join', len(table.columns), 'cols:', table) else: # Train on a single table. table = loaded_tables[0] if self.dataset != 'tpcds' or len(self.join_tables) == 1: table.data.info() self.train_data = self.MakeTableDataset(table) self.table = table # Provide true cardinalities in a file or implement an oracle CardEst. self.oracle = None self.table_bits = 0 # A fixed ordering? self.fixed_ordering = self.MakeOrdering(table) model = self.MakeModel(self.table, self.train_data, table_primary_index=table_primary_index) # NOTE: ReportModel()'s returned value is the true model size in # megabytes containing all all *trainable* parameters. As impl # convenience, the saved ckpts on disk have slightly bigger footprint # due to saving non-trainable constants (the masks in each layer) as # well. They can be deterministically reconstructed based on RNG seeds # and so should not be counted as model size. self.mb = train_utils.ReportModel(model) if not isinstance(model, transformer.Transformer): print('applying train_utils.weight_init()') model.apply(train_utils.weight_init) self.model = model if self.use_data_parallel: self.model = DataParallelPassthrough(self.model) wandb.watch(model, log='all') if self.use_transformer: opt = torch.optim.Adam( list(model.parameters()), 2e-4, # betas=(0.9, 0.98), # B in Lingvo; in Trfmr paper. betas=(0.9, 0.997), # A in Lingvo. eps=1e-9, ) else: if self.optimizer == 'adam': opt = torch.optim.Adam(list(model.parameters()), 2e-4) else: print('Using Adagrad') opt = torch.optim.Adagrad(list(model.parameters()), 2e-4) print('Optimizer:', opt) self.opt = opt total_steps = self.epochs * self.max_steps if self.lr_scheduler == 'CosineAnnealingLR': # Starts decaying to 0 immediately. self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( opt, total_steps) elif self.lr_scheduler == 'OneCycleLR': # Warms up to max_lr, then decays to ~0. self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( opt, max_lr=2e-3, total_steps=total_steps) elif self.lr_scheduler is not None and self.lr_scheduler.startswith( 'OneCycleLR-'): warmup_percentage = float(self.lr_scheduler.split('-')[-1]) # Warms up to max_lr, then decays to ~0. self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( opt, max_lr=2e-3, total_steps=total_steps, pct_start=warmup_percentage) elif self.lr_scheduler is not None and self.lr_scheduler.startswith( 'wd_'): # Warmups and decays. splits = self.lr_scheduler.split('_') assert len(splits) == 3, splits lr, warmup_fraction = float(splits[1]), float(splits[2]) self.custom_lr_lambda = train_utils.get_cosine_learning_rate_fn( total_steps, learning_rate=lr, min_learning_rate_mult=1e-5, constant_fraction=0., warmup_fraction=warmup_fraction) else: assert self.lr_scheduler is None, self.lr_scheduler self.tbx_logger = tune_logger.TBXLogger(self.config, self.logdir) if self.checkpoint_to_load: self.LoadCheckpoint() self.loaded_queries = None self.oracle_cards = None if self.dataset == 'tpcds' and len(self.join_tables) > 1: queries_job_format = utils.JobToQuery(self.queries_csv) self.loaded_queries, self.oracle_cards = utils.UnpackQueries( self.table, queries_job_format) if config['__gpu'] == 0: print('CUDA not available, using # cpu cores for intra-op:', torch.get_num_threads(), '; inter-op:', torch.get_num_interop_threads())
def main(argv): del argv # Unused. # conn = pg.connect(FLAGS.db) # conn.set_session(autocommit=True) # cursor = conn.cursor() cursor = None tables = datasets.LoadImdb(use_cols=None) # Load all templates in original JOB-light. queries = utils.JobToQuery(FLAGS.tds_light_csv, use_alias_keys=False) tables_to_join_keys = {} for query in queries: key = MakeTablesKey(query[0]) if key not in tables_to_join_keys: join_dict = query[1] # Disambiguate: title->id changed to title->title.id. for table_name in join_dict.keys(): # TODO: only support a single join key join_key = next(iter(join_dict[table_name])) join_dict[table_name] = common.JoinTableAndColumnNames( table_name, join_key, sep='.') tables_to_join_keys[key] = join_dict num_templates = len(tables_to_join_keys) num_queries_per_template = FLAGS.num_queries // num_templates logging.info('%d join templates', num_templates) rng = np.random.RandomState(1234) queries = [] # [(cols, ops, vals)] # Disambiguate to not prune away stuff during join sampling. for table_name, table in tables.items(): for col in table.columns: col.name = common.JoinTableAndColumnNames(table.name, col.name, sep='.') table.data.columns = [col.name for col in table.columns] if FLAGS.print_sel: # Print true selectivities. df = pd.read_csv(FLAGS.output_csv, sep='#', header=None) assert len(df) == FLAGS.num_queries, (len(df), FLAGS.num_queries) inner = [] true_inner_card_cache = {} for row in df.iterrows(): vs = row[1] table_names, join_clauses, true_card = vs[0], vs[1], vs[3] table_names = table_names.split(',') print('Template: {}\tTrue card: {}'.format(table_names, true_card)) # JOB-light: contains 'full_name alias'. # JOB-light-ranges: just 'full_name'. if ' ' in table_names[0]: table_names = [n.split(' ')[0] for n in table_names] tables_in_templates = [tables[n] for n in table_names] key = MakeTablesKey(table_names) join_keys_list = tables_to_join_keys[key] if key not in true_inner_card_cache: join_spec = join_utils.get_join_spec({ "join_tables": table_names, "join_keys": dict( zip(table_names, [[k.split(".")[1]] for k in join_keys_list])), "join_root": "item", "join_how": "inner", }) ds = FactorizedSamplerIterDataset( tables_in_templates, join_spec, sample_batch_size=num_queries, disambiguate_column_names=False, add_full_join_indicators=False, add_full_join_fanouts=False) true_inner_card_cache[key] = ds.sampler.join_card inner.append(true_inner_card_cache[key]) pd.DataFrame({ 'true_cards': df[3], 'true_inner': inner, 'inner_sel': df[3] * 1.0 / inner, 'outer_sel': df[3] * 1.0 / TDS_LIGHT_OUTER_CARDINALITY }).to_csv(FLAGS.output_csv + '.sel', index=False) print('Done:', FLAGS.output_csv + '.sel') else: # Generate queries. last_run_queries = file_len(FLAGS.output_csv) if os.path.exists( FLAGS.output_csv) else 0 next_template_idx = last_run_queries // num_queries_per_template print('next_template_idx', next_template_idx) print(tables_to_join_keys.items()) spark = StartSpark() for i, (tables_to_join, join_keys) in enumerate(tables_to_join_keys.items()): if i < next_template_idx: print('Skipping template:', tables_to_join) continue print('Template:', tables_to_join) if i == num_templates - 1: num_queries_per_template += FLAGS.num_queries % num_templates # Generate num_queries_per_template. table_names = tables_to_join.split('-') tables_in_templates = [tables[n] for n in table_names] queries.extend( MakeQueries(spark, cursor, num_queries_per_template, tables_in_templates, table_names, join_keys, rng))