コード例 #1
0
ファイル: model.py プロジェクト: klb3713/cw_word_embedding
 def __init__(self):
     vocabulary.load_vocabulary()
     self.parameters = Parameters()
     self.train_loss = 0
     self.train_err = 0
     self.train_lossnonzero = 0
     self.train_cnt = 0
コード例 #2
0
    def __init__(self):
        vocabulary.load_vocabulary()
        self.parameters = Parameters()
        self.train_loss = 0
        self.train_err = 0
        self.train_lossnonzero = 0
        self.train_cnt = 0
        self.COMPILE_MODE = theano.compile.Mode('c|py', 'fast_run')

        self.train_function = self._get_train_function()
コード例 #3
0
ファイル: model_gpu.py プロジェクト: chagge/cw_word_embedding
    def __init__(self):
        vocabulary.load_vocabulary()
        self.parameters = Parameters()
        self.train_loss = 0
        self.train_err = 0
        self.train_lossnonzero = 0
        self.train_cnt = 0
        self.COMPILE_MODE = theano.compile.Mode('c|py', 'fast_run')

        self.train_function = self._get_train_function()
コード例 #4
0
def build_samples():
    if not vocabulary.length():
        vocabulary.load_vocabulary()

    sample_file = open(config.SAMPLE_FILE, 'w')
    for line in open(config.TRAIN_FILE, 'r'):
        line = line.strip('\n')
        words = [getNormalWord(word) for word in line.split() if word]
        word_ids = [vocabulary.id(word) for word in words]
        sent_length = len(word_ids)
        half_window = config.WINDOW_SIZE / 2
        window = []
        padding_id = vocabulary.id(config.PADDING_WORD)
        unknown_id = vocabulary.id(config.UNKNOWN_WORD)
        symbol_id = vocabulary.id(config.SYMBOL_WORD)
        for index, word_id in enumerate(word_ids):
            if word_id == unknown_id:
                continue
            if index - half_window >= 0 and index + half_window < sent_length:
                window = word_ids[index - half_window:index + half_window + 1]
                if window.count(unknown_id) + window.count(
                        symbol_id) + window.count(padding_id) <= half_window:
                    sample_file.write(' '.join([str(id)
                                                for id in window]) + '\n')
                window = []
                continue

            if index - half_window < 0:
                for i in range(half_window - index):
                    window.append(padding_id)
                window.extend(word_ids[:index + 1])
            else:
                window.extend(word_ids[index - half_window:index + 1])

            if index + half_window >= sent_length:
                window.extend(word_ids[index + 1:])
                for i in range(index + half_window - sent_length + 1):
                    window.append(padding_id)
            else:
                window.extend(word_ids[index + 1:index + half_window + 1])

            if window.count(unknown_id) + window.count(
                    symbol_id) + window.count(padding_id) <= half_window:
                sample_file.write(' '.join([str(id) for id in window]) + '\n')
            window = []

    sample_file.close()
コード例 #5
0
def build_samples():
    if not vocabulary.length():
        vocabulary.load_vocabulary()

    sample_file = open(config.SAMPLE_FILE, 'w')
    for line in open(config.TRAIN_FILE, 'r'):
        line = line.strip('\n')
        words = [getNormalWord(word) for word in line.split() if word]
        word_ids = [vocabulary.id(word) for word in words]
        sent_length = len(word_ids)
        half_window = config.WINDOW_SIZE / 2
        window = []
        padding_id = vocabulary.id(config.PADDING_WORD)
        unknown_id = vocabulary.id(config.UNKNOWN_WORD)
        symbol_id = vocabulary.id(config.SYMBOL_WORD)
        for index, word_id in enumerate(word_ids):
            if word_id == unknown_id:
                continue
            if index - half_window >= 0 and index + half_window < sent_length:
                window = word_ids[index - half_window : index + half_window + 1]
                if window.count(unknown_id) + window.count(symbol_id) + window.count(padding_id) <= half_window:
                    sample_file.write(' '.join([str(id) for id in window]) + '\n')
                window = []
                continue

            if index - half_window < 0:
                for i in range(half_window - index):
                    window.append(padding_id)
                window.extend(word_ids[:index + 1])
            else:
                window.extend(word_ids[index - half_window : index + 1])

            if index + half_window >= sent_length:
                window.extend(word_ids[index + 1:])
                for i in range(index + half_window - sent_length + 1):
                    window.append(padding_id)
            else:
                window.extend(word_ids[index + 1 : index + half_window + 1])

            if window.count(unknown_id) + window.count(symbol_id) + window.count(padding_id) <= half_window:
                sample_file.write(' '.join([str(id) for id in window]) + '\n')
            window = []

    sample_file.close()
コード例 #6
0
def initialize_model(cfg):
    """
    Initialize a new model using the given config.

    :param cfg: the `Config` instance used to create the model
    :returns: (model, optimizer, vocab, stats, cfg)
    """

    print("Initializing a new model...")
    vocab = load_vocabulary(cfg.vocab_file, cfg.vocab_size)
    model = Seq2Seq(vocab, cfg)

    if torch.cuda.device_count() > 1:
        print("%d GPUs available - wrapping in DataParallel" %
              torch.cuda.device_count())
        model = nn.DataParallel(model)

    model.to(DEVICE)

    optimizer = initialize_optimizer(model.parameters(), cfg)
    stats = defaultdict(int)
    stats["model_identifier"] = get_model_identifier()
    return model, optimizer, vocab, stats, cfg