コード例 #1
0
class BasicInstructor:
    def __init__(self, opt):
        self.log = create_logger(__name__,
                                 silent=False,
                                 to_disk=True,
                                 log_file=cfg.log_filename if cfg.if_test else
                                 [cfg.log_filename, cfg.save_root + 'log.txt'])
        self.sig = Signal(cfg.signal_file)
        self.opt = opt
        self.show_config()

        self.clas = None

        # load dictionary
        self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset)

        # Dataloader
        try:
            self.train_data = GenDataIter(cfg.train_data)
            self.test_data = GenDataIter(cfg.test_data, if_test_data=True)
        except:
            pass

        try:
            self.train_data_list = [
                GenDataIter(cfg.cat_train_data.format(i))
                for i in range(cfg.k_label)
            ]
            self.test_data_list = [
                GenDataIter(cfg.cat_test_data.format(i), if_test_data=True)
                for i in range(cfg.k_label)
            ]
            self.clas_data_list = [
                GenDataIter(cfg.cat_test_data.format(str(i)),
                            if_test_data=True) for i in range(cfg.k_label)
            ]

            self.train_samples_list = [
                self.train_data_list[i].target for i in range(cfg.k_label)
            ]
            self.clas_samples_list = [
                self.clas_data_list[i].target for i in range(cfg.k_label)
            ]
        except:
            pass

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_criterion = nn.CrossEntropyLoss()
        self.clas_criterion = nn.CrossEntropyLoss()

        # Optimizer
        self.clas_opt = None

        # Metrics
        self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu)
        self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA)
        self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA)
        self.self_bleu = BLEU('Self-BLEU',
                              gram=[2, 3, 4],
                              if_use=cfg.use_self_bleu)
        self.clas_acc = ACC(if_use=cfg.use_clas_acc)
        self.ppl = PPL(self.train_data,
                       self.test_data,
                       n_gram=5,
                       if_use=cfg.use_ppl)
        self.all_metrics = [
            self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl
        ]

    def _run(self):
        print('Nothing to run in Basic Instructor!')
        pass

    def _test(self):
        pass

    def init_model(self):
        if cfg.dis_pretrain:
            self.log.info('Load pre-trained discriminator: {}'.format(
                cfg.pretrained_dis_path))
            self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path))
        if cfg.gen_pretrain:
            self.log.info('Load MLE pre-trained generator: {}'.format(
                cfg.pretrained_gen_path))
            self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path))

        if cfg.CUDA:
            self.gen = self.gen.cuda()
            self.dis = self.dis.cuda()

    def train_gen_epoch(self, model, data_loader, criterion, optimizer):
        total_loss = 0
        for i, data in enumerate(data_loader):
            inp, target = data['input'], data['target']
            if cfg.CUDA:
                inp, target = inp.cuda(), target.cuda()

            hidden = model.init_hidden(data_loader.batch_size)
            pred = model.forward(inp, hidden)
            loss = criterion(pred, target.view(-1))
            self.optimize(optimizer, loss, model)
            total_loss += loss.item()
        return total_loss / len(data_loader)

    def train_dis_epoch(self, model, data_loader, criterion, optimizer):
        total_loss = 0
        total_acc = 0
        total_num = 0
        for i, data in enumerate(data_loader):
            inp, target = data['input'], data['target']
            if cfg.CUDA:
                inp, target = inp.cuda(), target.cuda()

            pred = model.forward(inp)
            loss = criterion(pred, target)
            self.optimize(optimizer, loss, model)

            total_loss += loss.item()
            total_acc += torch.sum((pred.argmax(dim=-1) == target)).item()
            total_num += inp.size(0)

        total_loss /= len(data_loader)
        total_acc /= total_num
        return total_loss, total_acc

    def train_classifier(self, epochs):
        """
        Classifier for calculating the classification accuracy metric of category text generation.

        Note: the train and test data for the classifier is opposite to the generator.
        Because the classifier is to calculate the classification accuracy of the generated samples
        where are trained on self.train_samples_list.

        Since there's no test data in synthetic data (oracle data), the synthetic data experiments
        doesn't need a classifier.
        """
        import copy

        # Prepare data for Classifier
        clas_data = CatClasDataIter(self.clas_samples_list)
        eval_clas_data = CatClasDataIter(self.train_samples_list)

        max_acc = 0
        best_clas = None
        for epoch in range(epochs):
            c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader,
                                                 self.clas_criterion,
                                                 self.clas_opt)
            _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader,
                                        self.clas_criterion)
            if eval_acc > max_acc:
                best_clas = copy.deepcopy(
                    self.clas.state_dict())  # save the best classifier
                max_acc = eval_acc
            self.log.info(
                '[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f',
                epoch, c_loss, c_acc, eval_acc, max_acc)
        self.clas.load_state_dict(
            copy.deepcopy(best_clas))  # Reload the best classifier

    @staticmethod
    def eval_dis(model, data_loader, criterion):
        total_loss = 0
        total_acc = 0
        total_num = 0
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                inp, target = data['input'], data['target']
                if cfg.CUDA:
                    inp, target = inp.cuda(), target.cuda()

                pred = model.forward(inp)
                loss = criterion(pred, target)
                total_loss += loss.item()
                total_acc += torch.sum((pred.argmax(dim=-1) == target)).item()
                total_num += inp.size(0)
            total_loss /= len(data_loader)
            total_acc /= total_num
        return total_loss, total_acc

    @staticmethod
    def optimize_multi(opts, losses):
        for i, (opt, loss) in enumerate(zip(opts, losses)):
            opt.zero_grad()
            loss.backward(retain_graph=True if i < len(opts) - 1 else False)
            opt.step()

    @staticmethod
    def optimize(opt, loss, model=None, retain_graph=False):
        opt.zero_grad()
        loss.backward(retain_graph=retain_graph)
        if model is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
        opt.step()

    def show_config(self):
        self.log.info(100 * '=')
        self.log.info('> training arguments:')
        for arg in vars(self.opt):
            self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))
        self.log.info(100 * '=')

    def cal_metrics(self, fmt_str=False):
        """
        Calculate metrics
        :param fmt_str: if return format string for logging
        """
        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200),
                                            self.idx2word_dict)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens,
                            real_text=self.test_data.tokens)
            self.nll_gen.reset(self.gen, self.train_data.loader)
            self.nll_div.reset(self.gen, gen_data.loader)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.ppl.reset(gen_tokens)

        if fmt_str:
            return ', '.join([
                '%s = %s' % (metric.get_name(), metric.get_score())
                for metric in self.all_metrics
            ])
        else:
            return [metric.get_score() for metric in self.all_metrics]

    def cal_metrics_with_label(self, label_i):
        assert type(label_i) == int, 'missing label'

        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen.sample(cfg.samples_num,
                                           8 * cfg.batch_size,
                                           label_i=label_i)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(
                self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict)
            clas_data = CatClasDataIter([eval_samples], label_i)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens,
                            real_text=self.test_data_list[label_i].tokens)
            self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader,
                               label_i)
            self.nll_div.reset(self.gen, gen_data.loader, label_i)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.clas_acc.reset(self.clas, clas_data.loader)
            self.ppl.reset(gen_tokens)

        return [metric.get_score() for metric in self.all_metrics]

    def comb_metrics(self, fmt_str=False):
        all_scores = [
            self.cal_metrics_with_label(label_i)
            for label_i in range(cfg.k_label)
        ]
        all_scores = np.array(
            all_scores).T.tolist()  # each row for each metric

        if fmt_str:
            return ', '.join([
                '%s = %s' % (metric.get_name(), score)
                for (metric, score) in zip(self.all_metrics, all_scores)
            ])
        return all_scores

    def _save(self, phase, epoch):
        """Save model state dict and generator's samples"""
        if phase != 'ADV':
            torch.save(
                self.gen.state_dict(),
                cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch))
        save_sample_path = cfg.save_samples_root + 'samples_{}_{}_{:05d}.txt'.format(
            phase, cfg.samples_num, epoch)
        samples = self.gen.sample(5000, cfg.batch_size)
        write_tokens(save_sample_path,
                     tensor_to_tokens(samples, self.idx2word_dict))

    def update_temperature(self, i, N):
        self.gen.temperature.data = torch.Tensor(
            [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)])
        if cfg.CUDA:
            self.gen.temperature.data = self.gen.temperature.data.cuda()
コード例 #2
0
class Model():
    def __init__(self, encoder_layer_num, decoder_layer_num, hidden_dim, batch_size, learning_rate, dropout, init_train = True):
        self.encoder_layer_num = encoder_layer_num
        self.decoder_layer_num = decoder_layer_num
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.dropout = dropout
        self.init_train = init_train
        #---------fix----------
        self.vocab_size = cfg.vocab_size
        self.max_length = cfg.max_length
        self.embedding_matrix = make_embedding_matrix(cfg.all_captions)
        self.SOS_token = cfg.SOS_token
        self.EOS_token = cfg.EOS_token
        self.idx2word_dict = load_dict()
        #----------------------
        
        self.bleu = BLEU('BLEU', gram=[2,3,4,5])
        #self.bleu.reset(test_text = gen_tokens, real_text = self.test_data.tokens)
              
        if init_train:
            self._init_train()
            train_week_stock, train_month_stock, t_month_stock,train_input_cap_vector, train_output_cap_vector = load_training_data()
            self.train_data = batch_generator(train_week_stock, train_month_stock, t_month_stock,train_input_cap_vector, train_output_cap_vector, self.batch_size)
            self.total_iter = len(train_input_cap_vector)
            
            self._init_eval()
            val_week_stock, val_month_stock, val_t_month_stock,val_input_cap_vector, val_output_cap_vector = load_val_data()
            self.val_data = batch_generator(val_week_stock, val_month_stock, val_t_month_stock,val_input_cap_vector, val_output_cap_vector, self.batch_size)
            self.val_total_iter = len(val_input_cap_vector)
            
    # gpu 탄력적으로 사용.
    def gpu_session_config(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        return config

    def _init_train(self):
        self.train_graph = tf.Graph()
        with self.train_graph.as_default():
            with tf.variable_scope('encoder_input'):
                self.week_input = tf.placeholder(tf.float64, shape= [None, 7], name='week_input')
                self.month_input = tf.placeholder(tf.float64, shape=[None, 28], name='month_input')
                self.t_month_input = tf.placeholder(tf.float64, shape=[None, 84], name='t_month_input')

            with tf.variable_scope("decoder_input"):
                self.decoder_input = tf.placeholder(tf.int32, [None, self.max_length], name='input')
                self.decoder_target = tf.placeholder(tf.int32, [None, self.max_length], name='target')
                self.decoder_targets_length = tf.placeholder(tf.int32, shape = [self.batch_size, ], name = 'targets_length')
                
            encoded_output, encoded_state = encoder_module(self.week_input,
                                                         self.month_input,
                                                         self.t_month_input,
                                                         self.encoder_layer_num,
                                                         self.decoder_layer_num,
                                                         self.hidden_dim)

            decoder_output, decoder_state = decoder_module(encoded_state,
                                                          encoded_output,
                                                          self.decoder_input,
                                                          self.decoder_targets_length,
                                                          self.embedding_matrix,
                                                          self.decoder_layer_num,
                                                          self.hidden_dim,
                                                          self.max_length,
                                                          self.vocab_size,
                                                          self.batch_size,
                                                          self.dropout,
                                                          self.SOS_token, 
                                                          self.EOS_token, 
                                                          train = True)

            self.logits = decoder_output.rnn_output
            # traning output
            self.sample_id = decoder_output.sample_id
            
            self._init_optimizer()
            
            self.train_init = tf.global_variables_initializer()
            self.train_saver = tf.train.Saver()
        self.train_session = tf.Session(graph=self.train_graph, config = self.gpu_session_config())
        

    def _init_optimizer(self):
        #loss mask
        mask = tf.cast(tf.sequence_mask(self.decoder_targets_length, self.max_length),tf.float64)
        self.loss = tf.contrib.seq2seq.sequence_loss(logits= self.logits,
                                                    targets = self.decoder_target,
                                                    weights = mask,
                                                    average_across_timesteps = True,
                                                    average_across_batch = True)
        #tf.summary.scalar('loss', self.loss)
        #self.summary_op = tf.summary.merge_all()
 
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients,5.0)
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params))
        

    # batch 단위로 계산
    def cal_metrics(self, infer_text, real_text):
        self.bleu.reset(infer_text = infer_text, real_text = real_text)
        return self.bleu.get_score()

    # bleu, greedy/beam search init
    def _init_eval(self):
        self.eval_graph = tf.Graph()
        with self.eval_graph.as_default():
            self.eval_week_input = tf.placeholder(tf.float64, shape= [None, 7])
            self.eval_month_input = tf.placeholder(tf.float64, shape=[None, 28])
            self.eval_t_month_input = tf.placeholder(tf.float64, shape=[None, 84])
            self.eval_decoder_targets_length = tf.placeholder(tf.int32, shape = [self.batch_size, ])
            eval_encoded_output, eval_encoded_state = encoder_module(self.eval_week_input,
                                                                     self.eval_month_input,
                                                                     self.eval_t_month_input,
                                                                     self.encoder_layer_num,
                                                                     self.decoder_layer_num,
                                                                     self.hidden_dim)
                
            self.eval_decoder_output, eval_decoder_state = decoder_module(eval_encoded_state,
                                                                          eval_encoded_output,
                                                                          None,
                                                                          self.eval_decoder_targets_length,
                                                                          self.embedding_matrix,
                                                                          self.decoder_layer_num,
                                                                          self.hidden_dim,
                                                                          self.max_length,
                                                                          self.vocab_size,
                                                                          self.batch_size,
                                                                          self.dropout,
                                                                          self.SOS_token, 
                                                                          self.EOS_token, 
                                                                          train = False)
            
            self.predicted_ids = tf.identity(self.eval_decoder_output.predicted_ids)
            self.eval_saver = tf.train.Saver()
        self.eval_session = tf.Session(graph=self.eval_graph,config=self.gpu_session_config())       
            
    def train_epoch(self, epochs):
        if not self.init_train:
            raise Exception('Train graph is not inited')
        with self.train_graph.as_default():
            if os.path.isfile(cfg.save_path + '.meta'):
                print("##########################")
                print('#     Model restore..    #')
                print("##########################")
                self.train_saver.restore(self.train_session, cfg.save_path)
            else:
                self.train_session.run(self.train_init)
            total_loss = 0
            total_step = 0
            start_time =time.time()
            for e in range(epochs):
                for step in range(self.total_iter// self.batch_size):
                    data = next(self.train_data)
                    week_stock = data['week_stock']
                    month_stock = data['month_stock']
                    t_month_stock = data['t_month_stock']
                    decoder_input = data['decoder_input']
                    decoder_target = data['decoder_target']
                    batch_seq = batch_seq_len(data['decoder_target'])
                    _, loss, sample_id = self.train_session.run([self.optimizer, self.loss, self.sample_id], 
                                                            feed_dict = {self.week_input : week_stock,
                                                                         self.month_input : month_stock,
                                                                         self.t_month_input : t_month_stock,
                                                                         self.decoder_input : decoder_input,
                                                                         self.decoder_target : decoder_target,
                                                                        self.decoder_targets_length : batch_seq})
#                     total_loss += loss
#                 total_step += self.total_iter
#                 loss = total_loss/total_step
                end = time.time()
                print('epoch: {}|{}  minibatch loss: {:.6f}   Time: {:.1f} min'.format(e+1, epochs, loss, (end-start_time)/60 ))
                
                if e % 50 ==0:
                    self.train_saver.save(self.train_session, cfg.save_path)
                    #랜덤 sid 선택, training output_text
                    sid = random.randint(0, self.batch_size-1)
                    target_text = decode_text(decoder_target[sid],self.idx2word_dict)
                    output_text = decode_text(sample_id[sid],self.idx2word_dict)
                    print('============ training sample text =============')
                    print('training_target :' + target_text)
                    print('training_output :' + output_text)
                    print('===============================================')
                    self.eval()

    def eval(self):
        with self.eval_graph.as_default():
            self.eval_saver.restore(self.eval_session, cfg.save_path)
            all_bleu = [0] * 4
            eval_mask_weights = tf.ones(shape=[self.batch_size, self.max_length],dtype=tf.float64)
            for step in range(self.val_total_iter//self.batch_size):
                data = next(self.val_data)
                week_stock = data['week_stock']
                month_stock = data['month_stock']
                t_month_stock = data['t_month_stock']
                batch_seq = batch_seq_len(data['decoder_target'])
                #beam search_output
                beam_output = self.eval_session.run([self.predicted_ids], 
                                                    feed_dict = {self.eval_week_input : week_stock,
                                                                 self.eval_month_input : month_stock,
                                                                 self.eval_t_month_input : t_month_stock,
                                                                 self.eval_decoder_targets_length : batch_seq
                                                                })   
                
                target_text = idx_to_text(data['decoder_input'][:,1:],self.idx2word_dict) 
                target_text = remove_sent_pad(target_text)
                
                beam_output = np.squeeze(np.array(beam_output),axis=0)
                output_text = idx_to_text(beam_output[:,:,0], self.idx2word_dict)
                bleu_score = self.cal_metrics(target_text, output_text)

                for idx,score in enumerate(bleu_score):
                    all_bleu[idx] += score
            print('================ BLEU score ================')
            for idx, bleu in enumerate(bleu_score):#2,3,4,5
                print('BLEU-{} : {}'.format(idx+2, bleu))
            sid = random.randint(0, self.batch_size-1)
            target_text = decode_text(data['decoder_target'][sid],self.idx2word_dict)
            output_text = decode_text(beam_output[sid,:,0],self.idx2word_dict)
            print('============= Beam search text =============')
            print('infer_target : ' + target_text)
            print('beam_search  : ' + output_text)
            print('============================================')