def get_meta_mat(self, batch_meta): batch_meta_ids = [] batch_cat_ids = [] batch_cat_offsets = [0] for meta in batch_meta: for l in self.META_LABELS: if l == "CATEGORIES": batch_cat_ids += meta[l] batch_cat_offsets.append( len(meta[l]) + batch_cat_offsets[-1]) else: batch_meta_ids.append(meta[l]) batch_meta_ids = cudify(self.args, Variable(torch.LongTensor(batch_meta_ids))) batch_meta_embeds = self.meta_embed(batch_meta_ids).view( len(batch_meta), len(self.META_LABELS) - 1, -1) batch_cat_ids = cudify(self.args, Variable(torch.LongTensor(batch_cat_ids))) batch_cat_embeds = self.meta_embed(batch_cat_ids) batch_avg_cat_embeds = [] for b_id in range(len(batch_meta)): s, e = batch_cat_offsets[b_id], batch_cat_offsets[b_id + 1] cat_embeds = batch_cat_embeds[s:e, :] avg_embed = cat_embeds.mean(dim=0).unsqueeze(0) batch_avg_cat_embeds.append(avg_embed) batch_avg_cat_embeds = torch.cat(batch_avg_cat_embeds).unsqueeze(1) full_meta = torch.cat([batch_avg_cat_embeds, batch_meta_embeds], dim=1) return full_meta
def __init__(self, args): self.args = args self.dim = self.args.hidden_size self.valences = None self.hs = None self.cs = None self.num_push = 0 self.num_pop = 0 self.zero_state = (cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)), cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)))
def __init__(self, h_s, c_s, args): self.states = list(zip( list(torch.split(h_s.squeeze(0), 1, 0)), list(torch.split(c_s.squeeze(0), 1, 0)) )) self.args = args self.zero_state = ( cudify(self.args, Variable(torch.zeros(1, self.args.hidden_size), requires_grad=False)), cudify(self.args, Variable(torch.zeros(1, self.args.hidden_size), requires_grad=False)) )
def reduce(self, mass_remaining): mass_remaining = cudify(self.args, Variable(torch.FloatTensor([mass_remaining]))) size = self.size() read_mask = cudify(self.args, Variable(torch.zeros(size, 1), requires_grad=False)) idx = size - 1 while mass_remaining.data[0] > 0.0 and idx >= 0: mass_remaining_data = mass_remaining.data[0] this_valence = self.valences[idx].data[0] if mass_remaining_data - this_valence >= 1.0: mass_coeff = self.valences[idx] elif mass_remaining_data > 1.0 and mass_remaining_data - this_valence < 1.0: skip_mass = mass_remaining - 1.0 mass_coeff = self.valences[idx] - skip_mass read_mask[idx] = mass_coeff else: mass_coeff = torch.min(torch.cat([self.valences[idx], mass_remaining])) read_mask[idx] = mass_coeff mass_remaining -= mass_coeff idx -= 1 reduced_hs = torch.mul(read_mask, self.hs).sum(0, keepdim=True) reduced_cs = torch.mul(read_mask, self.cs).sum(0, keepdim=True) return reduced_hs, reduced_cs
def build_biases(self, batch_meta): biases = [] for meta in batch_meta: credit_vec = meta['CREDIT_VEC'] credit_vec_sum = sum(credit_vec) for i in range(len(credit_vec)): if credit_vec_sum > 0: credit_vec[i] /= float(credit_vec_sum) else: credit_vec[i] = 1.0 / 6.0 biases.append( cudify( self.args, Variable(torch.FloatTensor(credit_vec), requires_grad=False)).unsqueeze(0)) return torch.cat(biases)
def __init__(self, args, vocab): super(SNLIClassifier, self).__init__() padding_idx = vocab.stoi['<pad>'] self.args = args self.embed = nn.Embedding(len(vocab.stoi), self.args.embed_dim, padding_idx=padding_idx) self.softmax = nn.Softmax() self.relu = nn.ReLU() self.layer_norm_mlp_input = LayerNormalization(4 * self.args.hidden_size) self.layer_norm_mlp1_hidden = LayerNormalization(self.args.snli_h_dim) self.layer_norm_mlp2_hidden = LayerNormalization(self.args.snli_h_dim) self.dropout = nn.Dropout(p=self.args.dropout_rate_classify) self.mlp1 = nn.Linear(4 * self.args.hidden_size, self.args.snli_h_dim) HeKaimingInitializer(self.mlp1.weight) self.mlp2 = nn.Linear(self.args.snli_h_dim, self.args.snli_h_dim) HeKaimingInitializer(self.mlp2.weight) self.output = nn.Linear(self.args.snli_h_dim, 3) HeKaimingInitializer(self.output.weight) self.spinn = SPINN(self.args) self.encoder = nn.LSTM(input_size=self.args.embed_dim, hidden_size=self.args.hidden_size // 2, batch_first=True, bidirectional=False, num_layers=1, dropout=self.args.dropout_rate_input) self.init_lstm_state = cudify( args, Variable(torch.zeros(1, 1, self.args.hidden_size // 2), requires_grad=False))
def train(args): print("\nStarting...") sys.stdout.flush() label_names, (train_iter, dev_iter, test_iter, inputs) = prepare_snli_batches(args) label_names = label_names[1:] # don't count UNK num_labels = len(label_names) print("Prepared Dataset...\n") sys.stdout.flush() model = SNLIClassifier(args, inputs.vocab) model.set_weight(inputs.vocab.vectors.numpy()) print("Instantiated Model...\n") sys.stdout.flush() model = cudify(args, model) loss = torch.nn.NLLLoss() optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-08) count_iter = 0 train_iter.repeat = False step = 0 teacher_prob = 1.0 for epoch in range(args.epochs): epoch_interp = float(args.epochs - epoch) / float(args.epochs) args.teach_lambda = (epoch_interp * args.teach_lambda_init) + ( (1.0 - epoch_interp) * args.teach_lambda_end) train_iter.init_epoch() cost = 0 for batch_idx, batch in enumerate(train_iter): model.train() step += 1 count_iter += batch.batch_size cost += train_batch( args, model, loss, optimizer, (batch.hypothesis.transpose( 0, 1), batch.hypothesis_transitions.t()), (batch.premise.transpose(0, 1), batch.premise_transitions.t()), batch.label - 1, step, teacher_prob) if count_iter >= args.eval_freq: correct, total = 0.0, 0.0 count_iter = 0 confusion_matrix = np.zeros([num_labels, num_labels]) dev_iter.init_epoch() for dev_batch_idx, dev_batch in enumerate(dev_iter): model.eval() pred = predict(args, model, (dev_batch.hypothesis.transpose(0, 1), dev_batch.hypothesis_transitions.t()), (dev_batch.premise.transpose(0, 1), dev_batch.premise_transitions.t())) if args.gpu > -1: true_labels = dev_batch.label.data.cpu().numpy() - 1.0 else: true_labels = dev_batch.label.data.numpy() - 1.0 for i in range(num_labels): true_labels_by_cat = np.where(true_labels == i)[0] pred_values_by_cat = pred[true_labels_by_cat] num_labels_by_cat = len(true_labels_by_cat) mass_so_far = 0 for j in range(num_labels - 1): mass = len( pred_values_by_cat[pred_values_by_cat == j]) confusion_matrix[i, j] += mass mass_so_far += mass confusion_matrix[i, num_labels - 1] += num_labels_by_cat - mass_so_far total += dev_batch.batch_size correct = np.trace(confusion_matrix) print("Accuracy for batch #%d, epoch #%d --> %.1f%%\n" % (batch_idx, epoch, float(correct) / total * 100)) true_label_counts = confusion_matrix.sum(axis=1) pred_label_counts = confusion_matrix.sum(axis=0).tolist() pred_label_counts = [str(int(c)) for c in pred_label_counts ] + ["--> guessed distribution"] print("\nConfusion matrix (x-axis is true labels)\n") label_names = [n[0:6] + '.' for n in label_names] print("\t" + "\t".join(label_names) + "\n") for i in range(num_labels): print(label_names[i], end="") for j in range(num_labels): if true_label_counts[i] == 0: perc = 0.0 else: perc = confusion_matrix[i, j] / true_label_counts[i] print("\t%.2f%%" % (perc * 100), end="") print("\t(%d examples)\n" % true_label_counts[i]) print("\t" + "\t".join(pred_label_counts)) print("") sys.stdout.flush() teacher_prob *= args.force_decay print("Cost for Epoch #%d --> %.2f\n" % (epoch, cost)) torch.save(model, '../weights/model_%d.pth' % epoch)
def one_valence(self): return cudify(self.args, Variable(torch.FloatTensor([1]), requires_grad=False))
def __init__(self, args): self.args = args self.states = [] self.dim = args.hidden_size self.zero_state = (cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)), cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)))
def forward(self, sentence, transitions, num_ops, other_sent, teacher_prob): batch_size, sent_len, _ = sentence.size() out = self.word(sentence) # batch, |sent|, h * 2s # batch normalization and dropout if not self.args.no_batch_norm: out = out.transpose(1, 2).contiguous() out = self.batch_norm1( out ) # batch, h * 2, |sent| (Normalizes batch * |sent| slices for each feature out = out.transpose(1, 2) if self.args.dropout_rate_input > 0: out = self.dropout(out) # batch, |sent|, h * 2 (h_sent, c_sent) = torch.chunk(out, 2, 2) # ((batch, |sent|, h), (batch, |sent|, h)) buffer_batch = [ Buffer(h_s, c_s, self.args) for h_s, c_s in zip(list(torch.split(h_sent, 1, 0)), list(torch.split(c_sent, 1, 0))) ] stack_batch = [create_stack(self.args) for _ in buffer_batch] if self.args.tracking: self.track.initialize_states(other_sent) else: assert transitions is not None if transitions is None: num_transitions = (2 * sent_len) - 1 else: transitions_batch = [ trans.squeeze(1) for trans in list(torch.split(transitions, 1, 1)) ] num_transitions = len(transitions_batch) lstm_actions, true_actions = [], [] for time_stamp in range(num_transitions): ops_left = num_transitions - time_stamp reduce_ids = [] reduce_lh, reduce_lc = [], [] reduce_rh, reduce_rc = [], [] reduce_valences = [] reduce_tracking_states = [] teacher_valences = None if self.args.tracking: valences, tracking_state = self.update_tracker( buffer_batch, stack_batch, batch_size) _, pred_trans = valences.max(dim=1) if self.training and self.args.teacher: use_teacher = True # TODO for now always use teacher - later --> random() < teacher_prob if use_teacher and self.args.continuous_stack: teacher_valences = cudify( self.args, Variable(torch.zeros(valences.size()), requires_grad=False)) temp_trans = transitions_batch[time_stamp] for b_id in range(batch_size): if temp_trans[b_id].data[0] > PAD: true_actions.append(temp_trans[b_id]) lstm_actions.append(valences[b_id].unsqueeze(0)) if teacher_valences is not None: teacher_valences[ b_id, temp_trans[b_id].data[0]] = 1.0 temp_trans = temp_trans.data if use_teacher else pred_trans.data else: temp_trans = pred_trans.data else: valences = None temp_trans = transitions_batch[time_stamp].data for b_id in range(batch_size): stack_size, buffer_size = stack_batch[b_id].size( ), buffer_batch[b_id].size() # this sentence is done! my_ops_left = num_ops[b_id] - time_stamp if my_ops_left <= 0: # should coincide with teacher padding or else num_ops has a bug if self.training and self.args.teacher: assert temp_trans[b_id] == PAD continue else: act = temp_trans[b_id] # ensures it's a valid act according to state of buffer, batch, and timestamp # safe check actions if not using teacher forcing... or using teacher forcing but in evaluation if self.args.tracking and (not self.args.teacher or (self.args.teacher and not self.training)): act, act_ignored = self.resolve_action( buffer_batch[b_id], stack_batch[b_id], buffer_size, stack_size, act, time_stamp, my_ops_left) if self.args.tracking: # use teacher valences over predicted valences if teacher_valences is not None: reduce_valence, shift_valence = teacher_valences[b_id] else: reduce_valence, shift_valence = valences[b_id] else: reduce_valence, shift_valence = None, None no_action = True # 2 - REDUCE if act == REDUCE or (self.args.continuous_stack and not self.args.teacher and stack_size >= 2): no_action = False reduce_ids.append(b_id) r = stack_batch[b_id].peek() if not stack_batch[b_id].pop(reduce_valence): print(sentence[b_id, :, :].sum(dim=1), transitions[b_id, :]) raise Exception("Tried to pop from an empty list.") l = stack_batch[b_id].peek() if not stack_batch[b_id].pop(reduce_valence): print(sentence[b_id, :, :].sum(dim=1), transitions[b_id, :]) raise Exception("Tried to pop from an empty list.") reduce_lh.append(l[0]) reduce_lc.append(l[1]) reduce_rh.append(r[0]) reduce_rc.append(r[1]) if self.args.tracking: reduce_valences.append(reduce_valence) reduce_tracking_states.append( tracking_state[b_id].unsqueeze(0)) # 3 - SHIFT if act == SHIFT or (self.args.continuous_stack and not self.args.teacher and buffer_size > 0): no_action = False word = buffer_batch[b_id].pop() stack_batch[b_id].add(word, shift_valence, time_stamp) if no_action: print( "\n\nWarning: Didn't choose an action. Look for a bug! Attempted %d action but was denied!" % act) if len(reduce_ids) > 0: h_lefts = torch.cat(reduce_lh) c_lefts = torch.cat(reduce_lc) h_rights = torch.cat(reduce_rh) c_rights = torch.cat(reduce_rc) if self.args.tracking: e_out = torch.cat(reduce_tracking_states) h_outs, c_outs = self.reduce((h_lefts, c_lefts), (h_rights, c_rights), e_out) else: h_outs, c_outs = self.reduce((h_lefts, c_lefts), (h_rights, c_rights)) for i, state in enumerate(zip(h_outs, c_outs)): reduce_valence = reduce_valences[ i] if self.args.tracking else None stack_batch[reduce_ids[i]].add(state, reduce_valence) outputs = [] for (i, stack) in enumerate(stack_batch): if not self.args.continuous_stack: if not stack.size() == 1: print("Stack size is %d. Should be 1" % stack.size()) assert stack.size() == 1 top_h = stack.peek()[0] outputs.append(top_h) if len(true_actions) > 0 and self.training: return torch.cat(outputs), torch.cat(true_actions), torch.log( torch.cat(lstm_actions)) return torch.cat(outputs), None, None