def train(self): params = self.params batches = mini_batches(X_added_code=self.added_code, X_removed_code=self.removed_code, Y=self.msgs, mini_batch_size=params.batch_size) msg_labels_shape = self.msgs.shape self._clean_unused_variables() params.save_dir = os.path.join( params.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) params.vocab_code = len(self.dict_code) if len(msg_labels_shape) == 1: params.class_num = 1 else: params.class_num = msg_labels_shape[1] # Device configuration params.device = torch.device('cuda' if params.cuda else 'cpu') model = HierachicalRNN(args=params) if params.cuda: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=params.l2_reg_lambda) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, params.num_epochs + 1): total_loss = 0 for i, (batch) in enumerate(tqdm(batches)): # reset the hidden state of hierarchical attention model state_word = model.init_hidden_word() state_sent = model.init_hidden_sent() state_hunk = model.init_hidden_hunk() pad_added_code, pad_removed_code, labels = batch labels = torch.cuda.FloatTensor(labels) optimizer.zero_grad() predict = model.forward(pad_added_code, pad_removed_code, state_hunk, state_sent, state_word) loss = criterion(predict, labels) loss.backward() total_loss += loss.detach() optimizer.step() print('Training: Epoch %i / %i -- Total loss: %f' % (epoch, params.num_epochs, total_loss)) save(model, params.save_dir, 'epoch', epoch)
def train_model(data, params): pad_added_code, pad_removed_code, pad_msg_labels, dict_msg, dict_code = data batches = mini_batches(X_added_code=pad_added_code, X_removed_code=pad_removed_code, Y=pad_msg_labels, mini_batch_size=params.batch_size) torch.cuda.set_device(-1) params.cuda = (not params.no_cuda) and torch.cuda.is_available() del params.no_cuda params.save_dir = os.path.join( params.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) params.vocab_code = len(dict_code) if len(pad_msg_labels.shape) == 1: params.class_num = 1 else: params.class_num = pad_msg_labels.shape[1] # Device configuration params.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') model = HierachicalRNN(args=params) if torch.cuda.is_available(): model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=params.l2_reg_lambda) criterion = nn.BCEWithLogitsLoss() batches = batches[:10] for epoch in range(1, params.num_epochs + 1): total_loss = 0 for i, (batch) in enumerate(tqdm(batches)): # reset the hidden state of hierarchical attention model state_word = model.init_hidden_word() state_sent = model.init_hidden_sent() state_hunk = model.init_hidden_hunk() pad_added_code, pad_removed_code, labels = batch labels = torch.cuda.FloatTensor(labels) optimizer.zero_grad() predict = model.forward(pad_added_code, pad_removed_code, state_hunk, state_sent, state_word) loss = criterion(predict, labels) loss.backward() total_loss += loss optimizer.step() print('Training: Epoch %i / %i -- Total loss: %f' % (epoch, params.num_epochs, total_loss)) save(model, params.save_dir, 'epoch', epoch)
def predict(self): params = self.params batches = mini_batches(X_added_code=self.added_code, X_removed_code=self.removed_code, Y=self.msgs, mini_batch_size=params.batch_size, shuffled=False) msg_labels_shape = self.msgs.shape params.vocab_code = len(self.dict_code) if len(msg_labels_shape) == 1: params.class_num = 1 else: params.class_num = msg_labels_shape[1] # Device configuration params.device = torch.device('cuda' if params.cuda else 'cpu') model = HierachicalRNN(args=params) model.load_state_dict(torch.load(params.load_model)) if params.cuda: model = model.cuda() model.eval( ) # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) commit_ftrs = list() with torch.no_grad(): for i, (batch) in enumerate(tqdm(batches)): state_word = model.init_hidden_word() state_sent = model.init_hidden_sent() state_hunk = model.init_hidden_hunk() pad_added_code, pad_removed_code, labels = batch labels = torch.cuda.FloatTensor(labels) commit_ftr = model.forward_commit_embeds_diff( pad_added_code, pad_removed_code, state_hunk, state_sent, state_word) commit_ftrs.append(commit_ftr) commit_ftrs = torch.cat(commit_ftrs).cpu().detach().numpy() pickle.dump(commit_ftrs, open(params.name, 'wb'))