示例#1
0
    def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader
        else:
            raise ValueError

        self.model.eval()

        average_meter_set = AverageMeterSet()
        num_instance = 0

        with torch.no_grad():
            tqdm_dataloader = tqdm(loader) if not self.pilot else loader
            for batch_idx, batch in enumerate(tqdm_dataloader):
                if self.pilot and batch_idx >= self.pilot_batch_cnt:
                    # print('Break validation due to pilot mode')
                    break
                batch = {k: v.to(self.device) for k, v in batch.items()}
                batch_size = next(iter(batch.values())).size(0)
                num_instance += batch_size

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                if not self.pilot:
                    description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                          ['Recall@%d' % k for k in self.metric_ks[:3]]
                    description = '{}: '.format(mode.capitalize()) + ', '.join(
                        s + ' {:.3f}' for s in description_metrics)
                    description = description.replace('NDCG', 'N').replace(
                        'Recall', 'R')
                    description = description.format(
                        *(average_meter_set[k].avg
                          for k in description_metrics))
                    tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict(epoch, accum_iter)),
                'epoch': epoch,
                'accum_iter': accum_iter,
                'num_eval_instance': num_instance,
            }
            log_data.update(average_meter_set.averages())
            log_data.update(kwargs)
            if doLog:
                if mode == 'val':
                    self.logger_service.log_val(log_data)
                elif mode == 'test':
                    self.logger_service.log_test(log_data)
                else:
                    raise ValueError
        return log_data
示例#2
0
    def train_one_epoch(self, epoch, accum_iter, train_loader, **kwargs):
        self.model.train()

        average_meter_set = AverageMeterSet()
        num_instance = 0
        tqdm_dataloader = tqdm(train_loader) if not self.pilot else train_loader

        for batch_idx, batch in enumerate(tqdm_dataloader):
            if self.pilot and batch_idx >= self.pilot_batch_cnt:
                # print('Break training due to pilot mode')
                break
            batch_size = next(iter(batch.values())).size(0)
            batch = {k:v.to(self.device) for k, v in batch.items()}
            num_instance += batch_size

            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch)
            if isinstance(loss, tuple):
                loss, extra_info = loss
                for k, v in extra_info.items():
                    average_meter_set.update(k, v)
            loss.backward()

            if self.clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)

            self.optimizer.step()

            average_meter_set.update('loss', loss.item())
            if not self.pilot:
                tqdm_dataloader.set_description(
                    'Epoch {}, loss {:.3f} '.format(epoch, average_meter_set['loss'].avg))

            accum_iter += batch_size

            if self._needs_to_log(accum_iter):
                if not self.pilot:
                    tqdm_dataloader.set_description('Logging')
                log_data = {
                    # 'state_dict': (self._create_state_dict()),
                    'epoch': epoch,
                    'accum_iter': accum_iter,
                }
                log_data.update(average_meter_set.averages())
                log_data.update(kwargs)
                self.log_extra_train_info(log_data)
                self.logger_service.log_train(log_data)

        log_data = {
            # 'state_dict': (self._create_state_dict()),
            'epoch': epoch,
            'accum_iter': accum_iter,
            'num_train_instance': num_instance,
        }
        log_data.update(average_meter_set.averages())
        log_data.update(kwargs)
        self.log_extra_train_info(log_data)
        self.logger_service.log_train(log_data)
        return accum_iter
示例#3
0
    def tobigs_test(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader
        elif mode == 'tobigs_test':
            print('tobigs test')

        import pickle
        import sys, os
        with open('dataset.pkl', 'rb') as fp:
            data_new = pickle.load(fp)

        self.model.eval()
        average_meter_set = AverageMeterSet()
        num_instance = 0

        input_middle_seq = self.input_middle_seq
        input_middle_num = self.input_middle_num
        input_middle_target = self.input_middle_target
        input_future_seq = self.input_future_seq

        if input_middle_seq != None:
            prediction_mode = 'middle'
            input_middle_num = int(input_middle_num[0])
        elif input_future_seq != None:
            prediction_mode = 'future'

        item_true_name_dict = data_new['smap']
        inv_map = {v: k for k, v in item_true_name_dict.items()}

        if prediction_mode == 'middle':

            print('##################')
            print(' ')
            print('##################')
            print(' ')
            print('##################')
            print(' ')

            print('input_middle_seq is')
            print(' ')
            print(input_middle_seq)
            print(' ')
            '''
            strings = input_middle_seq
            new_strings = []
            for string in strings:
                new_string = string.replace("_", " ")
                new_strings.append(new_string)
            input_middle_seq = new_strings
            '''

            print('input_middle_target is')
            print(' ')
            print(input_middle_target)
            '''
            strings = input_middle_target
            new_strings = []
            for string in strings:
                new_string = string.replace("_", " ")
                new_strings.append(new_string)
            input_middle_target = new_strings
            '''

            new_input_seq = [
                item_true_name_dict[name] for name in input_middle_seq
            ]
            new_input_target = [
                item_true_name_dict[name] for name in input_middle_target
            ]

            input_middle_seq = list(map(int, new_input_seq))
            input_middle_target = list(map(int, new_input_target))

            pred_seq_len = len(input_middle_seq)
            dummy_tokens = torch.cat([
                torch.tensor([0] * (42 - (pred_seq_len + 2))),
                torch.tensor(input_middle_seq),
                torch.tensor([1692]),
                torch.tensor(input_middle_target)
            ])
            dummy_tokens = torch.cat(
                [dummy_tokens.reshape(1, 42),
                 dummy_tokens.reshape(1, 42)],
                dim=0)
            dummy_labels = dummy_tokens * (dummy_tokens == 1692) / 1692

            batch = {'tokens': dummy_tokens, 'labels': dummy_labels}

            with torch.no_grad():
                test_assign = self.calculate_loss2(batch)
                valid_index, valid_scores, valid_labels = test_assign

                minimini = valid_scores.min()
                valid_scores = valid_scores + abs(minimini) + 0.01
                scores_all = valid_scores

                item_type_pickle = pd.read_pickle('type_dict.pkl')
                nothaksa_list = [
                    k for k, v in item_type_pickle.items()
                    if v not in ['학사', '학사_복전']
                ]

                tmp = torch.zeros(1692, requires_grad=False).cuda()

                for i in range(1, 1692):
                    if i in nothaksa_list:
                        tmp[i - 1] = 1
                    else:
                        tmp[i - 1] = 0
                tmp = tmp.reshape(1, 1692)
                tmp2 = torch.cat([tmp, tmp], dim=0)

                valid_scores = torch.mul(valid_scores, tmp2)
                tokens_name1 = []
                for key in dummy_tokens[0]:
                    key = key.cpu().detach().numpy()
                    key = int(key)
                    if key == 0:
                        tokens_name1.append('blank')
                    elif key == 1692:
                        tokens_name1.append('#MASK#')
                    else:
                        tokens_name1.append(inv_map[key])
                tokens_name2 = []
                for key in dummy_tokens[1]:
                    key = key.cpu().detach().numpy()
                    key = int(key)
                    if key == 0:
                        tokens_name2.append('blank')
                    elif key == 1692:
                        tokens_name2.append('#MASK#')
                    else:
                        tokens_name2.append(inv_map[key])

                # Scores all
                _, my_indices = torch.max(valid_scores, 1)
                _, my_indices = torch.topk(valid_scores,
                                           k=input_middle_num,
                                           dim=1)
                my_indices = my_indices.cpu().detach().numpy()
                my_indices = my_indices[0]

                print('new my indices is')
                print(' ')
                print(my_indices)

                item_true_name_dict = data_new['smap']
                inv_map = {v: k for k, v in item_true_name_dict.items()}

                pred_item = []
                pred_item_type = []
                pred_item_name = []
                for key in my_indices:
                    true_key = key + 1
                    pred_item.append(true_key)
                    pred_item_type.append(item_type_pickle[true_key])
                    pred_item_name.append(inv_map[true_key])

                print(' ')
                print(' ')
                print('prediction softmax : all item is')
                print(' ')
                print(scores_all)
                print(' ')
                print('shape is')
                print(scores_all.shape)
                print(' ')
                print('pred_item(all) is')
                print(' ')
                print(pred_item)
                print(' ')
                print('pred_item : type is')
                print(' ')
                print(pred_item_type)
                print(' ')
                print('pred_item : name is')
                print(' ')
                print(pred_item_name)
                print(' ')
                print('The End')
                print(' ')

                pred_name = pred_item_name
                with open('pred_middle_name.txt', 'w') as f:
                    for name in pred_name:
                        f.write(name + '\n')
                f.close()

            return torch.max(valid_scores, 1)

        if prediction_mode == 'future':

            print('##################')
            print(' ')
            print('##################')
            print(' ')
            print('##################')
            print(' ')

            print('input_future_seq is')
            print(' ')
            '''
            strings = input_future_seq
            new_strings = []
            for string in strings:
                new_string = string.replace("_", " ")
                new_strings.append(new_string)
            input_future_seq = new_strings
            '''
            print(input_future_seq)

            new_input_seq = [
                item_true_name_dict[name] for name in input_future_seq
            ]

            input_future_seq = list(map(int, new_input_seq))

            pred_seq_len = len(input_future_seq)
            dummy_tokens = torch.cat([
                torch.tensor([0] * (42 - (pred_seq_len + 1))),
                torch.tensor(input_future_seq),
                torch.tensor([1692])
            ])
            dummy_tokens = torch.cat(
                [dummy_tokens.reshape(1, 42),
                 dummy_tokens.reshape(1, 42)],
                dim=0)
            dummy_labels = dummy_tokens * (dummy_tokens == 1692) / 1692

            batch = {'tokens': dummy_tokens, 'labels': dummy_labels}

            with torch.no_grad():
                test_assign = self.calculate_loss2(batch)
                valid_index, valid_scores, valid_labels = test_assign

                minimini = valid_scores.min()
                valid_scores = valid_scores + abs(minimini) + 0.01
                scores_all = valid_scores

                item_type_pickle = pd.read_pickle('type_dict.pkl')
                career_type_list = [
                    k for k, v in item_type_pickle.items()
                    if v in ['중소기업', '스타트업', '대기업']
                ]

                tmp = torch.zeros(1692, requires_grad=False).cuda()

                for i in range(1, 1692):
                    if i in career_type_list:
                        tmp[i - 1] = 1
                    else:
                        tmp[i - 1] = 0
                tmp = tmp.reshape(1, 1692)
                tmp2 = torch.cat([tmp, tmp], dim=0)

                valid_scores = torch.mul(valid_scores, tmp2)

                tokens_name1 = []
                for key in dummy_tokens[0]:
                    key = key.cpu().detach().numpy()
                    key = int(key)
                    if key == 0:
                        tokens_name1.append('blank')
                    elif key == 1692:
                        tokens_name1.append('#MASK#')
                    else:
                        tokens_name1.append(inv_map[key])
                tokens_name2 = []
                for key in dummy_tokens[1]:
                    key = key.cpu().detach().numpy()
                    key = int(key)
                    if key == 0:
                        tokens_name2.append('blank')
                    elif key == 1692:
                        tokens_name2.append('#MASK#')
                    else:
                        tokens_name2.append(inv_map[key])

                print(tokens_name1)
                print(tokens_name2)
                print(' ')
                print(dummy_tokens)
                print(' ')
                print(dummy_labels)

                _, my_indices = torch.max(valid_scores, 1)
                my_indices = my_indices.cpu().detach().numpy()

                pred_item = []
                pred_item_type = []
                pred_item_name = []
                for key in my_indices:
                    true_key = key + 1
                    pred_item.append(true_key)
                    pred_item_type.append(item_type_pickle[true_key])
                    pred_item_name.append(inv_map[true_key])
                print(' ')
                print('pred_item(career) is')
                print(' ')
                print(pred_item[0])
                print(' ')
                print('pred_item : type is')
                print(' ')
                print(pred_item_type[0])
                pred_type = pred_item_type[0]
                print(' ')
                print('pred_item : name is')
                print(' ')
                print(pred_item_name[0])

                pred_name = pred_item_name[0]
                with open('pred_future_name.txt', 'w') as f:
                    f.write(pred_name)
                f.close()

                print(' ')
                print('The End')
                print(' ')

            return torch.max(valid_scores, 1)
示例#4
0
    def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        print(' ')
        print('meantime / trainers / base.py / AbstractTrainer.validate is')

        ### My Code Start###
        my_final_result = -1 * torch.ones(1, 205)
        my_dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.FloatTensor
        my_final_result = my_final_result.to(self.device)
        ### My Code End###

        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader

        else:
            raise ValueError

        self.model.eval()

        average_meter_set = AverageMeterSet()
        num_instance = 0

        with torch.no_grad():
            tqdm_dataloader = tqdm(loader) if not self.pilot else loader
            for batch_idx, batch in enumerate(tqdm_dataloader):
                if self.pilot and batch_idx >= self.pilot_batch_cnt:
                    # print('Break validation due to pilot mode')
                    break
                batch = {k: v.to(self.device) for k, v in batch.items()}
                batch_size = next(iter(batch.values())).size(0)
                num_instance += batch_size

                metrics = self.calculate_metrics(batch)
                '''
                print(' ')
                print(' ')
                print('batch idx is')
                print(batch_idx)

                print('batch : token,   [Batch_size x seq_len]')
                print(batch['tokens'])
                print('batch : candidate,   [Batch_size x 100_negative_samples is]')
                print(batch['candidates'])
                print('batch : labels,   [Batch_size x (1 + 100)_labels is]')
                print(batch['labels'])
                ###### MY CODE ######
                #print('epoch is') # 20201214
                #print(epoch)
                #print('batch is') ##### My code 20201119
                #print(batch)
                #print('true answer is')
                #print(batch['candidates'][:,0])
                MY_SCORES, MY_LABELS, MY_CUT, MY_HITS = self.NEW_CODE_PRINT_PREDICTION(batch) ##### My code 20201119
                my_len = len(MY_CUT)
                print("MY_SCORES is,   [Batch_size x (1 + 100)]")
                print(MY_SCORES) ##### My code 20201119
                print(' ')
                #print("MY_LABELS")
                #print(MY_LABELS) ##### My code 20201119
                print("MY_CUT(prediction) is,   [Batch_size x 1]")
                print(MY_CUT) ##### My code 20201119
                print(' ')
                print("MY_HITS is,   [Batch_size x 1]")
                print(MY_HITS) ##### My code 20201119
                print(' ')
                #print('MY_SCORES shape')
                #print(MY_SCORES.shape)
                #print(' ')
                #print('MY_LABELS shape')
                #print(MY_LABELS.shape)
                #print(' ')
                #print('MY_CUT shape')
                #print(MY_CUT.shape)
                #print('MY_HITS.shape')
                #print(MY_HITS.shape)
                '''
                #my_epoch = epoch
                #my_batch_idx = batch_idx
                #my_batch_token = batch['tokens']
                #my_batch_candidate = batch['candidates']

                #my_batch_score = MY_SCORES
                #my_batch_cut = MY_CUT
                #my_hit = MY_HITS

                #print('true answer is')
                #print(batch['candidates'][:,0])

                #my_epoch1 = torch.Tensor([my_epoch]*batch_size).reshape(batch_size,1)
                #batch_idx1 = torch.Tensor([my_batch_idx]*batch_size).reshape(batch_size,1)
                #batch_idx2 = torch.Tensor(range(batch_size)).reshape(batch_size,1)
                #my_batch_token = my_batch_token.to(self.device)
                #my_candi = batch['candidates'][:,0]
                #my_candi = my_candi.to(self.device)
                #my_cut = MY_CUT
                #my_cut = my_cut.to(self.device)

                #my_epoch1 = my_epoch1.type(my_dtype)
                #batch_idx1 = batch_idx1.type(my_dtype)
                #batch_idx2 = batch_idx2.type(my_dtype)
                #my_batch_token = my_batch_token.type(my_dtype)
                #my_candi = my_candi.type(my_dtype).reshape(batch_size,1)
                #my_hit = my_hit.type(my_dtype)
                #my_cut = my_cut.type(my_dtype)

                #print('###')
                #print('my batch token shape')
                #print(my_batch_token.shape)
                #print(my_candi.shape)
                #print(my_hit.shape)
                #print('batch_idx1')
                #print(batch_idx1)
                #print(batch_idx2)
                #print('my_epoch')
                #print(my_epoch)

                #my_epoch_result = torch.cat([my_epoch1, batch_idx1, batch_idx2, my_batch_token, my_candi, my_cut], 1)

                #my_final_result = torch.cat([my_final_result, my_epoch_result], 0)
                ###### MY CODE ######

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                if not self.pilot:
                    description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                          ['Recall@%d' % k for k in self.metric_ks[:3]]
                    description = '{}: '.format(mode.capitalize()) + ', '.join(
                        s + ' {:.3f}' for s in description_metrics)
                    description = description.replace('NDCG', 'N').replace(
                        'Recall', 'R')
                    description = description.format(
                        *(average_meter_set[k].avg
                          for k in description_metrics))
                    tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict(epoch, accum_iter)),
                'epoch': epoch,
                'accum_iter': accum_iter,
                'num_eval_instance': num_instance,
            }
            log_data.update(average_meter_set.averages())
            log_data.update(kwargs)
            if doLog:
                if mode == 'val':
                    self.logger_service.log_val(log_data)
                elif mode == 'test':
                    self.logger_service.log_test(log_data)
                else:
                    raise ValueError

        ###### MY CODE ######
        #ts = time.time()
        #my_final_result = my_final_result.cpu()
        #my_final_result_np = my_final_result.numpy()
        #my_final_result_df = pd.DataFrame(my_final_result_np)
        #FILENAME = 'my_final_result' + mode + str(epoch) + 'time' + str(ts) + '_' +  '.csv'
        #my_final_result_df.to_csv(FILENAME)
        ###### MY CODE ######

        return log_data