Пример #1
0
    def train_one_epoch(self, epoch, accum_iter, train_loader, **kwargs):
        self.model.train()

        average_meter_set = AverageMeterSet()
        num_instance = 0
        tqdm_dataloader = tqdm(train_loader) if not self.pilot else train_loader

        for batch_idx, batch in enumerate(tqdm_dataloader):
            if self.pilot and batch_idx >= self.pilot_batch_cnt:
                # print('Break training due to pilot mode')
                break
            batch_size = next(iter(batch.values())).size(0)
            batch = {k:v.to(self.device) for k, v in batch.items()}
            num_instance += batch_size

            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch)
            if isinstance(loss, tuple):
                loss, extra_info = loss
                for k, v in extra_info.items():
                    average_meter_set.update(k, v)
            loss.backward()

            if self.clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)

            self.optimizer.step()

            average_meter_set.update('loss', loss.item())
            if not self.pilot:
                tqdm_dataloader.set_description(
                    'Epoch {}, loss {:.3f} '.format(epoch, average_meter_set['loss'].avg))

            accum_iter += batch_size

            if self._needs_to_log(accum_iter):
                if not self.pilot:
                    tqdm_dataloader.set_description('Logging')
                log_data = {
                    # 'state_dict': (self._create_state_dict()),
                    'epoch': epoch,
                    'accum_iter': accum_iter,
                }
                log_data.update(average_meter_set.averages())
                log_data.update(kwargs)
                self.log_extra_train_info(log_data)
                self.logger_service.log_train(log_data)

        log_data = {
            # 'state_dict': (self._create_state_dict()),
            'epoch': epoch,
            'accum_iter': accum_iter,
            'num_train_instance': num_instance,
        }
        log_data.update(average_meter_set.averages())
        log_data.update(kwargs)
        self.log_extra_train_info(log_data)
        self.logger_service.log_train(log_data)
        return accum_iter
Пример #2
0
    def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader
        else:
            raise ValueError

        self.model.eval()

        average_meter_set = AverageMeterSet()
        num_instance = 0

        with torch.no_grad():
            tqdm_dataloader = tqdm(loader) if not self.pilot else loader
            for batch_idx, batch in enumerate(tqdm_dataloader):
                if self.pilot and batch_idx >= self.pilot_batch_cnt:
                    # print('Break validation due to pilot mode')
                    break
                batch = {k: v.to(self.device) for k, v in batch.items()}
                batch_size = next(iter(batch.values())).size(0)
                num_instance += batch_size

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                if not self.pilot:
                    description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                          ['Recall@%d' % k for k in self.metric_ks[:3]]
                    description = '{}: '.format(mode.capitalize()) + ', '.join(
                        s + ' {:.3f}' for s in description_metrics)
                    description = description.replace('NDCG', 'N').replace(
                        'Recall', 'R')
                    description = description.format(
                        *(average_meter_set[k].avg
                          for k in description_metrics))
                    tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict(epoch, accum_iter)),
                'epoch': epoch,
                'accum_iter': accum_iter,
                'num_eval_instance': num_instance,
            }
            log_data.update(average_meter_set.averages())
            log_data.update(kwargs)
            if doLog:
                if mode == 'val':
                    self.logger_service.log_val(log_data)
                elif mode == 'test':
                    self.logger_service.log_test(log_data)
                else:
                    raise ValueError
        return log_data
Пример #3
0
    def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        print(' ')
        print('meantime / trainers / base.py / AbstractTrainer.validate is')

        ### My Code Start###
        my_final_result = -1 * torch.ones(1, 205)
        my_dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.FloatTensor
        my_final_result = my_final_result.to(self.device)
        ### My Code End###

        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader

        else:
            raise ValueError

        self.model.eval()

        average_meter_set = AverageMeterSet()
        num_instance = 0

        with torch.no_grad():
            tqdm_dataloader = tqdm(loader) if not self.pilot else loader
            for batch_idx, batch in enumerate(tqdm_dataloader):
                if self.pilot and batch_idx >= self.pilot_batch_cnt:
                    # print('Break validation due to pilot mode')
                    break
                batch = {k: v.to(self.device) for k, v in batch.items()}
                batch_size = next(iter(batch.values())).size(0)
                num_instance += batch_size

                metrics = self.calculate_metrics(batch)
                '''
                print(' ')
                print(' ')
                print('batch idx is')
                print(batch_idx)

                print('batch : token,   [Batch_size x seq_len]')
                print(batch['tokens'])
                print('batch : candidate,   [Batch_size x 100_negative_samples is]')
                print(batch['candidates'])
                print('batch : labels,   [Batch_size x (1 + 100)_labels is]')
                print(batch['labels'])
                ###### MY CODE ######
                #print('epoch is') # 20201214
                #print(epoch)
                #print('batch is') ##### My code 20201119
                #print(batch)
                #print('true answer is')
                #print(batch['candidates'][:,0])
                MY_SCORES, MY_LABELS, MY_CUT, MY_HITS = self.NEW_CODE_PRINT_PREDICTION(batch) ##### My code 20201119
                my_len = len(MY_CUT)
                print("MY_SCORES is,   [Batch_size x (1 + 100)]")
                print(MY_SCORES) ##### My code 20201119
                print(' ')
                #print("MY_LABELS")
                #print(MY_LABELS) ##### My code 20201119
                print("MY_CUT(prediction) is,   [Batch_size x 1]")
                print(MY_CUT) ##### My code 20201119
                print(' ')
                print("MY_HITS is,   [Batch_size x 1]")
                print(MY_HITS) ##### My code 20201119
                print(' ')
                #print('MY_SCORES shape')
                #print(MY_SCORES.shape)
                #print(' ')
                #print('MY_LABELS shape')
                #print(MY_LABELS.shape)
                #print(' ')
                #print('MY_CUT shape')
                #print(MY_CUT.shape)
                #print('MY_HITS.shape')
                #print(MY_HITS.shape)
                '''
                #my_epoch = epoch
                #my_batch_idx = batch_idx
                #my_batch_token = batch['tokens']
                #my_batch_candidate = batch['candidates']

                #my_batch_score = MY_SCORES
                #my_batch_cut = MY_CUT
                #my_hit = MY_HITS

                #print('true answer is')
                #print(batch['candidates'][:,0])

                #my_epoch1 = torch.Tensor([my_epoch]*batch_size).reshape(batch_size,1)
                #batch_idx1 = torch.Tensor([my_batch_idx]*batch_size).reshape(batch_size,1)
                #batch_idx2 = torch.Tensor(range(batch_size)).reshape(batch_size,1)
                #my_batch_token = my_batch_token.to(self.device)
                #my_candi = batch['candidates'][:,0]
                #my_candi = my_candi.to(self.device)
                #my_cut = MY_CUT
                #my_cut = my_cut.to(self.device)

                #my_epoch1 = my_epoch1.type(my_dtype)
                #batch_idx1 = batch_idx1.type(my_dtype)
                #batch_idx2 = batch_idx2.type(my_dtype)
                #my_batch_token = my_batch_token.type(my_dtype)
                #my_candi = my_candi.type(my_dtype).reshape(batch_size,1)
                #my_hit = my_hit.type(my_dtype)
                #my_cut = my_cut.type(my_dtype)

                #print('###')
                #print('my batch token shape')
                #print(my_batch_token.shape)
                #print(my_candi.shape)
                #print(my_hit.shape)
                #print('batch_idx1')
                #print(batch_idx1)
                #print(batch_idx2)
                #print('my_epoch')
                #print(my_epoch)

                #my_epoch_result = torch.cat([my_epoch1, batch_idx1, batch_idx2, my_batch_token, my_candi, my_cut], 1)

                #my_final_result = torch.cat([my_final_result, my_epoch_result], 0)
                ###### MY CODE ######

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                if not self.pilot:
                    description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                          ['Recall@%d' % k for k in self.metric_ks[:3]]
                    description = '{}: '.format(mode.capitalize()) + ', '.join(
                        s + ' {:.3f}' for s in description_metrics)
                    description = description.replace('NDCG', 'N').replace(
                        'Recall', 'R')
                    description = description.format(
                        *(average_meter_set[k].avg
                          for k in description_metrics))
                    tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict(epoch, accum_iter)),
                'epoch': epoch,
                'accum_iter': accum_iter,
                'num_eval_instance': num_instance,
            }
            log_data.update(average_meter_set.averages())
            log_data.update(kwargs)
            if doLog:
                if mode == 'val':
                    self.logger_service.log_val(log_data)
                elif mode == 'test':
                    self.logger_service.log_test(log_data)
                else:
                    raise ValueError

        ###### MY CODE ######
        #ts = time.time()
        #my_final_result = my_final_result.cpu()
        #my_final_result_np = my_final_result.numpy()
        #my_final_result_df = pd.DataFrame(my_final_result_np)
        #FILENAME = 'my_final_result' + mode + str(epoch) + 'time' + str(ts) + '_' +  '.csv'
        #my_final_result_df.to_csv(FILENAME)
        ###### MY CODE ######

        return log_data