示例#1
0
 def _track_per_iteration(self, output, labels, loss):
     norms = output[-1]
     output = output[:-1]
     sum_norm = 0
     for norm in norms:
         sum_norm += norm
     sum_norm = float(sum_norm / len(norms))
     norm = sum_norm
     similarities = np.zeros(2)
     for i in range(2):
         if len(output) != 2:
             similarity = (output[2 * i] + output[2 * i + 1]) / 2
         else:
             similarity = output[i]
         if i == 0:
             length = similarity.size()[0]
         similarities[i] += torch.sum(similarity, dim=0).data.cpu().numpy()
     if loss is not None:
         loss = float(loss.data.cpu())
     norm /= length
     similarities /= length
     result = [('loss', loss), ('norm', norm), ('cross', similarities[0]),
               ('div', similarities[1])]
     delimiter = ' '
     path = os.path.join(self.results_dir, 'train_iter.txt')
     f = open(path, 'a')
     f.write(
         utils.print_iterable(result,
                              delimiter=' ',
                              max_digits=self.max_digits,
                              print_keys=False))
     f.write('\n')
     f.close()
示例#2
0
 def _training(self, split_batch=1):
     t0 = time()
     self.net.train(mode=True)
     iterator = iter(self.dataloader)
     loss_cum = 0
     counter_cum = 0
     while (True):
         data, labels = self._get_data(iterator)
         if data is None:
             break
         self.optimizer.zero_grad()
         # data_split, labels_split = split_data_labels(data, labels, split_batch)
         # for i in range(split_batch):
         # 	data_mini = data_split[i]
         # 	labels_mini = labels_split[i]
         # 	output = self.net(*data_mini)
         # 	loss = self._get_loss(output, labels_mini) / split_batch
         # 	loss.backward()
         # 	self.tracker.update(output, labels_mini, loss)
         output = self._forward(data)
         loss = self._get_loss(output, labels)
         loss.backward()
         self.optimizer.step()
         self.tracker.update(output, labels, loss)
         self._track_per_iteration(output, labels, loss)
         loss_cum += float(loss)
         counter_cum += 1
     loss_cum /= counter_cum
     result = self.tracker.result()
     t1 = time()
     runtime = t1 - t0
     result = [('epoch', self.epoch), ('runtime', runtime)] + result
     print('train ' +
           utils.print_iterable(result, max_digits=self.max_digits))
     self._write_progress('train', result)
示例#3
0
 def _write_progress(self, name, result):
     delimiter = ' '
     path = os.path.join(self.results_dir, '%s.txt' % name)
     f = open(path, 'a')
     f.write(
         utils.print_iterable(result,
                              delimiter=' ',
                              max_digits=self.max_digits,
                              print_keys=False))
     f.write('\n')
     f.close()
示例#4
0
 def _testing(self, split_batch=1):
     t0 = time()
     self.net.train(mode=False)
     iterator = iter(self.dataloader)
     while (True):
         data, labels = self._get_data(iterator)
         if data is None:
             break
         data_split, labels_split = split_data_labels(
             data, labels, split_batch)
         output_minis = []
         labels_minis = []
         if split_batch != 1:
             for i in range(split_batch):
                 data_mini = data_split[i]
                 labels_mini = labels_split[i]
                 output_mini = self._forward(data_mini)
                 output_minis.append(output_mini.detach())
             output = torch.cat(output_minis)
         else:
             output = self._forward(data)
         loss = self._get_loss(output, labels)
         self.tracker.update(output, labels, loss)
     result = self.tracker.result()
     t1 = time()
     runtime = t1 - t0
     result = [('epoch', self.epoch), ('runtime', runtime)] + result
     print(
         '-------------------------------------------------------------------'
     )
     print('test ' +
           utils.print_iterable(result, max_digits=self.max_digits))
     print(
         '-------------------------------------------------------------------'
     )
     self._write_progress('test', result)
示例#5
0
 def _print_infos(self):
     path = os.path.join(self.results_dir, 'infos.txt')
     print(utils.print_iterable(self.list_infos, delimiter='\n'))
示例#6
0
 def _write_infos(self):
     path = os.path.join(self.results_dir, 'infos.txt')
     f = open(path, 'a')
     f.write(utils.print_iterable(self.list_infos, delimiter='\n'))
     f.close()