コード例 #1
0
 def test_add_from_counter_n_most_values_with_default(self):
     vocab = Vocabulary()
     new_token_to_id = Counter({"a": 3, "b": 1, "c": 4, "d": 1})
     vocab.add_from_counter(
         "token_to_id", new_token_to_id, n_most_values=4, add_values=["<SOS>", "<EOS>"],
     )
     self.assertDictEqual(vocab.token_to_id, {"<SOS>": 0, "<EOS>": 1, "c": 2, "a": 3})
コード例 #2
0
def load_datasets(train_dataset_params: dict, validation_dataset_params: dict):
    # load PyTorch ``Dataset`` objects for the train & validation sets
    train_dataset = TwitterDataset(**train_dataset_params)
    validation_dataset = TwitterDataset(**validation_dataset_params)

    # use tokens and tags in the training set to create `Vocabulary` objects
    token_vocab = Vocabulary(train_dataset.get_tokens_list(),
                             add_unk_token=True)
    tag_vocab = Vocabulary(train_dataset.get_tags_list())

    # add `Vocabulary` objects to datasets for tokens/tags to ID mapping
    train_dataset.set_vocab(token_vocab, tag_vocab)
    validation_dataset.set_vocab(token_vocab, tag_vocab)

    return train_dataset, validation_dataset
コード例 #3
0
def train(dataset_name: str,
          model_name: str,
          expt_dir: str,
          data_folder: str,
          num_workers: int = 0,
          is_test: bool = False,
          resume_from_checkpoint: str = None):
    seed_everything(SEED)
    dataset_main_folder = data_folder
    vocab = Vocabulary.load(join(dataset_main_folder, "vocabulary.pkl"))

    if model_name == "code2seq":
        config_function = get_code2seq_test_config if is_test else get_code2seq_default_config
        config = config_function(dataset_main_folder)
        model = Code2Seq(config, vocab, num_workers)
        model.half()
    #elif model_name == "code2class":
    #	config_function = get_code2class_test_config if is_test else get_code2class_default_config
    #	config = config_function(dataset_main_folder)
    #	model = Code2Class(config, vocab, num_workers)
    else:
        raise ValueError(f"Model {model_name} is not supported")

    # define logger
    wandb_logger = WandbLogger(project=f"{model_name}-{dataset_name}",
                               log_model=True,
                               offline=True)
    wandb_logger.watch(model)
    # define model checkpoint callback
    model_checkpoint_callback = ModelCheckpoint(
        filepath=join(expt_dir, "{epoch:02d}-{val_loss:.4f}"),
        period=config.hyperparams.save_every_epoch,
        save_top_k=3,
    )
    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyperparams.patience, verbose=True, mode="min")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateLogger()
    trainer = Trainer(
        max_epochs=20,
        gradient_clip_val=config.hyperparams.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.hyperparams.val_every_epoch,
        row_log_interval=config.hyperparams.log_every_epoch,
        logger=wandb_logger,
        checkpoint_callback=model_checkpoint_callback,
        early_stop_callback=early_stopping_callback,
        resume_from_checkpoint=resume_from_checkpoint,
        gpus=gpu,
        callbacks=[lr_logger],
        reload_dataloaders_every_epoch=True,
    )
    trainer.fit(model)
    trainer.save_checkpoint(join(expt_dir, 'Latest.ckpt'))

    trainer.test()
コード例 #4
0
    def build_data_loader(self):
        self.vocabulary = Vocabulary()
        self.config.vocab_size = self.vocabulary.get_num_words()
        assert self.config.num_classes == self.vocabulary.get_num_topics(
        ), "class number doesn't match topic number"

        self.train_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='train',
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

        self.valid_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='valid',
            batch_size=1,
            num_workers=self.config.num_workers)
コード例 #5
0
def main(config):
    print("loading the best model...")
    vocabulary = Vocabulary()
    vocab_size = vocabulary.get_num_words()

    model = TopicClassCNN(vocab_size=vocab_size,
                          emb_size=config.emb_size,
                          dropout=config.dropout,
                          kernel_sizes=config.kernel_sizes,
                          num_feat_maps=config.num_feat_maps,
                          num_classes=config.num_classes)

    checkpoint = torch.load(config.checkpoint_dir)
    model.load_state_dict(checkpoint['state_dict'])

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    model = model.to(device)

    filename = '/home/haochen/Projects/cnn-topic-classification/data/topicclass_valid.txt'

    acc, confusion_matrix, results = eval_validation_set(
        filename, vocabulary.w2i, vocabulary.t2i, model, device)

    print("\n {}".format(confusion_matrix))
    print("val acc: {0:.4f}".format(acc * 100))

    model_name = os.path.split(config.checkpoint_dir)[-1]
    txt_name = os.path.join(
        os.path.split(os.path.split(config.checkpoint_dir)[0])[0], 'results',
        'val_acc_{}.txt'.format(model_name))
    result_str = "val acc: {0:.4f}".format(acc * 100)
    with open(txt_name, 'a') as f:
        result_str += '\n'
        f.write(result_str)

    dev_txt_name = os.path.join(
        os.path.split(os.path.split(config.checkpoint_dir)[0])[0], 'results',
        'dev_result.txt')
    with open(dev_txt_name, 'a') as f:
        for result in results:
            f.write(result + '\n')
コード例 #6
0
def _vocab_from_counters(
    config: PreprocessingConfig, token_counter: Counter, target_counter: Counter, type_counter: Counter
) -> Vocabulary:
    vocab = Vocabulary()
    names_additional_tokens = [SOS, EOS, PAD, UNK] if config.wrap_name else [PAD, UNK]
    vocab.add_from_counter("token_to_id", token_counter, config.subtoken_vocab_max_size, names_additional_tokens)
    target_additional_tokens = [SOS, EOS, PAD, UNK] if config.wrap_target else [PAD, UNK]
    vocab.add_from_counter("label_to_id", target_counter, config.target_vocab_max_size, target_additional_tokens)
    paths_additional_tokens = [SOS, EOS, PAD, UNK] if config.wrap_path else [PAD, UNK]
    vocab.add_from_counter("type_to_id", type_counter, -1, paths_additional_tokens)
    return vocab
コード例 #7
0
    def __init__(self, tf_session, level, train_text_path, max_vocab_size,
                 neurons_per_layer, num_layers, batch_size, num_timesteps,
                 save_dir):
        """Creates a new language model without training it.

        Args:
            tf_session (tf.Session): The session to run the TF Variable initializer in.
            level (str): The level for tokenizing the text - either "char" or "word".
            train_text_path (str): Path to the .txt file containing the training text
            max_vocab_size (int): Maximum size of the vocabulary to use to translate tokens into
                ids. If the text contains a lower number of distinct tokens than the max_vocab_size,
                the vocabulary will be smaller than max_vocab_size.
            neurons_per_layer (int): Number of neurons / units per layer.
            num_layers (int): Number of LSTM layers.
            batch_size (int): The batch size to use for training and evaluation.
            num_timesteps (int): The number of time steps to unroll the LSTM for. The
                back-propagation / training will only go num_timesteps steps into the past. A higher
                number makes the LSTM remember information from tokens that it read earlier.
            save_dir (str): Path to the directory, where the model's learned parameters, the
                vocabulary and training statistics will be saved. Use None, to disable saving
                anything.
        """
        self.train_text_path = train_text_path
        self._batch_size = batch_size
        self.num_timesteps = num_timesteps
        self.save_dir = save_dir
        self.saved = False

        if self.save_dir and not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        # Build the vocabulary to determine the actual vocabulary size
        self.vocab = Vocabulary.load_or_create(save_dir, train_text_path,
                                               max_vocab_size, level)
        if self.save_dir:
            self.vocab.save_to_dir(save_dir)

        num_params = get_num_params(self.vocab.get_size(), num_layers,
                                    neurons_per_layer)
        print('vocab_size=%d' % self.vocab.get_size())
        print('num_layers=%d' % num_layers)
        print('neurons_per_layer=%d' % neurons_per_layer)
        print('num_params=%d' % num_params)

        # Reload or Create a new TF model
        self.tf_model = GeneratingLSTM(vocab_size=self.vocab.get_size(),
                                       neurons_per_layer=neurons_per_layer,
                                       num_layers=num_layers,
                                       max_batch_size=batch_size)

        if save_dir and restore_possible(save_dir):
            ckpt = tf.train.latest_checkpoint(save_dir)
            self.tf_model.saver.restore(tf_session, ckpt)
        else:
            tf_session.run(tf.global_variables_initializer())
コード例 #8
0
def build_vocab(instances, min_word_cnt=5):
    vocab = Vocabulary()
    all_words = [word for sent in instances for word in sent]
    full_vocab = Counter(
        all_words).most_common()  # [('a', 5), ('b', 4), ('c', 3)]
    print('[Info] Original Vocabulary size =', len(full_vocab))

    for item in full_vocab:
        if item[1] >= min_word_cnt:
            vocab.add_word(item[0])
        else:
            break

    print('[Info] Trimmed vocabulary size = {},'.format(len(vocab)),
          'each with minimum occurrence = {}'.format(min_word_cnt))

    print(
        "[Info] Ignored word count = {}".format(len(full_vocab) - len(vocab)))

    return vocab
コード例 #9
0
def get_test_data(df_path):
    df = pd.read_csv(df_path)
    test_df = df[df['split'] == 'test']

    img_ids = test_df.file_name.unique()

    test_dict = {}
    for img_id in img_ids:
        list_tokens = []
        for sent in test_df[test_df['file_name'] == img_id]['caption'].values:
            list_tokens.append(Vocabulary.tokenize_en(sent))

        test_dict[img_id] = list_tokens

    return test_dict
コード例 #10
0
def preprocess(problem: str, data: str, is_vocab_collected: bool, n_jobs: int):
    # Collect vocabulary from train holdout if needed
    if problem not in _config_switcher:
        raise ValueError(f"Unknown problem ({problem}) passed")
    config_function = _config_switcher[problem]
    config = config_function(data)

    vocab_path = path.join(DATA_FOLDER, config.dataset_name, "vocabulary.pkl")
    if path.exists(vocab_path):
        vocab = Vocabulary.load(vocab_path)
    else:
        vocab = collect_vocabulary(config) if is_vocab_collected else convert_vocabulary(config)
        vocab.dump(vocab_path)
    for holdout in ["train", "val", "test"]:
        convert_holdout(holdout, vocab, config, n_jobs)
コード例 #11
0
ファイル: code2seq.py プロジェクト: zeta1999/code2seq
 def __init__(self, config: Code2SeqConfig, vocab: Vocabulary,
              num_workers: int):
     super().__init__(config.hyperparams, vocab, num_workers)
     self.save_hyperparameters()
     if SOS not in vocab.label_to_id:
         vocab.label_to_id[SOS] = len(vocab.label_to_id)
     encoder_config = config.encoder_config
     decoder_config = config.decoder_config
     self.encoder = PathEncoder(
         encoder_config,
         decoder_config.decoder_size,
         len(vocab.token_to_id),
         vocab.token_to_id[PAD],
         len(vocab.type_to_id),
         vocab.type_to_id[PAD],
     )
     self.decoder = PathDecoder(decoder_config, len(vocab.label_to_id),
                                vocab.label_to_id[SOS],
                                vocab.label_to_id[PAD])
コード例 #12
0
def preprocess(problem: str, data: str, is_vocab_collected: bool, n_jobs: int,
               data_folder: str, just_test: bool, test_name: str):
    # Collect vocabulary from train holdout if needed
    if problem not in _config_switcher:
        raise ValueError(f"Unknown problem ({problem}) passed")
    config_function = _config_switcher[problem]
    config = config_function(data)

    vocab_path = path.join(data_folder, "vocabulary.pkl")
    if path.exists(vocab_path):
        vocab = Vocabulary.load(vocab_path)
    else:
        vocab = collect_vocabulary(
            config, data_folder) if is_vocab_collected else convert_vocabulary(
                config, data_folder)
        vocab.dump(vocab_path)

    split = ["train", "val", "test"]
    if just_test:
        split = ["test"]
    for holdout in split:
        convert_holdout(holdout, vocab, config, n_jobs, data_folder, test_name)
コード例 #13
0
def main(args):
    global verbose, encoding
    verbose = args.verbose
    encoding = args.encoding
    assert args.poly_degree >= 1, '--degree must be positive integer'
    poly_degree = args.poly_degree

    gpu = args.gpu
    if gpu >= 0:
        cuda.check_cuda_available()
        if verbose:
            logger.info('Use GPU {}'.format(gpu))
        cuda.get_device_from_id(gpu).use()

    df = read_dataset(args.path_input, args.flag_has_header)

    # agg = df.groupby('fact_en')['twa'].mean()
    # invalid_facts = set(agg[(agg == 1.0)|(agg == 0.0)].index)
    # if verbose:
    #     logger.info('Invalid facts: {}'.format(len(invalid_facts)))
    # df = df[~df['fact_en'].isin(invalid_facts)]
    # if verbose:
    #     logger.info('Remained {} lines'.format(len(df)))

    # Load vocabulary
    if verbose:
        logger.info('Load vocabulary')
    rel2id = Vocabulary()
    rel2id.read_from_file(args.path_rels)
    fact2id = Vocabulary()
    fact2id.read_from_list(np.unique(get_values(df, 'fact')))
    ja2id = Vocabulary()
    ja2id.read_from_list(np.unique(get_values(df, 'fact_ja')))
    en2id = Vocabulary()
    en2id.read_from_list(np.unique(get_values(df, 'fact_en')))

    df.index = df['fact']
    df.loc[:, 'fact'] = replace_by_dic(df['fact'], fact2id).astype(np.int32)
    df.loc[:, 'fact_ja'] = replace_by_dic(df['fact_ja'], ja2id).astype(np.int32)
    df.loc[:, 'fact_en'] = replace_by_dic(df['fact_en'], en2id).astype(np.int32)
    df.loc[:, 'rel'] = replace_by_dic(df['rel'], rel2id).astype(np.int32)

    en2ja = {en: set(df[df['fact_en'] == en]['fact'].unique())
             for en in sorted(df['fact_en'].unique())}
    idx2vec = get_idx2vec(df, poly_degree=poly_degree)
    if gpu >= 0:
        idx2vec = cuda.to_gpu(idx2vec)

    ss = df.drop_duplicates('fact_en')
    itr = FactIterator(ss, len(ss), ja2id, en2id, train=False, evaluate=True,
                       repeat=False, poly_degree=poly_degree)

    # Define a model
    model_type = args.model.lower()
    dim_in = len(COL_BASIC_FEATURES)
    rel_size = len(rel2id)
    if model_type.startswith('linear'):
        ensembler = LinearEnsembler(dim_in, rel_size, use_gpu=(gpu >= 0),
                                    poly_degree=poly_degree,
                                    flag_unifw=args.flag_unifw,
                                    verbose=verbose)
    elif model_type.startswith('mlp'):
        options = args.model.split(':')
        params = {}
        if len(options) > 1:
            params['dim_hid'] = int(options[1])
        if len(options) > 2:
            params['activation'] = options[2]
        ensembler = MLPEnsembler(
            dim_in, rel_size, use_gpu=(gpu >= 0),
            poly_degree=poly_degree, flag_unifw=args.flag_unifw,
            verbose=verbose, **params)
    else:
        raise ValueError('Invalid --model: {}'.format(model_type))

    ensembler.add_persistent('_mu', None)
    ensembler.add_persistent('_sigma', None)
    # load a trained model
    chainer.serializers.load_npz(args.path_model, ensembler)
    if ensembler._mu is not None:
        logger.info('standardize vectors: True')
        itr.standardize_vectors(mu=ensembler._mu, sigma=ensembler._sigma)
        idx2vec = standardize_vectors(idx2vec, ensembler._mu, ensembler._sigma)
    else:
        logger.info('standardize vectors: False')

    model = Classifier(ensembler, en2ja, idx2vec)

    # calculate probabilities for testing set
    buff = []
    for i, (rels, _, en_indices) in enumerate(itr, start=1):
        if i % 500 == 0:
            logger.info('Evaluating: {}'.format(i))
        buff.append((model(rels, en_indices), en_indices))
    scores = list(chain.from_iterable(t[0] for t in buff))

    if verbose:
        logger.info('Output results to ' + args.path_output)
    with open(args.path_output, 'w') as f:
        header = '\t'.join(['rel', 'start', 'end', 'start_en', 'end_en',
                            'score', 'label'])
        f.write(header + '\n')
        for row in sorted(scores, key=lambda t: t[2], reverse=True):
            idx_fact, idx_en, score = row
            fact = fact2id.id2word[idx_fact]
            fact_ja, fact_en = fact.split('@@@')
            rel, start_en, end_en = fact_en.split('|||')
            rel, start_ja, end_ja = fact_ja.split('|||')
            try:
                label = df.loc[fact, 'label']
            except KeyError:
                label = df.loc[fact, 'twa']
            f.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                rel, start_ja, end_ja, start_en, end_en, score, label))
コード例 #14
0
    return paded_sentences, topics


class TopicClassDataLoader(DataLoader):
    def __init__(self, vocabulary, split, batch_size, num_workers=0):
        self.dataset = TopicClassDataset(vocabulary.get_dataset(split))
        self.batch_size = batch_size
        self.num_workers = num_workers

        super(TopicClassDataLoader, self).__init__(
            dataset=self.dataset,
            batch_size=self.batch_size,
            shuffle=True if split == 'train' else False,
            collate_fn=lambda batch: pad_collate_fn(batch, padding_value=0),
            num_workers=self.num_workers)


if __name__ == '__main__':
    from dataset import Vocabulary
    vocab = Vocabulary()

    data_loader = TopicClassDataLoader(vocab,
                                       split='train',
                                       batch_size=16,
                                       num_workers=0)
    for i, (sent, topic, lengths, masks) in enumerate(data_loader):
        # print(sent.size)
        print(sent)
        print(masks)
        print(topic)
コード例 #15
0
from vae_model import LSTMLM
from vae_model import VAE

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_data = preprocess(
    open('./02-21.10way.clean', 'r', encoding='utf-8').read().splitlines())
print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
print('Number of sentences in training set: {}'.format(len(train_data)))
val_data = preprocess(
    open('./22.auto.clean', 'r', encoding='utf-8').read().splitlines())
print('Number of sentences in validation set: {}'.format(len(val_data)))
test_data = preprocess(
    open('./23.auto.clean', 'r', encoding='utf-8').read().splitlines())
print('Number of sentences in testing set: {}'.format(len(test_data)))
vocab = Vocabulary()
for sentence in train_data:
    for word in sentence:
        vocab.count_token(word)
vocab.build()  # build the dictionary
vocab_size = len(vocab.w2i)
print('Vocabulary size: {}'.format(vocab_size))
print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')

# print('train data head', train_data[0:5])
# print('train data tail', train_data[-5:])
# print('vocab 0,1', vocab.i2w[0], vocab.i2w[1]) #0=unk, 1=pad, 4=SOS, 6=EOS
# exit()


# prepare_example keep it numpy for making batched copies in validation
コード例 #16
0
def get_word_embedding_plot_of_sentence(
    sequence,
    word_embedding,
    vocabulary_file_path,
    metric='cosine',
    num_nearby_words=40,
    plot_width=600,
    plot_height=600,
):
    """
    For a given embedding and a sequence of a sentence,
    1) form a sentence vector by summing over the word vectors of the sentence,
    2) find nearby words of the sentence vector, and
    3) show the sentence vector & nearby word vectors using t-SNE.
    """
    vocabulary = Vocabulary(file_path=vocabulary_file_path)
    if metric == 'cosine':
        tsne_metric = sklearn.metrics.pairwise.cosine_distances
        get_distance = get_cosine_distance
    elif metric == 'euclidean':
        tsne_metric = sklearn.metrics.pairwise.euclidean_distances
        get_distance = get_euclidean_distance
    else:
        raise ValueError('Unknown metric: {}'.format(metric))

    vocabulary_size, embedding_size = word_embedding.shape

    sentence_vector = get_sentence_vector(sequence, word_embedding)
    distances, nearby_word_ids, nearby_word_vectors = get_nearby_word_vectors(
        sentence_vector[np.newaxis, :],
        word_embedding,
        num_nearby_words,
        metric=metric,
    )

    concatenated_vectors = np.concatenate((sentence_vector[np.newaxis, :],
                                           nearby_word_vectors.reshape(
                                               (-1, embedding_size))), )
    tsne = TSNE(metric=tsne_metric)
    transformed_vectors = tsne.fit_transform(concatenated_vectors)

    bokeh_figure = figure(
        tools='reset,box_zoom,pan,wheel_zoom,save,tap',
        plot_width=plot_width,
        plot_height=plot_height,
        title='Word Embedding',
    )

    tsne_sentence_vector = transformed_vectors[0]
    tsne_nearby_word_vectors = transformed_vectors[1:]

    sentence_words = [
        vocabulary.get_word_of_id(word_id) for word_id in sequence
    ]
    data_dict = {
        'x': [tsne_sentence_vector[0]],
        'y': [tsne_sentence_vector[1]],
        'color': ['navy'],
        'text': ['SENTENCE'],
        'distance': [0],
    }
    cds = ColumnDataSource(data_dict)
    bokeh_figure.circle_cross(
        x='x',
        y='y',
        size=10,
        color='color',
        source=cds,
    )
    labels = LabelSet(
        x='x',
        y='y',
        text='text',
        level='glyph',
        x_offset=5,
        y_offset=5,
        source=cds,
        render_mode='canvas',
    )
    bokeh_figure.add_layout(labels)

    nearby_words = [
        vocabulary.get_word_of_id(word_id) for word_id in nearby_word_ids[0]
    ]
    data_dict = {
        'x': tsne_nearby_word_vectors[:, 0],
        'y': tsne_nearby_word_vectors[:, 1],
        'color': ['olive'] * num_nearby_words,
        'text': nearby_words,
        'distance': distances[0],
    }
    cds = ColumnDataSource(data_dict)
    bokeh_figure.circle(
        x='x',
        y='y',
        size=5,
        color='color',
        source=cds,
    )
    labels = LabelSet(
        x='x',
        y='y',
        text='text',
        level='glyph',
        x_offset=5,
        y_offset=5,
        source=cds,
        render_mode='canvas',
    )
    bokeh_figure.add_layout(labels)

    return bokeh_figure
コード例 #17
0
ファイル: train.py プロジェクト: notani/CLKP-MTKBC
def main(args):
    global verbose
    verbose = args.verbose

    assert args.poly_degree >= 1, '--degree must be positive integer'
    poly_degree = args.poly_degree

    if verbose:
        report_params(args)

    gpu = args.gpu
    if gpu >= 0:
        cuda.check_cuda_available()
        if verbose:
            logger.info('Use GPU {}'.format(gpu))
        cuda.get_device_from_id(gpu).use()

    set_random_seed(0, use_gpu=(gpu >= 0))

    n_epochs = args.n_epochs
    batch_size = args.batch_size

    # Dataset
    dfs = {}
    dfs['train'] = read_dataset(path.join(args.dir_in, args.filename_train))
    dfs['devel'] = read_dataset(path.join(args.dir_in, args.filename_devel))

    # Load relation vocabulary
    rel2id = Vocabulary()
    rel2id.read_from_file(args.path_rels)

    # Load concept vocabulary
    if verbose:
        logger.info('Load vocabulary')
    fact2id = Vocabulary()
    fact2id.read_from_list(np.unique(get_values(list(dfs.values()), 'fact')))
    ja2id = Vocabulary()
    ja2id.read_from_list(np.unique(get_values(list(dfs.values()), 'fact_ja')))
    en2id = Vocabulary()
    en2id.read_from_list(np.unique(get_values(list(dfs.values()), 'fact_en')))

    if verbose:
        logger.info('Replace facts with indices')
    for col in dfs.keys():
        dfs[col].loc[:, 'fact'] = replace_by_dic(dfs[col]['fact'],
                                                 fact2id).astype(np.int32)
        dfs[col].loc[:, 'fact_ja'] = replace_by_dic(dfs[col]['fact_ja'],
                                                    ja2id).astype(np.int32)
        dfs[col].loc[:, 'fact_en'] = replace_by_dic(dfs[col]['fact_en'],
                                                    en2id).astype(np.int32)
        dfs[col].loc[:, 'rel'] = replace_by_dic(dfs[col]['rel'],
                                                rel2id).astype(np.int32)
    label2fact = {
        i: set(
            np.concatenate(
                [df[df['twa'] == i]['fact'].unique() for df in dfs.values()]))
        for i in [0, 1]
    }
    en2ja = {
        en: set(df[df['fact_en'] == en]['fact'].unique())
        for df in dfs.values() for en in sorted(df['fact_en'].unique())
    }
    idx2vec = get_idx2vec(list(dfs.values()), poly_degree=poly_degree)

    n_facts = len(fact2id)
    n_en = len(en2id)
    n_ja = len(ja2id)
    assert n_facts + 1 == len(
        idx2vec), '{}[n_facts] != {}[len(idx2vec)]'.format(
            n_facts + 1, len(idx2vec))

    if verbose:
        logger.info('Alignment: {}'.format(n_facts))
        logger.info('En: {}'.format(n_en))
        logger.info('Ja: {}'.format(n_ja))
        logger.info('Train: {}'.format(len(dfs['train'])))
        logger.info('Devel: {}'.format(len(dfs['devel'])))

    model_type = args.model.lower()
    dim_in = len(COL_BASIC_FEATURES)
    rel_size = len(rel2id)
    if model_type.startswith('linear'):
        ensembler = LinearEnsembler(dim_in,
                                    rel_size,
                                    use_gpu=(gpu >= 0),
                                    poly_degree=poly_degree,
                                    flag_unifw=args.flag_unifw,
                                    verbose=verbose)
    elif model_type.startswith('mlp'):
        options = args.model.split(':')
        params = {}
        if len(options) > 1:
            params['dim_hid'] = int(options[1])
        if len(options) > 2:
            params['activation'] = options[2]
        ensembler = MLPEnsembler(dim_in,
                                 rel_size,
                                 use_gpu=(gpu >= 0),
                                 poly_degree=poly_degree,
                                 flag_unifw=args.flag_unifw,
                                 verbose=verbose,
                                 **params)
    else:
        raise ValueError('Invalid --model: {}'.format(model_type))

    # Set up a dataset iterator
    train_iter = FactIterator(dfs['train'],
                              args.batch_size,
                              ja2id,
                              en2id,
                              train=True,
                              repeat=True,
                              poly_degree=poly_degree)
    # Only keep positive examples in development set
    df = dfs['devel'][dfs['devel']['twa'] == 1].drop_duplicates('fact_en')

    # Set batch size
    batch_size = find_greatest_divisor(len(df))
    if batch_size == 1 and len(df) <= 10**4:
        batch_size = len(df)
    if verbose:
        logger.info('Devel batch size = {}'.format(batch_size))
    devel_iter = FactIterator(df,
                              batch_size,
                              ja2id,
                              en2id,
                              train=False,
                              repeat=False,
                              poly_degree=poly_degree)

    # Standardize vectors
    if args.flag_standardize:
        mu, sigma = train_iter.standardize_vectors()
        devel_iter.standardize_vectors(mu=mu, sigma=sigma)
        idx2vec = standardize_vectors(idx2vec, mu, sigma)
    else:
        mu, sigma = None, None

    if gpu >= 0:
        idx2vec = cuda.to_gpu(idx2vec)

    # Set up a model
    model = Classifier(ensembler,
                       label2fact,
                       en2ja,
                       idx2vec,
                       margin=args.margin,
                       lam=args.lam)

    if gpu >= 0:
        model.to_gpu(device=gpu)

    # Set up an optimizer
    optimizer = optimizers.AdaGrad(lr=args.lr)
    optimizer.setup(model)

    # Set up a trainer
    updater = Updater(train_iter, optimizer, device=gpu)
    trainer = training.Trainer(updater, (n_epochs, 'epoch'), out=args.dir_out)

    # evaluate development set
    evaluator = Evaluator(devel_iter, model, device=gpu)
    trainer.extend(evaluator)

    # Write out a log
    trainer.extend(extensions.LogReport())
    # Display training status
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss',
            'validation/main/meanrank', 'validation/main/mrr', 'elapsed_time'
        ]))

    if args.save:
        trainer.extend(extensions.snapshot(), trigger=(args.n_epochs, 'epoch'))
        trainer.extend(extensions.snapshot_object(
            ensembler, 'model_iter_{.updater.iteration}'),
                       trigger=(1, 'epoch'))

    # Launch training process
    trainer.run()

    # Report the best score
    (epoch, score) = evaluator.get_best_score()
    if verbose:
        logger.info('Best score: {} (epoch={})'.format(score, epoch))

    # Clean the output directory
    if args.save:
        save_best_model(args.dir_out, ensembler, mu=mu, sigma=sigma)

    del dfs
    del fact2id
    del ja2id
    del en2id

    return score
コード例 #18
0
def get_word_embedding_plot_of_nearby_words(
    image_vector,
    sequence,
    word_embedding,
    vocabulary,
    metric='cosine',
    num_nearby_words=5,
    use_pca=False,
    plot_width=600,
    plot_height=600,
):
    """
    For a given embedding, and an embedding vector of an image,
    and a sequence of its caption sentence, 1) form a sentence vector
    by summing over the word vectors of the sentence,
    2) find nearby words of the image vector, sentence vector, and word vectors
    in the sentence, and 3) show all the word vectors using t-SNE.
    """
    vocabulary = Vocabulary(file_path=vocabulary_file_path)
    if metric == 'cosine':
        tsne_metric = sklearn.metrics.pairwise.cosine_distances
        get_distance = get_cosine_distance
    elif metric == 'euclidean':
        tsne_metric = sklearn.metrics.pairwise.euclidean_distances
        get_distance = get_euclidean_distance
    else:
        raise ValueError('Unknown metric: {}'.format(metric))

    vocabulary_size, embedding_size = word_embedding.shape

    sentence_vector = get_sentence_vector(sequence, word_embedding)
    vectors = np.concatenate(
        (image_vector[np.newaxis, :], sentence_vector[np.newaxis, :],
         word_embedding[sequence, :]))
    distances, nearby_word_ids, nearby_word_vectors = get_nearby_word_vectors(
        vectors,
        word_embedding,
        num_nearby_words,
        metric=metric,
    )

    # NOTE: concatenated_vectors = [
    #   image_vector,
    #   sentence_vector,
    #   word_in_sentence_vectors,
    #   flattened_nearby_word_vectors,
    # ]
    concatenated_vectors = np.concatenate(
        (vectors, nearby_word_vectors.reshape((-1, embedding_size))), )
    tsne = TSNE(metric=tsne_metric)
    transformed_vectors = tsne.fit_transform(concatenated_vectors)
    image_vector_distance = get_distance(
        image_vector,
        sentence_vector,
    )

    hover = HoverTool(tooltips=[
        ('distance_from', '@source'),
        ('distance', '@distance'),
    ])
    bokeh_figure = figure(
        tools='reset,box_zoom,pan,wheel_zoom,save,tap',
        plot_width=plot_width,
        plot_height=plot_height,
        title='Word Embedding',
    )
    bokeh_figure.add_tools(hover)

    num_source_vectors = 2 + len(sequence)
    tsne_image_vector = transformed_vectors[0]
    tsne_sentence_vector = transformed_vectors[1]
    tsne_sentence_word_vectors = transformed_vectors[2:num_source_vectors]
    tsne_nearby_word_vectors = (
        transformed_vectors[num_source_vectors:].reshape(
            (2 + len(sequence), num_nearby_words, -1)))

    sentence_words = [
        vocabulary.get_word_of_id(word_id) for word_id in sequence
    ]
    source_words = ['IMAGE', 'SENTENCE'] + sentence_words
    data_dict = {
        'x': transformed_vectors[:num_source_vectors, 0],
        'y': transformed_vectors[:num_source_vectors, 1],
        'color': ['navy'] * num_source_vectors,
        'source': source_words,
        'text': source_words,
        'distance': [0] * num_source_vectors,
    }
    cds = ColumnDataSource(data_dict)
    bokeh_figure.circle_cross(
        x='x',
        y='y',
        size=5,
        color='color',
        source=cds,
    )
    labels = LabelSet(
        x='x',
        y='y',
        text='text',
        level='glyph',
        x_offset=5,
        y_offset=5,
        source=cds,
        render_mode='canvas',
    )
    bokeh_figure.add_layout(labels)

    nearby_words = [[
        vocabulary.get_word_of_id(word_id) for word_id in word_ids
    ] for word_ids in nearby_word_ids]
    for i_src_vec in range(num_source_vectors):
        start = (i_src_vec + 1) * num_nearby_words
        stop = start + num_nearby_words
        target_vectors = transformed_vectors[start:stop, :]
        data_dict = {
            'x': target_vectors[:, 0],
            'y': target_vectors[:, 1],
            'color': ['olive'] * num_nearby_words,
            'source': [source_words[i_src_vec]] * num_nearby_words,
            'text': nearby_words[i_src_vec],
            'distance': distances[i_src_vec],
        }
        cds = ColumnDataSource(data_dict)
        bokeh_figure.circle(
            x='x',
            y='y',
            size=5,
            color='color',
            source=cds,
        )
        labels = LabelSet(
            x='x',
            y='y',
            text='text',
            level='glyph',
            x_offset=5,
            y_offset=5,
            source=cds,
            render_mode='canvas',
        )
        bokeh_figure.add_layout(labels)

    return bokeh_figure
コード例 #19
0
def get_word_embedding_plot(
    sequence,
    tsne_file_path,
    vocabulary_file_path=None,
    vocabulary=None,
    notebook=False,
    plot_width=600,
    plot_height=600,
):
    """
    Plot
    """
    if vocabulary_file_path is not None:
        vocabulary = Vocabulary(file_path=vocabulary_file_path)
    elif vocabulary is None:
        raise ValueError('Either vocabulary_file_path or vocabulary '
                         'should be provided.')
    if tsne_file_path is None:
        raise ValueError

    sequence = np.unique(sequence)
    hover = HoverTool(tooltips=[
        ('word', '@word'),
    ])
    bokeh_figure = figure(
        tools='reset,box_zoom,pan,wheel_zoom,save,tap',
        plot_width=plot_width,
        plot_height=plot_height,
        title=('Click on legend entries to hide the corresponding data.'),
    )
    bokeh_figure.add_tools(hover)

    word_embedding = load_tsne_of_word_embedding(tsne_file_path)

    sentence_words = [
        vocabulary.get_word_of_id(word_id) for word_id in sequence
    ]
    sentence_words_data_dict = {
        'x': word_embedding[sequence, 0],
        'y': word_embedding[sequence, 1],
        'color': ['navy'] * len(sequence),
        'word': sentence_words,
    }
    sentence_words_data_source = ColumnDataSource(sentence_words_data_dict)
    bokeh_figure.circle_cross(
        x='x',
        y='y',
        size=10,
        color='color',
        fill_alpha=0.5,
        muted_alpha=0.2,
        legend='sentence words',
        source=sentence_words_data_source,
    )
    labels = LabelSet(
        x='x',
        y='y',
        text='word',
        level='glyph',
        x_offset=5,
        y_offset=5,
        render_mode='canvas',
        source=sentence_words_data_source,
    )
    bokeh_figure.add_layout(labels)

    other_word_ids = np.delete(
        np.array(range(vocabulary.get_size())),
        sequence,
    )
    other_words = [
        vocabulary.get_word_of_id(word_id) for word_id in other_word_ids
    ]
    other_words_data_dict = {
        'x': word_embedding[other_word_ids, 0],
        'y': word_embedding[other_word_ids, 1],
        'word': other_words,
    }
    other_words_data_source = ColumnDataSource(other_words_data_dict)
    bokeh_figure.circle(
        x='x',
        y='y',
        size=10,
        color='gray',
        fill_alpha=0.1,
        line_alpha=0.1,
        muted_alpha=0.05,
        muted_color='gray',
        legend='other words',
        source=other_words_data_source,
    )

    bokeh_figure.legend.location = 'top_left'
    bokeh_figure.legend.click_policy = 'mute'

    if notebook:
        return bokeh_figure
    else:
        script, div = components(bokeh_figure)
        return {
            'script': script,
            'div': div,
        }
コード例 #20
0
class Trainer(object):
    def __init__(self, config):
        self.logger = logging.getLogger("Training")
        self.config = config
        self.start_epoch = 1
        self.monitor = self.config.monitor
        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = np.inf if self.mnt_mode == 'min' else -np.inf
            self.early_stop = self.config.early_stop
        self.prepare_device()

        self.logger.info("Creating tensorboard writer...")
        self.writer = TensorboardWriter(log_dir=self.config.summary_dir,
                                        logger=self.logger,
                                        enabled=True)

        self.logger.info("Creating data loaders...")
        self.build_data_loader()

        self.logger.info("Creating model architecture...")
        self.build_model()

        self.logger.info("Creating optimizers...")
        self.build_optimizer()

        self.logger.info("Creating losses...")
        self.build_loss()

        self.logger.info("Creating metric trackers...")
        self.build_metrics()

        self.logger.info("Creating checkpoints...")
        self.load_checkpoint(self.config.checkpoint, self.config.resume_epoch)

        self.logger.info("Check parallelism...")
        self.parallelism()

    def build_model(self):
        self.model = TopicClassCNN(
            vocab_size=self.config.vocab_size,  # add 1 for <pad>
            emb_size=self.config.emb_size,
            n_layers=self.config.n_layers,
            attn_heads=self.config.attn_heads,
            dropout=self.config.dropout,
            num_classes=self.config.num_classes)

        # load pretrained bert model
        # checkpoint = '/home/haochen/Projects/cnn-topic-classification/pretrained_bert/uncased_L-12_H-768_A-12/bert_model_sent.pth'
        # state_dict = torch.load(checkpoint)
        # self.model.load_state_dict(state_dict)

    def build_data_loader(self):
        self.vocabulary = Vocabulary()
        self.config.vocab_size = 30522
        assert self.config.vocab_size == 30522, "vocabulary size do not concer with bert config"
        assert self.config.num_classes == self.vocabulary.get_num_topics(
        ), "class number doesn't match topic number"

        self.train_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='train',
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

        self.valid_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='valid',
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

    def build_optimizer(self):
        self.optimizer = optim.AdamW(self.model.parameters(),
                                     self.config.lr,
                                     weight_decay=self.config.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=self.config.exp_lr_gamma)

    def build_loss(self):
        self.cls_loss = nn.CrossEntropyLoss()
        self.cls_loss.to(self.device)

    def build_metrics(self):
        loss_tags = ['loss', 'acc_1', 'acc_5']
        self.train_metrics = MetricTracker(*loss_tags, writer=self.writer)
        self.val_metrics = MetricTracker(*loss_tags, writer=self.writer)

    def train(self):
        not_improved_count = 0
        self.logger.info("Starting training...")

        for epoch in range(self.start_epoch, self.config.epochs + 1):
            result = self.train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged information to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode ==
                                'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            self.save_checkpoint('latest', save_best=False)
            if epoch % self.config.save_period == 0:
                self.save_checkpoint(epoch, save_best=best)

            self.config.resume_epoch = epoch

    def train_epoch(self, epoch):
        self.model.train()
        for batch_idx, (sentences, gt_topics) in enumerate(self.train_loader):
            sentences = sentences.to(self.device)
            gt_topics = gt_topics.to(self.device)

            pred_topics = self.model(sentences)
            loss = self.cls_loss(pred_topics, gt_topics)
            acc = accuracy(pred_topics, gt_topics)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # add loss summary when update generator to save memory
            self.writer.set_step(
                (epoch - 1) * len(self.train_loader) + batch_idx, mode='train')
            self.train_metrics.update('loss', loss.item())
            self.train_metrics.update('acc_1', acc[0].item())
            self.train_metrics.update('acc_5', acc[1].item())

            # log on console
            if batch_idx % self.config.summary_step == 0:
                self.logger.info(
                    'Train Epoch: {} {} Loss:{:.4f}, Acc_1: {:.2f}, Acc_5: {:.2f}]'
                    .format(epoch, self._progress(batch_idx), loss.item(),
                            acc[0].item(), acc[1].item()))

        self.lr_scheduler.step()
        log = self.train_metrics.result()
        val_log = self.valid_epoch(epoch)
        log.update(**{'val_' + k: v for k, v in val_log.items()})
        return log

    def valid_epoch(self, epoch):
        self.model.eval()
        val_loss = []
        val_acc = []
        val_acc5 = []
        with torch.no_grad():
            for batch_idx, (sentences,
                            gt_topics) in enumerate(self.valid_loader):
                sentences, gt_topics = sentences.to(self.device), gt_topics.to(
                    self.device)

                pred_topics = self.model(sentences)
                loss = self.cls_loss(pred_topics, gt_topics)
                acc = accuracy(pred_topics, gt_topics)

                val_loss.append(loss.item())
                val_acc.append(acc[0].item())
                val_acc5.append(acc[1].item())

        self.writer.set_step(epoch, mode='val')
        self.val_metrics.update('loss', np.mean(val_loss))
        self.val_metrics.update('acc_1', np.mean(val_acc))
        self.val_metrics.update('acc_5', np.mean(val_acc5))

        return self.val_metrics.result()

    def save_checkpoint(self, epoch, save_best):
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        filename = 'epoch_{}.pth'.format(epoch)
        torch.save(state, os.path.join(self.config.checkpoint_dir, filename))
        if save_best:
            best_path = os.path.join(self.config.checkpoint_dir,
                                     'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def load_checkpoint(self, checkpoint_dir=None, epoch=None):
        if checkpoint_dir is None:
            self.logger.info("Training from scratch...")
            self.model.to(self.device)
            self.start_epoch = 1
            return

        self.logger.info(
            "Loading checkpoints from {}...".format(checkpoint_dir))
        self.start_epoch = epoch + 1
        self.logger.info("Continuing training from epoch {}...".format(epoch))
        filename = 'epoch_{}.pth'.format(epoch)
        checkpoint = torch.load(os.path.join(checkpoint_dir, filename))
        model_to_load = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        self.model.load_state_dict(model_to_load)
        self.model.to(self.device)
        if self.config.mode == 'train':
            self.optimizer.load_state_dict(checkpoint['optimizer'])

    def prepare_device(self):
        self.cuda = torch.cuda.is_available()
        if self.cuda:
            self.device = torch.device("cuda:0")
            self.logger.info("Training will be conducted on GPU")
        else:
            self.device = torch.device("cpu")
            self.logger.info("Training will be conducted on CPU")

        n_gpu = torch.cuda.device_count()
        n_gpu_use = self.config.ngpu

        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine,"
                "training will be performed on CPU.")
            n_gpu_use = 0
            self.config.ngpu = n_gpu_use
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        self.device_ids = list(range(n_gpu_use))

    def parallelism(self):
        if len(self.device_ids) > 1:
            self.logger.info("Using {} GPUs...".format(len(self.device_ids)))
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.device_ids)
        else:
            if self.cuda:
                self.logger.info(
                    "Using only 1 GPU and do not parallelize the models...")
            else:
                self.logger.info("Using CPU...")

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = len(self.train_loader)
        return base.format(current, total, 100.0 * current / total)
コード例 #21
0
 def test_add_from_counter_raise_error(self):
     vocab = Vocabulary()
     values_counter = Counter({"a": 3, "b": 1, "c": 4, "d": 1})
     with self.assertRaises(ValueError):
         vocab.add_from_counter("unknown_field", values_counter)
コード例 #22
0
 def test_add_from_counter_all_values(self):
     vocab = Vocabulary()
     new_token_to_id = Counter({"a": 3, "b": 1, "c": 4, "d": 1})
     vocab.add_from_counter("token_to_id", new_token_to_id)
     self.assertDictEqual(vocab.token_to_id, {"c": 0, "a": 1, "b": 2, "d": 3})
コード例 #23
0
def main():
    parameters = {}
    embedding = vocab = None
    if args.glove is not None:
        vocab, embedding = Vocabulary.from_glove(args.glove)
        embed_size = args.glove
        word = True
    else:
        embed_size = args.embed_size
        word = args.word

    dataset = Dataset(train_file, test_file, args.dict, use_cuda, word, vocab)
    xtrain, ytrain = dataset.xtrain, dataset.ytrain
    xtest, ytest = dataset.xtest, dataset.ytest
    alphabet_size = len(dataset.vocab)
    hidden_size = args.hidden
    num_train, num_test = len(xtrain), len(xtest)
    iters = args.iters
    log_interval = args.log_interval
    epochs = iters / num_train

    print('Training size: %d ' % num_train)
    print('Test size: %d' % num_test)
    print('Alphabet size: %d' % alphabet_size)
    print('Num epochs: %s' % epochs)
    print('Device = %s' % device)

    num_classes = 2

    parameters['iterations'] = iters
    parameters['epochs'] = epochs
    parameters['log interval'] = log_interval
    parameters['algorithm'] = 'lstm'
    parameters['train set'] = train_file
    parameters['test set'] = test_file
    parameters['dictionary file'] = args.dict
    parameters['cuda'] = use_cuda
    parameters['bidir'] = bidir
    parameters['layers'] = n_layers
    parameters['embedding size'] = embed_size
    parameters['learning rate'] = lr
    parameters['hidden size'] = hidden_size
    parameters['optimizer'] = args.opt
    parameters['loss function'] = 'cross entropy'
    parameters['class weights'] = '[{}, {}]'.format(args.neg_weight,
                                                    args.pos_weight)

    model = BetterLSTM(input_size=alphabet_size,
                       embedding_size=embed_size,
                       hidden_size=hidden_size,
                       output_size=num_classes,
                       n_layers=n_layers,
                       bidir=bidir,
                       embedding=embedding).to(device)

    if args.opt == 'adagrad':
        opt = optim.Adagrad(model.parameters(), lr=lr)
    elif args.opt == 'adam':
        opt = optim.Adam(model.parameters(), lr=lr)
    else:
        opt = optim.SGD(model.parameters(), lr=lr)

    class_weights = torch.FloatTensor([args.neg_weight,
                                       args.pos_weight]).to(device)

    loss_function = F.cross_entropy

    hist = History()
    interval_loss = 0

    for i in trange(1, iters + 1):
        opt.zero_grad()
        # sample training set
        rand = random.randint(0, num_train - 1)
        x, y = xtrain[rand], ytrain[rand]
        x = torch.unsqueeze(
            x, dim=1)  # (seqlen, sigma) --> (seqlen, batch=1, sigma)
        h0, c0 = model.init_hidden(batch=1)
        y_pred = model(x, h0, c0)
        loss = loss_function(y_pred, y, class_weights)
        loss.backward()
        opt.step()

        interval_loss += loss.item()

        if i % log_interval == 0:
            avg_loss = interval_loss / log_interval
            hist.add_loss(i, avg_loss)
            interval_loss = 0

            train_eval = Evaluation(model, dataset.xtrain, dataset.ytrain)
            test_eval = Evaluation(model, dataset.xtest, dataset.ytest)

            hist.add_acc(i, train_eval.accuracy, test_eval.accuracy)
            hist.add_auc(i, train_eval.auc, test_eval.auc)
            hist.add_roc_info(test_eval.increasing_fprs,
                              test_eval.increasing_tprs)
            summary = ("Iter {}\ntrain acc = {}\ntest acc = {}\n"
                       "TPR/sensitvity/recall = {}\nTNR/specificity = {}\n"
                       "train loss = {}\nAUROC = {}".format(
                           i, train_eval.accuracy, test_eval.accuracy,
                           test_eval.tpr, test_eval.tnr, avg_loss,
                           test_eval.auc))
            print(summary)
            print("Confusion:")
            test_eval.show_confusion()

    final_train_eval = Evaluation(model, dataset.xtrain, dataset.ytrain)
    final_test_eval = Evaluation(model, dataset.xtest, dataset.ytest)

    summary = ("Final Eval:\ntrain acc = {}\ntest acc = {}\n"
               "TPR/sensitvity/recall = {}\nTNR/specificity = {}"
               "\nAUROC = {}".format(final_train_eval.accuracy,
                                     final_test_eval.accuracy,
                                     final_test_eval.tpr, final_test_eval.tnr,
                                     final_test_eval.auc))
    print(summary)

    final_test_eval.show_confusion()

    parameters['accuracy test'] = final_test_eval.accuracy
    parameters['accuracy train'] = final_train_eval.accuracy
    parameters['auc test'] = final_test_eval.auc
    parameters['auc train'] = final_train_eval.auc
    parameters['TPR/sensitvity/recall'] = final_test_eval.tpr
    parameters['TNR/specificity'] = final_test_eval.tnr

    # if output_directory specified, write data for future viewing
    # otherwise, it'll be discarded
    if args.output_directory is not None:
        path = args.output_directory
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path)
        hist.save_data(path)
        summary_file = os.path.join(path, 'about.txt')

        with open(summary_file, 'w+') as f:
            now = datetime.datetime.now()
            f.write(now.strftime('%Y-%m-%d %H:%M') + '\n')
            f.write('command_used: python ' + ' '.join(sys.argv) + '\n')
            f.write('(may have included CUDA_VISIBLE_DEVICES=x first)\n\n')
            for key, value in sorted(parameters.items()):
                f.write(key + ': ' + str(value) + '\n')

        hist.plot_acc(show=False, path=path)
        hist.plot_loss(show=False, path=path)
        hist.plot_auc(show=False, path=path)
        hist.plot_roc(show=False, path=path)
        final_test_eval.plot_confusion(show=False, path=path)

        print("Saved results to " + path)

    if args.show_graphs:
        hist.plot_acc(show=True)
        hist.plot_loss(show=True)
        hist.plot_auc(show=True)
        hist.plot_roc(show=True)
        final_test_eval.plot_confusion(show=True)