def do_train2(self): if not hasattr(self, 'batch_cache'): self.build_cache_for_train2() for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch and epoch % max(1, (self.checkpoint - 1)) == 0: #self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() losses = [] for input_ in tqdm(self.batch_cache, desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() idxs, word, targets = input_ output = self.__(self.forward(word), 'output') loss = self.loss_function(output, targets) losses.append(loss) loss.backward() self.optimizer.step() epoch_loss = torch.stack(losses).mean() self.train_loss.append(epoch_loss.data.item()) self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss)) for m in self.metrics: m.write_to_file() return True
def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch % max(1, (self.checkpoint - 1)) == 0: #self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() losses = [] for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = self.train_feed.next_batch() idxs, inputs, targets = input_ output = self.forward(input_) loss = self.loss_function(output, input_) #print(loss.data.cpu().numpy()) losses.append(loss) loss.backward() self.optimizer.step() epoch_loss = torch.stack(losses).mean() self.train_loss.append(epoch_loss.data.item()) self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss)) for m in self.metrics: m.write_to_file() return True
def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch % max(1, (self.checkpoint - 1)) == 0: if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() teacher_force_count = [0, 0] for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = self.train_feed.next_batch() idxs, inputs, targets = input_ sequence = inputs[0].transpose(0, 1) _, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss = 0 output = sequence[0] for ti in range(1, sequence.size(0) - 1): output = self.forward(output, state) loss += self.loss_function(ti, output, input_) output, state = output if random.random() > 0.5: output = output.max(1)[1] teacher_force_count[0] += 1 else: output = sequence[ti + 1] teacher_force_count[1] += 1 loss.backward() self.train_loss.cache(loss.data.item()) self.optimizer.step() self.log.info( 'teacher_force_count: {}'.format(teacher_force_count)) self.log.info('-- {} -- loss: {}\n'.format( epoch, self.train_loss.epoch_cache)) self.train_loss.clear_cache() for m in self.metrics: m.write_to_file() return True
def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch and epoch % max(1, (self.checkpoint - 1)) == 0: #self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() losses = [] for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = self.train_feed.next_batch() idxs, seq, targets = input_ seq_size, batch_size = seq.size() pad_mask = (seq > 0).float() loss = 0 outputs = [] output = self.__(seq[0], 'output') state = self.__(self.init_hidden(batch_size), 'init_hidden') for index in range(seq_size - 1): output, state = self.__(self.forward(output, state), 'output, state') loss += self.loss_function(output, targets[index + 1]) output = self.__(output.max(1)[1], 'output') outputs.append(output) losses.append(loss) loss.backward() self.optimizer.step() epoch_loss = torch.stack(losses).mean() self.train_loss.append(epoch_loss.data.item()) self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss)) for m in self.metrics: m.write_to_file() return True
def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch and epoch % max(1, (self.checkpoint - 1)) == 0: #self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() losses = [] tracemalloc.start() for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = self.train_feed.next_batch() idxs, word, targets = input_ loss = 0 encoded_info = self.__(self.encode(word), 'encoded_info') keys = self.__(self.keys.transpose(0, 1), 'keys') keys = self.__( keys.expand([encoded_info.size(0), *keys.size()]), 'keys') inner_product = self.__( torch.bmm( encoded_info.unsqueeze(1), #final state keys), 'inner_product') values = self.__(self.values, 'values') values = self.__( values.expand([inner_product.size(0), *values.size()]), 'values') weighted_sum = self.__(torch.bmm(inner_product, values), 'weighted_sum') weighted_sum = self.__(weighted_sum.squeeze(1), 'weighted_sum') #make the same chane in do_[predict|validate] tseq_len, batch_size = targets.size() state = self.__( (weighted_sum, self.init_hidden(batch_size).squeeze(0)), 'decoder initial state') #state = self.__( (encoded_info, state[1].squeeze(0)), 'decoder initial state') prev_output = self.__( self.sos_token.expand([encoded_info.size(0)]), 'sos_token') for i in range(targets.size(0)): output = self.decode(prev_output, state) loss += self.loss_function(output, targets[i]) prev_output = output.max(1)[1].long() losses.append(loss) loss.backward() self.optimizer.step() del input_ #, keys, values if j and not j % 100000: malloc_snap = tracemalloc.take_snapshot() display_tracemalloc_top(malloc_snap, limit=100) epoch_loss = torch.stack(losses).mean() self.train_loss.append(epoch_loss.data.item()) self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss)) for m in self.metrics: m.write_to_file() return True
def do_train(self): self.teacher_forcing_ratio = 1 for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch % max(1, (self.checkpoint - 1)) == 0: length = random.randint(5, 10) beam_width = random.randint(5, 50) self.do_predict(length=length, beam_width=beam_width) if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() teacher_force_count = [0, 0] def train_on_feed(feed): losses = [] feed.reset_offset() for j in tqdm(range(feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = feed.next_batch() idxs, (gender, sequence), targets = input_ sequence = sequence.transpose(0,1) seq_size, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss = 0 output = sequence[0] positions = LongVar(self.config, np.linspace(0, 1, seq_size)) for ti in range(1, sequence.size(0) - 1): output = self.forward(gender, positions[ti], output, state) loss += self.loss_function(ti, output, input_) output, state = output if random.random() > self.teacher_forcing_ratio: output = output.max(1)[1] teacher_force_count[0] += 1 else: output = sequence[ti+1] teacher_force_count[1] += 1 losses.append(loss) loss.backward() self.optimizer.step() return torch.stack(losses).mean() for i in range(config.HPCONFIG.pretrain_count): loss = train_on_feed(self.pretrain_feed) for i in range(config.HPCONFIG.train_count): loss = train_on_feed(self.train_feed) self.teacher_forcing_ratio -= 0.1/self.epochs self.train_loss.append(loss.data.item()) self.log.info('teacher_force_count: {}'.format(teacher_force_count)) self.log.info('-- {} -- loss: {}\n'.format(epoch, self.train_loss)) for m in self.metrics: m.write_to_file() return True
def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch % max(1, (self.checkpoint - 1)) == 0: self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() def train_on_feed(feed): losses = [] feed.reset_offset() for j in tqdm(range(feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = feed.next_batch() idxs, (gender, seq), target = input_ seq_size, batch_size = seq.size() pad_mask = (seq > 0).float() hidden_states, (hidden, cell_state) = self.__( self.encode_sequence(seq), 'encoded_outpus') loss = 0 outputs = [] target_size, batch_size = target.size() #TODO: target[0] should not be used. will throw error when used without GO token from batchip output = self.__(target[0], 'hidden') state = self.__((hidden, cell_state), 'init_hidden') gender_embedding = self.gender_embed(gender) for index in range(target_size - 1): output, state = self.__( self.decode(hidden_states, output, state, gender_embedding), 'output, state') loss += self.loss_function(output, target[index + 1]) output = self.__(output.max(1)[1], 'output') outputs.append(output) losses.append(loss) loss.backward() self.optimizer.step() return torch.stack(losses).mean() for i in range(config.HPCONFIG.pretrain_count): loss = train_on_feed(self.pretrain_feed) for i in range(config.HPCONFIG.train_count): loss = train_on_feed(self.train_feed) self.log.info('-- {} -- loss: {}\n'.format(epoch, loss.item())) self.train_loss.append(loss.data.item()) self.log.info('-- {} -- best loss: {}\n'.format( epoch, self.best_model[0])) for m in self.metrics: m.write_to_file() return True