コード例 #1
0
ファイル: skipgram.py プロジェクト: vanangamudi/tamil-lm2
    def do_train2(self):
        if not hasattr(self, 'batch_cache'):
            self.build_cache_for_train2()
        
        for epoch in range(self.epochs):
            self.log.critical('memory consumed : {}'.format(memory_consumed()))            
            self.epoch = epoch
            if epoch and epoch % max(1, (self.checkpoint - 1)) == 0:
                #self.do_predict()
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return
                           
            self.train()
            losses = []
            for input_ in tqdm(self.batch_cache, desc='Trainer.{}'.format(self.name())):
                self.optimizer.zero_grad()
                idxs, word, targets = input_
                output = self.__(self.forward(word), 'output')
                loss   = self.loss_function(output, targets)
                    
                losses.append(loss)
                loss.backward()
                self.optimizer.step()

            epoch_loss = torch.stack(losses).mean()
            self.train_loss.append(epoch_loss.data.item())

            self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss))
            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #2
0
    def do_train(self):
        for epoch in range(self.epochs):
            self.log.critical('memory consumed : {}'.format(memory_consumed()))
            self.epoch = epoch
            if epoch % max(1, (self.checkpoint - 1)) == 0:
                #self.do_predict()
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return

            self.train()
            losses = []
            for j in tqdm(range(self.train_feed.num_batch),
                          desc='Trainer.{}'.format(self.name())):
                self.optimizer.zero_grad()
                input_ = self.train_feed.next_batch()
                idxs, inputs, targets = input_

                output = self.forward(input_)
                loss = self.loss_function(output, input_)
                #print(loss.data.cpu().numpy())
                losses.append(loss)
                loss.backward()
                self.optimizer.step()

            epoch_loss = torch.stack(losses).mean()
            self.train_loss.append(epoch_loss.data.item())

            self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss))
            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #3
0
ファイル: lm.py プロジェクト: vanangamudi/tamil-lm
    def do_train(self):
        for epoch in range(self.epochs):
            self.log.critical('memory consumed : {}'.format(memory_consumed()))
            self.epoch = epoch
            if epoch % max(1, (self.checkpoint - 1)) == 0:
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return

            self.train()
            teacher_force_count = [0, 0]
            for j in tqdm(range(self.train_feed.num_batch),
                          desc='Trainer.{}'.format(self.name())):
                self.optimizer.zero_grad()
                input_ = self.train_feed.next_batch()
                idxs, inputs, targets = input_
                sequence = inputs[0].transpose(0, 1)
                _, batch_size = sequence.size()

                state = self.initial_hidden(batch_size)
                loss = 0
                output = sequence[0]
                for ti in range(1, sequence.size(0) - 1):
                    output = self.forward(output, state)
                    loss += self.loss_function(ti, output, input_)
                    output, state = output

                    if random.random() > 0.5:
                        output = output.max(1)[1]
                        teacher_force_count[0] += 1
                    else:
                        output = sequence[ti + 1]
                        teacher_force_count[1] += 1

                loss.backward()
                self.train_loss.cache(loss.data.item())
                self.optimizer.step()

            self.log.info(
                'teacher_force_count: {}'.format(teacher_force_count))

            self.log.info('-- {} -- loss: {}\n'.format(
                epoch, self.train_loss.epoch_cache))
            self.train_loss.clear_cache()

            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #4
0
    def do_train(self):
        for epoch in range(self.epochs):
            self.log.critical('memory consumed : {}'.format(memory_consumed()))
            self.epoch = epoch
            if epoch and epoch % max(1, (self.checkpoint - 1)) == 0:
                #self.do_predict()
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return

            self.train()
            losses = []
            for j in tqdm(range(self.train_feed.num_batch),
                          desc='Trainer.{}'.format(self.name())):
                self.optimizer.zero_grad()
                input_ = self.train_feed.next_batch()
                idxs, seq, targets = input_

                seq_size, batch_size = seq.size()
                pad_mask = (seq > 0).float()

                loss = 0
                outputs = []
                output = self.__(seq[0], 'output')
                state = self.__(self.init_hidden(batch_size), 'init_hidden')
                for index in range(seq_size - 1):
                    output, state = self.__(self.forward(output, state),
                                            'output, state')
                    loss += self.loss_function(output, targets[index + 1])
                    output = self.__(output.max(1)[1], 'output')
                    outputs.append(output)

                losses.append(loss)
                loss.backward()
                self.optimizer.step()

            epoch_loss = torch.stack(losses).mean()
            self.train_loss.append(epoch_loss.data.item())

            self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss))
            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #5
0
    def do_train(self):
        for epoch in range(self.epochs):

            self.log.critical('memory consumed : {}'.format(memory_consumed()))
            self.epoch = epoch
            if epoch and epoch % max(1, (self.checkpoint - 1)) == 0:
                #self.do_predict()
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return

            self.train()
            losses = []
            tracemalloc.start()
            for j in tqdm(range(self.train_feed.num_batch),
                          desc='Trainer.{}'.format(self.name())):
                self.optimizer.zero_grad()
                input_ = self.train_feed.next_batch()
                idxs, word, targets = input_

                loss = 0
                encoded_info = self.__(self.encode(word), 'encoded_info')

                keys = self.__(self.keys.transpose(0, 1), 'keys')
                keys = self.__(
                    keys.expand([encoded_info.size(0), *keys.size()]), 'keys')
                inner_product = self.__(
                    torch.bmm(
                        encoded_info.unsqueeze(1),  #final state
                        keys),
                    'inner_product')

                values = self.__(self.values, 'values')
                values = self.__(
                    values.expand([inner_product.size(0), *values.size()]),
                    'values')

                weighted_sum = self.__(torch.bmm(inner_product, values),
                                       'weighted_sum')
                weighted_sum = self.__(weighted_sum.squeeze(1), 'weighted_sum')

                #make the same chane in do_[predict|validate]
                tseq_len, batch_size = targets.size()
                state = self.__(
                    (weighted_sum, self.init_hidden(batch_size).squeeze(0)),
                    'decoder initial state')
                #state = self.__( (encoded_info, state[1].squeeze(0)), 'decoder initial state')
                prev_output = self.__(
                    self.sos_token.expand([encoded_info.size(0)]), 'sos_token')

                for i in range(targets.size(0)):
                    output = self.decode(prev_output, state)
                    loss += self.loss_function(output, targets[i])
                    prev_output = output.max(1)[1].long()

                losses.append(loss)
                loss.backward()
                self.optimizer.step()

                del input_  #, keys, values

                if j and not j % 100000:
                    malloc_snap = tracemalloc.take_snapshot()
                    display_tracemalloc_top(malloc_snap, limit=100)

            epoch_loss = torch.stack(losses).mean()
            self.train_loss.append(epoch_loss.data.item())

            self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss))
            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #6
0
ファイル: lm.py プロジェクト: indicnlp/tamil-name-gen
    def do_train(self):
        self.teacher_forcing_ratio = 1
        for epoch in range(self.epochs):

            self.log.critical('memory consumed : {}'.format(memory_consumed()))            
            self.epoch = epoch
            if epoch % max(1, (self.checkpoint - 1)) == 0:
                length = random.randint(5, 10)
                beam_width = random.randint(5, 50)
                self.do_predict(length=length, beam_width=beam_width)
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return
                           
            self.train()
            teacher_force_count = [0, 0]
            def train_on_feed(feed):

                losses = []
                feed.reset_offset()
                for j in tqdm(range(feed.num_batch), desc='Trainer.{}'.format(self.name())):
                    self.optimizer.zero_grad()
                    input_ = feed.next_batch()
                    idxs, (gender, sequence), targets = input_
                    sequence = sequence.transpose(0,1)
                    seq_size, batch_size = sequence.size()

                    state = self.initial_hidden(batch_size)
                    loss = 0
                    output = sequence[0]
                    positions = LongVar(self.config, np.linspace(0, 1, seq_size))
                    for ti in range(1, sequence.size(0) - 1):
                        output = self.forward(gender, positions[ti], output, state)
                        loss += self.loss_function(ti, output, input_)
                        output, state = output

                        if random.random() > self.teacher_forcing_ratio:
                            output = output.max(1)[1]
                            teacher_force_count[0] += 1
                        else:
                            output = sequence[ti+1]
                            teacher_force_count[1] += 1

                    losses.append(loss)
                    loss.backward()
                    self.optimizer.step()
                    
                return torch.stack(losses).mean()

            for i in range(config.HPCONFIG.pretrain_count):
                loss = train_on_feed(self.pretrain_feed)
                
            for i in range(config.HPCONFIG.train_count):
                loss = train_on_feed(self.train_feed)
                self.teacher_forcing_ratio -= 0.1/self.epochs
                
            self.train_loss.append(loss.data.item())                
            self.log.info('teacher_force_count: {}'.format(teacher_force_count))

            self.log.info('-- {} -- loss: {}\n'.format(epoch, self.train_loss))
            
            for m in self.metrics:
                m.write_to_file()

        return True
コード例 #7
0
    def do_train(self):

        for epoch in range(self.epochs):
            self.log.critical('memory consumed : {}'.format(memory_consumed()))
            self.epoch = epoch
            if epoch % max(1, (self.checkpoint - 1)) == 0:
                self.do_predict()
                if self.do_validate() == FLAGS.STOP_TRAINING:
                    self.log.info('loss trend suggests to stop training')
                    return

            self.train()

            def train_on_feed(feed):
                losses = []
                feed.reset_offset()

                for j in tqdm(range(feed.num_batch),
                              desc='Trainer.{}'.format(self.name())):
                    self.optimizer.zero_grad()
                    input_ = feed.next_batch()
                    idxs, (gender, seq), target = input_

                    seq_size, batch_size = seq.size()
                    pad_mask = (seq > 0).float()

                    hidden_states, (hidden, cell_state) = self.__(
                        self.encode_sequence(seq), 'encoded_outpus')

                    loss = 0
                    outputs = []
                    target_size, batch_size = target.size()
                    #TODO: target[0] should not be used. will throw error when used without GO token from batchip
                    output = self.__(target[0], 'hidden')
                    state = self.__((hidden, cell_state), 'init_hidden')
                    gender_embedding = self.gender_embed(gender)
                    for index in range(target_size - 1):
                        output, state = self.__(
                            self.decode(hidden_states, output, state,
                                        gender_embedding), 'output, state')
                        loss += self.loss_function(output, target[index + 1])
                        output = self.__(output.max(1)[1], 'output')
                        outputs.append(output)

                    losses.append(loss)
                    loss.backward()
                    self.optimizer.step()

                return torch.stack(losses).mean()

            for i in range(config.HPCONFIG.pretrain_count):
                loss = train_on_feed(self.pretrain_feed)

            for i in range(config.HPCONFIG.train_count):
                loss = train_on_feed(self.train_feed)
                self.log.info('-- {} -- loss: {}\n'.format(epoch, loss.item()))

            self.train_loss.append(loss.data.item())
            self.log.info('-- {} -- best loss: {}\n'.format(
                epoch, self.best_model[0]))

            for m in self.metrics:
                m.write_to_file()

        return True