Exemple #1
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 vocab,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Build Graph')
    def build(self, cuda=True):

        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            # orthogonal initialiation for hidden weights
            # input gate bias for GRUs
            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    # bias_hh is concatenation of reset, input, new gates
                    # only set the input gate bias to 2.0
                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        # Overview Parameters
        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        """Save parameters to checkpoint"""
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        """Load parameters from checkpoint"""
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)
        if epoch_recon_loss is not None:
            self.writer.update_loss(loss=epoch_recon_loss,
                                    step_i=epoch_i + 1,
                                    name='train_recon_loss')

        epoch_kl_div = getattr(self, 'epoch_kl_div', None)
        if epoch_kl_div is not None:
            self.writer.update_loss(loss=epoch_kl_div,
                                    step_i=epoch_i + 1,
                                    name='train_kl_div')

        kl_mult = getattr(self, 'kl_mult', None)
        if kl_mult is not None:
            self.writer.update_loss(loss=kl_mult,
                                    step_i=epoch_i + 1,
                                    name='kl_mult')

        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)
        if epoch_bow_loss is not None:
            self.writer.update_loss(loss=epoch_bow_loss,
                                    step_i=epoch_i + 1,
                                    name='bow_loss')

        validation_loss = getattr(self, 'validation_loss', None)
        if validation_loss is not None:
            self.writer.update_loss(loss=validation_loss,
                                    step_i=epoch_i + 1,
                                    name='validation_loss')

        average_bleu = getattr(self, "average_bleu", None)
        if average_bleu is not None:
            self.writer.update_loss(loss=average_bleu,
                                    step_i=epoch_i + 1,
                                    name='average_bleu')

        average_sequences = getattr(self, "average_sequences", None)
        if average_sequences is not None:
            self.writer.update_loss(loss=average_sequences,
                                    step_i=epoch_i + 1,
                                    name='average_sequences')

        average_levenshteins = getattr(self, "average_levenshteins", None)
        if average_levenshteins is not None:
            self.writer.update_loss(loss=average_levenshteins,
                                    step_i=epoch_i + 1,
                                    name='average_levenshteins')

    @time_desc_decorator('Training Start!')
    def train(self):
        epoch_loss_history = []
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            self.model.train()
            n_total_words = 0
            for batch_i, (conversations, conversation_length,
                          sentence_length) in enumerate(
                              tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                input_conversations = [conv[:-1] for conv in conversations]
                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                input_sentences = [
                    sent for conv in input_conversations for sent in conv
                ]
                target_sentences = [
                    sent for conv in target_conversations for sent in conv
                ]
                input_sentence_length = [
                    l for len_list in sentence_length for l in len_list[:-1]
                ]
                target_sentence_length = [
                    l for len_list in sentence_length for l in len_list[1:]
                ]
                input_conversation_length = [
                    l - 1 for l in conversation_length
                ]

                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits = self.model(input_sentences,
                                             input_sentence_length,
                                             input_conversation_length,
                                             target_sentences,
                                             decode=False)

                batch_loss, n_words = masked_cross_entropy(
                    sentence_logits, target_sentences, target_sentence_length)

                assert not isnan(batch_loss.item())
                batch_loss_history.append(batch_loss.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}'
                    )

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)

                # Run optimizer
                self.optimizer.step()

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

        self.save_model(self.config.n_epoch)

        return epoch_loss_history

    def generate_sentence(self, input_sentences, input_sentence_length,
                          input_conversation_length, target_sentences):
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        generated_sentences = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences,
                                         decode=True)

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'),
                  'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(
                    input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join(
                    [self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join([
                    'Input sentence: ' + input_sent,
                    'Ground truth: ' + target_sent,
                    'Generated response: ' + output_sent + '\n'
                ])
                f.write(s + '\n')
                print(s)
            print('')

    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            if batch_i == 0:
                self.generate_sentence(input_sentences, input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)

            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}\n'
        print(print_str)

        return epoch_loss

    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)

        print_str = f'Word perplexity : {word_perplexity:.3f}\n'
        print(print_str)

        return word_perplexity

    def embedding_metric(self):
        word2vec = getattr(self, 'word2vec', None)
        if word2vec is None:
            print('Loading word2vec model')
            word2vec = gensim.models.KeyedVectors.load_word2vec_format(
                word2vec_path, binary=True)
            self.word2vec = word2vec
        keys = word2vec.vocab
        self.model.eval()
        n_context = self.config.n_context
        n_sample_step = self.config.n_sample_step
        metric_average_history = []
        metric_extrema_history = []
        metric_greedy_history = []
        context_history = []
        sample_history = []
        n_sent = 0
        n_conv = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            conv_indices = [
                i for i in range(len(conversations))
                if len(conversations[i]) >= n_context + n_sample_step
            ]
            context = [
                c for i in conv_indices
                for c in [conversations[i][:n_context]]
            ]
            ground_truth = [
                c for i in conv_indices for c in
                [conversations[i][n_context:n_context + n_sample_step]]
            ]
            sentence_length = [
                c for i in conv_indices
                for c in [sentence_length[i][:n_context]]
            ]

            with torch.no_grad():
                context = to_var(torch.LongTensor(context))
                sentence_length = to_var(torch.LongTensor(sentence_length))

            samples = self.model.generate(context, sentence_length, n_context)

            context = context.data.cpu().numpy().tolist()
            samples = samples.data.cpu().numpy().tolist()
            context_history.append(context)
            sample_history.append(samples)

            samples = [[self.vocab.decode(sent) for sent in c]
                       for c in samples]
            ground_truth = [[self.vocab.decode(sent) for sent in c]
                            for c in ground_truth]

            samples = [sent for c in samples for sent in c]
            ground_truth = [sent for c in ground_truth for sent in c]

            samples = [[word2vec[s] for s in sent.split() if s in keys]
                       for sent in samples]
            ground_truth = [[word2vec[s] for s in sent.split() if s in keys]
                            for sent in ground_truth]

            indices = [
                i
                for i, s, g in zip(range(len(samples)), samples, ground_truth)
                if s != [] and g != []
            ]
            samples = [samples[i] for i in indices]
            ground_truth = [ground_truth[i] for i in indices]
            n = len(samples)
            n_sent += n

            metric_average = embedding_metric(samples, ground_truth, word2vec,
                                              'average')
            metric_extrema = embedding_metric(samples, ground_truth, word2vec,
                                              'extrema')
            metric_greedy = embedding_metric(samples, ground_truth, word2vec,
                                             'greedy')
            metric_average_history.append(metric_average)
            metric_extrema_history.append(metric_extrema)
            metric_greedy_history.append(metric_greedy)

        epoch_average = np.mean(np.concatenate(metric_average_history), axis=0)
        epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0)
        epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0)

        print('n_sentences:', n_sent)
        print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}'
        print(print_str)
        print('\n')

        return epoch_average, epoch_extrema, epoch_greedy
Exemple #2
0
class Solver(object):
    def __init__(self,
                 config=None,
                 train_loader=None,
                 test_loader=None,
                 valid_loader=None):
        """Class that Builds, Trains and Evaluates SCLSTM model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config.gpu
        self.vocab = pickle.load(open(p.word_vocab_pkl, 'rb'))
        self.kvoc = pickle.load(open(p.kwd_pkl, 'rb'))
        self.i2w = {i: w for i, w in enumerate(self.vocab)}  # index to vocab
        self.i2k = {i: k for i, k in enumerate(self.kvoc)}  # index to keyword
        self.w2i = {w: i for i, w in self.i2w.items()}

    def build(self):
        # Build Modules
        # self.device = torch.device('cuda:0,1')
        self.embedding = nn.Embedding(self.config.vocab_size,
                                      self.config.wemb_size,
                                      padding_idx=0)

        if True:
            weights_matrix = torch.FloatTensor(
                pickle.load(open(p.word_vec_pkl, 'rb')))
            self.embedding.from_pretrained(weights_matrix, freeze=False)
            self.embedding.weight.requires_grad = True

        self.w_hr_fw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])
        self.w_hr_bw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])

        self.w_wr = nn.Linear(self.config.wemb_size,
                              self.config.kwd_size,
                              bias=False)
        self.w_ho_fw = nn.Sequential(
            nn.Linear(self.config.hidden_size * self.config.num_layers,
                      self.config.vocab_size),
            #             nn.LogSoftmax(dim=-1)
        )
        self.w_ho_bw = nn.Linear(
            self.config.hidden_size * self.config.num_layers,
            self.config.vocab_size)
        self.sc_rnn_fw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.sc_rnn_bw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.model = nn.ModuleList([
            self.w_hr_fw, self.w_hr_bw, self.w_wr, self.w_ho_fw, self.w_ho_bw,
            self.sc_rnn_fw, self.sc_rnn_bw
        ])

        self.criterion = nn.CrossEntropyLoss(reduction='none')

        with torch.no_grad():
            self.hc_list_init = (Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False),
                                 Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False))

        #--- Init dirs for output ---
        self.current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        if self.config.mode == 'train':
            # Overview Parameters
            print('Init Model Parameters')
            for name, param in self.model.named_parameters():
                print('\t' + name + '\t', list(param.size()))
                if param.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(param.data)
                else:
                    nn.init.zeros_(param.data)

            # Tensorboard
            self.writer = TensorboardWriter(p.tb_dir + self.current_time)
            # Add emb-layer
            self.model.train()
            # create dir
            #             self.res_dir = p.result_path.format(p.dataname, self.current_time) # result dir
            self.cp_dir = p.check_point.format(
                p.dataname, self.current_time)  # checkpoint dir
            #             os.makedirs(self.res_dir)
            os.makedirs(self.cp_dir)

        #--- Setup output file ---
        self.out_file = open(
            p.out_result_dir.format(p.dataname, self.current_time), 'w')

        self.model.append(self.embedding)
        #         self.model.to(self.device)
        # Build Optimizers
        self.optimizer = optim.Adam(list(self.model.parameters()),
                                    lr=self.config.lr)
        print(self.model)

    def load_model(self, ep):
        _fname = (self.cp_dir if self.config.mode == 'train' else
                  self.config.resume_dir) + 'chk_point_{}.pth'.format(ep)
        if os.path.isfile(_fname):
            print("=> loading checkpoint '{}'".format(_fname))
            if self.config.load_cpu:
                checkpoint = torch.load(_fname,
                                        map_location=lambda storage, loc:
                                        storage)  # load into cpu-mode
            else:
                checkpoint = torch.load(_fname)  # gpu-mode
            self.start_epoch = checkpoint['epoch']
            # checkpoint['state_dict'].pop('1.s_lstm.out.0.bias',None) # remove bias in selector
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'][0])
        else:
            print("=> no checkpoint found at '{}'".format(_fname))

    def _zero_grads(self):
        self.optimizer.zero_grad()

    def save_checkpoint(self, state, filename):
        torch.save(state, filename)

    def get_norm_grad(self, module, norm_type=2):
        total_norm = 0
        for name, param in module.named_parameters():
            if param.grad is not None:
                total_norm += torch.sum(torch.pow(param.grad.view(-1), 2))
        return torch.sqrt(total_norm).data

    def one_step_fw(self, w_t, y_t, hc_list, d_t, rnn_model, w_hr, w_ho):
        h_tm1, _ = hc_list
        #--- Keyword detector ---
        res_hr = sum(
            [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
        r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr)
        d_t = r_t * d_t
        flat_h, hc_list = rnn_model(w_t, hc_list, d_t)

        with torch.no_grad():
            mask = Variable((y_t != 0).float(), requires_grad=False)
            assert not torch.isnan(mask).any()
        pred = w_ho(flat_h)
        llk_step = torch.mean(self.criterion(pred, y_t) * mask)
        l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1))
        assert not torch.isnan(llk_step).any()
        assert not torch.isnan(l1_step).any()
        return llk_step, l1_step, pred, hc_list, d_t

    def train_epoch(self):
        loss_list = []
        l1_list = []
        fw_list, bw_list = [], []
        for batch_i, doc_features in enumerate(
                tqdm(self.train_loader,
                     desc='Batch',
                     dynamic_ncols=True,
                     ascii=True)):
            self._zero_grads()
            doc, kwd = doc_features
            with torch.no_grad():
                var_doc = Variable(doc, requires_grad=False)
                var_kwd = Variable(kwd, requires_grad=False)

            doc_emb = self.embedding(var_doc)  # get word-emb

            #--- Word generation ---
            step_loss = []
            step_l1 = []

            #--- FW Stage ---
            hc_list = self.hc_list_init
            d_t = var_kwd
            for t in range(p.MAX_DOC_LEN - 1):
                w_t = doc_emb[:, t, :]
                y_t = var_doc[:, t + 1]
                #                 h_tm1, _ = hc_list

                #                 #--- Keyword detector ---
                #                 res_hr = sum([self.w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
                #                 r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha*res_hr)
                #                 d_t = r_t*d_t
                # #                 print hc_list[0].shape, w_t.shape, d_t.shape
                #                 flat_h, hc_list = self.sc_rnn(w_t, hc_list, d_t)

                #                 #--- Log LLK ---
                #                 with torch.no_grad():
                #                     mask = Variable((y_t!=0).float(), requires_grad=False)
                #                     assert not torch.isnan(mask).any()
                #                 pred = self.w_ho(flat_h)
                #                 llk_step = torch.mean(self.criterion(pred, y_t) * mask)
                #                 l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1))

                #                 assert not torch.isnan(llk_step).any()
                #                 assert not torch.isnan(l1_step).any()
                llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw(
                    w_t, y_t, hc_list, d_t, self.sc_rnn_fw, self.w_hr_fw,
                    self.w_ho_fw)
                p_pred, w_pred = torch.max(nn.LogSoftmax(dim=-1)(pred), dim=-1)
                #                 print [(self.i2w[i], v) for i, v in zip(w_pred.detach().cpu().numpy(), p_pred.detach().cpu().numpy())]

                step_loss.append(llk_step)
                step_l1.append(l1_step)

            fw_loss = sum(step_loss)
            fw_l1 = sum(step_l1) * self.config.eta
            batch_loss = fw_loss + fw_l1
            batch_loss.backward(retain_graph=True)

            #--- BW Stage ---
            torch.cuda.empty_cache()
            step_loss = []
            step_l1 = []
            hc_list = self.hc_list_init
            d_t = var_kwd
            for t in range(p.MAX_DOC_LEN - 1, 0, -1):
                w_t = doc_emb[:, t, :]
                y_t = var_doc[:, t - 1]
                llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw(
                    w_t, y_t, hc_list, d_t, self.sc_rnn_bw, self.w_hr_bw,
                    self.w_ho_bw)
                step_loss.append(llk_step)
                step_l1.append(l1_step)

            bw_loss = sum(step_loss)
            bw_l1 = sum(step_l1) * self.config.eta

            #--- BW for learning ---
            #             _loss = (fw_loss + bw_loss)/2.
            #             _l1 = (fw_l1 + bw_l1)/2.
            batch_loss = bw_loss + bw_l1
            batch_loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config.clip)
            self.optimizer.step()

            #--- tracking loss ---
            loss_list.append(0.5 * (fw_loss + bw_loss).cpu().data.numpy())
            l1_list.append(0.5 * (fw_l1 + bw_l1).cpu().data.numpy())
            fw_list.append(fw_loss.cpu().data.numpy())
            bw_list.append(bw_loss.cpu().data.numpy())

        return loss_list, l1_list, fw_list, bw_list

    def train(self):
        print('***Start training ...')
        for epoch_i in tqdm(range(self.config.n_epoch),
                            desc='Epoch',
                            dynamic_ncols=True,
                            ascii=True):
            loss_list, l1_list, fw_list, bw_list = self.train_epoch()
            # Save parameters at checkpoint
            if (epoch_i + 1) % self.config.eval_rate == 0:

                #--- Dump model ---
                if self.config.write_model:
                    # save model
                    self.save_checkpoint(
                        {
                            'epoch': epoch_i + 1,
                            'state_dict': self.model.state_dict(),
                            'total_loss': np.mean(loss_list),
                            'optimizer': [self.optimizer.state_dict()],
                        },
                        filename=self.cp_dir +
                        'chk_point_{}.pth'.format(epoch_i + 1))

                #--- Eval each step ---
                if self.config.is_eval:
                    self.evaluate(epoch_i + 1)

            print(
                '\n***Ep-{} | Total_loss: {} [FW/BW {}/{}] | D-L1: {} | NORM: {}'
                .format(epoch_i, np.mean(loss_list), np.mean(fw_list),
                        np.mean(bw_list), np.mean(l1_list),
                        self.get_norm_grad(self.model)))

            #             self.writer.update_parameters(self.model, epoch_i)
            self.writer.update_loss(np.mean(loss_list), epoch_i, 'total_loss')
            self.writer.update_loss(np.mean(l1_list), epoch_i, 'l1_reg')
            self.writer.update_loss(np.mean(fw_list), epoch_i, 'fw_loss')
            self.writer.update_loss(np.mean(bw_list), epoch_i, 'bw_loss')

    def gen_one_step(self, x, hc_list, d_t, rnn_model, w_hr, w_ho):
        with torch.no_grad():
            var_x = Variable(torch.LongTensor(x), requires_grad=False)
            d_t = Variable(d_t, requires_grad=False)
            hc_list = self.to_gpu(hc_list)

        w_t = self.embedding(var_x)
        h_tm1, _ = hc_list
        res_hr = sum(
            [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
        r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr)
        d_t = r_t * d_t
        flat_h, hc_list = rnn_model(w_t, hc_list, d_t)
        _prob = nn.LogSoftmax(dim=-1)(w_ho(flat_h))
        return _prob.detach().cpu().numpy().squeeze(), self.to_cpu(
            hc_list), d_t.detach().cpu()

    def get_top_index(self, _prob):
        # [b, vocab]
        _prob = np.exp(_prob)
        if self.config.is_sample:
            top_indices = np.random.choice(self.config.vocab_size,
                                           self.config.beam_size,
                                           replace=False,
                                           p=_prob.reshape(-1))
        else:
            top_indices = np.argsort(-_prob)

        return top_indices

    def to_cpu(self, _list):
        return tuple([m.detach().cpu() for m in _list])

    def to_gpu(self, _list):
        return tuple([Variable(m, requires_grad=False) for m in _list])

    def rerank(self, beams, d_t):
        def add_bw_score(w_list, d_t):
            #             import pdb; pdb.set_trace()
            with torch.no_grad():
                hc_list = (torch.zeros(self.config.num_layers, 1,
                                       self.config.hidden_size),
                           torch.zeros(self.config.num_layers, 1,
                                       self.config.hidden_size))
            w_list = [self.w2i[w] for w in w_list[::-1]]
            llk = 0.
            for i, w in enumerate(w_list[:-1]):
                _prob, hc_list, d_t = self.gen_one_step([w], hc_list, d_t,
                                                        self.sc_rnn_bw,
                                                        self.w_hr_bw,
                                                        self.w_ho_bw)
                llk += _prob[w_list[i + 1]]
            return llk / (len(w_list) - 1)

        for i, b in enumerate(beams):
            #             import pdb; pdb.set_trace()
            beams[i] = tuple([0.5 *
                              (b[0] + add_bw_score(b[1], d_t))]) + tuple(b[1:])

        return beams

    def evaluate(self, epoch_i):
        #--- load model ---
        self.load_model(epoch_i)
        self.model.eval()
        for r_id, doc_features in enumerate(
                tqdm(self.test_loader,
                     desc='Test',
                     dynamic_ncols=True,
                     ascii=True)):
            _, d_t = doc_features
            try:
                if torch.sum(d_t) == 0:
                    continue
                #--- Gen 1st step ---
                with torch.no_grad():
                    hc_list = (torch.zeros(self.config.num_layers, 1,
                                           self.config.hidden_size),
                               torch.zeros(self.config.num_layers, 1,
                                           self.config.hidden_size))

                b = (0.0, [self.i2w[1]], [1], hc_list, d_t)
                _prob, hc_list, d_t = self.gen_one_step(
                    b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw,
                    self.w_ho_fw)
                top_indices = self.get_top_index(_prob)
                beam_candidates = []
                for i in range(self.config.beam_size):
                    wordix = top_indices[i]
                    beam_candidates.append(
                        (b[0] + _prob[wordix], b[1] + [self.i2w[wordix]],
                         [wordix], hc_list, d_t))

                #--- Gen the whole sentence ---
                beams = beam_candidates[:self.config.beam_size]
                for t in range(self.config.gen_size - 1):
                    beam_candidates = []
                    for b in beams:
                        _prob, hc_list, d_t = self.gen_one_step(
                            b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw,
                            self.w_ho_fw)
                        top_indices = self.get_top_index(_prob)

                        for i in range(self.config.beam_size):
                            #--- already EOS ---
                            if b[2] == [2]:
                                beam_candidates.append(b)
                                break
                            wordix = top_indices[i]
                            beam_candidates.append((b[0] + _prob[wordix],
                                                    b[1] + [self.i2w[wordix]],
                                                    [wordix], hc_list, d_t))

                    beam_candidates.sort(key=lambda x: x[0] / (len(x[1]) - 1),
                                         reverse=True)  # decreasing order
                    beams = beam_candidates[:self.config.
                                            beam_size]  # truncate to get new beams

                #--- RERANK beams ---
                beams = self.rerank(beams, doc_features[1])
                beams.sort(key=lambda x: x[0], reverse=True)

                res = "[*]EP_{}_KW_[{}]_SENT_[{}]\n".format(
                    epoch_i, ' '.join([
                        self.i2k[int(j)] for j in torch.flatten(
                            torch.nonzero(doc_features[1][0])).numpy()
                    ]), ' '.join(beams[0][1]))
                print(res)
                self.out_file.write(res)
                self.out_file.flush()
            except Exception as e:
                print('Exception: ', str(e))
                pass


#         self.out_file.close()

        self.model.train()
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 vocab,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model
        self.writer = None
        self.optimizer = None
        self.epoch_loss = None
        self.validation_loss = None

    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        chpt = torch.load(checkpoint)
        new_state_dict = OrderedDict()
        for k, v in chpt.items():
            name = k[7:] if k.startswith(
                "module.") else k  #remove 'module.' of DataParallel
            new_state_dict[name] = v
        self.model.load_state_dict(new_state_dict)

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)
        if epoch_recon_loss is not None:
            self.writer.update_loss(loss=epoch_recon_loss,
                                    step_i=epoch_i + 1,
                                    name='train_recon_loss')

        epoch_kl_div = getattr(self, 'epoch_kl_div', None)
        if epoch_kl_div is not None:
            self.writer.update_loss(loss=epoch_kl_div,
                                    step_i=epoch_i + 1,
                                    name='train_kl_div')

        kl_mult = getattr(self, 'kl_mult', None)
        if kl_mult is not None:
            self.writer.update_loss(loss=kl_mult,
                                    step_i=epoch_i + 1,
                                    name='kl_mult')

        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)
        if epoch_bow_loss is not None:
            self.writer.update_loss(loss=epoch_bow_loss,
                                    step_i=epoch_i + 1,
                                    name='bow_loss')

        validation_loss = getattr(self, 'validation_loss', None)
        if validation_loss is not None:
            self.writer.update_loss(loss=validation_loss,
                                    step_i=epoch_i + 1,
                                    name='validation_loss')

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError

    def export_samples(self, beam_size=5):
        raise NotImplementedError
Exemple #4
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates AC-SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.actor = Actor(state_size=self.config.action_state_size,
                           action_size=self.config.action_state_size).cuda()
        self.critic = Critic(state_size=self.config.action_state_size,
                             action_size=self.config.action_state_size).cuda()
        self.model = nn.ModuleList([
            self.linear_compress, self.summarizer, self.discriminator,
            self.actor, self.critic
        ])

        if self.config.mode == 'train':
            # Build Optimizers
            self.e_optimizer = optim.Adam(
                self.summarizer.vae.e_lstm.parameters(), lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                self.summarizer.vae.d_lstm.parameters(), lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            self.optimizerA_s = optim.Adam(
                list(self.actor.parameters()) +
                list(self.summarizer.s_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.optimizerC = optim.Adam(self.critic.parameters(),
                                         lr=self.config.lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def AC(self, original_features, seq_len, action_fragments):
        """ Function that makes the actor's actions, in the training steps where the actor and critic components are not trained"""
        scores = self.summarizer.s_lstm(original_features)  # [seq_len, 1]

        fragment_scores = np.zeros(
            self.config.action_state_size)  # [num_fragments, 1]
        for fragment in range(self.config.action_state_size):
            fragment_scores[fragment] = scores[action_fragments[
                fragment, 0]:action_fragments[fragment, 1] + 1].mean()
        state = fragment_scores

        previous_actions = [
        ]  # save all the actions (the selected fragments of each episode)
        reduction_factor = (
            self.config.action_state_size -
            self.config.termination_point) / self.config.action_state_size
        action_scores = (torch.ones(seq_len) * reduction_factor).cuda()
        action_fragment_scores = (torch.ones(
            self.config.action_state_size)).cuda()

        counter = 0
        for ACstep in range(self.config.termination_point):

            state = torch.FloatTensor(state).cuda()
            # select an action
            dist = self.actor(state)
            action = dist.sample(
            )  # returns a scalar between 0-action_state_size

            if action not in previous_actions:
                previous_actions.append(action)
                action_factor = (self.config.termination_point - counter) / (
                    self.config.action_state_size - counter) + 1

                action_scores[action_fragments[action,
                                               0]:action_fragments[action, 1] +
                              1] = action_factor
                action_fragment_scores[action] = 0

                counter = counter + 1

            next_state = state * action_fragment_scores
            next_state = next_state.cpu().detach().numpy()
            state = next_state

        weighted_scores = action_scores.unsqueeze(1) * scores
        weighted_features = weighted_scores.view(-1, 1, 1) * original_features

        return weighted_features, weighted_scores

    def train(self):

        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            self.model.train()
            recon_loss_init_history = []
            recon_loss_history = []
            sparsity_loss_history = []
            prior_loss_history = []
            g_loss_history = []
            e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            actor_loss_history = []
            critic_loss_history = []
            reward_history = []

            # Train in batches of as many videos as the batch_size
            num_batches = int(len(self.train_loader) / self.config.batch_size)
            iterator = iter(self.train_loader)
            for batch in range(num_batches):
                list_image_features = []
                list_action_fragments = []

                print(f'batch: {batch}')

                # ---- Train eLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training eLSTM...')
                self.e_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features, action_fragments = next(iterator)

                    action_fragments = action_fragments.squeeze(0)
                    # [batch_size, seq_len, input_size]
                    # [seq_len, input_size]
                    image_features = image_features.view(
                        -1, self.config.input_size)

                    list_image_features.append(image_features)
                    list_action_fragments.append(action_fragments)

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, scores = self.AC(
                        original_features, seq_len, action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    if self.config.verbose:
                        tqdm.write(
                            f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                        )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    prior_loss = self.prior_loss(h_mu, h_log_variance)

                    tqdm.write(
                        f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}'
                    )

                    e_loss = reconstruction_loss + prior_loss
                    e_loss = e_loss / self.config.batch_size
                    e_loss.backward()

                    prior_loss_history.append(prior_loss.data)
                    e_loss_history.append(e_loss.data)

                # Update e_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.e_lstm.parameters(), self.config.clip)
                self.e_optimizer.step()

                #---- Train dLSTM (decoder/generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')
                self.d_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    g_loss = self.criterion(sum_prob, original_label)

                    orig_features = original_features.squeeze(
                        1)  # [seq_len, hidden_size]
                    gen_features = generated_features.squeeze(1)  #         >>
                    recon_losses = []
                    for frame_index in range(seq_len):
                        recon_losses.append(
                            self.reconstruction_loss(
                                orig_features[frame_index, :],
                                gen_features[frame_index, :]))
                    reconstruction_loss_init = torch.stack(recon_losses).mean()

                    if self.config.verbose:
                        tqdm.write(
                            f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                        )

                    d_loss = reconstruction_loss + g_loss
                    d_loss = d_loss / self.config.batch_size
                    d_loss.backward()

                    recon_loss_init_history.append(
                        reconstruction_loss_init.data)
                    recon_loss_history.append(reconstruction_loss.data)
                    g_loss_history.append(g_loss.data)
                    d_loss_history.append(d_loss.data)

                # Update d_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.d_lstm.parameters(), self.config.clip)
                self.d_optimizer.step()

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')
                self.c_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # Train with original loss
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    h_origin, original_prob = self.discriminator(
                        original_features)
                    c_original_loss = self.criterion(original_prob,
                                                     original_label)
                    c_original_loss = c_original_loss / self.config.batch_size
                    c_original_loss.backward()

                    # Train with summary loss
                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)
                    h_sum, sum_prob = self.discriminator(
                        generated_features.detach())
                    c_summary_loss = self.criterion(sum_prob, summary_label)
                    c_summary_loss = c_summary_loss / self.config.batch_size
                    c_summary_loss.backward()

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    c_original_loss_history.append(c_original_loss.data)
                    c_summary_loss_history.append(c_summary_loss.data)

                # Update c_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.discriminator.parameters()) +
                    list(self.linear_compress.parameters()), self.config.clip)
                self.c_optimizer.step()

                #---- Train sLSTM and actor-critic ----#
                if self.config.verbose:
                    tqdm.write('Training sLSTM, actor and critic...')
                self.optimizerA_s.zero_grad()
                self.optimizerC.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    scores = self.summarizer.s_lstm(
                        original_features)  # [seq_len, 1]

                    fragment_scores = np.zeros(
                        self.config.action_state_size)  # [num_fragments, 1]
                    for fragment in range(self.config.action_state_size):
                        fragment_scores[fragment] = scores[action_fragments[
                            fragment,
                            0]:action_fragments[fragment, 1] + 1].mean()

                    state = fragment_scores  # [action_state_size, 1]

                    previous_actions = [
                    ]  # save all the actions (the selected fragments of each step)
                    reduction_factor = (self.config.action_state_size -
                                        self.config.termination_point
                                        ) / self.config.action_state_size
                    action_scores = (torch.ones(seq_len) *
                                     reduction_factor).cuda()
                    action_fragment_scores = (torch.ones(
                        self.config.action_state_size)).cuda()

                    log_probs = []
                    values = []
                    rewards = []
                    masks = []
                    entropy = 0

                    counter = 0
                    for ACstep in range(self.config.termination_point):
                        # select an action, get a value for the current state
                        state = torch.FloatTensor(
                            state).cuda()  # [action_state_size, 1]
                        dist, value = self.actor(state), self.critic(state)
                        action = dist.sample(
                        )  # returns a scalar between 0-action_state_size

                        if action in previous_actions:

                            reward = 0

                        else:

                            previous_actions.append(action)
                            action_factor = (
                                self.config.termination_point - counter
                            ) / (self.config.action_state_size - counter) + 1

                            action_scores[action_fragments[
                                action, 0]:action_fragments[action, 1] +
                                          1] = action_factor
                            action_fragment_scores[action] = 0

                            weighted_scores = action_scores.unsqueeze(
                                1) * scores
                            weighted_features = weighted_scores.view(
                                -1, 1, 1) * original_features

                            h_mu, h_log_variance, generated_features = self.summarizer.vae(
                                weighted_features)

                            h_origin, original_prob = self.discriminator(
                                original_features)
                            h_sum, sum_prob = self.discriminator(
                                generated_features)

                            tqdm.write(
                                f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                            )

                            rec_loss = self.reconstruction_loss(
                                h_origin, h_sum)
                            reward = 1 - rec_loss.item(
                            )  # the less the distance, the higher the reward
                            counter = counter + 1

                        next_state = state * action_fragment_scores
                        next_state = next_state.cpu().detach().numpy()

                        log_prob = dist.log_prob(action).unsqueeze(0)
                        entropy += dist.entropy().mean()

                        log_probs.append(log_prob)
                        values.append(value)
                        rewards.append(
                            torch.tensor([reward],
                                         dtype=torch.float,
                                         device=device))

                        if ACstep == self.config.termination_point - 1:
                            masks.append(
                                torch.tensor([0],
                                             dtype=torch.float,
                                             device=device))
                        else:
                            masks.append(
                                torch.tensor([1],
                                             dtype=torch.float,
                                             device=device))

                        state = next_state

                    next_state = torch.FloatTensor(next_state).to(device)
                    next_value = self.critic(next_state)
                    returns = compute_returns(next_value, rewards, masks)

                    log_probs = torch.cat(log_probs)
                    returns = torch.cat(returns).detach()
                    values = torch.cat(values)

                    advantage = returns - values

                    actor_loss = -((log_probs * advantage.detach()).mean() +
                                   (self.config.entropy_coef /
                                    self.config.termination_point) * entropy)
                    sparsity_loss = self.sparsity_loss(scores)
                    critic_loss = advantage.pow(2).mean()

                    actor_loss = actor_loss / self.config.batch_size
                    sparsity_loss = sparsity_loss / self.config.batch_size
                    critic_loss = critic_loss / self.config.batch_size
                    actor_loss.backward()
                    sparsity_loss.backward()
                    critic_loss.backward()

                    reward_mean = torch.mean(torch.stack(rewards))
                    reward_history.append(reward_mean)
                    actor_loss_history.append(actor_loss)
                    sparsity_loss_history.append(sparsity_loss)
                    critic_loss_history.append(critic_loss)

                    if self.config.verbose:
                        tqdm.write('Plotting...')

                    self.writer.update_loss(original_prob.data, step,
                                            'original_prob')
                    self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                    step += 1

                # Update s_lstm, actor and critic parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.actor.parameters()) +
                    list(self.linear_compress.parameters()) +
                    list(self.summarizer.s_lstm.parameters()) +
                    list(self.critic.parameters()), self.config.clip)
                self.optimizerA_s.step()
                self.optimizerC.step()

            recon_loss_init = torch.stack(recon_loss_init_history).mean()
            recon_loss = torch.stack(recon_loss_history).mean()
            prior_loss = torch.stack(prior_loss_history).mean()
            g_loss = torch.stack(g_loss_history).mean()
            e_loss = torch.stack(e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()
            sparsity_loss = torch.stack(sparsity_loss_history).mean()
            actor_loss = torch.stack(actor_loss_history).mean()
            critic_loss = torch.stack(critic_loss_history).mean()
            reward = torch.mean(torch.stack(reward_history))

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(recon_loss_init, epoch_i,
                                    'recon_loss_init_epoch')
            self.writer.update_loss(recon_loss, epoch_i, 'recon_loss_epoch')
            self.writer.update_loss(prior_loss, epoch_i, 'prior_loss_epoch')
            self.writer.update_loss(g_loss, epoch_i, 'g_loss_epoch')
            self.writer.update_loss(e_loss, epoch_i, 'e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, epoch_i,
                                    'c_original_loss_epoch')
            self.writer.update_loss(c_summary_loss, epoch_i,
                                    'c_summary_loss_epoch')
            self.writer.update_loss(sparsity_loss, epoch_i,
                                    'sparsity_loss_epoch')
            self.writer.update_loss(actor_loss, epoch_i, 'actor_loss_epoch')
            self.writer.update_loss(critic_loss, epoch_i, 'critic_loss_epoch')
            self.writer.update_loss(reward, epoch_i, 'reward_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            if self.config.verbose:
                tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for image_features, video_name, action_fragments in tqdm(
                self.test_loader, desc='Evaluate', ncols=80, leave=False):
            # [seq_len, batch_size=1, input_size)]
            image_features = image_features.view(-1, self.config.input_size)
            image_features_ = Variable(image_features).cuda()

            # [seq_len, 1, hidden_size]
            original_features = self.linear_compress(
                image_features_.detach()).unsqueeze(1)
            seq_len = original_features.shape[0]

            with torch.no_grad():

                _, scores = self.AC(original_features, seq_len,
                                    action_fragments)

                scores = scores.squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                if self.config.verbose:
                    tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)
Exemple #5
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.is_train = is_train
        self.model = model
        self.writer = None
        self.optimizer = None
        self.epoch_loss = None
        self.validation_loss = None
        self.true_scores = 0
        self.false_scores = 0
        self.eval_epoch_loss = 0

    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        raise NotImplementedError

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError
Exemple #6
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN-sl model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                self.model.train()

                # [batch_size=1, seq_len, 1024]
                # [seq_len, 1024]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 1024]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}, sparsity loss: {sparsity_loss.item():.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM (generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                g_loss = self.criterion(sum_prob, original_label)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                )

                d_loss = reconstruction_loss + g_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')

                self.c_optimizer.zero_grad()

                # Train with original loss
                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)
                h_origin, original_prob = self.discriminator(original_features)
                c_original_loss = self.criterion(original_prob, original_label)
                c_original_loss.backward()

                # Train with summary loss
                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                h_sum, sum_prob = self.discriminator(
                    generated_features.detach())
                c_summary_loss = self.criterion(sum_prob, summary_label)
                c_summary_loss.backward()

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )
                tqdm.write(f'gen loss: {g_loss.item():.3f}')

                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.c_optimizer.step()

                c_original_loss_history.append(c_original_loss.data)
                c_summary_loss_history.append(c_summary_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(g_loss.data, step, 'gen_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, step, 'c_original_loss')
            self.writer.update_loss(c_summary_loss, step, 'c_summary_loss')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 1024]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            with torch.no_grad():
                scores = self.summarizer.s_lstm(video_feature).squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            self.writer = TensorboardWriter(self.config.log_dir)

    @staticmethod
    def freeze_model(module):
        for p in module.parameters():
            p.requires_grad = False

    def reconstruction_loss(self, h_origin, h_fake):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_fake, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(torch.mean(scores) - self.config.summary_rate)

    def gan_loss(self, original_prob, fake_prob, uniform_prob):
        """Typical GAN loss + Classify uniformly scored features"""

        gan_loss = torch.mean(
            torch.log(original_prob) + torch.log(1 - fake_prob) +
            torch.log(1 - uniform_prob))  # Discriminate uniform score

        return gan_loss

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                if image_features.size(1) > 10000:
                    continue

                # [batch_size=1, seq_len, 2048]
                # [seq_len, 2048]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 2048]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, prior loss: {prior_loss.data[0]:.3f}, sparsity loss: {sparsity_loss.data[0]:.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                gan_loss = self.gan_loss(original_prob, fake_prob,
                                         uniform_prob)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, gan loss: {gan_loss.data[0]:.3f}'
                )

                d_loss = reconstruction_loss + gan_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if batch_i > self.config.discriminator_slow_start:
                    if self.config.verbose:
                        tqdm.write('Training cLSTM...')
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    scores, h_mu, h_log_variance, generated_features = self.summarizer(
                        original_features)
                    _, _, _, uniform_features = self.summarizer(
                        original_features, uniform=True)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_fake, fake_prob = self.discriminator(generated_features)
                    h_uniform, uniform_prob = self.discriminator(
                        uniform_features)
                    tqdm.write(
                        f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                    )

                    # Maximization
                    c_loss = -1 * self.gan_loss(original_prob, fake_prob,
                                                uniform_prob)

                    tqdm.write(f'gan loss: {gan_loss.data[0]:.3f}')

                    self.c_optimizer.zero_grad()
                    c_loss.backward()
                    # Gradient cliping
                    torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                                  self.config.clip)
                    self.c_optimizer.step()

                    c_loss_history.append(c_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(gan_loss.data, step, 'gan_loss')

                # self.writer.update_loss(s_e_loss.data, step, 's_e_loss')
                # self.writer.update_loss(d_loss.data, step, 'd_loss')
                # self.writer.update_loss(c_loss.data, step, 'c_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(fake_prob.data, step, 'fake_prob')
                self.writer.update_loss(uniform_prob.data, step,
                                        'uniform_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_loss = torch.stack(c_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_loss, epoch_i, 'c_loss_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'_epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

            self.model.train()

    def evaluate(self, epoch_i):
        # checkpoint = self.config.ckpt_path
        # print(f'Load parameters from {checkpoint}')
        # self.model.load_state_dict(torch.load(checkpoint))

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 2048]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor, volatile=True).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            scores = self.summarizer.s_lstm(video_feature).squeeze(1)

            scores = np.array(scores.data).tolist()

            out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass
Exemple #8
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Build Graph')
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            if self.config.optimizer is None:
                # AdamW
                no_decay = ['bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [{
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.01
                }, {
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.0
                }]
                self.optimizer = AdamW(optimizer_grouped_parameters,
                                       lr=self.config.learning_rate)
            else:
                self.optimizer = self.config.optimizer(
                    filter(lambda p: p.requires_grad, self.model.parameters()),
                    lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        train_acc = getattr(self, 'train_acc', None)
        if train_acc is not None:
            self.writer.update_loss(loss=train_acc,
                                    step_i=epoch_i + 1,
                                    name='train_acc')

        validation_acc = getattr(self, 'validation_acc', None)
        if validation_acc is not None:
            self.writer.update_loss(loss=validation_acc,
                                    step_i=epoch_i + 1,
                                    name='validation_acc')

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self, is_print=True):
        raise NotImplementedError

    def _calc_accuracy(self, x, y):
        max_vals, max_indices = torch.max(x, 1)
        train_acc = (max_indices
                     == y).sum().data.cpu().numpy() / max_indices.size()[0]

        return train_acc