Ejemplo n.º 1
0
def bilm_predict():
    options, ckpt_file = load_options_latest_checkpoint('dump/bilm_pretrain')
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    vocab = load_vocab('dump/bilm_pretrain/vocab-2016-09-10.txt',
                       max_word_length)
    test_prefix = '../../deps/bilm-tf/tests/fixtures/train/data.txt'

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    perplexity = test(options, ckpt_file, data, batch_size=1)
    from IPython import embed
    embed()
    import os
    os._exit(1)
Ejemplo n.º 2
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    permute_number = options.get('permute_number', 4)

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    elif options.get('multidirectional'):
        data = MultidirectionalLMDataset(test_prefix, vocab, permute_number, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_siz, permute_number=permute_number)
Ejemplo n.º 3
0
def main(args):
    print(args)
    print('-' * 100)
    print('Loading models and options...')
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)
    print('Loading vocabulary...')
    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    shards = glob(args.test_prefix)
    shards.sort()
    # print(shards)
    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }
    print(f'Building dataset...')
    datasets = []
    for shard in shards:
        if options.get('bidirectional'):
            datasets.append(BidirectionalLMDataset(shard, vocab, **kwargs))
        else:
            datasets.append(LMDataset(shard, vocab, **kwargs))

    print('Predicting...')
    tag(options, ckpt_file, shards, datasets, batch_size=args.batch_size)
    print('-' * 100)
    print('done.')
Ejemplo n.º 4
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    # vocab = load_vocab(args.vocab_file, max_word_length)
    vocab = load_vocab(args.vocab_file, args.stroke_vocab_file,
                       50)  # Winfred stroke_vocab

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_size)
Ejemplo n.º 5
0
def main(args):

    if args.gpu is not None:
        n_gpus = len(args.gpu)
        set_gpu(args.gpu)
    else:
        n_gpus = 0

    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_files,
                       max_word_length=max_word_length,
                       polyglot=True)

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    data = BidirectionalPolyglotLMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_size)
Ejemplo n.º 6
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    prefix = args.train_prefix

    kwargs = {
        'test': False,
        'shuffle_on_load': True,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(prefix, vocab, **kwargs)
    else:
        data = LMDataset(prefix, vocab, **kwargs)

    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir

    # set optional inputs
    if args.n_train_tokens > 0:
        options['n_train_tokens'] = args.n_train_tokens
    if args.n_epochs > 0:
        options['n_epochs'] = args.n_epochs
    if args.batch_size > 0:
        options['batch_size'] = args.batch_size

    train(options, data, args.n_gpus, tf_save_dir, tf_log_dir,
          restart_ckpt_file=ckpt_file)
Ejemplo n.º 7
0
def main(args):
    config = Config(args)
    options, ckpt_file = load_options_latest_checkpoint(config.save_path)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
        vocab = UnicodeCharsVocabularyPad(args.vocab_file, max_word_length)
    else:
        ## Not tested yet
        vocab = VocabularyPad(args.vocab_file)

    test_path = 'data/Selectional_Restrictions/Pylkkanen2007_processed.txt'
    # test_path = 'data/Selectional_Restrictions/Warren2015_processed.txt'
    # test_path = 'data/CSR/WSC_sent.txt'

    with open(test_path) as f:
        sents = [l.rstrip() for l in f.readlines()]
    num_per_group = 2 if 'WSC' in test_path else 3
    positions = _get_changed_positions(sents, num_per_group)
    data = SentenceDataset(test_path,
                           vocab,
                           test=True,
                           shuffle_on_load=False,
                           tokenizer=nltk.word_tokenize)

    # if options.get('bidirectional'):
    #     data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    # else:
    #     data = LMDataset(test_prefix, vocab, **kwargs)

    all_losses, all_lengths = test(options,
                                   ckpt_file,
                                   data,
                                   batch_size=args.batch_size)

    # Full score
    print('Full probability results')
    scores = all_losses.sum(axis=1) / all_lengths
    scores = np.array(scores).reshape(-1, num_per_group)
    res = scores.argmax(axis=1)
    for i in range(num_per_group):
        print(sum(res == i) / len(res))

    # Partial score
    print('Partial probability results')
    seq_mask = sequence_mask(np.array(positions) + 1, options['unroll_steps'])
    partial_losses = seq_mask * all_losses
    loss_mask = partial_losses > 0
    scores = partial_losses.sum(axis=1) / loss_mask.sum(axis=1)
    scores = np.array(scores).reshape(-1, num_per_group)
    res = scores.argmax(axis=1)
    for i in range(num_per_group):
        print(sum(res == i) / len(res))

    from IPython import embed
    embed()
    import os
    os._exit(1)
Ejemplo n.º 8
0
def main(args):
    ckpt_file = None
    if os.path.exists(args.save_dir+'options.json'):
        options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    vocab = load_vocab(args.vocab_file, 50)

    # define the options
    batch_size = 128  # batch size for each GPU
    n_gpus = args.n_gpus

    # number of tokens in training data (this for 1B Word Benchmark)
    n_train_tokens = 768648884

    options = {
     'bidirectional': True,

     'char_cnn': {'activation': 'relu',
      'embedding': {'dim': 16},
      'filters': [[1, 32],
       [2, 32],
       [3, 64],
       [4, 128],
       [5, 256],
       [6, 512],
       [7, 1024]],
      'max_characters_per_token': 50,
      'n_characters': 261,
      'n_highway': 2},
    
     'dropout': 0.1,
    
     'lstm': {
      'cell_clip': 3,
      'dim': 4096,
      'n_layers': 2,
      'proj_clip': 3,
      'projection_dim': 512,
      'use_skip_connections': True},
    
     'all_clip_norm_val': 10.0,
    
     'n_epochs': 10,
     'n_train_tokens': n_train_tokens,
     'batch_size': batch_size,
     'n_tokens_vocab': vocab.size,
     'unroll_steps': 20,
     'n_negative_samples_batch': 8192,
    }

    prefix = args.train_prefix
    data = BidirectionalLMDataset(prefix, vocab, test=False, shuffle_on_load=True)

    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir
    train(options, data, n_gpus, tf_save_dir, tf_log_dir,
          restart_ckpt_file=ckpt_file)
Ejemplo n.º 9
0
def top_level(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None

    vocab = load_vocab(os.path.join(args.save_dir, "vocabs.txt"),
                       max_word_length)

    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir

    # set optional inputs to overide the otpions.json
    if args.n_epochs > 0:
        options['n_epochs'] = args.n_epochs
    if args.batch_size > 0:
        options['batch_size'] = args.batch_size
    if args.n_gpus > 0:
        n_gpus = args.n_gpus
    else:
        n_gpus = options['n_gpus']

    # load train_prefixes
    #if args.train_prefix_paths != None:
    if False:
        with open(args.train_prefix_paths, "r") as fd:
            train_prefixes = fd.read().split('\n')
        train_prefixes = [f for f in train_prefixes if f != ""]
        options['train_prefix_paths'] = train_prefixes
        start = 0
    else:
        train_prefixes = options['train_prefix_paths']
        start = options['milestone']

    if start >= len(train_prefixes):
        print(
            "WARNING: Finish all train_prefix_paths. Reset milestone in options."
        )
        sys.exit(0)

    # loop all train_prefix_paths
    milestone = start
    for train_prefix in train_prefixes[start:]:
        prefix = train_prefix + '/*'

        if args.n_train_tokens > 0:
            options['n_train_tokens'] = args.n_train_tokens
        else:
            options['n_train_tokens'] = get_tokens_count(prefix)

        resume(options, prefix, vocab, n_gpus, tf_save_dir, tf_log_dir,
               ckpt_file)
        milestone += 1
        options['milestone'] = milestone
        save_options(options, os.path.join(args.save_dir, "options.json"))
Ejemplo n.º 10
0
    def test_train_bilm_chars(self):
        vocab, data, options = self._get_vocab_data_options(True, True)
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(True, True, True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 11
0
    def test_train_bilm_chars(self):
        vocab, data, options = self._get_vocab_data_options(True, True)
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(True, True, True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 12
0
def main(args):

    if args.gpu is not None:
        if ',' in args.gpu:
            args.gpu = args.gpu.split(',')
        n_gpus = len(args.gpu)
        set_gpu(args.gpu)
    else:
        n_gpus = 0

    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    if 'polyglot' in options or args.polyglot:
        polyglot = True
    vocab = load_vocab(args.vocab_files,
                       max_word_length=max_word_length,
                       polyglot=polyglot)

    prefix = args.train_prefix

    kwargs = {
        'test': False,
        'shuffle_on_load': True,
    }

    if options.get('bidirectional'):
        if 'polyglot' in options or args.polyglot:
            data = BidirectionalPolyglotLMDataset(prefix, vocab, **kwargs)
        else:
            data = BidirectionalLMDataset(prefix, vocab, **kwargs)
    else:
        data = LMDataset(prefix, vocab, **kwargs)

    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir

    # set optional inputs
    if args.n_train_tokens > 0:
        options['n_train_tokens'] = args.n_train_tokens
    if args.n_epochs > 0:
        options['n_epochs'] = args.n_epochs
    if args.batch_size > 0:
        options['batch_size'] = args.batch_size

    train(options,
          data,
          None,
          args.n_gpus,
          tf_save_dir,
          tf_log_dir,
          restart_ckpt_file=ckpt_file)
Ejemplo n.º 13
0
    def test_train_skip_connections(self):
        bidirectional = True
        use_chars = False
        vocab, data, options = self._get_vocab_data_options(
            bidirectional, use_chars)
        options['lstm']['use_skip_connections'] = True
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(bidirectional, use_chars, test=True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 14
0
    def test_train_shared_softmax_embedding(self):
        bidirectional = True
        use_chars = False

        vocab, data, options = self._get_vocab_data_options(
            bidirectional, use_chars, share_embedding_softmax=True)
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(
            bidirectional, use_chars, test=True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 15
0
    def test_train_skip_connections(self):
        bidirectional = True
        use_chars = False
        vocab, data, options = self._get_vocab_data_options(
            bidirectional, use_chars)
        options['lstm']['use_skip_connections'] = True
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(
            bidirectional, use_chars, test=True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 16
0
    def test_train_shared_softmax_embedding(self):
        bidirectional = True
        use_chars = False

        vocab, data, options = self._get_vocab_data_options(
            bidirectional, use_chars, share_embedding_softmax=True)
        train(options, data, 1, self.tmp_dir, self.tmp_dir)

        # now test
        tf.reset_default_graph()
        options, ckpt_file = load_options_latest_checkpoint(self.tmp_dir)
        data_test, vocab_test = self._get_data(
            bidirectional, use_chars, test=True)
        perplexity = test(options, ckpt_file, data_test, batch_size=1)
        self.assertTrue(perplexity < 20.0)
Ejemplo n.º 17
0
def special_words():
    config = Config(args)
    options, ckpt_file = load_options_latest_checkpoint(config.save_path)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
        vocab = UnicodeCharsVocabularyPad(args.vocab_file, max_word_length)
    else:
        ## Not tested yet
        vocab = VocabularyPad(args.vocab_file)

    #test_path = 'data/Selectional_Restrictions/Pylkkanen2007_processed.txt'
    test_path = 'data/Selectional_Restrictions/Warren2015_processed.txt'

    all_losses, all_lengths = test(options, ckpt_file, data,
                                   batch_size=args.batch_size)

    from IPython import embed; embed(); import os; os._exit(1)
Ejemplo n.º 18
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_size)
Ejemplo n.º 19
0
def main(args):
    ent_num = 14541
    with open(
            "/home/why2011btv/research/OpenKE/benchmarks/FB15K237/test2id.txt",
            'r') as f:
        #with open("/home/why2011btv/KG-embedding/obama.txt",'r') as f:
        lines = f.readlines()
        triplet_num = len(lines) - 1
        print("triplet_num:", triplet_num)
        #triplet_num = 600
        test_set = np.zeros([triplet_num, 3], np.int32)
        i = 0
        for line in lines:
            a = line.split(' ')
            if len(a) > 1 and i < triplet_num:
                #a[2] = (a[2])[:-1] #because of newline
                #a[1] = a[1][:-1]

                #            test_set[i][0] = int(a[0])
                #            test_set[i][1] = int(a[2])
                #            test_set[i][2] = int(a[1])
                aa = 1
                test_set[i][0] = int(a[0])
                test_set[i][1] = int(a[2]) + ent_num
                test_set[i][2] = int(a[1])
                #print("a[0]:",test_set[i][0])
                #print("a[2]:",test_set[i][1])
                #print("a[1]:",test_set[i][2])
                #print("a:",aa)
                #print(test_set)
                i += 1

    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)
    data = MYDataset(test_set)

    perplexity = test(options, ckpt_file, data, batch_size=2)
    return perplexity
Ejemplo n.º 20
0
def top_level(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)
    vocab_file = os.path.join(args.save_dir, 'vocabs.txt')

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(vocab_file, max_word_length)

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_size)
Ejemplo n.º 21
0
    config.gpu_options.visible_device_list = '0,1'  # see the gpu 0, 1, 2

    ## replace: tf.ConfigProto by tf.compat.v1.ConfigProto
    print("====== Num GPUs Available: ",
          len(tf.config.experimental.list_physical_devices('GPU')))

    tf.compat.v1.logging.set_verbosity(tf.logging.INFO)

print('=== Loading pre-trained %s language model ...' % args.model.upper())
## size of context for rnn-based LM
k = 2
res_perplexities2 = []

if args.model == 'elmo':
    ## loading the ELMo model just once
    options, ckpt_file = load_options_latest_checkpoint(args.modelfile)
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.modelfile + '/vocab.txt', max_word_length)

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    model = None

## a simple LSTM that casts the next word prediction as a classification task (choose from all the words in the vocabulary)
elif args.model == 'rnn':
Ejemplo n.º 22
0
def main(args):
    is_load, load_path, save_path, budget = cuhk_prototype_tuner_v2.preprocess(
        t_id, params, args.save_dir)

    vocab = load_vocab(args.vocab_file, 50)

    batch_size = int(params['batch_size'])

    gpus_index_list = list(
        map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(',')))
    n_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))

    n_train_tokens = 768648884

    sess_config = tf.compat.v1.ConfigProto(
        allow_soft_placement=True,
        inter_op_parallelism_threads=int(
            params['inter_op_parallelism_threads']),
        intra_op_parallelism_threads=int(
            params['intra_op_parallelism_threads']),
        graph_options=tf.compat.v1.GraphOptions(
            infer_shapes=params['infer_shapes'],
            place_pruned_graph=params['place_pruned_graph'],
            enable_bfloat16_sendrecv=params['enable_bfloat16_sendrecv'],
            optimizer_options=tf.compat.v1.OptimizerOptions(
                do_common_subexpression_elimination=params[
                    'do_common_subexpression_elimination'],
                max_folded_constant_in_bytes=int(
                    params['max_folded_constant']),
                do_function_inlining=params['do_function_inlining'],
                global_jit_level=params['global_jit_level'])))

    options = {
        'bidirectional': True,
        'char_cnn': {
            'activation':
            'relu',
            'embedding': {
                'dim': 16
            },
            'filters': [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256],
                        [6, 512], [7, 1024]],
            'max_characters_per_token':
            50,
            'n_characters':
            261,
            'n_highway':
            2
        },
        'dropout': 0.1,
        'lstm': {
            'cell_clip': 3,
            'dim': 4096,
            'n_layers': 2,
            'proj_clip': 3,
            'projection_dim': 512,
            'use_skip_connections': True
        },
        'all_clip_norm_val': 10.0,
        'n_epochs': int(budget),  # NNI modification
        'n_train_tokens': n_train_tokens,
        'batch_size': batch_size,
        'n_tokens_vocab': vocab.size,
        'unroll_steps': 20,
        'n_negative_samples_batch': 8192,
    }
    prefix = args.train_prefix
    data = BidirectionalLMDataset(prefix,
                                  vocab,
                                  test=False,
                                  shuffle_on_load=True)
    tf_save_dir = save_path
    tf_log_dir = save_path
    if not os.path.exists(tf_save_dir):
        os.makedirs(tf_save_dir)

    if params['tf_gpu_thread_mode'] in ["global", "gpu_private", "gpu_shared"]:
        os.environ['TF_GPU_THREAD_MODE'] = params['tf_gpu_thread_mode']
    if is_load:
        load_file = os.path.join(load_path, 'model.ckpt')
        start = time.time()
        final_perplexity = train(options,
                                 data,
                                 n_gpus,
                                 gpus_index_list,
                                 tf_save_dir,
                                 tf_log_dir,
                                 sess_config,
                                 restart_ckpt_file=load_file)
        end = time.time()
        shutil.rmtree(load_path)
    else:
        start = time.time()
        final_perplexity = train(options, data, n_gpus, gpus_index_list,
                                 tf_save_dir, tf_log_dir, sess_config)
        end = time.time()
    spent_time = (end - start) / 3600.0
    if args.test_prefix != '':
        options, ckpt_file = load_options_latest_checkpoint(tf_save_dir)
        kwargs = {
            'test': True,
            'shuffle_on_load': False,
        }
        test_data = BidirectionalLMDataset(args.test_prefix, vocab, **kwargs)
        final_perplexity = test(options, ckpt_file, test_data, batch_size=128)
    report_dict = {'runtime': spent_time, 'default': final_perplexity}
    nni.report_final_result(report_dict)
Ejemplo n.º 23
0
def top_level(args):
    if not os.path.isdir(args.save_dir):
        os.system("mkdir %s" % args.save_dir)

    # define the options
    if args.config_file == None:
        args.config_file = os.path.join(current_dir,
                                        "resources/default_config.json")
    options = load_options(args.config_file)

    # load train_prefixes
    with open(args.train_prefix_paths, "r") as fd:
        train_prefixes = fd.read().split('\n')
    train_prefixes = [f for f in train_prefixes if f != ""]
    options['train_prefix_paths'] = train_prefixes

    # load the vocab
    vocab = load_vocab(args.vocab_file, 50)

    # number of tokens in training data (this for 1B Word Benchmark)
    # batch_no = n_epochs*n_train_tokens/(batch_size*unroll_steps*n_gpus)
    #25600  => 100 n_batch  #example filtered 1330337  #1B 768648884
    if args.n_train_tokens == None:
        options['n_train_tokens'] = get_tokens_count(args.train_prefix)
    else:
        options['n_train_tokens'] = args.n_train_tokens

    options['n_tokens_vocab'] = vocab.size
    options['milestone'] = 0
    os.system("cp %s %s/vocabs.txt" % (args.vocab_file, args.save_dir))

    n_gpus = options['n_gpus']
    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir

    prefix = train_prefixes[0] + '/*'
    data = BidirectionalLMDataset(prefix,
                                  vocab,
                                  test=False,
                                  shuffle_on_load=True)

    print("options:", options)
    train(options, data, n_gpus, tf_save_dir, tf_log_dir)
    options['milestone'] = 1
    save_options(options, os.path.join(args.save_dir, "options.json"))

    if len(train_prefixes) == 1:
        return

    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # loop all train_prefix_paths
    milestone = 1
    for train_prefix in train_prefixes[1:]:
        prefix = train_prefix + '/*'

        if args.n_train_tokens > 0:
            options['n_train_tokens'] = args.n_train_tokens
        else:
            options['n_train_tokens'] = get_tokens_count(prefix)

        restarter.resume(options, prefix, vocab, n_gpus, tf_save_dir,
                         tf_log_dir, ckpt_file)
        milestone += 1
        options['milestone'] = milestone
        save_options(options, os.path.join(args.save_dir, "options.json"))