def main(): # Load config. config = Config("train", training=True) trace(config) torch.backends.cudnn.benchmark = True # Load train dataset. train_data = load_dataset( config.train_dataset, config.train_batch_size, config, prefix="Training:") # Load valid dataset. valid_data = load_dataset( config.valid_dataset, config.valid_batch_size, config, prefix="Validation:") # Build model. vocab = train_data.get_vocab() model = model_factory(config, config.checkpoint, *vocab) if config.verbose: trace(model) # start training trg_vocab = train_data.trg_vocab padding_idx = trg_vocab.padding_idx trainer = Trainer(model, trg_vocab, padding_idx, config) start_epoch = 1 for epoch in range(start_epoch, config.epochs + 1): trainer.train(epoch, config.epochs, train_data, valid_data, train_data.num_batches) dump_checkpoint(trainer.model, config.save_model)
def main(): """main function for checkpoint ensemble.""" config = Config("ensemble", training=True) trace(config) torch.backends.cudnn.benchmark = True train_data = load_dataset(config.train_dataset, config.train_batch_size, config, prefix="Training:") # Build model. vocab = train_data.get_vocab() model = model_factory(config, config.checkpoint, *vocab) cp = CheckPoint(config.checkpoint) model.load_state_dict(cp.state_dict['model']) dump_checkpoint(model, config.save_model, ".ensemble")
def main(): config = Config("translate", training=False) if config.verbose: trace(config) torch.backends.cudnn.benchmark = True test_data = load_dataset(config.test_dataset, config.test_batch_size, config, prefix="Translate:") # Build model. vocab = test_data.get_vocab() pred_file = codecs.open(config.output + ".pred.txt", 'w', 'utf-8') model = model_factory(config, config.checkpoint, *vocab) translator = BatchTranslator(model, config, test_data.src_vocab, test_data.trg_vocab) # Statistics counter = count(1) pred_list = [] gold_list = [] for batch in tqdm(iter(test_data), total=test_data.num_batches): batch_trans = translator.translate(batch) for trans in batch_trans: if config.verbose: sent_number = next(counter) trace(trans.pprint(sent_number)) if config.plot_attn: plot_attn(trans.src, trans.preds[0], trans.attns[0].cpu()) pred_file.write(" ".join(trans.preds[0]) + "\n") pred_list.append(trans.preds[0]) gold_list.append(trans.gold) report_bleu(gold_list, pred_list) report_rouge(gold_list, pred_list)
def data(): config = Config('math') # TODO # Get plan, get locations, get setters # Pull the goal data # Return routes according to loss and novelty. #request.args.get('data') db = connect() # Pull the route and plan data routes, fields, grades, plan, goals, settings = get_collections(db) historical_routes, active_routes = separate_routes(routes) # Restrict historical routes to the previous 6 months. now = date.today() six_months = now - relativedelta(months=+6) six_months_routes = restrict_history_by_date(historical_routes, six_months) # Restrict historical routes to the previous N routes (sorted according to date) N_historical_routes = historical_routes[-config.total_routes:] # Max grade to be suggested by engine setters = plan['setters'] # Instantiate the utils class - this converts routes into arrays and stores them locally in the utils utils = ServerUtilities(active_routes, six_months_routes, fields, config) utils.convert_goals(goals) max_setting, setting_time, setter_nicknames, relative_time, setting_mask, num_grades, max_grade, grade_index = get_setter_attributes( setters, utils) # Set max grade available to set utils.update_max_grade(max_grade) # Two setting styles -> By Location, By Route # Two climbing disciplines -> Bouldering and Roped Climbing # if plan['discipline'] == 'Bouldering': if plan['byLocation']: print('by location') locations = plan['locations'] # Change this based on location flag. To take account of when they set by route. # Find all the routes we are about to strip stripping_routes = return_stripped(active_routes, locations) # update config based on settings and goals. Update tehcnique mask, grade mask, novelty weights, routes by location. terrain types. update_config(config, goals, settings, stripping_routes) utils.bulk_strip_by_location(locations) routes, readable_routes = utils.return_suggestions() else: print('by route') routes = plan['routes'] update_config(config, goals, settings, routes) utils.bulk_strip_by_route(routes) routes, readable_routes = utils.return_suggestions() # Distribute the routes among the setters distributed_routes = distribute_routes(routes, readable_routes, max_setting, setting_time, setter_nicknames, relative_time, setting_mask, num_grades, grade_index) update_plan(db, distributed_routes) # else: # raise ValueError("{} is not supported... die".format(plan['discipline'])) # location_routes = get_routes_by_location(routes,locations) # return jsonify(json_obj) return 'Donezors!'