import argparse import pickle from model import NLG from data_engine import DataEngine from text_token import _UNK, _PAD, _BOS, _EOS import torch import torch.nn as nn import numpy as np import os from utils import print_config, add_path from model_utils import get_embeddings from argument import define_arguments from utils import get_time _, args = define_arguments() args = add_path(args) if args.verbose_level > 0: print_config(args) use_cuda = torch.cuda.is_available() train_data_engine = DataEngine( data_dir=args.data_dir, dataset=args.dataset, save_path=args.train_data_file, vocab_path=args.vocab_file, is_spacy=args.is_spacy, is_lemma=args.is_lemma, fold_attr=args.fold_attr, use_punct=args.use_punct, vocab_size=args.vocab_size,
from argument import define_arguments parser, args = define_arguments(script=True) list_arg = set() for arg in vars(args): if type(getattr(args, arg)) == list: if len(getattr(args, arg)) == 1: setattr(args, arg, getattr(args, arg)[0]) else: list_arg.add(arg) loop_cnt = 0 for arg in vars(args): if arg in list_arg: loop_cnt += 1 attrs = getattr(args, arg) print("{}for {} in {}; do".format( "\t"*(loop_cnt-1), arg, ' '.join(list(map(str, attrs))))) print("{}python3 train.py \\".format("\t"*loop_cnt)) for arg in vars(args): if arg in list_arg: print("{}--{} {} \\".format("\t"*loop_cnt, arg, "${{{}}}".format(arg))) else: if getattr(args, arg) != parser.get_default(arg): print("{}--{} {} \\".format( "\t"*loop_cnt, arg, getattr(args, arg)))