Exemplo n.º 1
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))
Exemplo n.º 2
0
    def build(self):
        # 내가 추가한 코드
        torch.backends.cudnn.enabled = False

        # 내가 추가한 코드 / GPU 정보
        USE_CUDA = torch.cuda.is_available()
        print(USE_CUDA)
        device = torch.device('cuda:0' if USE_CUDA else 'cpu')
        print('학습을 진행하는 기기:', device)
        print('cuda index:', torch.cuda.current_device())
        print('gpu 개수:', torch.cuda.device_count())
        print('graphic name:', torch.cuda.get_device_name())
        # setting device on GPU if available, else CPU
        print('Using device:', device)

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            print(self.model)
            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # VAE만 학습시키기
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))
            #     if 'vae' not in name:
            #         param.requires_grad = False
            #     print('\t train: ' + '\t', param.requires_grad)

            # Tensorboard 주석처리 내가 했음
            self.writer = TensorboardWriter(self.config.log_dir)
Exemplo n.º 3
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            import ipdb
            ipdb.set_trace()
            self.writer = TensorboardWriter(self.config.log_dir)
Exemplo n.º 4
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates AC-SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.actor = Actor(state_size=self.config.action_state_size,
                           action_size=self.config.action_state_size).cuda()
        self.critic = Critic(state_size=self.config.action_state_size,
                             action_size=self.config.action_state_size).cuda()
        self.model = nn.ModuleList([
            self.linear_compress, self.summarizer, self.discriminator,
            self.actor, self.critic
        ])

        if self.config.mode == 'train':
            # Build Optimizers
            self.e_optimizer = optim.Adam(
                self.summarizer.vae.e_lstm.parameters(), lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                self.summarizer.vae.d_lstm.parameters(), lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            self.optimizerA_s = optim.Adam(
                list(self.actor.parameters()) +
                list(self.summarizer.s_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.optimizerC = optim.Adam(self.critic.parameters(),
                                         lr=self.config.lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def AC(self, original_features, seq_len, action_fragments):
        """ Function that makes the actor's actions, in the training steps where the actor and critic components are not trained"""
        scores = self.summarizer.s_lstm(original_features)  # [seq_len, 1]

        fragment_scores = np.zeros(
            self.config.action_state_size)  # [num_fragments, 1]
        for fragment in range(self.config.action_state_size):
            fragment_scores[fragment] = scores[action_fragments[
                fragment, 0]:action_fragments[fragment, 1] + 1].mean()
        state = fragment_scores

        previous_actions = [
        ]  # save all the actions (the selected fragments of each episode)
        reduction_factor = (
            self.config.action_state_size -
            self.config.termination_point) / self.config.action_state_size
        action_scores = (torch.ones(seq_len) * reduction_factor).cuda()
        action_fragment_scores = (torch.ones(
            self.config.action_state_size)).cuda()

        counter = 0
        for ACstep in range(self.config.termination_point):

            state = torch.FloatTensor(state).cuda()
            # select an action
            dist = self.actor(state)
            action = dist.sample(
            )  # returns a scalar between 0-action_state_size

            if action not in previous_actions:
                previous_actions.append(action)
                action_factor = (self.config.termination_point - counter) / (
                    self.config.action_state_size - counter) + 1

                action_scores[action_fragments[action,
                                               0]:action_fragments[action, 1] +
                              1] = action_factor
                action_fragment_scores[action] = 0

                counter = counter + 1

            next_state = state * action_fragment_scores
            next_state = next_state.cpu().detach().numpy()
            state = next_state

        weighted_scores = action_scores.unsqueeze(1) * scores
        weighted_features = weighted_scores.view(-1, 1, 1) * original_features

        return weighted_features, weighted_scores

    def train(self):

        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            self.model.train()
            recon_loss_init_history = []
            recon_loss_history = []
            sparsity_loss_history = []
            prior_loss_history = []
            g_loss_history = []
            e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            actor_loss_history = []
            critic_loss_history = []
            reward_history = []

            # Train in batches of as many videos as the batch_size
            num_batches = int(len(self.train_loader) / self.config.batch_size)
            iterator = iter(self.train_loader)
            for batch in range(num_batches):
                list_image_features = []
                list_action_fragments = []

                print(f'batch: {batch}')

                # ---- Train eLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training eLSTM...')
                self.e_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features, action_fragments = next(iterator)

                    action_fragments = action_fragments.squeeze(0)
                    # [batch_size, seq_len, input_size]
                    # [seq_len, input_size]
                    image_features = image_features.view(
                        -1, self.config.input_size)

                    list_image_features.append(image_features)
                    list_action_fragments.append(action_fragments)

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, scores = self.AC(
                        original_features, seq_len, action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    if self.config.verbose:
                        tqdm.write(
                            f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                        )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    prior_loss = self.prior_loss(h_mu, h_log_variance)

                    tqdm.write(
                        f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}'
                    )

                    e_loss = reconstruction_loss + prior_loss
                    e_loss = e_loss / self.config.batch_size
                    e_loss.backward()

                    prior_loss_history.append(prior_loss.data)
                    e_loss_history.append(e_loss.data)

                # Update e_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.e_lstm.parameters(), self.config.clip)
                self.e_optimizer.step()

                #---- Train dLSTM (decoder/generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')
                self.d_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    g_loss = self.criterion(sum_prob, original_label)

                    orig_features = original_features.squeeze(
                        1)  # [seq_len, hidden_size]
                    gen_features = generated_features.squeeze(1)  #         >>
                    recon_losses = []
                    for frame_index in range(seq_len):
                        recon_losses.append(
                            self.reconstruction_loss(
                                orig_features[frame_index, :],
                                gen_features[frame_index, :]))
                    reconstruction_loss_init = torch.stack(recon_losses).mean()

                    if self.config.verbose:
                        tqdm.write(
                            f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                        )

                    d_loss = reconstruction_loss + g_loss
                    d_loss = d_loss / self.config.batch_size
                    d_loss.backward()

                    recon_loss_init_history.append(
                        reconstruction_loss_init.data)
                    recon_loss_history.append(reconstruction_loss.data)
                    g_loss_history.append(g_loss.data)
                    d_loss_history.append(d_loss.data)

                # Update d_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.d_lstm.parameters(), self.config.clip)
                self.d_optimizer.step()

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')
                self.c_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # Train with original loss
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    h_origin, original_prob = self.discriminator(
                        original_features)
                    c_original_loss = self.criterion(original_prob,
                                                     original_label)
                    c_original_loss = c_original_loss / self.config.batch_size
                    c_original_loss.backward()

                    # Train with summary loss
                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)
                    h_sum, sum_prob = self.discriminator(
                        generated_features.detach())
                    c_summary_loss = self.criterion(sum_prob, summary_label)
                    c_summary_loss = c_summary_loss / self.config.batch_size
                    c_summary_loss.backward()

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    c_original_loss_history.append(c_original_loss.data)
                    c_summary_loss_history.append(c_summary_loss.data)

                # Update c_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.discriminator.parameters()) +
                    list(self.linear_compress.parameters()), self.config.clip)
                self.c_optimizer.step()

                #---- Train sLSTM and actor-critic ----#
                if self.config.verbose:
                    tqdm.write('Training sLSTM, actor and critic...')
                self.optimizerA_s.zero_grad()
                self.optimizerC.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    scores = self.summarizer.s_lstm(
                        original_features)  # [seq_len, 1]

                    fragment_scores = np.zeros(
                        self.config.action_state_size)  # [num_fragments, 1]
                    for fragment in range(self.config.action_state_size):
                        fragment_scores[fragment] = scores[action_fragments[
                            fragment,
                            0]:action_fragments[fragment, 1] + 1].mean()

                    state = fragment_scores  # [action_state_size, 1]

                    previous_actions = [
                    ]  # save all the actions (the selected fragments of each step)
                    reduction_factor = (self.config.action_state_size -
                                        self.config.termination_point
                                        ) / self.config.action_state_size
                    action_scores = (torch.ones(seq_len) *
                                     reduction_factor).cuda()
                    action_fragment_scores = (torch.ones(
                        self.config.action_state_size)).cuda()

                    log_probs = []
                    values = []
                    rewards = []
                    masks = []
                    entropy = 0

                    counter = 0
                    for ACstep in range(self.config.termination_point):
                        # select an action, get a value for the current state
                        state = torch.FloatTensor(
                            state).cuda()  # [action_state_size, 1]
                        dist, value = self.actor(state), self.critic(state)
                        action = dist.sample(
                        )  # returns a scalar between 0-action_state_size

                        if action in previous_actions:

                            reward = 0

                        else:

                            previous_actions.append(action)
                            action_factor = (
                                self.config.termination_point - counter
                            ) / (self.config.action_state_size - counter) + 1

                            action_scores[action_fragments[
                                action, 0]:action_fragments[action, 1] +
                                          1] = action_factor
                            action_fragment_scores[action] = 0

                            weighted_scores = action_scores.unsqueeze(
                                1) * scores
                            weighted_features = weighted_scores.view(
                                -1, 1, 1) * original_features

                            h_mu, h_log_variance, generated_features = self.summarizer.vae(
                                weighted_features)

                            h_origin, original_prob = self.discriminator(
                                original_features)
                            h_sum, sum_prob = self.discriminator(
                                generated_features)

                            tqdm.write(
                                f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                            )

                            rec_loss = self.reconstruction_loss(
                                h_origin, h_sum)
                            reward = 1 - rec_loss.item(
                            )  # the less the distance, the higher the reward
                            counter = counter + 1

                        next_state = state * action_fragment_scores
                        next_state = next_state.cpu().detach().numpy()

                        log_prob = dist.log_prob(action).unsqueeze(0)
                        entropy += dist.entropy().mean()

                        log_probs.append(log_prob)
                        values.append(value)
                        rewards.append(
                            torch.tensor([reward],
                                         dtype=torch.float,
                                         device=device))

                        if ACstep == self.config.termination_point - 1:
                            masks.append(
                                torch.tensor([0],
                                             dtype=torch.float,
                                             device=device))
                        else:
                            masks.append(
                                torch.tensor([1],
                                             dtype=torch.float,
                                             device=device))

                        state = next_state

                    next_state = torch.FloatTensor(next_state).to(device)
                    next_value = self.critic(next_state)
                    returns = compute_returns(next_value, rewards, masks)

                    log_probs = torch.cat(log_probs)
                    returns = torch.cat(returns).detach()
                    values = torch.cat(values)

                    advantage = returns - values

                    actor_loss = -((log_probs * advantage.detach()).mean() +
                                   (self.config.entropy_coef /
                                    self.config.termination_point) * entropy)
                    sparsity_loss = self.sparsity_loss(scores)
                    critic_loss = advantage.pow(2).mean()

                    actor_loss = actor_loss / self.config.batch_size
                    sparsity_loss = sparsity_loss / self.config.batch_size
                    critic_loss = critic_loss / self.config.batch_size
                    actor_loss.backward()
                    sparsity_loss.backward()
                    critic_loss.backward()

                    reward_mean = torch.mean(torch.stack(rewards))
                    reward_history.append(reward_mean)
                    actor_loss_history.append(actor_loss)
                    sparsity_loss_history.append(sparsity_loss)
                    critic_loss_history.append(critic_loss)

                    if self.config.verbose:
                        tqdm.write('Plotting...')

                    self.writer.update_loss(original_prob.data, step,
                                            'original_prob')
                    self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                    step += 1

                # Update s_lstm, actor and critic parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.actor.parameters()) +
                    list(self.linear_compress.parameters()) +
                    list(self.summarizer.s_lstm.parameters()) +
                    list(self.critic.parameters()), self.config.clip)
                self.optimizerA_s.step()
                self.optimizerC.step()

            recon_loss_init = torch.stack(recon_loss_init_history).mean()
            recon_loss = torch.stack(recon_loss_history).mean()
            prior_loss = torch.stack(prior_loss_history).mean()
            g_loss = torch.stack(g_loss_history).mean()
            e_loss = torch.stack(e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()
            sparsity_loss = torch.stack(sparsity_loss_history).mean()
            actor_loss = torch.stack(actor_loss_history).mean()
            critic_loss = torch.stack(critic_loss_history).mean()
            reward = torch.mean(torch.stack(reward_history))

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(recon_loss_init, epoch_i,
                                    'recon_loss_init_epoch')
            self.writer.update_loss(recon_loss, epoch_i, 'recon_loss_epoch')
            self.writer.update_loss(prior_loss, epoch_i, 'prior_loss_epoch')
            self.writer.update_loss(g_loss, epoch_i, 'g_loss_epoch')
            self.writer.update_loss(e_loss, epoch_i, 'e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, epoch_i,
                                    'c_original_loss_epoch')
            self.writer.update_loss(c_summary_loss, epoch_i,
                                    'c_summary_loss_epoch')
            self.writer.update_loss(sparsity_loss, epoch_i,
                                    'sparsity_loss_epoch')
            self.writer.update_loss(actor_loss, epoch_i, 'actor_loss_epoch')
            self.writer.update_loss(critic_loss, epoch_i, 'critic_loss_epoch')
            self.writer.update_loss(reward, epoch_i, 'reward_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            if self.config.verbose:
                tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for image_features, video_name, action_fragments in tqdm(
                self.test_loader, desc='Evaluate', ncols=80, leave=False):
            # [seq_len, batch_size=1, input_size)]
            image_features = image_features.view(-1, self.config.input_size)
            image_features_ = Variable(image_features).cuda()

            # [seq_len, 1, hidden_size]
            original_features = self.linear_compress(
                image_features_.detach()).unsqueeze(1)
            seq_len = original_features.shape[0]

            with torch.no_grad():

                _, scores = self.AC(original_features, seq_len,
                                    action_fragments)

                scores = scores.squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                if self.config.verbose:
                    tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)
Exemplo n.º 5
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN-sl model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                self.model.train()

                # [batch_size=1, seq_len, 1024]
                # [seq_len, 1024]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 1024]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}, sparsity loss: {sparsity_loss.item():.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM (generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                g_loss = self.criterion(sum_prob, original_label)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                )

                d_loss = reconstruction_loss + g_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')

                self.c_optimizer.zero_grad()

                # Train with original loss
                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)
                h_origin, original_prob = self.discriminator(original_features)
                c_original_loss = self.criterion(original_prob, original_label)
                c_original_loss.backward()

                # Train with summary loss
                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                h_sum, sum_prob = self.discriminator(
                    generated_features.detach())
                c_summary_loss = self.criterion(sum_prob, summary_label)
                c_summary_loss.backward()

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )
                tqdm.write(f'gen loss: {g_loss.item():.3f}')

                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.c_optimizer.step()

                c_original_loss_history.append(c_original_loss.data)
                c_summary_loss_history.append(c_summary_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(g_loss.data, step, 'gen_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, step, 'c_original_loss')
            self.writer.update_loss(c_summary_loss, step, 'c_summary_loss')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 1024]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            with torch.no_grad():
                scores = self.summarizer.s_lstm(video_feature).squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            self.writer = TensorboardWriter(self.config.log_dir)

    @staticmethod
    def freeze_model(module):
        for p in module.parameters():
            p.requires_grad = False

    def reconstruction_loss(self, h_origin, h_fake):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_fake, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(torch.mean(scores) - self.config.summary_rate)

    def gan_loss(self, original_prob, fake_prob, uniform_prob):
        """Typical GAN loss + Classify uniformly scored features"""

        gan_loss = torch.mean(
            torch.log(original_prob) + torch.log(1 - fake_prob) +
            torch.log(1 - uniform_prob))  # Discriminate uniform score

        return gan_loss

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                if image_features.size(1) > 10000:
                    continue

                # [batch_size=1, seq_len, 2048]
                # [seq_len, 2048]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 2048]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, prior loss: {prior_loss.data[0]:.3f}, sparsity loss: {sparsity_loss.data[0]:.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                gan_loss = self.gan_loss(original_prob, fake_prob,
                                         uniform_prob)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, gan loss: {gan_loss.data[0]:.3f}'
                )

                d_loss = reconstruction_loss + gan_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if batch_i > self.config.discriminator_slow_start:
                    if self.config.verbose:
                        tqdm.write('Training cLSTM...')
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    scores, h_mu, h_log_variance, generated_features = self.summarizer(
                        original_features)
                    _, _, _, uniform_features = self.summarizer(
                        original_features, uniform=True)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_fake, fake_prob = self.discriminator(generated_features)
                    h_uniform, uniform_prob = self.discriminator(
                        uniform_features)
                    tqdm.write(
                        f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                    )

                    # Maximization
                    c_loss = -1 * self.gan_loss(original_prob, fake_prob,
                                                uniform_prob)

                    tqdm.write(f'gan loss: {gan_loss.data[0]:.3f}')

                    self.c_optimizer.zero_grad()
                    c_loss.backward()
                    # Gradient cliping
                    torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                                  self.config.clip)
                    self.c_optimizer.step()

                    c_loss_history.append(c_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(gan_loss.data, step, 'gan_loss')

                # self.writer.update_loss(s_e_loss.data, step, 's_e_loss')
                # self.writer.update_loss(d_loss.data, step, 'd_loss')
                # self.writer.update_loss(c_loss.data, step, 'c_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(fake_prob.data, step, 'fake_prob')
                self.writer.update_loss(uniform_prob.data, step,
                                        'uniform_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_loss = torch.stack(c_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_loss, epoch_i, 'c_loss_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'_epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

            self.model.train()

    def evaluate(self, epoch_i):
        # checkpoint = self.config.ckpt_path
        # print(f'Load parameters from {checkpoint}')
        # self.model.load_state_dict(torch.load(checkpoint))

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 2048]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor, volatile=True).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            scores = self.summarizer.s_lstm(video_feature).squeeze(1)

            scores = np.array(scores.data).tolist()

            out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass