Пример #1
0
    def run_batch(self, sess, inps, istrn=True):
        plhs = [self.inp_l_plh,
                self.tgt_l_plh,
                self.msk_l_plh,
                self.y_plh,
                self.inp_u_plh,
                self.tgt_u_plh,
                self.msk_u_plh,]

        if istrn:
            fetch_dict = [['elbo_l', self.elbo_loss_l],
                    ['ppl_l', self.ppl_l],
                    ['kl_l', self.kl_loss_l],
                    ['pred_l', self.predict_loss_l],
                    ['loss_u', self.loss_u],
                    ['train_u', self.train_unlabel],
                    ['klw', self.kl_w]]

            feed_dict = dict(list(zip(plhs, inps)) + [[self.is_training, True]])
            fetch = [self.merged] + [t[1] for t in fetch_dict] + [self.train_op]
            res = sess.run(fetch, feed_dict)
            res_dict = dict([[fetch_dict[i][0], res[i+1]] for i in range(len(fetch_dict))])
            res_str = res_to_string(res_dict)
        else:
            fetch_dict = [['pred_l', self.predict_loss_l],]
            feed_dict = dict(list(zip(plhs, inps+inps[:-1])) + [[self.is_training, False]])
            fetch = [self.merged] + [t[1] for t in fetch_dict]
            res = sess.run(fetch, feed_dict)
            res_dict = dict([[fetch_dict[i][0], res[i+1]] for i in range(len(fetch_dict))])
            res_str = res_to_string(res_dict)
        return res_dict, res_str, res[0]
Пример #2
0
def validate(valid_dset, model, sess):
    res_list = []
    #threaded_it = threaded_generator(valid_dset, 200)
    for batch in valid_dset:
        res_dict, res_str, summary = model.run_batch(sess, batch, istrn=False)
        res_list.append(res_dict)
    out_str = res_to_string(average_res(res_list))
    return out_str
Пример #3
0
    def run_batch(self, sess, inps, istrn=True):
        plhs = [self.input_plh, self.mask_plh, self.label_plh]

        if istrn:
            fetch_dict = [
                ['loss_total', self.loss_total],
                ['loss', self.loss],
                ['reg_l1', self.loss_reg_l1],
                ['reg_diff', self.loss_reg_diff],
                ['reg_sharp', self.loss_reg_sharp],
                ['reg_frobenius', self.loss_reg_frobenius],
                ['acc', self.accuracy],
            ]

            feed_dict = dict(
                list(zip(plhs, inps)) + [[self.is_training_plh, istrn]])
            fetch_nonscalar = [self.merged, self.gate_weights, self.pred]
            fetch = fetch_nonscalar + [t[1]
                                       for t in fetch_dict] + [self.train_op]
            res = sess.run(fetch, feed_dict)
            res_dict = dict([[fetch_dict[i][0], res[i + len(fetch_nonscalar)]]
                             for i in range(len(fetch_dict))])
            res_str = res_to_string(res_dict)
        else:
            fetch_dict = [
                ['loss_total', self.loss_total],
                ['loss', self.loss],
                ['loss_reg_l1', self.loss_reg_l1],
                ['loss_reg_diff', self.loss_reg_diff],
                ['loss_reg_sharp', self.loss_reg_sharp],
                ['acc', self.accuracy],
            ]

            feed_dict = dict(
                list(zip(plhs, inps)) + [[self.is_training_plh, istrn]])
            fetch_nonscalar = [self.merged, self.gate_weights, self.pred]
            fetch = fetch_nonscalar + [t[1] for t in fetch_dict]
            res = sess.run(fetch, feed_dict)
            res_dict = dict([[fetch_dict[i][0], res[i + len(fetch_nonscalar)]]
                             for i in range(len(fetch_dict))])
            res_str = res_to_string(res_dict)
        return res_dict, res_str, res[0], res[1:len(fetch_nonscalar)]
Пример #4
0
def run(args, model, sess, label_dset, unlabel_dset, valid_dset, test_dset,
        explogger):
    batch_cnt = 0
    res_list = []

    # init tensorboard writer
    tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph)

    t_time = time.time()
    time_cnt = 0
    for epoch_idx in range(args.max_epoch):
        #threaded_it_u = threaded_generator(unlabel_dset, 200)
        #threaded_it_l = threaded_generator(label_dset, 200)
        for batch_u in unlabel_dset:
            batch_cnt += 1
            batch_l = get_batch(label_dset)
            gen_time = time.time() - t_time
            t_time = time.time()
            #res_dict, res_str, summary = model.run_batch(sess, batch_l+batch_u[:-args.num_tasks], istrn=True)
            res_dict, res_str, summary = model.run_batch(sess,
                                                         batch_l + batch_u,
                                                         istrn=True,
                                                         args=args)
            run_time = time.time() - t_time
            res_dict.update({'run_time': run_time})
            res_dict.update({'gen_time': gen_time})
            res_list.append(res_dict)
            res_list = res_list[-200:]
            time_cnt += gen_time + run_time

            if batch_cnt % args.show_every == 0:
                tb_writer.add_summary(summary, batch_cnt)
                out_str = res_to_string(average_res(res_list))
                explogger.message(out_str, True)

            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = 'VALIDATE:' + validate(valid_dset, model, sess, args)
                explogger.message(out_str, True)

            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = 'TEST:' + validate(test_dset, model, sess, args)
                explogger.message(out_str, True)

            if args.save_every != -1 and batch_cnt % args.save_every == 0:
                save_fn = os.path.join(args.save_dir, args.log_prefix)
                explogger.message("Saving checkpoint model: {} ......".format(
                    args.save_dir))
                model.saver.save(sess,
                                 save_fn,
                                 write_meta_graph=False,
                                 global_step=batch_cnt)

            t_time = time.time()
Пример #5
0
def train_and_validate(args, model, sess, train_dset, valid_dset, test_dset,
                       explogger, vocab, class_map):
    batch_cnt = 0
    res_list = []

    # init tensorboard writer
    tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph)

    t_time = time.time()
    time_cnt = 0
    for epoch_idx in range(args.max_epoch):
        threaded_it = threaded_generator(train_dset, 200)
        for batch in threaded_it:
            batch_cnt += 1
            gen_time = time.time() - t_time
            t_time = time.time()
            res_dict, res_str, summary, gate_weights = model.run_batch(
                sess, batch, istrn=True)
            run_time = time.time() - t_time
            res_dict.update({'run_time': run_time})
            res_dict.update({'gen_time': gen_time})
            res_list.append(res_dict)
            res_list = res_list[-200:]
            time_cnt += gen_time + run_time

            if batch_cnt % args.show_every == 0:
                tb_writer.add_summary(summary, batch_cnt)
                out_str = res_to_string(average_res(res_list))
                explogger.message(out_str, True)

            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = validate(valid_dset, model, sess, args, vocab,
                                   class_map)
                explogger.message('VALIDATE: ' + out_str, True)

            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = validate(test_dset, model, sess, args, vocab,
                                   class_map)
                explogger.message('TEST: ' + out_str, True)

            if args.save_every != -1 and batch_cnt % args.save_every == 0:
                save_fn = os.path.join(args.save_dir, args.log_prefix)
                explogger.message("Saving checkpoint model: {} ......".format(
                    args.save_dir))
                model.saver.save(sess,
                                 save_fn,
                                 write_meta_graph=False,
                                 global_step=batch_cnt)

            t_time = time.time()
Пример #6
0
def validate(valid_dset, model, sess, args, vocab, class_map):
    res_list = []
    threaded_it = threaded_generator(valid_dset, 200)
    wf = open(os.path.join(args.save_dir, 'gate.log'), 'w')
    rf = open(os.path.join(args.save_dir, 'predict.log'), 'w')
    rf.write('pred\ttarget\ttxt\n')
    list_predict = []
    list_target = []
    for batch in threaded_it:
        res_dict, res_str, summary, [gate_weights,
                                     pred] = model.run_batch(sess,
                                                             batch,
                                                             istrn=False)
        res_list.append(res_dict)

        batch_size = batch[0].shape[0]
        #wf.write(str(gate_weights))
        #wf.write(str(gate_weights.shape))
        for sidx in range(batch_size):
            for idx, w in enumerate(batch[0][sidx]):
                wf.write(vocab[w] + '\t')
                wf.write('{0:.3f}\t'.format(gate_weights[sidx][idx][0]) + '\n')
            wf.write('\n')

        pred = [class_map[idx] for idx in pred]
        tgt = [class_map[idx] for idx in batch[2]]
        list_predict.extend(pred)
        list_target.extend(tgt)
        inp = [[vocab[idx] for idx in s if idx != 0] for s in batch[0]]

        for i in range(batch_size):
            rf.write(str(pred[i]) + '\t')
            rf.write(str(tgt[i]) + '\t')
            rf.write(''.join(inp[i]) + '\n')

    wf.close()
    rf.close()

    cf = open(os.path.join(args.save_dir, 'confusion_matrix.log'), 'w')
    labels = sorted(list(class_map.values()))
    confusion = confusion_matrix(list_target, list_predict, labels)
    outstr = utils.confusion_matrix_to_string(confusion, labels)
    #print(outstr)
    cf.write(outstr)

    out_str = res_to_string(average_res(res_list))
    return out_str
Пример #7
0
def run(args, model, sess, train_dset, valid_dset, test_dset, explogger):
    with open(args.vocab_path, 'rb') as f:
        vocab = pkl.load(f)
        vocab = {int(vocab[k]): k for k in vocab}

    batch_cnt = 0
    res_list = []

    # init tensorboard writer
    tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph)

    t_time = time.time()
    time_cnt = 0
    for epoch_idx in range(args.max_epoch):
        #threaded_it_u = threaded_generator(unlabel_dset, 200)
        #threaded_it_l = threaded_generator(label_dset, 200)
        for batch in train_dset:
            batch_cnt += 1
            gen_time = time.time() - t_time
            t_time = time.time()
            #res_dict, res_str, summary = model.run_batch(sess, batch_l+batch_u[:-1], istrn=True)
        
            plhs = [model.input_plh,
                    model.mask_plh,
                    model.label_plh,
                    model.is_training]

            gates = sess.run(model.weights, dict(zip(plhs, list(batch)+[True])))
            print(gates)
            print(gates.shape)
            for sidx in range(20):
                for idx, w in enumerate(batch[0][sidx]):
                    print(vocab[w] + '\t', end='')
                #for idx, w in enumerate(batch[1][sidx]):
                    print('{0:.3f}\t'.format(gates[sidx][idx][0]))
                print()

            return

            run_time = time.time() - t_time
            res_dict.update({'run_time': run_time})
            res_dict.update({'gen_time': gen_time})
            res_list.append(res_dict)
            res_list = res_list[-200:]
            time_cnt += gen_time + run_time

            if batch_cnt % args.show_every == 0:
                tb_writer.add_summary(summary, batch_cnt)
                out_str = res_to_string(average_res(res_list))
                explogger.message(out_str)
            
            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = 'VALIDATE:' + validate(valid_dset, model, sess)
                explogger.message(out_str)

            if args.validate_every != -1 and batch_cnt % args.validate_every == 0:
                out_str = 'TEST:' + validate(test_dset, model, sess)
                explogger.message(out_str)

            if batch_cnt % args.save_every == 0:
                save_fn = os.path.join(args.save_dir, args.log_prefix)
                explogger.message("Saving checkpoint model: {} ......".format(args.save_dir))
                model.saver.save(sess, save_fn, 
                        write_meta_graph=False,
                        global_step=batch_cnt)

            t_time = time.time()