예제 #1
0
    def __iter__(self):
        # sent_stream is an iterator
        sent_stream = self.get_sent_stream()

        for batch in self.stream_iterator(sent_stream):
            bptt = min(max([len(wp) for wp in batch[0]]), self.maxlen)
            yield encap_batch(batch, bptt, self.device)
예제 #2
0
    def forward(self, online_batch, stats, skip_optim=False):
        mems = tuple()
        # Evaluate online batch.
        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        # * Optimize for k step on current dataset
        for k in range(self.max_k_steps):
            # * Sample a batch of data from replay buffer
            replay_encodes = self.replay_buffer.sample_batch(
                self.replay_bsz,
                self.device,
            )

            # reset optimizer gradients
            self.optimizer.zero_grad()

            data, target, users, token_len, word_len, weights = replay_encodes[:
                                                                               6]
            raw_loss = self.para_model(data, target, users, *mems)[0]
            per_sample_loss = raw_loss.float().sum(dim=0, keepdim=True) \
                            / token_len.type_as(raw_loss)

            total_loss = (per_sample_loss * weights).mean()
            total_loss.backward()
            clip_grad_norm_(self.model.parameters(), self.clip_value)

            self.optimizer.step()

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += raw_loss.float().sum().item()
            stats['total_word_count'] += word_len.float().sum().item()
            stats['total_token_count'] += token_len.float().sum().item()
예제 #3
0
    def forward(self, online_batch, stats, skip_optim=False):
        mems = tuple()
        # Evaluate online batch.
        if not self.eval_online_fit:
            stats['online_loss']        += 0
            stats['online_word_count']  += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch, max([ len(_) for _ in online_batch[0] ]), self.device),
            )

            stats['online_loss']        += total_loss.float().sum().item()
            stats['online_word_count']  += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # input batch data
        self.online_buffer.add_batch(*online_batch)
        # * Optimize for k step on current dataset
        for k in range(self.max_k_steps):
            # * Sample a batch of data from replay buffer
            online_encodes = self.online_buffer.sample_batch(
                                self.online_bsz, self.device,
                            )
            data, target, users, token_len, word_len = online_encodes[:5]

            # reset optimizer gradients
            self.optimizer.zero_grad()

            # decapsulate batch
            raw_loss = self.para_model(data, target, users, *mems)[0]
            per_sample_loss = raw_loss.float().sum(dim=0, keepdim=True) \
                            / token_len.type_as(raw_loss)

            total_loss = per_sample_loss.mean()
            total_loss.backward()
            clip_grad_norm_(self.model.parameters(), self.clip_value)

            self.optimizer.step()

            # * Update global stats
            stats['num_updates']        += 1
            stats['total_loss']         += raw_loss.float().sum().item()
            stats['total_word_count']   += word_len.float().sum().item()
            stats['total_token_count']  += token_len.float().sum().item()
예제 #4
0
    def get_batch(self, batch_id):
        num_batch = len(self)
        assert batch_id < num_batch, '[Fatal error]'
        start_id = batch_id * self.bsz
        end_id = (batch_id + 1) * self.bsz

        wordpieces = [None] * self.bsz
        wordends = [None] * self.bsz
        users = [None] * self.bsz
        for i, data_id in enumerate(range(start_id, end_id)):
            sent, users[i] = self.get_data(data_id)
            wordpieces[i], wordends[i] = self.vocab.encode_as_ids(
                sent, sample=self.subword_augment)

        bptt = min(max([len(wp) for wp in wordpieces]), self.maxlen)

        return encap_batch((wordpieces, wordends, users), bptt, self.device)
예제 #5
0
파일: base.py 프로젝트: marcelomata/congrad
    def sample_batch(self, batch_size, device, bptt=None):
        if len(self) == 0: return None

        batch = [[], [], []]
        data_weights = torch.FloatTensor(1, batch_size)
        data_ids = torch.LongTensor(1, batch_size)
        for i in range(batch_size):
            wp, wd, user, data_weight, data_id = self.sample_data()
            batch[0].append(wp)
            batch[1].append(wd)
            batch[2].append(user)
            data_weights[:, i] = data_weight
            data_ids[:, i] = data_id

        bptt = max([len(_) for _ in batch[0]]) if bptt is None else bptt
        bptt = min(bptt, self.max_seqlen)
        encoded_batch = encap_batch(batch, bptt, device)
        return encoded_batch + (data_weights.to(device), data_ids)
예제 #6
0
    def forward(self, online_batch, stats, skip_optim=False):
        mems = tuple()
        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        # Sample a cached (validation) batch for model selection
        online_encodes = self.online_buffer.sample_batch(
            self.online_bsz,
            self.device,
        )

        # Snapshot Current Model
        if self.allow_zero_step:
            val_loss, _, _ = self.forward_eval(online_encodes)
            self.take_snapshot(0, val_loss.sum().item())

        # * Optimize for k step on current dataset
        for k in range(1, self.max_k_steps + 1):
            # * Sample a batch of data from replay buffer
            replay_encodes = self.replay_buffer.sample_batch(
                self.replay_bsz,
                self.device,
            )

            # decapsulate batch
            data, target, users, token_len, word_len = replay_encodes[:5]
            raw_loss = self.para_model(data, target, users, *mems)[0]
            replay_psl = raw_loss.sum(dim=0, keepdim=True) \
                       / token_len.type_as(raw_loss)

            replay_loss = replay_psl.mean()

            self.optimizer.zero_grad()
            replay_loss.backward()
            clip_grad_norm_(self.model.parameters(), self.clip_value)
            self.optimizer.step()

            # Take snapshot
            val_loss, _, _ = self.forward_eval(online_encodes)
            self.take_snapshot(k, val_loss.sum().item())

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += raw_loss.sum().float().item()
            stats['total_word_count'] += word_len.sum().float().item()
            stats['total_token_count'] += token_len.sum().float().item()

        # choose best snapshot based on current online measure
        best_k = self.resume_snapshot()
예제 #7
0
    def forward(self, online_batch, stats, skip_optim=False):
        mems = tuple()
        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        mems = tuple()
        model_params = list(
            itertools.chain(*[
                self.model.word_emb.parameters(),
                self.model.layers.parameters(),
                self.model.mtl_layers.parameters(),
            ]))

        # Sample a cached (validation) batch for model selection
        val_encodes = self.online_buffer.sample_batch(
            self.online_bsz,
            self.device,
        )
        # Snapshot Current Model
        if self.allow_zero_step:
            val_loss, _, _ = self.forward_eval(val_encodes)
            self.take_snapshot(0, val_loss.sum().item())

        retired_bsz = self.replay_bsz // 2
        # * Optimize for k step on current dataset
        for k in range(1, self.max_k_steps + 1):
            # Optimization on Replay Data
            replay_encodes = self.replay_buffer.sample_batch(
                retired_bsz,
                self.device,
            )

            data, target, users, token_len, word_len = replay_encodes[:5]
            total_loss = self.para_model(data, target, users, *mems)[0]

            replay_psl = total_loss.sum(
                dim=0, keepdim=True) / token_len.type_as(total_loss)
            replay_loss = replay_psl.mean()

            # (*) Extract Replay Gradients
            self.optimizer.zero_grad()
            replay_loss.backward()
            extract_grads(model_params, self.replay_grad, self.grad_dims)

            selected_data = np.random.permutation(len(
                retired_batch[0]))[:retired_bsz].tolist()
            _retired_batch = [[
                _rbdata[selected_id] for selected_id in selected_data
            ] for _rbdata in retired_batch]
            online_encodes = encap_batch(
                _retired_batch, max([len(_) for _ in _retired_batch[0]]),
                self.device)

            # decapsulate batch
            data, target, users, token_len, word_len = online_encodes[:5]
            total_loss = self.para_model(data, target, users, *mems)[0]

            online_psl = total_loss.sum(
                dim=0, keepdim=True) / token_len.type_as(total_loss)
            online_loss = online_psl.mean()

            # (*) Extract Online Gradients
            self.optimizer.zero_grad()
            online_loss.backward()
            extract_grads(model_params, self.online_grad, self.grad_dims)

            online_replay_dp = torch.mm(self.online_grad.view(1, -1),
                                        self.replay_grad.view(-1, 1))
            if online_replay_dp.item() < 0:
                # constraint violation, next batch
                # reproject gradient when the constraint is violated
                replay_replay_dp = torch.mm(self.replay_grad.view(1, -1),
                                            self.replay_grad.view(-1, 1))
                self.online_grad.copy_( self.online_grad -  \
                            (online_replay_dp.item() / replay_replay_dp.item()) * self.replay_grad)
                overwrite_grad(model_params, self.online_grad, self.grad_dims)

            # clip gradient norm
            clip_grad_norm_(self.model.parameters(), self.clip_value)
            self.optimizer.step()

            # Take snapshot
            val_loss, _, _ = self.forward_eval(val_encodes)
            self.take_snapshot(k, val_loss.sum().item())

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += total_loss.sum().float().item()
            stats['total_word_count'] += word_len.sum().float().item()
            stats['total_token_count'] += token_len.sum().float().item()

        # choose best snapshot based on current online measure
        best_k = self.resume_snapshot()
예제 #8
0
파일: agem.py 프로젝트: marcelomata/congrad
    def forward(self, online_batch, stats, skip_optim=False):
        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        # Skip optimization if in resume mode
        if skip_optim: return

        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        mems = tuple()
        model_params = list(
            itertools.chain(*[
                self.model.word_emb.parameters(),
                self.model.layers.parameters(),
                self.model.mtl_layers.parameters(),
            ]))

        # * Optimize for k step on current dataset
        for k in range(self.max_k_steps):
            # Optimization on Replay Data
            replay_encodes = self.replay_buffer.sample_batch(
                self.replay_bsz,
                self.device,
            )

            data, target, users, token_len, word_len = replay_encodes[:5]
            total_loss = self.para_model(data, target, users, *mems)[0]

            replay_psl = total_loss.sum(
                dim=0, keepdim=True) / token_len.type_as(total_loss)
            replay_loss = replay_psl.mean()

            # (*) Extract Replay Gradients
            self.optimizer.zero_grad()
            replay_loss.backward()
            extract_grads(model_params, self.replay_grad, self.grad_dims)

            # * Sample a batch of data from replay buffer
            online_encodes = self.online_buffer.sample_batch(
                self.online_bsz,
                self.device,
            )

            # decapsulate batch
            data, target, users, token_len, word_len = online_encodes[:5]
            total_loss = self.para_model(data, target, users, *mems)[0]

            online_psl = total_loss.sum(
                dim=0, keepdim=True) / token_len.type_as(total_loss)
            online_loss = online_psl.mean()

            # (*) Extract Online Gradients
            self.optimizer.zero_grad()
            online_loss.backward()
            extract_grads(model_params, self.online_grad, self.grad_dims)

            online_replay_dp = torch.mm(self.online_grad.view(1, -1),
                                        self.replay_grad.view(-1, 1))
            if online_replay_dp.item() < 0:
                # constraint violation, next batch
                # reproject gradient when the constraint is violated
                replay_replay_dp = torch.mm(self.replay_grad.view(1, -1),
                                            self.replay_grad.view(-1, 1))
                self.online_grad.copy_( self.online_grad -  \
                            (online_replay_dp.item() / replay_replay_dp.item()) * self.replay_grad)
                overwrite_grad(model_params, self.online_grad, self.grad_dims)

            # clip gradient norm
            clip_grad_norm_(self.model.parameters(), self.clip_value)
            self.optimizer.step()

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += total_loss.sum().float().item()
            stats['total_word_count'] += word_len.sum().float().item()
            stats['total_token_count'] += token_len.sum().float().item()
예제 #9
0
    def forward(self, online_batch, stats, skip_optim=False):
        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        if skip_optim: return

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        mems = tuple()
        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # Sample a cached (validation) batch for model selection
        online_encodes = self.online_buffer.sample_batch(
            self.online_bsz,
            self.device,
        )

        # Snapshot Current Model
        if self.allow_zero_step:
            val_loss, _, _ = self.forward_eval(online_encodes)
            self.take_snapshot(0, val_loss.sum().item())

        retired_bsz = self.replay_bsz // 2
        # * Optimize for k step on current dataset
        for k in range(1, self.max_k_steps + 1):
            if len(retired_batch[0]):
                selected_data = np.random.permutation(len(
                    retired_batch[0]))[:retired_bsz].tolist()
                _retired_batch = [ [ _rbdata[selected_id] for selected_id in selected_data ] \
                                                          for _rbdata in retired_batch ]
                retired_encodes = encap_batch(
                    _retired_batch, max([len(_) for _ in _retired_batch[0]]),
                    self.device)

                # * Sample a batch of data from replay buffer
                replay_encodes = self.replay_buffer.sample_batch(
                    self.replay_bsz - retired_encodes[0].size(1),
                    self.device,
                    retired_encodes[0].size(0),
                )

                batch_data = []
                for online, replay in zip(
                        retired_encodes,
                        replay_encodes[:len(retired_encodes)]):
                    batch_data.append(torch.cat([online, replay], dim=1))
            else:
                # print('Triggered')
                # * Sample a batch of data from replay buffer
                batch_data = self.replay_buffer.sample_batch(
                    self.replay_bsz,
                    self.device,
                )

            # decapsulate batch
            data, target, users, token_len, word_len = batch_data[:5]
            raw_loss = self.para_model(data, target, users, *mems)[0]
            replay_psl = raw_loss.sum(dim=0, keepdim=True) \
                       / token_len.type_as(raw_loss)

            replay_loss = replay_psl.mean()

            self.optimizer.zero_grad()
            replay_loss.backward()
            clip_grad_norm_(self.model.parameters(), self.clip_value)
            self.optimizer.step()

            # Take snapshot
            val_loss, _, _ = self.forward_eval(online_encodes)
            self.take_snapshot(k, val_loss.sum().item())

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += raw_loss.sum().float().item()
            stats['total_word_count'] += word_len.sum().float().item()
            stats['total_token_count'] += token_len.sum().float().item()

        # choose best snapshot based on current online measure
        best_k = self.resume_snapshot()
예제 #10
0
파일: mr.py 프로젝트: marcelomata/congrad
    def forward(self, online_batch, stats, skip_optim=False):
        mems = tuple()
        if not self.eval_online_fit:
            stats['online_loss'] += 0
            stats['online_word_count'] += 1e-20
            stats['online_token_count'] += 1e-20
        else:
            # Evaluate online batch.
            total_loss, token_len, word_len = self.forward_eval(
                encap_batch(online_batch,
                            max([len(_) for _ in online_batch[0]]),
                            self.device), )

            stats['online_loss'] += total_loss.float().sum().item()
            stats['online_word_count'] += word_len.float().sum().item()
            stats['online_token_count'] += token_len.float().sum().item()

        # Update buffer: add online data stream into replay buffer
        retired_batch = self.online_buffer.add_batch(*online_batch)
        if retired_batch is not None:
            self.replay_buffer.add_batch(*retired_batch)

        # Start training only after replay buffer has data
        if self.replay_buffer.sample_batch(self.replay_bsz,
                                           self.device) is None:
            return

        # * Optimize for k step on current dataset
        for k in range(self.max_k_steps):
            # * Sample a batch of data from replay buffer
            online_encodes = self.online_buffer.sample_batch(
                self.online_bsz,
                self.device,
            )
            replay_encodes = self.replay_buffer.sample_batch(
                self.replay_bsz,
                self.device,
                online_encodes[0].size(0),
            )

            # * Use the union of online and replay data for a mini-batch
            batch_data = []
            for online, replay in zip(online_encodes,
                                      replay_encodes[:len(online_encodes)]):
                batch_data.append(torch.cat([online, replay], dim=1))

            # reset optimizer gradients
            self.optimizer.zero_grad()

            # decapsulate batch
            data, target, users, token_len, word_len, weights, _ = batch_data
            raw_loss = self.para_model(data, target, users, *mems)[0]
            per_sample_loss = raw_loss.float().sum(dim=0, keepdim=True) \
                            / token_len.type_as(raw_loss)

            total_loss = (per_sample_loss * weights).mean()
            total_loss.backward()
            clip_grad_norm_(self.model.parameters(), self.clip_value)

            self.optimizer.step()

            # * Postprocess: prioritized experience replay sub-routines
            if isinstance(self.replay_buffer, PrioritizedBuffer):
                indices = replay_encodes[-1].cpu().numpy()
                priorities = per_sample_loss[:, self.online_bsz:].detach().cpu(
                ).numpy() + 1e-6
                self.replay_buffer.update_priorities(indices, priorities)

            # * Update global stats
            stats['num_updates'] += 1
            stats['total_loss'] += raw_loss.float().sum().item()
            stats['total_word_count'] += word_len.float().sum().item()
            stats['total_token_count'] += token_len.float().sum().item()