class Solver(object): def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.vocab = vocab self.is_train = is_train self.model = model @time_desc_decorator('Build Graph') def build(self, cuda=True): if self.model is None: self.model = getattr(models, self.config.model)(self.config) # orthogonal initialiation for hidden weights # input gate bias for GRUs if self.config.mode == 'train' and self.config.checkpoint is None: print('Parameter initiailization') for name, param in self.model.named_parameters(): if 'weight_hh' in name: print('\t' + name) nn.init.orthogonal_(param) # bias_hh is concatenation of reset, input, new gates # only set the input gate bias to 2.0 if 'bias_hh' in name: print('\t' + name) dim = int(param.size(0) / 3) param.data[dim:2 * dim].fill_(2.0) if torch.cuda.is_available() and cuda: self.model.cuda() # Overview Parameters print('Model Parameters') for name, param in self.model.named_parameters(): print('\t' + name + '\t', list(param.size())) if self.config.checkpoint: self.load_model(self.config.checkpoint) if self.is_train: self.writer = TensorboardWriter(self.config.logdir) self.optimizer = self.config.optimizer( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) def save_model(self, epoch): """Save parameters to checkpoint""" ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) def load_model(self, checkpoint): """Load parameters from checkpoint""" print(f'Load parameters from {checkpoint}') epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) self.epoch_i = int(epoch) self.model.load_state_dict(torch.load(checkpoint)) def write_summary(self, epoch_i): epoch_loss = getattr(self, 'epoch_loss', None) if epoch_loss is not None: self.writer.update_loss(loss=epoch_loss, step_i=epoch_i + 1, name='train_loss') epoch_recon_loss = getattr(self, 'epoch_recon_loss', None) if epoch_recon_loss is not None: self.writer.update_loss(loss=epoch_recon_loss, step_i=epoch_i + 1, name='train_recon_loss') epoch_kl_div = getattr(self, 'epoch_kl_div', None) if epoch_kl_div is not None: self.writer.update_loss(loss=epoch_kl_div, step_i=epoch_i + 1, name='train_kl_div') kl_mult = getattr(self, 'kl_mult', None) if kl_mult is not None: self.writer.update_loss(loss=kl_mult, step_i=epoch_i + 1, name='kl_mult') epoch_bow_loss = getattr(self, 'epoch_bow_loss', None) if epoch_bow_loss is not None: self.writer.update_loss(loss=epoch_bow_loss, step_i=epoch_i + 1, name='bow_loss') validation_loss = getattr(self, 'validation_loss', None) if validation_loss is not None: self.writer.update_loss(loss=validation_loss, step_i=epoch_i + 1, name='validation_loss') average_bleu = getattr(self, "average_bleu", None) if average_bleu is not None: self.writer.update_loss(loss=average_bleu, step_i=epoch_i + 1, name='average_bleu') average_sequences = getattr(self, "average_sequences", None) if average_sequences is not None: self.writer.update_loss(loss=average_sequences, step_i=epoch_i + 1, name='average_sequences') average_levenshteins = getattr(self, "average_levenshteins", None) if average_levenshteins is not None: self.writer.update_loss(loss=average_levenshteins, step_i=epoch_i + 1, name='average_levenshteins') @time_desc_decorator('Training Start!') def train(self): epoch_loss_history = [] for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = [] self.model.train() n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [ l - 1 for l in conversation_length ] input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) # reset gradient self.optimizer.zero_grad() sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=False) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}' ) # Back-propagation batch_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) # Run optimizer self.optimizer.step() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) self.epoch_loss = epoch_loss print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n<Validation>...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) self.save_model(self.config.n_epoch) return epoch_loss_history def generate_sentence(self, input_sentences, input_sentence_length, input_conversation_length, target_sentences): self.model.eval() # [batch_size, max_seq_len, vocab_size] generated_sentences = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=True) # write output to file with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f: f.write(f'<Epoch {self.epoch_i}>\n\n') tqdm.write('\n<Samples>') for input_sent, target_sent, output_sent in zip( input_sentences, target_sentences, generated_sentences): input_sent = self.vocab.decode(input_sent) target_sent = self.vocab.decode(target_sent) output_sent = '\n'.join( [self.vocab.decode(sent) for sent in output_sent]) s = '\n'.join([ 'Input sentence: ' + input_sent, 'Ground truth: ' + target_sent, 'Generated response: ' + output_sent + '\n' ]) f.write(s + '\n') print(s) print('') def evaluate(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) if batch_i == 0: self.generate_sentence(input_sentences, input_sentence_length, input_conversation_length, target_sentences) sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}\n' print(print_str) return epoch_loss def test(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print(f'Number of words: {n_total_words}') print(f'Bits per word: {epoch_loss:.3f}') word_perplexity = np.exp(epoch_loss) print_str = f'Word perplexity : {word_perplexity:.3f}\n' print(print_str) return word_perplexity def embedding_metric(self): word2vec = getattr(self, 'word2vec', None) if word2vec is None: print('Loading word2vec model') word2vec = gensim.models.KeyedVectors.load_word2vec_format( word2vec_path, binary=True) self.word2vec = word2vec keys = word2vec.vocab self.model.eval() n_context = self.config.n_context n_sample_step = self.config.n_sample_step metric_average_history = [] metric_extrema_history = [] metric_greedy_history = [] context_history = [] sample_history = [] n_sent = 0 n_conv = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths conv_indices = [ i for i in range(len(conversations)) if len(conversations[i]) >= n_context + n_sample_step ] context = [ c for i in conv_indices for c in [conversations[i][:n_context]] ] ground_truth = [ c for i in conv_indices for c in [conversations[i][n_context:n_context + n_sample_step]] ] sentence_length = [ c for i in conv_indices for c in [sentence_length[i][:n_context]] ] with torch.no_grad(): context = to_var(torch.LongTensor(context)) sentence_length = to_var(torch.LongTensor(sentence_length)) samples = self.model.generate(context, sentence_length, n_context) context = context.data.cpu().numpy().tolist() samples = samples.data.cpu().numpy().tolist() context_history.append(context) sample_history.append(samples) samples = [[self.vocab.decode(sent) for sent in c] for c in samples] ground_truth = [[self.vocab.decode(sent) for sent in c] for c in ground_truth] samples = [sent for c in samples for sent in c] ground_truth = [sent for c in ground_truth for sent in c] samples = [[word2vec[s] for s in sent.split() if s in keys] for sent in samples] ground_truth = [[word2vec[s] for s in sent.split() if s in keys] for sent in ground_truth] indices = [ i for i, s, g in zip(range(len(samples)), samples, ground_truth) if s != [] and g != [] ] samples = [samples[i] for i in indices] ground_truth = [ground_truth[i] for i in indices] n = len(samples) n_sent += n metric_average = embedding_metric(samples, ground_truth, word2vec, 'average') metric_extrema = embedding_metric(samples, ground_truth, word2vec, 'extrema') metric_greedy = embedding_metric(samples, ground_truth, word2vec, 'greedy') metric_average_history.append(metric_average) metric_extrema_history.append(metric_extrema) metric_greedy_history.append(metric_greedy) epoch_average = np.mean(np.concatenate(metric_average_history), axis=0) epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0) epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0) print('n_sentences:', n_sent) print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}' print(print_str) print('\n') return epoch_average, epoch_extrema, epoch_greedy
class Solver(object): def __init__(self, config=None, train_loader=None, test_loader=None, valid_loader=None): """Class that Builds, Trains and Evaluates SCLSTM model""" self.config = config self.train_loader = train_loader self.test_loader = test_loader os.environ["CUDA_VISIBLE_DEVICES"] = self.config.gpu self.vocab = pickle.load(open(p.word_vocab_pkl, 'rb')) self.kvoc = pickle.load(open(p.kwd_pkl, 'rb')) self.i2w = {i: w for i, w in enumerate(self.vocab)} # index to vocab self.i2k = {i: k for i, k in enumerate(self.kvoc)} # index to keyword self.w2i = {w: i for i, w in self.i2w.items()} def build(self): # Build Modules # self.device = torch.device('cuda:0,1') self.embedding = nn.Embedding(self.config.vocab_size, self.config.wemb_size, padding_idx=0) if True: weights_matrix = torch.FloatTensor( pickle.load(open(p.word_vec_pkl, 'rb'))) self.embedding.from_pretrained(weights_matrix, freeze=False) self.embedding.weight.requires_grad = True self.w_hr_fw = nn.ModuleList(self.config.num_layers * [ nn.Linear( self.config.hidden_size, self.config.kwd_size, bias=False) ]) self.w_hr_bw = nn.ModuleList(self.config.num_layers * [ nn.Linear( self.config.hidden_size, self.config.kwd_size, bias=False) ]) self.w_wr = nn.Linear(self.config.wemb_size, self.config.kwd_size, bias=False) self.w_ho_fw = nn.Sequential( nn.Linear(self.config.hidden_size * self.config.num_layers, self.config.vocab_size), # nn.LogSoftmax(dim=-1) ) self.w_ho_bw = nn.Linear( self.config.hidden_size * self.config.num_layers, self.config.vocab_size) self.sc_rnn_fw = SCLSTM_MultiCell(self.config.num_layers, self.config.wemb_size, self.config.hidden_size, self.config.kwd_size, dropout=self.config.drop_rate) self.sc_rnn_bw = SCLSTM_MultiCell(self.config.num_layers, self.config.wemb_size, self.config.hidden_size, self.config.kwd_size, dropout=self.config.drop_rate) self.model = nn.ModuleList([ self.w_hr_fw, self.w_hr_bw, self.w_wr, self.w_ho_fw, self.w_ho_bw, self.sc_rnn_fw, self.sc_rnn_bw ]) self.criterion = nn.CrossEntropyLoss(reduction='none') with torch.no_grad(): self.hc_list_init = (Variable(torch.zeros(self.config.num_layers, self.config.batch_size, self.config.hidden_size), requires_grad=False), Variable(torch.zeros(self.config.num_layers, self.config.batch_size, self.config.hidden_size), requires_grad=False)) #--- Init dirs for output --- self.current_time = datetime.now().strftime('%b%d_%H-%M-%S') if self.config.mode == 'train': # Overview Parameters print('Init Model Parameters') for name, param in self.model.named_parameters(): print('\t' + name + '\t', list(param.size())) if param.data.ndimension() >= 2: nn.init.xavier_uniform_(param.data) else: nn.init.zeros_(param.data) # Tensorboard self.writer = TensorboardWriter(p.tb_dir + self.current_time) # Add emb-layer self.model.train() # create dir # self.res_dir = p.result_path.format(p.dataname, self.current_time) # result dir self.cp_dir = p.check_point.format( p.dataname, self.current_time) # checkpoint dir # os.makedirs(self.res_dir) os.makedirs(self.cp_dir) #--- Setup output file --- self.out_file = open( p.out_result_dir.format(p.dataname, self.current_time), 'w') self.model.append(self.embedding) # self.model.to(self.device) # Build Optimizers self.optimizer = optim.Adam(list(self.model.parameters()), lr=self.config.lr) print(self.model) def load_model(self, ep): _fname = (self.cp_dir if self.config.mode == 'train' else self.config.resume_dir) + 'chk_point_{}.pth'.format(ep) if os.path.isfile(_fname): print("=> loading checkpoint '{}'".format(_fname)) if self.config.load_cpu: checkpoint = torch.load(_fname, map_location=lambda storage, loc: storage) # load into cpu-mode else: checkpoint = torch.load(_fname) # gpu-mode self.start_epoch = checkpoint['epoch'] # checkpoint['state_dict'].pop('1.s_lstm.out.0.bias',None) # remove bias in selector self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer'][0]) else: print("=> no checkpoint found at '{}'".format(_fname)) def _zero_grads(self): self.optimizer.zero_grad() def save_checkpoint(self, state, filename): torch.save(state, filename) def get_norm_grad(self, module, norm_type=2): total_norm = 0 for name, param in module.named_parameters(): if param.grad is not None: total_norm += torch.sum(torch.pow(param.grad.view(-1), 2)) return torch.sqrt(total_norm).data def one_step_fw(self, w_t, y_t, hc_list, d_t, rnn_model, w_hr, w_ho): h_tm1, _ = hc_list #--- Keyword detector --- res_hr = sum( [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)]) r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr) d_t = r_t * d_t flat_h, hc_list = rnn_model(w_t, hc_list, d_t) with torch.no_grad(): mask = Variable((y_t != 0).float(), requires_grad=False) assert not torch.isnan(mask).any() pred = w_ho(flat_h) llk_step = torch.mean(self.criterion(pred, y_t) * mask) l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1)) assert not torch.isnan(llk_step).any() assert not torch.isnan(l1_step).any() return llk_step, l1_step, pred, hc_list, d_t def train_epoch(self): loss_list = [] l1_list = [] fw_list, bw_list = [], [] for batch_i, doc_features in enumerate( tqdm(self.train_loader, desc='Batch', dynamic_ncols=True, ascii=True)): self._zero_grads() doc, kwd = doc_features with torch.no_grad(): var_doc = Variable(doc, requires_grad=False) var_kwd = Variable(kwd, requires_grad=False) doc_emb = self.embedding(var_doc) # get word-emb #--- Word generation --- step_loss = [] step_l1 = [] #--- FW Stage --- hc_list = self.hc_list_init d_t = var_kwd for t in range(p.MAX_DOC_LEN - 1): w_t = doc_emb[:, t, :] y_t = var_doc[:, t + 1] # h_tm1, _ = hc_list # #--- Keyword detector --- # res_hr = sum([self.w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)]) # r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha*res_hr) # d_t = r_t*d_t # # print hc_list[0].shape, w_t.shape, d_t.shape # flat_h, hc_list = self.sc_rnn(w_t, hc_list, d_t) # #--- Log LLK --- # with torch.no_grad(): # mask = Variable((y_t!=0).float(), requires_grad=False) # assert not torch.isnan(mask).any() # pred = self.w_ho(flat_h) # llk_step = torch.mean(self.criterion(pred, y_t) * mask) # l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1)) # assert not torch.isnan(llk_step).any() # assert not torch.isnan(l1_step).any() llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw( w_t, y_t, hc_list, d_t, self.sc_rnn_fw, self.w_hr_fw, self.w_ho_fw) p_pred, w_pred = torch.max(nn.LogSoftmax(dim=-1)(pred), dim=-1) # print [(self.i2w[i], v) for i, v in zip(w_pred.detach().cpu().numpy(), p_pred.detach().cpu().numpy())] step_loss.append(llk_step) step_l1.append(l1_step) fw_loss = sum(step_loss) fw_l1 = sum(step_l1) * self.config.eta batch_loss = fw_loss + fw_l1 batch_loss.backward(retain_graph=True) #--- BW Stage --- torch.cuda.empty_cache() step_loss = [] step_l1 = [] hc_list = self.hc_list_init d_t = var_kwd for t in range(p.MAX_DOC_LEN - 1, 0, -1): w_t = doc_emb[:, t, :] y_t = var_doc[:, t - 1] llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw( w_t, y_t, hc_list, d_t, self.sc_rnn_bw, self.w_hr_bw, self.w_ho_bw) step_loss.append(llk_step) step_l1.append(l1_step) bw_loss = sum(step_loss) bw_l1 = sum(step_l1) * self.config.eta #--- BW for learning --- # _loss = (fw_loss + bw_loss)/2. # _l1 = (fw_l1 + bw_l1)/2. batch_loss = bw_loss + bw_l1 batch_loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) self.optimizer.step() #--- tracking loss --- loss_list.append(0.5 * (fw_loss + bw_loss).cpu().data.numpy()) l1_list.append(0.5 * (fw_l1 + bw_l1).cpu().data.numpy()) fw_list.append(fw_loss.cpu().data.numpy()) bw_list.append(bw_loss.cpu().data.numpy()) return loss_list, l1_list, fw_list, bw_list def train(self): print('***Start training ...') for epoch_i in tqdm(range(self.config.n_epoch), desc='Epoch', dynamic_ncols=True, ascii=True): loss_list, l1_list, fw_list, bw_list = self.train_epoch() # Save parameters at checkpoint if (epoch_i + 1) % self.config.eval_rate == 0: #--- Dump model --- if self.config.write_model: # save model self.save_checkpoint( { 'epoch': epoch_i + 1, 'state_dict': self.model.state_dict(), 'total_loss': np.mean(loss_list), 'optimizer': [self.optimizer.state_dict()], }, filename=self.cp_dir + 'chk_point_{}.pth'.format(epoch_i + 1)) #--- Eval each step --- if self.config.is_eval: self.evaluate(epoch_i + 1) print( '\n***Ep-{} | Total_loss: {} [FW/BW {}/{}] | D-L1: {} | NORM: {}' .format(epoch_i, np.mean(loss_list), np.mean(fw_list), np.mean(bw_list), np.mean(l1_list), self.get_norm_grad(self.model))) # self.writer.update_parameters(self.model, epoch_i) self.writer.update_loss(np.mean(loss_list), epoch_i, 'total_loss') self.writer.update_loss(np.mean(l1_list), epoch_i, 'l1_reg') self.writer.update_loss(np.mean(fw_list), epoch_i, 'fw_loss') self.writer.update_loss(np.mean(bw_list), epoch_i, 'bw_loss') def gen_one_step(self, x, hc_list, d_t, rnn_model, w_hr, w_ho): with torch.no_grad(): var_x = Variable(torch.LongTensor(x), requires_grad=False) d_t = Variable(d_t, requires_grad=False) hc_list = self.to_gpu(hc_list) w_t = self.embedding(var_x) h_tm1, _ = hc_list res_hr = sum( [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)]) r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr) d_t = r_t * d_t flat_h, hc_list = rnn_model(w_t, hc_list, d_t) _prob = nn.LogSoftmax(dim=-1)(w_ho(flat_h)) return _prob.detach().cpu().numpy().squeeze(), self.to_cpu( hc_list), d_t.detach().cpu() def get_top_index(self, _prob): # [b, vocab] _prob = np.exp(_prob) if self.config.is_sample: top_indices = np.random.choice(self.config.vocab_size, self.config.beam_size, replace=False, p=_prob.reshape(-1)) else: top_indices = np.argsort(-_prob) return top_indices def to_cpu(self, _list): return tuple([m.detach().cpu() for m in _list]) def to_gpu(self, _list): return tuple([Variable(m, requires_grad=False) for m in _list]) def rerank(self, beams, d_t): def add_bw_score(w_list, d_t): # import pdb; pdb.set_trace() with torch.no_grad(): hc_list = (torch.zeros(self.config.num_layers, 1, self.config.hidden_size), torch.zeros(self.config.num_layers, 1, self.config.hidden_size)) w_list = [self.w2i[w] for w in w_list[::-1]] llk = 0. for i, w in enumerate(w_list[:-1]): _prob, hc_list, d_t = self.gen_one_step([w], hc_list, d_t, self.sc_rnn_bw, self.w_hr_bw, self.w_ho_bw) llk += _prob[w_list[i + 1]] return llk / (len(w_list) - 1) for i, b in enumerate(beams): # import pdb; pdb.set_trace() beams[i] = tuple([0.5 * (b[0] + add_bw_score(b[1], d_t))]) + tuple(b[1:]) return beams def evaluate(self, epoch_i): #--- load model --- self.load_model(epoch_i) self.model.eval() for r_id, doc_features in enumerate( tqdm(self.test_loader, desc='Test', dynamic_ncols=True, ascii=True)): _, d_t = doc_features try: if torch.sum(d_t) == 0: continue #--- Gen 1st step --- with torch.no_grad(): hc_list = (torch.zeros(self.config.num_layers, 1, self.config.hidden_size), torch.zeros(self.config.num_layers, 1, self.config.hidden_size)) b = (0.0, [self.i2w[1]], [1], hc_list, d_t) _prob, hc_list, d_t = self.gen_one_step( b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw, self.w_ho_fw) top_indices = self.get_top_index(_prob) beam_candidates = [] for i in range(self.config.beam_size): wordix = top_indices[i] beam_candidates.append( (b[0] + _prob[wordix], b[1] + [self.i2w[wordix]], [wordix], hc_list, d_t)) #--- Gen the whole sentence --- beams = beam_candidates[:self.config.beam_size] for t in range(self.config.gen_size - 1): beam_candidates = [] for b in beams: _prob, hc_list, d_t = self.gen_one_step( b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw, self.w_ho_fw) top_indices = self.get_top_index(_prob) for i in range(self.config.beam_size): #--- already EOS --- if b[2] == [2]: beam_candidates.append(b) break wordix = top_indices[i] beam_candidates.append((b[0] + _prob[wordix], b[1] + [self.i2w[wordix]], [wordix], hc_list, d_t)) beam_candidates.sort(key=lambda x: x[0] / (len(x[1]) - 1), reverse=True) # decreasing order beams = beam_candidates[:self.config. beam_size] # truncate to get new beams #--- RERANK beams --- beams = self.rerank(beams, doc_features[1]) beams.sort(key=lambda x: x[0], reverse=True) res = "[*]EP_{}_KW_[{}]_SENT_[{}]\n".format( epoch_i, ' '.join([ self.i2k[int(j)] for j in torch.flatten( torch.nonzero(doc_features[1][0])).numpy() ]), ' '.join(beams[0][1])) print(res) self.out_file.write(res) self.out_file.flush() except Exception as e: print('Exception: ', str(e)) pass # self.out_file.close() self.model.train()
class Solver(object): def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.vocab = vocab self.is_train = is_train self.model = model self.writer = None self.optimizer = None self.epoch_loss = None self.validation_loss = None def build(self, cuda=True): if self.model is None: self.model = getattr(models, self.config.model)(self.config) if self.config.mode == 'train' and self.config.checkpoint is None: print('Parameter initiailization') for name, param in self.model.named_parameters(): if 'weight_hh' in name: print('\t' + name) nn.init.orthogonal_(param) if 'bias_hh' in name: print('\t' + name) dim = int(param.size(0) / 3) param.data[dim:2 * dim].fill_(2.0) if torch.cuda.is_available() and cuda: self.model.cuda() if self.config.checkpoint: self.load_model(self.config.checkpoint) if self.is_train: self.writer = TensorboardWriter(self.config.logdir) self.optimizer = self.config.optimizer( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) def save_model(self, epoch): ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) def load_model(self, checkpoint): print(f'Load parameters from {checkpoint}') epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) self.epoch_i = int(epoch) chpt = torch.load(checkpoint) new_state_dict = OrderedDict() for k, v in chpt.items(): name = k[7:] if k.startswith( "module.") else k #remove 'module.' of DataParallel new_state_dict[name] = v self.model.load_state_dict(new_state_dict) def write_summary(self, epoch_i): epoch_loss = getattr(self, 'epoch_loss', None) if epoch_loss is not None: self.writer.update_loss(loss=epoch_loss, step_i=epoch_i + 1, name='train_loss') epoch_recon_loss = getattr(self, 'epoch_recon_loss', None) if epoch_recon_loss is not None: self.writer.update_loss(loss=epoch_recon_loss, step_i=epoch_i + 1, name='train_recon_loss') epoch_kl_div = getattr(self, 'epoch_kl_div', None) if epoch_kl_div is not None: self.writer.update_loss(loss=epoch_kl_div, step_i=epoch_i + 1, name='train_kl_div') kl_mult = getattr(self, 'kl_mult', None) if kl_mult is not None: self.writer.update_loss(loss=kl_mult, step_i=epoch_i + 1, name='kl_mult') epoch_bow_loss = getattr(self, 'epoch_bow_loss', None) if epoch_bow_loss is not None: self.writer.update_loss(loss=epoch_bow_loss, step_i=epoch_i + 1, name='bow_loss') validation_loss = getattr(self, 'validation_loss', None) if validation_loss is not None: self.writer.update_loss(loss=validation_loss, step_i=epoch_i + 1, name='validation_loss') def train(self): raise NotImplementedError def evaluate(self): raise NotImplementedError def test(self): raise NotImplementedError def export_samples(self, beam_size=5): raise NotImplementedError
class Solver(object): def __init__(self, config=None, train_loader=None, test_loader=None): """Class that Builds, Trains and Evaluates AC-SUM-GAN model""" self.config = config self.train_loader = train_loader self.test_loader = test_loader def build(self): # Build Modules self.linear_compress = nn.Linear(self.config.input_size, self.config.hidden_size).cuda() self.summarizer = Summarizer(input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.discriminator = Discriminator( input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.actor = Actor(state_size=self.config.action_state_size, action_size=self.config.action_state_size).cuda() self.critic = Critic(state_size=self.config.action_state_size, action_size=self.config.action_state_size).cuda() self.model = nn.ModuleList([ self.linear_compress, self.summarizer, self.discriminator, self.actor, self.critic ]) if self.config.mode == 'train': # Build Optimizers self.e_optimizer = optim.Adam( self.summarizer.vae.e_lstm.parameters(), lr=self.config.lr) self.d_optimizer = optim.Adam( self.summarizer.vae.d_lstm.parameters(), lr=self.config.lr) self.c_optimizer = optim.Adam( list(self.discriminator.parameters()) + list(self.linear_compress.parameters()), lr=self.config.discriminator_lr) self.optimizerA_s = optim.Adam( list(self.actor.parameters()) + list(self.summarizer.s_lstm.parameters()) + list(self.linear_compress.parameters()), lr=self.config.lr) self.optimizerC = optim.Adam(self.critic.parameters(), lr=self.config.lr) self.writer = TensorboardWriter(str(self.config.log_dir)) def reconstruction_loss(self, h_origin, h_sum): """L2 loss between original-regenerated features at cLSTM's last hidden layer""" return torch.norm(h_origin - h_sum, p=2) def prior_loss(self, mu, log_variance): """KL( q(e|x) || N(0,1) )""" return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) - log_variance) def sparsity_loss(self, scores): """Summary-Length Regularization""" return torch.abs( torch.mean(scores) - self.config.regularization_factor) criterion = nn.MSELoss() def AC(self, original_features, seq_len, action_fragments): """ Function that makes the actor's actions, in the training steps where the actor and critic components are not trained""" scores = self.summarizer.s_lstm(original_features) # [seq_len, 1] fragment_scores = np.zeros( self.config.action_state_size) # [num_fragments, 1] for fragment in range(self.config.action_state_size): fragment_scores[fragment] = scores[action_fragments[ fragment, 0]:action_fragments[fragment, 1] + 1].mean() state = fragment_scores previous_actions = [ ] # save all the actions (the selected fragments of each episode) reduction_factor = ( self.config.action_state_size - self.config.termination_point) / self.config.action_state_size action_scores = (torch.ones(seq_len) * reduction_factor).cuda() action_fragment_scores = (torch.ones( self.config.action_state_size)).cuda() counter = 0 for ACstep in range(self.config.termination_point): state = torch.FloatTensor(state).cuda() # select an action dist = self.actor(state) action = dist.sample( ) # returns a scalar between 0-action_state_size if action not in previous_actions: previous_actions.append(action) action_factor = (self.config.termination_point - counter) / ( self.config.action_state_size - counter) + 1 action_scores[action_fragments[action, 0]:action_fragments[action, 1] + 1] = action_factor action_fragment_scores[action] = 0 counter = counter + 1 next_state = state * action_fragment_scores next_state = next_state.cpu().detach().numpy() state = next_state weighted_scores = action_scores.unsqueeze(1) * scores weighted_features = weighted_scores.view(-1, 1, 1) * original_features return weighted_features, weighted_scores def train(self): step = 0 for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80): self.model.train() recon_loss_init_history = [] recon_loss_history = [] sparsity_loss_history = [] prior_loss_history = [] g_loss_history = [] e_loss_history = [] d_loss_history = [] c_original_loss_history = [] c_summary_loss_history = [] actor_loss_history = [] critic_loss_history = [] reward_history = [] # Train in batches of as many videos as the batch_size num_batches = int(len(self.train_loader) / self.config.batch_size) iterator = iter(self.train_loader) for batch in range(num_batches): list_image_features = [] list_action_fragments = [] print(f'batch: {batch}') # ---- Train eLSTM ----# if self.config.verbose: tqdm.write('Training eLSTM...') self.e_optimizer.zero_grad() for video in range(self.config.batch_size): image_features, action_fragments = next(iterator) action_fragments = action_fragments.squeeze(0) # [batch_size, seq_len, input_size] # [seq_len, input_size] image_features = image_features.view( -1, self.config.input_size) list_image_features.append(image_features) list_action_fragments.append(action_fragments) # [seq_len, input_size] image_features_ = Variable(image_features).cuda() seq_len = image_features_.shape[0] # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) weighted_features, scores = self.AC( original_features, seq_len, action_fragments) h_mu, h_log_variance, generated_features = self.summarizer.vae( weighted_features) h_origin, original_prob = self.discriminator( original_features) h_sum, sum_prob = self.discriminator(generated_features) if self.config.verbose: tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) reconstruction_loss = self.reconstruction_loss( h_origin, h_sum) prior_loss = self.prior_loss(h_mu, h_log_variance) tqdm.write( f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}' ) e_loss = reconstruction_loss + prior_loss e_loss = e_loss / self.config.batch_size e_loss.backward() prior_loss_history.append(prior_loss.data) e_loss_history.append(e_loss.data) # Update e_lstm parameters every 'batch_size' iterations torch.nn.utils.clip_grad_norm_( self.summarizer.vae.e_lstm.parameters(), self.config.clip) self.e_optimizer.step() #---- Train dLSTM (decoder/generator) ----# if self.config.verbose: tqdm.write('Training dLSTM...') self.d_optimizer.zero_grad() for video in range(self.config.batch_size): image_features = list_image_features[video] action_fragments = list_action_fragments[video] # [seq_len, input_size] image_features_ = Variable(image_features).cuda() seq_len = image_features_.shape[0] # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) weighted_features, _ = self.AC(original_features, seq_len, action_fragments) h_mu, h_log_variance, generated_features = self.summarizer.vae( weighted_features) h_origin, original_prob = self.discriminator( original_features) h_sum, sum_prob = self.discriminator(generated_features) tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) reconstruction_loss = self.reconstruction_loss( h_origin, h_sum) g_loss = self.criterion(sum_prob, original_label) orig_features = original_features.squeeze( 1) # [seq_len, hidden_size] gen_features = generated_features.squeeze(1) # >> recon_losses = [] for frame_index in range(seq_len): recon_losses.append( self.reconstruction_loss( orig_features[frame_index, :], gen_features[frame_index, :])) reconstruction_loss_init = torch.stack(recon_losses).mean() if self.config.verbose: tqdm.write( f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}' ) d_loss = reconstruction_loss + g_loss d_loss = d_loss / self.config.batch_size d_loss.backward() recon_loss_init_history.append( reconstruction_loss_init.data) recon_loss_history.append(reconstruction_loss.data) g_loss_history.append(g_loss.data) d_loss_history.append(d_loss.data) # Update d_lstm parameters every 'batch_size' iterations torch.nn.utils.clip_grad_norm_( self.summarizer.vae.d_lstm.parameters(), self.config.clip) self.d_optimizer.step() #---- Train cLSTM ----# if self.config.verbose: tqdm.write('Training cLSTM...') self.c_optimizer.zero_grad() for video in range(self.config.batch_size): image_features = list_image_features[video] action_fragments = list_action_fragments[video] # [seq_len, input_size] image_features_ = Variable(image_features).cuda() seq_len = image_features_.shape[0] # Train with original loss # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) h_origin, original_prob = self.discriminator( original_features) c_original_loss = self.criterion(original_prob, original_label) c_original_loss = c_original_loss / self.config.batch_size c_original_loss.backward() # Train with summary loss weighted_features, _ = self.AC(original_features, seq_len, action_fragments) h_mu, h_log_variance, generated_features = self.summarizer.vae( weighted_features) h_sum, sum_prob = self.discriminator( generated_features.detach()) c_summary_loss = self.criterion(sum_prob, summary_label) c_summary_loss = c_summary_loss / self.config.batch_size c_summary_loss.backward() tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) c_original_loss_history.append(c_original_loss.data) c_summary_loss_history.append(c_summary_loss.data) # Update c_lstm parameters every 'batch_size' iterations torch.nn.utils.clip_grad_norm_( list(self.discriminator.parameters()) + list(self.linear_compress.parameters()), self.config.clip) self.c_optimizer.step() #---- Train sLSTM and actor-critic ----# if self.config.verbose: tqdm.write('Training sLSTM, actor and critic...') self.optimizerA_s.zero_grad() self.optimizerC.zero_grad() for video in range(self.config.batch_size): image_features = list_image_features[video] action_fragments = list_action_fragments[video] # [seq_len, input_size] image_features_ = Variable(image_features).cuda() seq_len = image_features_.shape[0] # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores = self.summarizer.s_lstm( original_features) # [seq_len, 1] fragment_scores = np.zeros( self.config.action_state_size) # [num_fragments, 1] for fragment in range(self.config.action_state_size): fragment_scores[fragment] = scores[action_fragments[ fragment, 0]:action_fragments[fragment, 1] + 1].mean() state = fragment_scores # [action_state_size, 1] previous_actions = [ ] # save all the actions (the selected fragments of each step) reduction_factor = (self.config.action_state_size - self.config.termination_point ) / self.config.action_state_size action_scores = (torch.ones(seq_len) * reduction_factor).cuda() action_fragment_scores = (torch.ones( self.config.action_state_size)).cuda() log_probs = [] values = [] rewards = [] masks = [] entropy = 0 counter = 0 for ACstep in range(self.config.termination_point): # select an action, get a value for the current state state = torch.FloatTensor( state).cuda() # [action_state_size, 1] dist, value = self.actor(state), self.critic(state) action = dist.sample( ) # returns a scalar between 0-action_state_size if action in previous_actions: reward = 0 else: previous_actions.append(action) action_factor = ( self.config.termination_point - counter ) / (self.config.action_state_size - counter) + 1 action_scores[action_fragments[ action, 0]:action_fragments[action, 1] + 1] = action_factor action_fragment_scores[action] = 0 weighted_scores = action_scores.unsqueeze( 1) * scores weighted_features = weighted_scores.view( -1, 1, 1) * original_features h_mu, h_log_variance, generated_features = self.summarizer.vae( weighted_features) h_origin, original_prob = self.discriminator( original_features) h_sum, sum_prob = self.discriminator( generated_features) tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) rec_loss = self.reconstruction_loss( h_origin, h_sum) reward = 1 - rec_loss.item( ) # the less the distance, the higher the reward counter = counter + 1 next_state = state * action_fragment_scores next_state = next_state.cpu().detach().numpy() log_prob = dist.log_prob(action).unsqueeze(0) entropy += dist.entropy().mean() log_probs.append(log_prob) values.append(value) rewards.append( torch.tensor([reward], dtype=torch.float, device=device)) if ACstep == self.config.termination_point - 1: masks.append( torch.tensor([0], dtype=torch.float, device=device)) else: masks.append( torch.tensor([1], dtype=torch.float, device=device)) state = next_state next_state = torch.FloatTensor(next_state).to(device) next_value = self.critic(next_state) returns = compute_returns(next_value, rewards, masks) log_probs = torch.cat(log_probs) returns = torch.cat(returns).detach() values = torch.cat(values) advantage = returns - values actor_loss = -((log_probs * advantage.detach()).mean() + (self.config.entropy_coef / self.config.termination_point) * entropy) sparsity_loss = self.sparsity_loss(scores) critic_loss = advantage.pow(2).mean() actor_loss = actor_loss / self.config.batch_size sparsity_loss = sparsity_loss / self.config.batch_size critic_loss = critic_loss / self.config.batch_size actor_loss.backward() sparsity_loss.backward() critic_loss.backward() reward_mean = torch.mean(torch.stack(rewards)) reward_history.append(reward_mean) actor_loss_history.append(actor_loss) sparsity_loss_history.append(sparsity_loss) critic_loss_history.append(critic_loss) if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(original_prob.data, step, 'original_prob') self.writer.update_loss(sum_prob.data, step, 'sum_prob') step += 1 # Update s_lstm, actor and critic parameters every 'batch_size' iterations torch.nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.linear_compress.parameters()) + list(self.summarizer.s_lstm.parameters()) + list(self.critic.parameters()), self.config.clip) self.optimizerA_s.step() self.optimizerC.step() recon_loss_init = torch.stack(recon_loss_init_history).mean() recon_loss = torch.stack(recon_loss_history).mean() prior_loss = torch.stack(prior_loss_history).mean() g_loss = torch.stack(g_loss_history).mean() e_loss = torch.stack(e_loss_history).mean() d_loss = torch.stack(d_loss_history).mean() c_original_loss = torch.stack(c_original_loss_history).mean() c_summary_loss = torch.stack(c_summary_loss_history).mean() sparsity_loss = torch.stack(sparsity_loss_history).mean() actor_loss = torch.stack(actor_loss_history).mean() critic_loss = torch.stack(critic_loss_history).mean() reward = torch.mean(torch.stack(reward_history)) # Plot if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(recon_loss_init, epoch_i, 'recon_loss_init_epoch') self.writer.update_loss(recon_loss, epoch_i, 'recon_loss_epoch') self.writer.update_loss(prior_loss, epoch_i, 'prior_loss_epoch') self.writer.update_loss(g_loss, epoch_i, 'g_loss_epoch') self.writer.update_loss(e_loss, epoch_i, 'e_loss_epoch') self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch') self.writer.update_loss(c_original_loss, epoch_i, 'c_original_loss_epoch') self.writer.update_loss(c_summary_loss, epoch_i, 'c_summary_loss_epoch') self.writer.update_loss(sparsity_loss, epoch_i, 'sparsity_loss_epoch') self.writer.update_loss(actor_loss, epoch_i, 'actor_loss_epoch') self.writer.update_loss(critic_loss, epoch_i, 'critic_loss_epoch') self.writer.update_loss(reward, epoch_i, 'reward_epoch') # Save parameters at checkpoint ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl' if self.config.verbose: tqdm.write(f'Save parameters at {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) self.evaluate(epoch_i) def evaluate(self, epoch_i): self.model.eval() out_dict = {} for image_features, video_name, action_fragments in tqdm( self.test_loader, desc='Evaluate', ncols=80, leave=False): # [seq_len, batch_size=1, input_size)] image_features = image_features.view(-1, self.config.input_size) image_features_ = Variable(image_features).cuda() # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) seq_len = original_features.shape[0] with torch.no_grad(): _, scores = self.AC(original_features, seq_len, action_fragments) scores = scores.squeeze(1) scores = scores.cpu().numpy().tolist() out_dict[video_name] = scores score_save_path = self.config.score_dir.joinpath( f'{self.config.video_type}_{epoch_i}.json') with open(score_save_path, 'w') as f: if self.config.verbose: tqdm.write(f'Saving score at {str(score_save_path)}.') json.dump(out_dict, f) score_save_path.chmod(0o777)
class Solver(object): def __init__(self, config, train_data_loader, eval_data_loader, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.is_train = is_train self.model = model self.writer = None self.optimizer = None self.epoch_loss = None self.validation_loss = None self.true_scores = 0 self.false_scores = 0 self.eval_epoch_loss = 0 def build(self, cuda=True): if self.model is None: self.model = getattr(models, self.config.model)(self.config) if torch.cuda.is_available() and cuda: self.model.cuda() if self.config.checkpoint: self.load_model(self.config.checkpoint) if self.is_train: self.writer = TensorboardWriter(self.config.logdir) self.optimizer = self.config.optimizer( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) def save_model(self, epoch): ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) def load_model(self, checkpoint): print(f'Load parameters from {checkpoint}') epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) self.epoch_i = int(epoch) self.model.load_state_dict(torch.load(checkpoint)) def write_summary(self, epoch_i): epoch_loss = getattr(self, 'epoch_loss', None) if epoch_loss is not None: self.writer.update_loss(loss=epoch_loss, step_i=epoch_i + 1, name='train_loss') raise NotImplementedError def train(self): raise NotImplementedError def evaluate(self): raise NotImplementedError def test(self): raise NotImplementedError
class Solver(object): def __init__(self, config=None, train_loader=None, test_loader=None): """Class that Builds, Trains and Evaluates SUM-GAN-sl model""" self.config = config self.train_loader = train_loader self.test_loader = test_loader def build(self): # Build Modules self.linear_compress = nn.Linear(self.config.input_size, self.config.hidden_size).cuda() self.summarizer = Summarizer(input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.discriminator = Discriminator( input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.model = nn.ModuleList( [self.linear_compress, self.summarizer, self.discriminator]) if self.config.mode == 'train': # Build Optimizers self.s_e_optimizer = optim.Adam( list(self.summarizer.s_lstm.parameters()) + list(self.summarizer.vae.e_lstm.parameters()) + list(self.linear_compress.parameters()), lr=self.config.lr) self.d_optimizer = optim.Adam( list(self.summarizer.vae.d_lstm.parameters()) + list(self.linear_compress.parameters()), lr=self.config.lr) self.c_optimizer = optim.Adam( list(self.discriminator.parameters()) + list(self.linear_compress.parameters()), lr=self.config.discriminator_lr) self.writer = TensorboardWriter(str(self.config.log_dir)) def reconstruction_loss(self, h_origin, h_sum): """L2 loss between original-regenerated features at cLSTM's last hidden layer""" return torch.norm(h_origin - h_sum, p=2) def prior_loss(self, mu, log_variance): """KL( q(e|x) || N(0,1) )""" return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) - log_variance) def sparsity_loss(self, scores): """Summary-Length Regularization""" return torch.abs( torch.mean(scores) - self.config.regularization_factor) criterion = nn.MSELoss() def train(self): step = 0 for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80): s_e_loss_history = [] d_loss_history = [] c_original_loss_history = [] c_summary_loss_history = [] for batch_i, image_features in enumerate( tqdm(self.train_loader, desc='Batch', ncols=80, leave=False)): self.model.train() # [batch_size=1, seq_len, 1024] # [seq_len, 1024] image_features = image_features.view(-1, self.config.input_size) # [seq_len, 1024] image_features_ = Variable(image_features).cuda() #---- Train sLSTM, eLSTM ----# if self.config.verbose: tqdm.write('\nTraining sLSTM and eLSTM...') # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) h_origin, original_prob = self.discriminator(original_features) h_sum, sum_prob = self.discriminator(generated_features) tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) reconstruction_loss = self.reconstruction_loss(h_origin, h_sum) prior_loss = self.prior_loss(h_mu, h_log_variance) sparsity_loss = self.sparsity_loss(scores) tqdm.write( f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}, sparsity loss: {sparsity_loss.item():.3f}' ) s_e_loss = reconstruction_loss + prior_loss + sparsity_loss self.s_e_optimizer.zero_grad() s_e_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.s_e_optimizer.step() s_e_loss_history.append(s_e_loss.data) #---- Train dLSTM (generator) ----# if self.config.verbose: tqdm.write('Training dLSTM...') # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) h_origin, original_prob = self.discriminator(original_features) h_sum, sum_prob = self.discriminator(generated_features) tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) reconstruction_loss = self.reconstruction_loss(h_origin, h_sum) g_loss = self.criterion(sum_prob, original_label) tqdm.write( f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}' ) d_loss = reconstruction_loss + g_loss self.d_optimizer.zero_grad() d_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.d_optimizer.step() d_loss_history.append(d_loss.data) #---- Train cLSTM ----# if self.config.verbose: tqdm.write('Training cLSTM...') self.c_optimizer.zero_grad() # Train with original loss # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) h_origin, original_prob = self.discriminator(original_features) c_original_loss = self.criterion(original_prob, original_label) c_original_loss.backward() # Train with summary loss scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) h_sum, sum_prob = self.discriminator( generated_features.detach()) c_summary_loss = self.criterion(sum_prob, summary_label) c_summary_loss.backward() tqdm.write( f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}' ) tqdm.write(f'gen loss: {g_loss.item():.3f}') # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.c_optimizer.step() c_original_loss_history.append(c_original_loss.data) c_summary_loss_history.append(c_summary_loss.data) if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(reconstruction_loss.data, step, 'recon_loss') self.writer.update_loss(prior_loss.data, step, 'prior_loss') self.writer.update_loss(sparsity_loss.data, step, 'sparsity_loss') self.writer.update_loss(g_loss.data, step, 'gen_loss') self.writer.update_loss(original_prob.data, step, 'original_prob') self.writer.update_loss(sum_prob.data, step, 'sum_prob') step += 1 s_e_loss = torch.stack(s_e_loss_history).mean() d_loss = torch.stack(d_loss_history).mean() c_original_loss = torch.stack(c_original_loss_history).mean() c_summary_loss = torch.stack(c_summary_loss_history).mean() # Plot if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch') self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch') self.writer.update_loss(c_original_loss, step, 'c_original_loss') self.writer.update_loss(c_summary_loss, step, 'c_summary_loss') # Save parameters at checkpoint ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl' tqdm.write(f'Save parameters at {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) self.evaluate(epoch_i) def evaluate(self, epoch_i): self.model.eval() out_dict = {} for video_tensor, video_name in tqdm(self.test_loader, desc='Evaluate', ncols=80, leave=False): # [seq_len, batch=1, 1024] video_tensor = video_tensor.view(-1, self.config.input_size) video_feature = Variable(video_tensor).cuda() # [seq_len, 1, hidden_size] video_feature = self.linear_compress( video_feature.detach()).unsqueeze(1) # [seq_len] with torch.no_grad(): scores = self.summarizer.s_lstm(video_feature).squeeze(1) scores = scores.cpu().numpy().tolist() out_dict[video_name] = scores score_save_path = self.config.score_dir.joinpath( f'{self.config.video_type}_{epoch_i}.json') with open(score_save_path, 'w') as f: tqdm.write(f'Saving score at {str(score_save_path)}.') json.dump(out_dict, f) score_save_path.chmod(0o777) def pretrain(self): pass
class Solver(object): def __init__(self, config=None, train_loader=None, test_loader=None): """Class that Builds, Trains and Evaluates SUM-GAN model""" self.config = config self.train_loader = train_loader self.test_loader = test_loader def build(self): # Build Modules self.linear_compress = nn.Linear(self.config.input_size, self.config.hidden_size).cuda() self.summarizer = Summarizer(input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.discriminator = Discriminator( input_size=self.config.hidden_size, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers).cuda() self.model = nn.ModuleList( [self.linear_compress, self.summarizer, self.discriminator]) if self.config.mode == 'train': # Build Optimizers self.s_e_optimizer = optim.Adam( list(self.summarizer.s_lstm.parameters()) + list(self.summarizer.vae.e_lstm.parameters()) + list(self.linear_compress.parameters()), lr=self.config.lr) self.d_optimizer = optim.Adam( list(self.summarizer.vae.d_lstm.parameters()) + list(self.linear_compress.parameters()), lr=self.config.lr) self.c_optimizer = optim.Adam( list(self.discriminator.parameters()) + list(self.linear_compress.parameters()), lr=self.config.discriminator_lr) self.model.train() # self.model.apply(apply_weight_norm) # Overview Parameters # print('Model Parameters') # for name, param in self.model.named_parameters(): # print('\t' + name + '\t', list(param.size())) # Tensorboard self.writer = TensorboardWriter(self.config.log_dir) @staticmethod def freeze_model(module): for p in module.parameters(): p.requires_grad = False def reconstruction_loss(self, h_origin, h_fake): """L2 loss between original-regenerated features at cLSTM's last hidden layer""" return torch.norm(h_origin - h_fake, p=2) def prior_loss(self, mu, log_variance): """KL( q(e|x) || N(0,1) )""" return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) - log_variance) def sparsity_loss(self, scores): """Summary-Length Regularization""" return torch.abs(torch.mean(scores) - self.config.summary_rate) def gan_loss(self, original_prob, fake_prob, uniform_prob): """Typical GAN loss + Classify uniformly scored features""" gan_loss = torch.mean( torch.log(original_prob) + torch.log(1 - fake_prob) + torch.log(1 - uniform_prob)) # Discriminate uniform score return gan_loss def train(self): step = 0 for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80): s_e_loss_history = [] d_loss_history = [] c_loss_history = [] for batch_i, image_features in enumerate( tqdm(self.train_loader, desc='Batch', ncols=80, leave=False)): if image_features.size(1) > 10000: continue # [batch_size=1, seq_len, 2048] # [seq_len, 2048] image_features = image_features.view(-1, self.config.input_size) # [seq_len, 2048] image_features_ = Variable(image_features).cuda() #---- Train sLSTM, eLSTM ----# if self.config.verbose: tqdm.write('\nTraining sLSTM and eLSTM...') # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) _, _, _, uniform_features = self.summarizer(original_features, uniform=True) h_origin, original_prob = self.discriminator(original_features) h_fake, fake_prob = self.discriminator(generated_features) h_uniform, uniform_prob = self.discriminator(uniform_features) tqdm.write( f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}' ) reconstruction_loss = self.reconstruction_loss( h_origin, h_fake) prior_loss = self.prior_loss(h_mu, h_log_variance) sparsity_loss = self.sparsity_loss(scores) tqdm.write( f'recon loss {reconstruction_loss.data[0]:.3f}, prior loss: {prior_loss.data[0]:.3f}, sparsity loss: {sparsity_loss.data[0]:.3f}' ) s_e_loss = reconstruction_loss + prior_loss + sparsity_loss self.s_e_optimizer.zero_grad() s_e_loss.backward() # retain_graph=True) # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.s_e_optimizer.step() s_e_loss_history.append(s_e_loss.data) #---- Train dLSTM ----# if self.config.verbose: tqdm.write('Training dLSTM...') # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) _, _, _, uniform_features = self.summarizer(original_features, uniform=True) h_origin, original_prob = self.discriminator(original_features) h_fake, fake_prob = self.discriminator(generated_features) h_uniform, uniform_prob = self.discriminator(uniform_features) tqdm.write( f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}' ) reconstruction_loss = self.reconstruction_loss( h_origin, h_fake) gan_loss = self.gan_loss(original_prob, fake_prob, uniform_prob) tqdm.write( f'recon loss {reconstruction_loss.data[0]:.3f}, gan loss: {gan_loss.data[0]:.3f}' ) d_loss = reconstruction_loss + gan_loss self.d_optimizer.zero_grad() d_loss.backward() # retain_graph=True) # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.d_optimizer.step() d_loss_history.append(d_loss.data) #---- Train cLSTM ----# if batch_i > self.config.discriminator_slow_start: if self.config.verbose: tqdm.write('Training cLSTM...') # [seq_len, 1, hidden_size] original_features = self.linear_compress( image_features_.detach()).unsqueeze(1) scores, h_mu, h_log_variance, generated_features = self.summarizer( original_features) _, _, _, uniform_features = self.summarizer( original_features, uniform=True) h_origin, original_prob = self.discriminator( original_features) h_fake, fake_prob = self.discriminator(generated_features) h_uniform, uniform_prob = self.discriminator( uniform_features) tqdm.write( f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}' ) # Maximization c_loss = -1 * self.gan_loss(original_prob, fake_prob, uniform_prob) tqdm.write(f'gan loss: {gan_loss.data[0]:.3f}') self.c_optimizer.zero_grad() c_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm(self.model.parameters(), self.config.clip) self.c_optimizer.step() c_loss_history.append(c_loss.data) if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(reconstruction_loss.data, step, 'recon_loss') self.writer.update_loss(prior_loss.data, step, 'prior_loss') self.writer.update_loss(sparsity_loss.data, step, 'sparsity_loss') self.writer.update_loss(gan_loss.data, step, 'gan_loss') # self.writer.update_loss(s_e_loss.data, step, 's_e_loss') # self.writer.update_loss(d_loss.data, step, 'd_loss') # self.writer.update_loss(c_loss.data, step, 'c_loss') self.writer.update_loss(original_prob.data, step, 'original_prob') self.writer.update_loss(fake_prob.data, step, 'fake_prob') self.writer.update_loss(uniform_prob.data, step, 'uniform_prob') step += 1 s_e_loss = torch.stack(s_e_loss_history).mean() d_loss = torch.stack(d_loss_history).mean() c_loss = torch.stack(c_loss_history).mean() # Plot if self.config.verbose: tqdm.write('Plotting...') self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch') self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch') self.writer.update_loss(c_loss, epoch_i, 'c_loss_epoch') # Save parameters at checkpoint ckpt_path = str(self.config.save_dir) + f'_epoch-{epoch_i}.pkl' tqdm.write(f'Save parameters at {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) self.evaluate(epoch_i) self.model.train() def evaluate(self, epoch_i): # checkpoint = self.config.ckpt_path # print(f'Load parameters from {checkpoint}') # self.model.load_state_dict(torch.load(checkpoint)) self.model.eval() out_dict = {} for video_tensor, video_name in tqdm(self.test_loader, desc='Evaluate', ncols=80, leave=False): # [seq_len, batch=1, 2048] video_tensor = video_tensor.view(-1, self.config.input_size) video_feature = Variable(video_tensor, volatile=True).cuda() # [seq_len, 1, hidden_size] video_feature = self.linear_compress( video_feature.detach()).unsqueeze(1) # [seq_len] scores = self.summarizer.s_lstm(video_feature).squeeze(1) scores = np.array(scores.data).tolist() out_dict[video_name] = scores score_save_path = self.config.score_dir.joinpath( f'{self.config.video_type}_{epoch_i}.json') with open(score_save_path, 'w') as f: tqdm.write(f'Saving score at {str(score_save_path)}.') json.dump(out_dict, f) score_save_path.chmod(0o777) def pretrain(self): pass
class Solver(object): def __init__(self, config, train_data_loader, eval_data_loader, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.is_train = is_train self.model = model @time_desc_decorator('Build Graph') def build(self, cuda=True): if self.model is None: self.model = getattr(models, self.config.model)(self.config) if self.config.mode == 'train' and self.config.checkpoint is None: print('Parameter initiailization') for name, param in self.model.named_parameters(): if 'weight_hh' in name: print('\t' + name) nn.init.orthogonal_(param) if 'bias_hh' in name: print('\t' + name) dim = int(param.size(0) / 3) param.data[dim:2 * dim].fill_(2.0) if torch.cuda.is_available() and cuda: self.model.cuda() print('Model Parameters') for name, param in self.model.named_parameters(): print('\t' + name + '\t', list(param.size())) if self.config.checkpoint: self.load_model(self.config.checkpoint) if self.is_train: self.writer = TensorboardWriter(self.config.logdir) if self.config.optimizer is None: # AdamW no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate) else: self.optimizer = self.config.optimizer( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) def save_model(self, epoch): ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) def load_model(self, checkpoint): print(f'Load parameters from {checkpoint}') epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) self.epoch_i = int(epoch) self.model.load_state_dict(torch.load(checkpoint)) def write_summary(self, epoch_i): train_acc = getattr(self, 'train_acc', None) if train_acc is not None: self.writer.update_loss(loss=train_acc, step_i=epoch_i + 1, name='train_acc') validation_acc = getattr(self, 'validation_acc', None) if validation_acc is not None: self.writer.update_loss(loss=validation_acc, step_i=epoch_i + 1, name='validation_acc') def train(self): raise NotImplementedError def evaluate(self): raise NotImplementedError def test(self, is_print=True): raise NotImplementedError def _calc_accuracy(self, x, y): max_vals, max_indices = torch.max(x, 1) train_acc = (max_indices == y).sum().data.cpu().numpy() / max_indices.size()[0] return train_acc