def run(options): logger = get_logger() experiment_logger = ExperimentLogger() train_dataset, validation_dataset = get_train_and_validation(options) train_iterator = get_train_iterator(options, train_dataset) validation_iterator = get_validation_iterator(options, validation_dataset) embeddings = train_dataset['embeddings'] logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) logger.info('Model:') for name, p in trainer.net.named_parameters(): logger.info('{} {}'.format(name, p.shape)) if options.save_init: logger.info('Saving model (init).') trainer.save_model( os.path.join(options.experiment_path, 'model_init.pt')) if options.parse_only: run_parse(options, train_iterator, trainer, validation_iterator) sys.exit() run_train(options, train_iterator, trainer, validation_iterator)
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.diora ## Monkey patch parsing specific methods. override_init_with_batch(diora) override_inside_hook(diora) ## Turn off outside pass. trainer.net.diora.outside = False ## Eval mode. trainer.net.eval() return trainer, word2idx
def run_parse(options, train_iterator, trainer, validation_iterator): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.diora ## Turn off outside pass. trainer.net.diora.outside = False ## Eval mode. trainer.net.eval() ## Topk predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) logger.info('Beginning to parse.') with torch.no_grad(): for i, batch_map in enumerate(batches): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Rather than skipping, just log the trees (they are trivially easy to find). if length <= 2: for i in range(batch_size): example_id = batch_map['example_ids'][i] tokens = sentences[i].tolist() words = [idx2word[idx] for idx in tokens] if length == 2: o = dict(example_id=example_id, tree=(words[0], words[1])) elif length == 1: o = dict(example_id=example_id, tree=words[0]) print(json.dumps(o)) continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees = parse_predictor.parse_batch(batch_map) for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) o = dict(example_id=example_id, tree=tr) print(json.dumps(o))
def build_net(options, embeddings, batch_iterator=None): from diora.net.trainer import build_net trainer = build_net(options, embeddings, batch_iterator, random_seed=options.seed) logger = get_logger() logger.info('# of params = {}'.format(count_params(trainer.net))) return trainer
def configure(options): # Configure output paths for this experiment. configure_experiment(options.experiment_path, rank=options.local_rank) # Get logger. logger = get_logger() # Print flags. logger.info(stringify_flags(options)) save_flags(options, options.experiment_path)
def __init__(self, data_source, batch_size, include_partial=False, rng=None, maxlen=None, length_to_size=None): self.data_source = data_source self.active = False if rng is None: rng = np.random.RandomState(seed=11) self.rng = rng self.batch_size = batch_size self.maxlen = maxlen self.include_partial = include_partial self.length_to_size = length_to_size self._batch_size_cache = { 0: self.batch_size } self.logger = get_logger()
def context_insensitive_elmo(weights_path, options_path, word2idx, cuda=False, cache_dir=None): logger = get_logger() vocab = [w for w, i in sorted(word2idx.items(), key=lambda x: x[1])] validate_word2idx(word2idx) if cache_dir is not None: key = hash_vocab(vocab) cache_path = os.path.join(cache_dir, 'elmo_{}.npy'.format(key)) if os.path.exists(cache_path): logger.info('Loading cached elmo vectors: {}'.format(cache_path)) return load_elmo_cache(cache_path) if cuda: device = 0 else: device = -1 batch_size = 256 nbatches = len(vocab) // batch_size + 1 logger.info('Begin caching vectors. nbatches={} device={}'.format( nbatches, device)) logger.info('Initialize ELMo Model.') # TODO: Does not support padding. elmo = ElmoEmbedder(options_file=options_path, weight_file=weights_path, cuda_device=device) vec_lst = [] for i in tqdm(range(nbatches), desc='elmo'): start = i * batch_size batch = vocab[start:start + batch_size] if len(batch) == 0: continue vec = elmo.embed_sentence(batch) vec_lst.append(vec) vectors = np.concatenate([x[0] for x in vec_lst], axis=0) if cache_dir is not None: logger.info('Saving cached elmo vectors: {}'.format(cache_path)) save_elmo_cache(cache_path, vectors) return vectors
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) #print(validation_dataset['sentence1'][0],validation_dataset['example_ids'][0]) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.encoder ## Monkey patch parsing specific methods. override_init_with_batch(diora) override_inside_hook(diora) ## Turn off outside pass. #trainer.net.encoder.outside = False ## Eval mode. trainer.net.eval() ## Parse predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) output_path1 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli1.jsonl')) output_path2 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli2.jsonl')) logger.info('Beginning.') logger.info('Writing output to = {}'.format(output_path1)) logger.info('Writing output to = {}'.format(output_path2)) f = open(output_path1, 'w') with torch.no_grad(): for i, batch_map in tqdm(enumerate(batches)): #print(batch_map.keys()) sentences1 = batch_map['sentences_1'] sentences2 = batch_map['sentences_2'] #print(sentences.shape) batch_size = sentences1.shape[0] length = sentences1.shape[1] # Skip very short sentences. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees1 = parse_predictor.parse_batch(sentences1) trees2 = parse_predictor.parse_batch(sentences2) #print(list(zip(trees1,trees2))) for ii,tree in enumerate(list(zip(trees1,trees2))): tr1,tr2 = tree[0],tree[1] example_id = batch_map['example_ids'][ii] #print(batch_map['example_ids']) s1 = [idx2word[idx] for idx in sentences1[ii].tolist()] s2 = [idx2word[idx] for idx in sentences2[ii].tolist()] tr1 = replace_leaves(tr1, s1) tr2 = replace_leaves(tr2, s2) if options.postprocess: tr = postprocess(tr, s1) o = collections.OrderedDict(example_id=example_id, sentence1=tr1,sentence2=tr2) #print(o) #exit() f.write(json.dumps(o) + '\n') f.close()
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) diora = trainer.net.diora tree_helper = TreeHelper(diora, word2idx) tree_helper.init(options) csv_helper = CSVHelper() ## Eval mode. trainer.net.eval() batches = validation_iterator.get_iterator(random_seed=options.seed) meta_output_path = os.path.abspath(os.path.join(options.experiment_path, 'vectors.csv')) vec_output_path = os.path.abspath(os.path.join(options.experiment_path, 'vectors.npy')) logger.info('Beginning.') logger.info('Writing vectors to = {}'.format(vec_output_path)) logger.info('Writing metadata to = {}'.format(meta_output_path)) f_csv = open(meta_output_path, 'w') f_vec = open(vec_output_path, 'ab') csv_helper.write_header(f_csv) with torch.no_grad(): for i, batch_map in tqdm(enumerate(batches)): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Skip very short sentences. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) if options.parse_mode == 'all-spans': for ii in range(batch_size): example_id = batch_map['example_ids'][ii] for level in range(length): size = level + 1 for pos in range(length - level): # metadata csv_helper.write_row(f_csv, collections.OrderedDict( example_id=example_id, position=str(pos), size=str(size) )) inside_vectors = diora.inside_h.view(-1, options.hidden_dim) outside_vectors = diora.outside_h.view(-1, options.hidden_dim) else: trees, spans = tree_helper.get_trees_for_batch(batch_map, options) batch_index = [] cell_index = [] offset_cache = diora.index.get_offset(length) for ii, sp_lst in enumerate(spans): example_id = batch_map['example_ids'][ii] for pos, size in sp_lst: # metadata csv_helper.write_row(f_csv, collections.OrderedDict( example_id=example_id, position=str(pos), size=str(size) )) # for vectors level = size - 1 cell = offset_cache[level] + pos batch_index.append(ii) cell_index.append(cell) inside_vectors = diora.inside_h[batch_index, cell_index] assert inside_vectors.shape == (len(batch_index), options.hidden_dim) outside_vectors = diora.outside_h[batch_index, cell_index] assert outside_vectors.shape == (len(batch_index), options.hidden_dim) vectors = np.concatenate([inside_vectors, outside_vectors], axis=1) np.savetxt(f_vec, vectors) f_csv.close() f_vec.close()
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.diora ## Monkey patch parsing specific methods. override_init_with_batch(diora) override_inside_hook(diora) ## Turn off outside pass. trainer.net.diora.outside = False ## Eval mode. trainer.net.eval() ## Parse predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) output_path = os.path.abspath(os.path.join(options.experiment_path, 'parse.jsonl')) logger.info('Beginning.') logger.info('Writing output to = {}'.format(output_path)) f = open(output_path, 'w') with torch.no_grad(): for i, batch_map in tqdm(enumerate(batches)): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Skip very short sentences. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees = parse_predictor.parse_batch(batch_map) for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) if options.postprocess: tr = postprocess(tr, s) o = collections.OrderedDict(example_id=example_id, tree=tr) f.write(json.dumps(o) + '\n') f.close()
def run_train(options, train_iterator, trainer, validation_iterator): logger = get_logger() experiment_logger = ExperimentLogger() logger.info('Running train.') seeds = generate_seeds(options.max_epoch, options.seed) step = 0 for epoch, seed in zip(range(options.max_epoch), seeds): # --- Train--- # seed = seeds[epoch] logger.info('epoch={} seed={}'.format(epoch, seed)) def myiterator(): it = train_iterator.get_iterator(random_seed=seed) count = 0 for batch_map in it: # TODO: Skip short examples (optionally). if batch_map['length'] <= 2: continue yield count, batch_map count += 1 for batch_idx, batch_map in myiterator(): if options.finetune and step >= options.finetune_after: trainer.freeze_diora() result = trainer.step(batch_map) experiment_logger.record(result) if step % options.log_every_batch == 0: experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) # -- Periodic Checkpoints -- # if not options.multigpu or options.local_rank == 0: if step % options.save_latest == 0 and step >= options.save_after: logger.info('Saving model (periodic).') trainer.save_model( os.path.join(options.experiment_path, 'model_periodic.pt')) save_experiment( os.path.join(options.experiment_path, 'experiment_periodic.json'), step) if step % options.save_distinct == 0 and step >= options.save_after: logger.info('Saving model (distinct).') trainer.save_model( os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step))) save_experiment( os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)), step) del result step += 1 experiment_logger.log_epoch(epoch, step) if options.max_step is not None and step >= options.max_step: logger.info('Max-Step={} Quitting.'.format(options.max_step)) sys.exit()
def __init__(self): super(ExperimentLogger, self).__init__() self.logger = get_logger() self.A = None self.c = Counter()
def read_glove(filename, word2idx): """ Two cases: 1. The word2idx has already been filtered according to embedding vocabulary. 2. The word2idx is derived solely from the raw text data. """ logger = get_logger() glove_vocab = set() size = None validate_word2idx(word2idx) logger.info('Reading Glove Vocab.') with open(filename) as f: for i, line in enumerate(f): word, vec = line.split(' ', 1) glove_vocab.add(word) if i == 0: size = len(vec.strip().split(' ')) new_vocab = set.intersection(set(word2idx.keys()), glove_vocab) new_vocab.discard(PADDING_TOKEN) new_vocab.discard(UNK_TOKEN) if word2idx.get(EXISTING_VOCAB_TOKEN, None) == 2: new_word2idx = word2idx.copy() logger.info('Using existing vocab mapping.') else: new_word2idx = OrderedDict() new_word2idx[PADDING_TOKEN] = len(new_word2idx) new_word2idx[UNK_TOKEN] = len(new_word2idx) new_word2idx[EXISTING_VOCAB_TOKEN] = len(new_word2idx) for w, _ in word2idx.items(): if w in new_word2idx: continue new_word2idx[w] = len(new_word2idx) logger.info('Creating new mapping.') logger.info( 'glove-vocab-size={} vocab-size={} intersection-size={} (-{})'.format( len(glove_vocab), len(word2idx), len(new_vocab), len(word2idx) - len(new_vocab))) embeddings = np.zeros((len(new_word2idx), size), dtype=np.float32) logger.info('Reading Glove Embeddings.') with open(filename) as f: for line in f: word, vec = line.strip().split(' ', 1) if word is PADDING_TOKEN or word is UNK_TOKEN: continue if word in new_vocab and word not in new_word2idx: raise ValueError if word not in new_word2idx: continue word_id = new_word2idx[word] vec = np.fromstring(vec, dtype=float, sep=' ') embeddings[word_id] = vec validate_word2idx(new_word2idx) return embeddings, new_word2idx
def __init__(self, reader): super(ReaderManager, self).__init__() self.reader = reader self.logger = get_logger()
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) diora = trainer.net.diora # 1. Get all relevant phrase vectors. dtype = { 'example_ids': 'list', 'labels': 'list', 'positions': 'list', 'sizes': 'list', 'phrases': 'list', 'inside': 'torch', 'outside': 'torch', } batch_recorder = BatchRecorder(dtype=dtype) ## Eval mode. trainer.net.eval() batches = validation_iterator.get_iterator(random_seed=options.seed) logger.info('Beginning to embed phrases.') with torch.no_grad(): for i, batch_map in enumerate(batches): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Skips very short examples. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) entity_labels = batch_map['entity_labels'] if len(entity_labels) == 0: continue try: batch_index, positions, sizes, labels = get_cell_index(entity_labels) except: continue # Skip short phrases. batch_index = [x for x, y in zip(batch_index, sizes) if y >= 2] positions = [x for x, y in zip(positions, sizes) if y >= 2] labels = [x for x, y in zip(labels, sizes) if y >= 2] sizes = [y for y in sizes if y >= 2] cell_index = (batch_index, positions, sizes) batch_result = {} batch_result['example_ids'] = [batch_map['example_ids'][idx] for idx in cell_index[0]] batch_result['labels'] = labels batch_result['positions'] = cell_index[1] batch_result['sizes'] = cell_index[2] batch_result['phrases'] = get_many_phrases(sentences, *cell_index) batch_result['inside'] = get_many_cells(diora, diora.inside_h, *cell_index) batch_result['outside'] = get_many_cells(diora, diora.outside_h, *cell_index) batch_recorder.record(**batch_result) result = batch_recorder.get_flattened_result() # 2. Build an index of nearest neighbors. vectors = np.concatenate([result['inside'], result['outside']], axis=1) normalize_L2(vectors) index = Index(dim=vectors.shape[1]) index.add(vectors) index.cache(vectors, options.k_candidates) # 3. Print a summary. example_ids = result['example_ids'] phrases = result['phrases'] labels = result['labels'] assert len(example_ids) == len(phrases) assert len(example_ids) == vectors.shape[0] def stringify(phrase): return ' '.join([idx2word[idx] for idx in phrase]) prec_1 = [] prec_10 = [] prec_100 = [] for i in range(vectors.shape[0]): topk = [] corr_lab = 0 for j, score in index.topk(i, options.k_candidates): # Skip same example. if example_ids[i] == example_ids[j]: continue # Skip string match. if phrases[i] == phrases[j]: continue topk.append((j, score)) corr_lab += 1. * (labels[i] == labels[j]) if len(topk) == 1: prec_1.append(corr_lab) elif len(topk) == 10: prec_10.append(corr_lab) elif len(topk) == 100: prec_100.append(corr_lab) if len(topk) == options.k_top: break assert len(topk) == options.k_top, 'Did not find enough valid candidates.' # Print. # print('[query] example_id={} phrase={} lab={}'.format( # example_ids[i], stringify(phrases[i]),labels[i])) # for rank, (j, score) in enumerate(topk[:2]): # print('rank={} score={:.3f} example_id={} phrase={} lab={}'.format( # rank, score, example_ids[j], stringify(phrases[j]), labels[j])) print(np.mean(prec_1), np.mean(prec_10)/10)
def run_train(options, train_iterator, trainer, validation_iterator): logger = get_logger() experiment_logger = ExperimentLogger() logger.info('Running train.') seeds = generate_seeds(options.max_epoch, options.seed) step = 0 # Added now idx2word = {v: k for k, v in train_iterator.word2idx.items()} parse_predictor = CKY(net=trainer.net.diora, word2idx=train_iterator.word2idx) # Added now for epoch, seed in zip(range(options.max_epoch), seeds): # --- Train--- # # Added now precision = 0 recall = 0 total_len = 0 count_des = 0 # Added now seed = seeds[epoch] logger.info('epoch={} seed={}'.format(epoch, seed)) def myiterator(): it = train_iterator.get_iterator(random_seed=seed) count = 0 for batch_map in it: # TODO: Skip short examples (optionally). if batch_map['length'] <= 2: continue yield count, batch_map count += 1 for batch_idx, batch_map in myiterator(): if options.finetune and step >= options.finetune_after: trainer.freeze_diora() result = trainer.step(batch_map) # Added now trainer.net.eval() sentences = batch_map['sentences'] trees = parse_predictor.parse_batch(batch_map) o_list = [] for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) o = dict(example_id=example_id, tree=tr) o_list.append(o["tree"]) # print(json.dumps(o)) # print(o["tree"]) # print(batch_map["parse_tree"][ii]) if isinstance(batch_map["parse_tree"][ii], str): parse_tree_tuple = str_to_tuple( batch_map["parse_tree"][ii]) else: parse_tree_tuple = batch_map["parse_tree"][ii] o_spans = tree_to_spans(o["tree"]) batch_spans = tree_to_spans(parse_tree_tuple[0]) p, r, t = precision_and_recall(batch_spans, o_spans) precision += p recall += r total_len += t # print(precision, recall, total_len) # print(precision / total_len, recall / total_len) # print((2*precision*recall)/(total_len*(precision+recall))) trainer.net.train() # Added now experiment_logger.record(result) if step % options.log_every_batch == 0: experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) # -- Periodic Checkpoints -- # if not options.multigpu or options.local_rank == 0: if step % options.save_latest == 0 and step >= options.save_after: logger.info('Saving model (periodic).') trainer.save_model( os.path.join(options.experiment_path, 'model_periodic.pt')) save_experiment( os.path.join(options.experiment_path, 'experiment_periodic.json'), step) if step % options.save_distinct == 0 and step >= options.save_after: logger.info('Saving model (distinct).') trainer.save_model( os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step))) save_experiment( os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)), step) del result step += 1 # Added now print(precision, recall, total_len) print(precision / total_len, recall / total_len) print(count_des) # Added now experiment_logger.log_epoch(epoch, step) if options.max_step is not None and step >= options.max_step: logger.info('Max-Step={} Quitting.'.format(options.max_step)) sys.exit()
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) diora = trainer.net.diora # 1. Get all relevant phrase vectors. dtype = { 'example_ids': 'list', 'labels': 'list', 'positions': 'list', 'sizes': 'list', 'phrases': 'list', 'inside': 'torch', 'outside': 'torch', } batch_recorder = BatchRecorder(dtype=dtype) # Eval mode. trainer.net.eval() batches = validation_iterator.get_iterator(random_seed=options.seed) logger.info('Beginning to embed phrases.') strings = [] with torch.no_grad(): for i, batch_map in enumerate(batches): sentences = batch_map['sentences'] length = sentences.shape[1] # Skips very short examples. if length <= 2: continue strings.extend([ "".join([idx2word[idx] for idx in x]) for x in sentences.numpy() ]) trainer.step(batch_map, train=False, compute_loss=False) batch_result = {} batch_result['inside'] = diora.inside_h[:, -1] batch_result['outside'] = diora.outside_h[:, -1] batch_recorder.record(**batch_result) result = batch_recorder.get_flattened_result() # 2. Build an index of nearest neighbors. vectors = np.concatenate([result['inside'], result['outside']], axis=1) print(len(strings), vectors.shape) r = Reach(vectors, strings) for s in strings: print(s) print(r.most_similar(s))
def build_net(options, embeddings=None, batch_iterator=None, random_seed=None): logger = get_logger() lr = options.lr size = options.hidden_dim k_neg = options.k_neg margin = options.margin normalize = options.normalize input_dim = embeddings.shape[1] cuda = options.cuda rank = options.local_rank ngpus = 1 if cuda and options.multigpu: ngpus = torch.cuda.device_count() os.environ['MASTER_ADDR'] = options.master_addr os.environ['MASTER_PORT'] = options.master_port torch.distributed.init_process_group(backend='nccl', init_method='env://') # Embed embedding_layer = nn.Embedding.from_pretrained( torch.from_numpy(embeddings), freeze=True) embed = Embed(embedding_layer, input_size=input_dim, size=size) # Diora if options.arch == 'treelstm': diora = DioraTreeLSTM(size, outside=True, normalize=normalize, compress=False) elif options.arch == 'mlp': diora = DioraMLP(size, outside=True, normalize=normalize, compress=False) elif options.arch == 'mlp-shared': diora = DioraMLPShared(size, outside=True, normalize=normalize, compress=False) # Loss loss_funcs = get_loss_funcs(options, batch_iterator, embedding_layer) # Net net = Net(embed, diora, loss_funcs=loss_funcs) # Load model. if options.load_model_path is not None: logger.info('Loading model: {}'.format(options.load_model_path)) Trainer.load_model(net, options.load_model_path) # CUDA-support if cuda: if options.multigpu: torch.cuda.set_device(options.local_rank) net.cuda() diora.cuda() if cuda and options.multigpu: net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], output_device=rank) # Trainer trainer = Trainer(net, k_neg=k_neg, ngpus=ngpus, cuda=cuda) trainer.rank = rank trainer.experiment_name = options.experiment_name # for multigpu cleanup trainer.init_optimizer(optim.Adam, dict(lr=lr, betas=(0.9, 0.999), eps=1e-8)) return trainer
def __init__(self, net, word2idx): super(ParsePredictor, self).__init__() self.net = net self.word2idx = word2idx self.idx2word = {v: k for k, v in word2idx.items()} self.logger = get_logger()