Пример #1
0
def main():
    (train_data, val_data, test_data), tokenizer, args = get_data_and_args()
    model = get_model(args)

    save_root = '' if args.load is None else args.load
    save_root = save_root.replace('.current', '')
    save_root = os.path.splitext(save_root)[0]
    save_root += '_transfer'
    save_root = os.path.join(save_root, args.save_results)
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    print('writing results to ' + save_root)

    # featurize train, val, test or use previously cached features if possible
    print('transforming train')
    if not (os.path.exists(os.path.join(save_root, 'trXt.npy'))
            and args.use_cached):
        trXt, trY = transform(model, train_data, args)
        np.save(os.path.join(save_root, 'trXt'), trXt)
        np.save(os.path.join(save_root, 'trY'), trY)
    else:
        trXt = np.load(os.path.join(save_root, 'trXt.npy'))
        trY = np.load(os.path.join(save_root, 'trY.npy'))
    vaXt, vaY = None, None
    if val_data is not None:
        print('transforming validation')
        if not (os.path.exists(os.path.join(save_root, 'vaXt.npy'))
                and args.use_cached):
            vaXt, vaY = transform(model, val_data, args)
            np.save(os.path.join(save_root, 'vaXt'), vaXt)
            np.save(os.path.join(save_root, 'vaY'), vaY)
        else:
            vaXt = np.load(os.path.join(save_root, 'vaXt.npy'))
            vaY = np.load(os.path.join(save_root, 'vaY.npy'))
    teXt, teY = None, None
    if test_data is not None:
        print('transforming test')
        if not (os.path.exists(os.path.join(save_root, 'teXt.npy'))
                and args.use_cached):
            teXt, teY = transform(model, test_data, args)
            np.save(os.path.join(save_root, 'teXt'), teXt)
            np.save(os.path.join(save_root, 'teY'), teY)
        else:
            teXt = np.load(os.path.join(save_root, 'teXt.npy'))
            teY = np.load(os.path.join(save_root, 'teY.npy'))

    # train logistic regression model of featurized text against labels
    start = time.time()
    metric = 'mcc' if args.mcc else 'acc'
    logreg_model, logreg_scores, logreg_probs, c, nnotzero = train_logreg(
        trXt,
        trY,
        vaXt,
        vaY,
        teXt,
        teY,
        max_iter=args.epochs,
        eval_test=not args.no_test_eval,
        seed=args.seed,
        report_metric=metric,
        threshold_metric=metric)
    end = time.time()
    elapsed_time = end - start

    with open(os.path.join(save_root, 'all_neurons_score.txt'), 'w') as f:
        f.write(str(logreg_scores))
    with open(os.path.join(save_root, 'all_neurons_probs.pkl'), 'wb') as f:
        pkl.dump(logreg_probs, f)
    with open(os.path.join(save_root, 'neurons.pkl'), 'wb') as f:
        pkl.dump(logreg_model.coef_, f)

    print('all neuron regression took %s seconds' % (str(elapsed_time)))
    print(', '.join([str(score) for score in logreg_scores]),
          'train, val, test accuracy for all neuron regression')
    print(str(c) + ' regularization coefficient used')
    print(str(nnotzero) + ' features used in all neuron regression\n')

    # save a sentiment classification pytorch model
    sd = {}
    if not args.fp16:
        clf_sd = {
            'weight': torch.from_numpy(logreg_model.coef_).float(),
            'bias': torch.from_numpy(logreg_model.intercept_).float()
        }
    else:
        clf_sd = {
            'weight': torch.from_numpy(logreg_model.coef_).half(),
            'bias': torch.from_numpy(logreg_model.intercept_).half()
        }
    sd['classifier'] = clf_sd
    model.float().cpu()
    sd['lm_encoder'] = model.state_dict()
    with open(os.path.join(save_root, 'classifier.pt'), 'wb') as f:
        torch.save(sd, f)
    model.half()
    sd['lm_encoder'] = model.state_dict()
    with open(os.path.join(save_root, 'classifier.pt.16'), 'wb') as f:
        torch.save(sd, f)

    # extract sentiment neuron indices
    sentiment_neurons = get_top_k_neuron_weights(logreg_model, args.neurons)
    print('using neuron(s) %s as features for regression' % (', '.join(
        [str(neuron) for neuron in list(sentiment_neurons.reshape(-1))])))

    # train logistic regression model of features corresponding to sentiment neuron indices against labels
    start = time.time()
    logreg_neuron_model, logreg_neuron_scores, logreg_neuron_probs, neuron_c, neuron_nnotzero = train_logreg(
        trXt,
        trY,
        vaXt,
        vaY,
        teXt,
        teY,
        max_iter=args.epochs,
        eval_test=not args.no_test_eval,
        seed=args.seed,
        neurons=sentiment_neurons,
        drop_neurons=args.drop_neurons,
        report_metric=metric,
        threshold_metric=metric)
    end = time.time()

    if args.drop_neurons:
        with open(os.path.join(save_root, 'dropped_neurons_score.txt'),
                  'w') as f:
            f.write(str(logreg_neuron_scores))

        with open(os.path.join(save_root, 'dropped_neurons_probs.pkl'),
                  'wb') as f:
            pkl.dump(logreg_neuron_probs, f)

        print('%d dropped neuron regression took %s seconds' %
              (args.neurons, str(end - start)))
        print(
            ', '.join([str(score) for score in logreg_neuron_scores]),
            'train, val, test accuracy for %d dropped neuron regression' %
            (args.neurons))
        print(str(neuron_c) + ' regularization coefficient used')

        start = time.time()
        logreg_neuron_model, logreg_neuron_scores, logreg_neuron_probs, neuron_c, neuron_nnotzero = train_logreg(
            trXt,
            trY,
            vaXt,
            vaY,
            teXt,
            teY,
            max_iter=args.epochs,
            eval_test=not args.no_test_eval,
            seed=args.seed,
            neurons=sentiment_neurons,
            report_metric=metric,
            threshold_metric=metric)
        end = time.time()

    print('%d neuron regression took %s seconds' %
          (args.neurons, str(end - start)))
    print(
        ', '.join([str(score) for score in logreg_neuron_scores]),
        'train, val, test accuracy for %d neuron regression' % (args.neurons))
    print(str(neuron_c) + ' regularization coefficient used')

    # log model accuracies, predicted probabilities, and weight/bias of regression model

    with open(os.path.join(save_root, 'all_neurons_score.txt'), 'w') as f:
        f.write(str(logreg_scores))

    with open(os.path.join(save_root, 'neurons_score.txt'), 'w') as f:
        f.write(str(logreg_neuron_scores))

    with open(os.path.join(save_root, 'all_neurons_probs.pkl'), 'wb') as f:
        pkl.dump(logreg_probs, f)

    with open(os.path.join(save_root, 'neurons_probs.pkl'), 'wb') as f:
        pkl.dump(logreg_neuron_probs, f)

    with open(os.path.join(save_root, 'neurons.pkl'), 'wb') as f:
        pkl.dump(logreg_model.coef_, f)

    with open(os.path.join(save_root, 'neuron_bias.pkl'), 'wb') as f:
        pkl.dump(logreg_model.intercept_, f)

    #Plot feats
    use_feats, use_labels = teXt, teY
    if use_feats is None:
        use_feats, use_labels = vaXt, vaY
    if use_feats is None:
        use_feats, use_labels = trXt, trY
    try:
        plot_logits(save_root, use_feats, use_labels, sentiment_neurons)
    except:
        print('no labels to plot logits for')

    plot_weight_contribs_and_save(logreg_model.coef_,
                                  os.path.join(save_root, 'weight_vis.png'))

    print('results successfully written to ' + save_root)
    if args.write_results == '':
        exit()

    def get_csv_writer(feats, top_neurons, all_proba, neuron_proba):
        """makes a generator to be used in data_utils.datasets.csv_dataset.write()"""
        header = ['prob w/ all', 'prob w/ %d neuron(s)' % (len(top_neurons), )]
        top_feats = feats[:, top_neurons]
        header += ['neuron %s' % (str(x), ) for x in top_neurons]

        yield header

        for i, _ in enumerate(top_feats):
            row = []
            row.append(all_proba[i])
            row.append(neuron_proba[i])
            row.extend(list(top_feats[i].reshape(-1)))
            yield row

    data, use_feats = test_data, teXt
    if use_feats is None:
        data, use_feats = val_data, vaXt
    if use_feats is None:
        data, use_feats = train_data, trXt
    csv_writer = get_csv_writer(use_feats, sentiment_neurons, logreg_probs[-1],
                                logreg_neuron_probs[-1])
    data.dataset.write(csv_writer, path=args.write_results)
def main():
    (train_data, val_data, test_data), tokenizer, args = get_data_and_args()
    # Print args for logging & reproduction. Need to know, including default args
    if test_data is None:
        test_data = val_data
    model, optim, LR = get_model_and_optim(args, train_data)

    # save_root = '' if args.load is None else args.load
    # save_root = save_root.replace('.current', '')
    # save_root = os.path.splitext(save_root)[0]
    # save_root += '_transfer'
    save_root = os.path.join('', args.model_version_name)
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    print('writing results to ' + save_root)

    def clf_reg_loss(reg_penalty=.125, order=1):
        loss = 0
        for p in model.classifier.parameters():
            loss += torch.abs(p).sum() * reg_penalty
        return loss

    reg_loss = clf_reg_loss
    init_params = list(model.lm_encoder.parameters())

    if args.use_logreg:

        def transform_for_logreg(model, data, args, desc='train'):
            if data is None:
                return None, None

            X_out = []
            Y_out = []
            for i, batch in tqdm(enumerate(data),
                                 total=len(data),
                                 unit="batch",
                                 desc=desc,
                                 position=0,
                                 ncols=100):
                text_batch, labels_batch, length_batch = get_supervised_batch(
                    batch,
                    args.cuda,
                    model,
                    args.max_seq_len,
                    args,
                    heads_per_class=args.heads_per_class)
                # if args.non_binary_cols:
                #     labels_batch = labels_batch[:,0]-labels_batch[:,1]+1
                _, (_, state) = transform(model, text_batch, labels_batch,
                                          length_batch, args)
                X_out.append(state.cpu().numpy())
                Y_out.append(labels_batch.cpu().numpy())
            X_out = np.concatenate(X_out)
            Y_out = np.concatenate(Y_out)
            return X_out, Y_out

        model.eval()
        trX, trY = transform_for_logreg(model, train_data, args, desc='train')
        vaX, vaY = transform_for_logreg(model, val_data, args, desc='val')
        teX, teY = transform_for_logreg(model, test_data, args, desc='test')

        logreg_model, logreg_scores, logreg_preds, c, nnotzero = train_logreg(
            trX,
            trY,
            vaX,
            vaY,
            teX,
            teY,
            eval_test=not args.no_test_eval,
            report_metric=args.report_metric,
            threshold_metric=args.threshold_metric,
            automatic_thresholding=args.automatic_thresholding,
            micro=args.micro)
        print(', '.join([str(score) for score in logreg_scores]),
              'train, val, test accuracy for all neuron regression')
        print(str(c) + ' regularization coefficient used')
        print(str(nnotzero) + ' features used in all neuron regression\n')
    else:
        best_vaY = 0
        vaT = [
        ]  # Current "best thresholds" so we can get reasonable estimates on training set
        for e in tqdm(range(args.epochs),
                      unit="epoch",
                      desc="epochs",
                      position=0,
                      ncols=100):
            if args.use_softmax:
                vaT = []
            save_outputs = False
            report_metrics = [
                'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var'
            ] if args.all_metrics else [args.report_metric]
            print_str = ""
            trXt, trY, trC, _ = finetune(model,
                                         train_data,
                                         args,
                                         val_data=val_data,
                                         LR=LR,
                                         reg_loss=reg_loss,
                                         tqdm_desc='train',
                                         heads_per_class=args.heads_per_class,
                                         last_thresholds=vaT,
                                         threshold_validation=False)
            data_str_base = "Train Loss: {:4.2f} Train {:5s} (All): {:5.2f}, Train Class {:5s}: {}"
            for idx, m in enumerate(report_metrics):
                data_str = data_str_base.format(trXt, m, trY[idx] * 100, m,
                                                trC[idx])
                print_str += data_str + " " * max(0,
                                                  110 - len(data_str)) + "\n"

            vaXt, vaY = None, None
            if val_data is not None:
                vaXt, vaY, vaC, vaT = finetune(
                    model,
                    val_data,
                    args,
                    tqdm_desc='val',
                    heads_per_class=args.heads_per_class,
                    last_thresholds=vaT)
                # Take command line, for metric for which to measure best performance against.
                # NOTE: F1, MCC, Jaccard are good measures. Accuracy is not -- since so skewed.
                selection_metric = [
                    'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var'
                ].index(args.threshold_metric)
                avg_Y = vaY[selection_metric]
                tqdm.write('avg ' + args.threshold_metric + ' metric ' +
                           str(avg_Y))
                if avg_Y > best_vaY:
                    save_outputs = True
                    best_vaY = avg_Y
                elif avg_Y == best_vaY and random.random() > 0.5:
                    save_outputs = True
                    best_vaY = avg_Y
                data_str_base = "Val   Loss: {:4.2f} Val   {:5s} (All): {:5.2f}, Val   Class {:5s}: {}"
                for idx, m in enumerate(report_metrics):
                    data_str = data_str_base.format(vaXt, m, vaY[idx] * 100, m,
                                                    vaC[idx])
                    print_str += data_str + " " * max(
                        0, 110 - len(data_str)) + "\n"
            tqdm.write(print_str[:-1])
            teXt, teY = None, None
            if test_data is not None:
                # Hardcode -- enable to always save outputs [regardless of metrics]
                # save_outputs = True
                if save_outputs:
                    tqdm.write('performing test eval')
                    try:
                        with torch.no_grad():
                            if not args.no_test_eval:
                                auto_thresholds = None
                                dual_thresholds = None
                                # NOTE -- we manually threshold to F1 [not necessarily good]
                                V_pred, V_label, V_std = generate_outputs(
                                    model, val_data, args)
                                if args.automatic_thresholding:
                                    if args.dual_thresh:
                                        # get dual threshold (do not call auto thresholds)
                                        # TODO: Handle multiple heads per class
                                        _, dual_thresholds = _neutral_threshold_two_output(
                                            V_pred.cpu().numpy(),
                                            V_label.cpu().numpy())
                                        model.set_thresholds(
                                            dual_thresholds,
                                            dual_threshold=args.dual_thresh
                                            and not args.joint_binary_train)
                                    else:
                                        # Use args.threshold_metric to choose which category to threshold on. F1 and Jaccard are good options
                                        # NOTE: For multiple heads per class, can threshold each head (default) or single threshold. Little difference once model converges.
                                        auto_thresholds = vaT
                                        # _, auto_thresholds, _, _ = _binary_threshold(V_pred.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(), V_label.view(-1, int(model.out_dim/args.heads_per_class)).contiguous(),
                                        #     args.threshold_metric, args.micro, global_tweaks=args.global_tweaks)
                                        model.set_thresholds(
                                            auto_thresholds,
                                            args.double_thresh)
                                T_pred, T_label, T_std = generate_outputs(
                                    model, test_data, args, auto_thresholds)
                                if not args.use_softmax and int(
                                        model.out_dim /
                                        args.heads_per_class) > 1:
                                    keys = list(args.non_binary_cols)
                                    if args.dual_thresh:
                                        if len(keys) == len(dual_thresholds):
                                            tqdm.write(
                                                'Dual thresholds: %s' % str(
                                                    list(
                                                        zip(
                                                            keys,
                                                            dual_thresholds))))
                                        keys += ['neutral']
                                    else:
                                        tqdm.write(
                                            'Class thresholds: %s' % str(
                                                list(zip(
                                                    keys, auto_thresholds))))
                                elif args.use_softmax:
                                    keys = [
                                        str(m) for m in range(model.out_dim)
                                    ]
                                else:
                                    tqdm.write('Class threshold: %s' % str(
                                        [args.label_key, auto_thresholds[0]]))
                                    keys = ['']
                                info_dicts = [{
                                    'fp': 0,
                                    'tp': 0,
                                    'fn': 0,
                                    'tn': 0,
                                    'std': 0.,
                                    'metric': args.report_metric,
                                    'micro': True
                                } for k in keys]
                                #perform dual threshold here, adding the neutral labels to T_label, thresholding existing predictions and adding neutral preds to T_Pred
                                if args.dual_thresh:
                                    if dual_thresholds is None:
                                        dual_thresholds = [.5, .5]

                                    def make_onehot_w_neutral(label):
                                        rtn = [0] * 3
                                        rtn[label] = 1
                                        return rtn

                                    def get_label(pos_neg):
                                        thresholded = [
                                            pos_neg[0] >= dual_thresholds[0],
                                            pos_neg[1] >= dual_thresholds[1]
                                        ]
                                        if thresholded[0] == thresholded[1]:
                                            return 2
                                        return thresholded.index(1)

                                    def get_new_std(std):
                                        return std[0], std[1], (std[0] +
                                                                std[1]) / 2

                                    new_labels = []
                                    new_preds = []
                                    T_std = torch.cat([
                                        T_std[:, :2],
                                        T_std[:, :2].mean(-1).view(-1, 1)
                                    ], -1).cpu().numpy()
                                    for j, lab in enumerate(T_label):
                                        pred = T_pred[j]
                                        new_preds.append(
                                            make_onehot_w_neutral(
                                                get_label(pred)))
                                        new_labels.append(
                                            make_onehot_w_neutral(
                                                get_label(lab)))
                                    T_pred = np.array(new_preds)
                                    T_label = np.array(new_labels)

                                # HACK: If dual threshold, hardcoded -- assume positive, negative and neutral -- in that order
                                # It's ok to train with other categories (after positive, neutral) as auxilary loss -- but won't calculate in test
                                if args.dual_thresh and args.joint_binary_train:
                                    keys = ['positive', 'negative', 'neutral']
                                    info_dicts = [{
                                        'fp': 0,
                                        'tp': 0,
                                        'fn': 0,
                                        'tn': 0,
                                        'std': 0.,
                                        'metric': args.report_metric,
                                        'micro': True
                                    } for k in keys]
                                for j, k in enumerate(keys):
                                    update_info_dict(info_dicts[j],
                                                     T_pred[:, j],
                                                     T_label[:, j],
                                                     std=T_std[:, j])
                                total_metrics, metric_strings = get_metric_report(
                                    info_dicts, args, keys)
                                test_str = ''
                                test_str_base = "Test  {:5s} (micro): {:5.2f}, Test  Class {:5s}: {}"
                                for idx, m in enumerate(report_metrics):
                                    data_str = test_str_base.format(
                                        m, total_metrics[idx] * 100, m,
                                        metric_strings[idx])
                                    test_str += data_str + " " * max(
                                        0, 110 - len(data_str)) + "\n"
                                tqdm.write(test_str[:-1])
                                # tqdm.write(str(total_metrics))
                                # tqdm.write('; '.join(metric_strings))
                            else:
                                V_pred, V_label, V_std = generate_outputs(
                                    model, val_data, args)

                                T_pred, T_label, T_std = generate_outputs(
                                    model, test_data, args)
                            val_path = os.path.join(save_root,
                                                    'val_results.txt')
                            tqdm.write(
                                'Saving validation prediction results of size %s to %s'
                                % (str(T_pred.shape[:]), val_path))
                            write_results(V_pred, V_label, val_path)

                            test_path = os.path.join(save_root,
                                                     'test_results.txt')
                            tqdm.write(
                                'Saving test prediction results of size %s to %s'
                                % (str(T_pred.shape[:]), test_path))
                            write_results(T_pred, T_label, test_path)
                    except KeyboardInterrupt:
                        pass
                else:
                    pass
            # Save the model, upon request
            if args.save_finetune and save_outputs:
                # Save model if best so far. Note epoch number, and also keys [what is it predicting], as well as optional version number
                # TODO: Add key string to handle multiple runs?
                if args.non_binary_cols:
                    keys = list(args.non_binary_cols)
                else:
                    keys = [args.label_key]
                # Also save args
                args_save_path = os.path.join(save_root, 'args.txt')
                tqdm.write('Saving commandline to %s' % args_save_path)
                with open(args_save_path, 'w') as f:
                    f.write(' '.join(sys.argv[1:]))
                # Save and add thresholds to arguments for easy reloading of model config
                if not args.no_test_eval and args.automatic_thresholding:
                    thresh_save_path = os.path.join(
                        save_root, 'thresh' + '_ep' + str(e) + '.npy')
                    tqdm.write('Saving thresh to %s' % thresh_save_path)
                    if args.dual_thresh:
                        np.save(thresh_save_path,
                                list(zip(keys, dual_thresholds)))
                        args.thresholds = list(zip(keys, dual_thresholds))
                    else:
                        np.save(thresh_save_path,
                                list(zip(keys, auto_thresholds)))
                        args.thresholds = list(zip(keys, auto_thresholds))
                else:
                    args.thresholds = None
                args.classes = keys
                #save full model with args to restore
                clf_save_path = os.path.join(save_root,
                                             'model' + '_ep' + str(e) + '.clf')
                tqdm.write('Saving full classifier to %s' % clf_save_path)
                torch.save({
                    'sd': model.state_dict(),
                    'args': args
                }, clf_save_path)