コード例 #1
0
ファイル: common.py プロジェクト: zerogerc/rnn-autocomplete
def run_model(model: nn.Module, iter_data, hidden, batch_size):
    (n_input, n_target), forget_vector = iter_data
    assert forget_vector.size()[0] == batch_size

    n_input = n_input.to(get_best_device())
    n_target = n_target.to(get_best_device())

    if hidden is None:
        hidden = model.init_hidden(batch_size=batch_size)

    prediction, hidden = model(n_input, hidden, forget_vector=forget_vector)

    return prediction, n_target, hidden
コード例 #2
0
    def _reinit_dropout_mask(self, batch_size):
        if self.dropout_mask is None:
            tensor = torch.zeros(batch_size,
                                 self.hidden_size,
                                 dtype=torch.float32,
                                 device=get_best_device())
        else:
            tensor = self.dropout_mask

        # 1 - self.dropout, if dropout is 0.25 then probability to draw one would be 0.75
        self.dropout_mask = torch.bernoulli(tensor.fill_(1 - self.dropout))
コード例 #3
0
ファイル: common.py プロジェクト: zerogerc/rnn-autocomplete
    def __init__(self, args):
        self.model = self.create_model(args).to(get_best_device())
        self.load_model(args)

        self.optimizers = self.create_optimizers(args)
        self.schedulers = self.create_schedulers(args)
        self.criterion = self.create_criterion(args)

        self.data_generator = self.create_data_generator(args)

        self.train_routine = self.create_train_routine(args)
        self.validation_routine = self.create_validation_routine(args)
        self.train_metrics = self.create_train_metrics(args)
        self.eval_metrics = self.create_eval_metrics(args)
        self.plotter = 'tensorboard'
コード例 #4
0
ファイル: data.py プロジェクト: zerogerc/rnn-autocomplete
    def _read_file(self, file_path, limit=100000, label='Data'):
        print('Reading {} ... '.format(label))
        data = []
        it = 0
        for l in tqdm(open(file=file_path, mode='r', encoding=ENCODING),
                      total=limit):
            it += 1

            tokens = json.loads(l)
            one_hot = torch.LongTensor(tokens).to(get_best_device())

            data.append(TokensDataChunk(one_hot_tensor=one_hot))

            if (limit is not None) and (it == limit):
                break

        return list(filter(lambda d: d.size() >= self.seq_len, data))
コード例 #5
0
    def __init__(self, pool: DataChunksPool, seq_len, batch_size):
        self.pool = pool
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.buckets = []

        self.forget_vector = torch.FloatTensor(batch_size,
                                               1).to(get_best_device())

        def forget(x):
            self.forget_vector[x] = 0

        def get_forget(x):
            return lambda: forget(x)

        for i in range(self.batch_size):
            self.buckets.append(
                DataBucket(pool=self.pool,
                           seq_len=self.seq_len,
                           on_new_chunk=get_forget(i)))
コード例 #6
0
ファイル: data.py プロジェクト: zerogerc/rnn-autocomplete
    def __init__(self, file_train, file_eval, seq_len, number_of_seq=20, limit=None):
        super().__init__()
        self.device = get_best_device()
        self.seq_len = seq_len
        self.number_of_seq = number_of_seq

        if file_train is not None:
            self.train_data, self.validation_data = split_train_validation(
                self._read_programs(file_train, total=100000, limit=limit),
                split_coefficient=0.8
            )
            self.validation_data = list(filter(
                lambda d: d.non_terminals_chunk.size() <= 30000,
                self.validation_data)
            )

            print('Train size: {}, Validation size: {}'.format(len(self.train_data), len(self.validation_data)))

        if file_eval is not None:
            self.eval_data, self.eval_tails = self._read_programs(
                file_eval, total=50000, limit=limit, count_tails=True, lim_30k=True
            )
コード例 #7
0
 def init_hidden(self, batch_size):
     b_matrix = torch.FloatTensor(batch_size, 2 * self.seq_len,
                                  self.hidden_size).to(get_best_device())
     self.context_buffer = LastKBuffer(window_len=self.seq_len,
                                       buffer=b_matrix)
コード例 #8
0
ファイル: data.py プロジェクト: zerogerc/rnn-autocomplete
    def __init__(self, one_hot_tensor):
        super().__init__()

        self.one_hot_tensor = one_hot_tensor.to(get_best_device())
        self.seq_len = None