Ejemplo n.º 1
0
    def get_meta_mat(self, batch_meta):
        batch_meta_ids = []
        batch_cat_ids = []
        batch_cat_offsets = [0]
        for meta in batch_meta:
            for l in self.META_LABELS:
                if l == "CATEGORIES":
                    batch_cat_ids += meta[l]
                    batch_cat_offsets.append(
                        len(meta[l]) + batch_cat_offsets[-1])
                else:
                    batch_meta_ids.append(meta[l])

        batch_meta_ids = cudify(self.args,
                                Variable(torch.LongTensor(batch_meta_ids)))
        batch_meta_embeds = self.meta_embed(batch_meta_ids).view(
            len(batch_meta),
            len(self.META_LABELS) - 1, -1)

        batch_cat_ids = cudify(self.args,
                               Variable(torch.LongTensor(batch_cat_ids)))
        batch_cat_embeds = self.meta_embed(batch_cat_ids)
        batch_avg_cat_embeds = []
        for b_id in range(len(batch_meta)):
            s, e = batch_cat_offsets[b_id], batch_cat_offsets[b_id + 1]
            cat_embeds = batch_cat_embeds[s:e, :]
            avg_embed = cat_embeds.mean(dim=0).unsqueeze(0)
            batch_avg_cat_embeds.append(avg_embed)

        batch_avg_cat_embeds = torch.cat(batch_avg_cat_embeds).unsqueeze(1)
        full_meta = torch.cat([batch_avg_cat_embeds, batch_meta_embeds], dim=1)
        return full_meta
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args
        self.dim = self.args.hidden_size
        self.valences = None
        self.hs = None
        self.cs = None
        self.num_push = 0
        self.num_pop = 0

        self.zero_state = (cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)),
                cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)))
Ejemplo n.º 3
0
    def __init__(self, h_s, c_s, args):
        self.states = list(zip(
            list(torch.split(h_s.squeeze(0), 1, 0)),
            list(torch.split(c_s.squeeze(0), 1, 0))
        ))

        self.args = args

        self.zero_state = (
            cudify(self.args, Variable(torch.zeros(1, self.args.hidden_size), requires_grad=False)),
            cudify(self.args, Variable(torch.zeros(1, self.args.hidden_size), requires_grad=False))
        )
Ejemplo n.º 4
0
    def reduce(self, mass_remaining):
        mass_remaining = cudify(self.args, Variable(torch.FloatTensor([mass_remaining])))
        size = self.size()
        read_mask = cudify(self.args, Variable(torch.zeros(size, 1), requires_grad=False))
        idx = size - 1
        while mass_remaining.data[0] > 0.0 and idx >= 0:
            mass_remaining_data = mass_remaining.data[0]
            this_valence = self.valences[idx].data[0]
            if mass_remaining_data - this_valence >= 1.0:
                mass_coeff = self.valences[idx]
            elif mass_remaining_data > 1.0 and mass_remaining_data - this_valence < 1.0:
                skip_mass = mass_remaining - 1.0
                mass_coeff = self.valences[idx] - skip_mass
                read_mask[idx] = mass_coeff
            else:
                mass_coeff = torch.min(torch.cat([self.valences[idx], mass_remaining]))
                read_mask[idx] = mass_coeff

            mass_remaining -= mass_coeff
            idx -= 1

        reduced_hs = torch.mul(read_mask, self.hs).sum(0, keepdim=True)
        reduced_cs = torch.mul(read_mask, self.cs).sum(0, keepdim=True)
        return reduced_hs, reduced_cs
Ejemplo n.º 5
0
    def build_biases(self, batch_meta):
        biases = []
        for meta in batch_meta:
            credit_vec = meta['CREDIT_VEC']
            credit_vec_sum = sum(credit_vec)

            for i in range(len(credit_vec)):
                if credit_vec_sum > 0:
                    credit_vec[i] /= float(credit_vec_sum)
                else:
                    credit_vec[i] = 1.0 / 6.0

            biases.append(
                cudify(
                    self.args,
                    Variable(torch.FloatTensor(credit_vec),
                             requires_grad=False)).unsqueeze(0))
        return torch.cat(biases)
Ejemplo n.º 6
0
    def __init__(self, args, vocab):
        super(SNLIClassifier, self).__init__()

        padding_idx = vocab.stoi['<pad>']
        self.args = args
        self.embed = nn.Embedding(len(vocab.stoi),
                                  self.args.embed_dim,
                                  padding_idx=padding_idx)
        self.softmax = nn.Softmax()
        self.relu = nn.ReLU()

        self.layer_norm_mlp_input = LayerNormalization(4 *
                                                       self.args.hidden_size)
        self.layer_norm_mlp1_hidden = LayerNormalization(self.args.snli_h_dim)
        self.layer_norm_mlp2_hidden = LayerNormalization(self.args.snli_h_dim)

        self.dropout = nn.Dropout(p=self.args.dropout_rate_classify)

        self.mlp1 = nn.Linear(4 * self.args.hidden_size, self.args.snli_h_dim)
        HeKaimingInitializer(self.mlp1.weight)
        self.mlp2 = nn.Linear(self.args.snli_h_dim, self.args.snli_h_dim)
        HeKaimingInitializer(self.mlp2.weight)

        self.output = nn.Linear(self.args.snli_h_dim, 3)
        HeKaimingInitializer(self.output.weight)
        self.spinn = SPINN(self.args)

        self.encoder = nn.LSTM(input_size=self.args.embed_dim,
                               hidden_size=self.args.hidden_size // 2,
                               batch_first=True,
                               bidirectional=False,
                               num_layers=1,
                               dropout=self.args.dropout_rate_input)
        self.init_lstm_state = cudify(
            args,
            Variable(torch.zeros(1, 1, self.args.hidden_size // 2),
                     requires_grad=False))
Ejemplo n.º 7
0
def train(args):
    print("\nStarting...")
    sys.stdout.flush()
    label_names, (train_iter, dev_iter, test_iter,
                  inputs) = prepare_snli_batches(args)
    label_names = label_names[1:]  # don't count UNK
    num_labels = len(label_names)

    print("Prepared Dataset...\n")

    sys.stdout.flush()
    model = SNLIClassifier(args, inputs.vocab)
    model.set_weight(inputs.vocab.vectors.numpy())

    print("Instantiated Model...\n")

    sys.stdout.flush()
    model = cudify(args, model)
    loss = torch.nn.NLLLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           betas=(0.9, 0.999),
                           eps=1e-08)
    count_iter = 0
    train_iter.repeat = False

    step = 0
    teacher_prob = 1.0

    for epoch in range(args.epochs):
        epoch_interp = float(args.epochs - epoch) / float(args.epochs)
        args.teach_lambda = (epoch_interp * args.teach_lambda_init) + (
            (1.0 - epoch_interp) * args.teach_lambda_end)
        train_iter.init_epoch()
        cost = 0
        for batch_idx, batch in enumerate(train_iter):
            model.train()
            step += 1
            count_iter += batch.batch_size
            cost += train_batch(
                args, model, loss, optimizer, (batch.hypothesis.transpose(
                    0, 1), batch.hypothesis_transitions.t()),
                (batch.premise.transpose(0, 1), batch.premise_transitions.t()),
                batch.label - 1, step, teacher_prob)

            if count_iter >= args.eval_freq:
                correct, total = 0.0, 0.0
                count_iter = 0
                confusion_matrix = np.zeros([num_labels, num_labels])
                dev_iter.init_epoch()

                for dev_batch_idx, dev_batch in enumerate(dev_iter):
                    model.eval()
                    pred = predict(args, model,
                                   (dev_batch.hypothesis.transpose(0, 1),
                                    dev_batch.hypothesis_transitions.t()),
                                   (dev_batch.premise.transpose(0, 1),
                                    dev_batch.premise_transitions.t()))
                    if args.gpu > -1:
                        true_labels = dev_batch.label.data.cpu().numpy() - 1.0
                    else:
                        true_labels = dev_batch.label.data.numpy() - 1.0
                    for i in range(num_labels):
                        true_labels_by_cat = np.where(true_labels == i)[0]
                        pred_values_by_cat = pred[true_labels_by_cat]
                        num_labels_by_cat = len(true_labels_by_cat)
                        mass_so_far = 0
                        for j in range(num_labels - 1):
                            mass = len(
                                pred_values_by_cat[pred_values_by_cat == j])
                            confusion_matrix[i, j] += mass
                            mass_so_far += mass

                        confusion_matrix[i, num_labels -
                                         1] += num_labels_by_cat - mass_so_far

                    total += dev_batch.batch_size
                correct = np.trace(confusion_matrix)
                print("Accuracy for batch #%d, epoch #%d --> %.1f%%\n" %
                      (batch_idx, epoch, float(correct) / total * 100))
                true_label_counts = confusion_matrix.sum(axis=1)
                pred_label_counts = confusion_matrix.sum(axis=0).tolist()
                pred_label_counts = [str(int(c)) for c in pred_label_counts
                                     ] + ["--> guessed distribution"]
                print("\nConfusion matrix (x-axis is true labels)\n")
                label_names = [n[0:6] + '.' for n in label_names]
                print("\t" + "\t".join(label_names) + "\n")
                for i in range(num_labels):
                    print(label_names[i], end="")
                    for j in range(num_labels):
                        if true_label_counts[i] == 0:
                            perc = 0.0
                        else:
                            perc = confusion_matrix[i,
                                                    j] / true_label_counts[i]
                        print("\t%.2f%%" % (perc * 100), end="")
                    print("\t(%d examples)\n" % true_label_counts[i])

                print("\t" + "\t".join(pred_label_counts))
                print("")
                sys.stdout.flush()

        teacher_prob *= args.force_decay
        print("Cost for Epoch #%d --> %.2f\n" % (epoch, cost))
        torch.save(model, '../weights/model_%d.pth' % epoch)
Ejemplo n.º 8
0
 def one_valence(self):
     return cudify(self.args, Variable(torch.FloatTensor([1]), requires_grad=False))
Ejemplo n.º 9
0
 def __init__(self, args):
     self.args = args
     self.states = []
     self.dim = args.hidden_size
     self.zero_state = (cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)),
             cudify(self.args, Variable(torch.zeros(1, self.dim), requires_grad=False)))
Ejemplo n.º 10
0
    def forward(self, sentence, transitions, num_ops, other_sent,
                teacher_prob):
        batch_size, sent_len, _ = sentence.size()
        out = self.word(sentence)  # batch, |sent|, h * 2s

        # batch normalization and dropout
        if not self.args.no_batch_norm:
            out = out.transpose(1, 2).contiguous()
            out = self.batch_norm1(
                out
            )  # batch,  h * 2, |sent| (Normalizes batch * |sent| slices for each feature
            out = out.transpose(1, 2)

        if self.args.dropout_rate_input > 0:
            out = self.dropout(out)  # batch, |sent|, h * 2

        (h_sent,
         c_sent) = torch.chunk(out, 2,
                               2)  # ((batch, |sent|, h), (batch, |sent|, h))

        buffer_batch = [
            Buffer(h_s, c_s, self.args)
            for h_s, c_s in zip(list(torch.split(h_sent, 1, 0)),
                                list(torch.split(c_sent, 1, 0)))
        ]

        stack_batch = [create_stack(self.args) for _ in buffer_batch]

        if self.args.tracking:
            self.track.initialize_states(other_sent)
        else:
            assert transitions is not None

        if transitions is None:
            num_transitions = (2 * sent_len) - 1
        else:
            transitions_batch = [
                trans.squeeze(1)
                for trans in list(torch.split(transitions, 1, 1))
            ]
            num_transitions = len(transitions_batch)

        lstm_actions, true_actions = [], []

        for time_stamp in range(num_transitions):
            ops_left = num_transitions - time_stamp

            reduce_ids = []
            reduce_lh, reduce_lc = [], []
            reduce_rh, reduce_rc = [], []
            reduce_valences = []
            reduce_tracking_states = []
            teacher_valences = None
            if self.args.tracking:
                valences, tracking_state = self.update_tracker(
                    buffer_batch, stack_batch, batch_size)
                _, pred_trans = valences.max(dim=1)
                if self.training and self.args.teacher:
                    use_teacher = True  # TODO for now always use teacher - later --> random() < teacher_prob
                    if use_teacher and self.args.continuous_stack:
                        teacher_valences = cudify(
                            self.args,
                            Variable(torch.zeros(valences.size()),
                                     requires_grad=False))

                    temp_trans = transitions_batch[time_stamp]

                    for b_id in range(batch_size):
                        if temp_trans[b_id].data[0] > PAD:
                            true_actions.append(temp_trans[b_id])
                            lstm_actions.append(valences[b_id].unsqueeze(0))

                            if teacher_valences is not None:
                                teacher_valences[
                                    b_id, temp_trans[b_id].data[0]] = 1.0

                    temp_trans = temp_trans.data if use_teacher else pred_trans.data
                else:
                    temp_trans = pred_trans.data
            else:
                valences = None
                temp_trans = transitions_batch[time_stamp].data

            for b_id in range(batch_size):
                stack_size, buffer_size = stack_batch[b_id].size(
                ), buffer_batch[b_id].size()
                # this sentence is done!
                my_ops_left = num_ops[b_id] - time_stamp
                if my_ops_left <= 0:
                    # should coincide with teacher padding or else num_ops has a bug
                    if self.training and self.args.teacher:
                        assert temp_trans[b_id] == PAD
                    continue
                else:
                    act = temp_trans[b_id]

                    # ensures it's a valid act according to state of buffer, batch, and timestamp
                    # safe check actions if not using teacher forcing... or using teacher forcing but in evaluation
                    if self.args.tracking and (not self.args.teacher or
                                               (self.args.teacher
                                                and not self.training)):
                        act, act_ignored = self.resolve_action(
                            buffer_batch[b_id], stack_batch[b_id], buffer_size,
                            stack_size, act, time_stamp, my_ops_left)

                if self.args.tracking:
                    # use teacher valences over predicted valences
                    if teacher_valences is not None:
                        reduce_valence, shift_valence = teacher_valences[b_id]
                    else:
                        reduce_valence, shift_valence = valences[b_id]
                else:
                    reduce_valence, shift_valence = None, None

                no_action = True

                # 2 - REDUCE
                if act == REDUCE or (self.args.continuous_stack
                                     and not self.args.teacher
                                     and stack_size >= 2):
                    no_action = False
                    reduce_ids.append(b_id)

                    r = stack_batch[b_id].peek()
                    if not stack_batch[b_id].pop(reduce_valence):
                        print(sentence[b_id, :, :].sum(dim=1),
                              transitions[b_id, :])
                        raise Exception("Tried to pop from an empty list.")

                    l = stack_batch[b_id].peek()
                    if not stack_batch[b_id].pop(reduce_valence):
                        print(sentence[b_id, :, :].sum(dim=1),
                              transitions[b_id, :])
                        raise Exception("Tried to pop from an empty list.")

                    reduce_lh.append(l[0])
                    reduce_lc.append(l[1])
                    reduce_rh.append(r[0])
                    reduce_rc.append(r[1])

                    if self.args.tracking:
                        reduce_valences.append(reduce_valence)
                        reduce_tracking_states.append(
                            tracking_state[b_id].unsqueeze(0))

                # 3 - SHIFT
                if act == SHIFT or (self.args.continuous_stack
                                    and not self.args.teacher
                                    and buffer_size > 0):
                    no_action = False
                    word = buffer_batch[b_id].pop()
                    stack_batch[b_id].add(word, shift_valence, time_stamp)

                if no_action:
                    print(
                        "\n\nWarning: Didn't choose an action.  Look for a bug!  Attempted %d action but was denied!"
                        % act)

            if len(reduce_ids) > 0:
                h_lefts = torch.cat(reduce_lh)
                c_lefts = torch.cat(reduce_lc)
                h_rights = torch.cat(reduce_rh)
                c_rights = torch.cat(reduce_rc)

                if self.args.tracking:
                    e_out = torch.cat(reduce_tracking_states)
                    h_outs, c_outs = self.reduce((h_lefts, c_lefts),
                                                 (h_rights, c_rights), e_out)
                else:
                    h_outs, c_outs = self.reduce((h_lefts, c_lefts),
                                                 (h_rights, c_rights))

                for i, state in enumerate(zip(h_outs, c_outs)):
                    reduce_valence = reduce_valences[
                        i] if self.args.tracking else None
                    stack_batch[reduce_ids[i]].add(state, reduce_valence)

        outputs = []
        for (i, stack) in enumerate(stack_batch):
            if not self.args.continuous_stack:
                if not stack.size() == 1:
                    print("Stack size is %d.  Should be 1" % stack.size())
                    assert stack.size() == 1
            top_h = stack.peek()[0]
            outputs.append(top_h)

        if len(true_actions) > 0 and self.training:
            return torch.cat(outputs), torch.cat(true_actions), torch.log(
                torch.cat(lstm_actions))
        return torch.cat(outputs), None, None