示例#1
0
def main(args):
    # parse args
    args = parse_args(args)

    # prepare data
    if args['prep_data']:
        print('\n ====== preparing data ====== \n')
        for i in range(1, 7):
            print(' TASK #{}\n'.format(i))
            prepare_data(args, task_id=i)
        sys.exit()

    ##################################################################
    # 데이터 준비가 아니면 read data and metadata from pickled files
    ##################################################################
    with open(P_DATA_DIR + str(args['task_id']) + '.metadata.pkl', 'rb') as f:
        metadata = pkl.load(f)
    with open(P_DATA_DIR + str(args['task_id']) + '.data.pkl', 'rb') as f:
        data_ = pkl.load(f)

    # read content of data and metadata
    candidates = data_['candidates']
    candid2idx, idx2candid = metadata['candid2idx'], metadata['idx2candid']

    # read train, test, val data
    train, test, val = data_['train'], data_['test'], data_['val']

    # get more required information from metadata
    sentence_size = metadata['sentence_size']
    w2idx = metadata['w2idx']
    idx2w = metadata['idx2w']
    memory_size = metadata['memory_size']
    vocab_size = metadata['vocab_size']
    n_cand = metadata['n_cand']
    candidate_sentence_size = metadata['candidate_sentence_size']

    # 후보 response들의 백터화
    candidates_vec = data_utils.vectorize_candidates(candidates, w2idx,
                                                     candidate_sentence_size)

    # create model - memn2n
    model = memn2n.MemN2NDialog(batch_size=BATCH_SIZE,
                                vocab_size=vocab_size,
                                candidates_size=n_cand,
                                sentence_size=sentence_size,
                                embedding_size=20,
                                candidates_vec=candidates_vec,
                                hops=3)

    train, val, test, batches = data_utils.get_batches(train,
                                                       val,
                                                       test,
                                                       metadata,
                                                       batch_size=BATCH_SIZE)

    # training은 여기서 실행 된다.
    if args['train']:
        epochs = args['epochs']
        eval_interval = args['eval_interval']

        print('\n>>> Training started...\n')

        # write log to file
        log_handle = open('log/' + args['log_file'], 'w')
        cost_total = 0.

        for i in range(epochs + 1):
            for start, end in batches:
                s = train['s'][start:end]
                q = train['q'][start:end]
                a = train['a'][start:end]
                cost_total += model.batch_fit(s, q, a)

            if i % eval_interval == 0 and i:
                train_preds = batch_predict(model,
                                            train['s'],
                                            train['q'],
                                            len(train['s']),
                                            batch_size=BATCH_SIZE)
                val_preds = batch_predict(model,
                                          val['s'],
                                          val['q'],
                                          len(val['s']),
                                          batch_size=BATCH_SIZE)
                train_acc = metrics.accuracy_score(np.array(train_preds),
                                                   train['a'])
                val_acc = metrics.accuracy_score(val_preds, val['a'])
                print(
                    'Epoch[{}] : <Accuracy>\n\ttraining : {} \n\tvalidation : {}'
                    .format(i, train_acc, val_acc))
                log_handle.write('{} {} {} {}\n'.format(
                    i, train_acc, val_acc,
                    cost_total / (eval_interval * len(batches))))
                cost_total = 0.

                model.saver.save(
                    model._sess,
                    CKPT_DIR + '{}/memn2n_model.ckpt'.format(args['task_id']),
                    global_step=i)

        log_handle.close()
    # inference
    else:
        # restore checkpoint
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR + str(args['task_id']))
        if ckpt and ckpt.model_checkpoint_path:
            print('\n>> restoring checkpoint from', ckpt.model_checkpoint_path)
            model.saver.restore(model._sess, ckpt.model_checkpoint_path)

        isess = InteractiveSession(model, idx2candid, w2idx, n_cand,
                                   memory_size)

        if args['infer']:
            query = ''
            while query != 'exit':
                query = input('>> ')
                print('>> ' + isess.reply(query))
示例#2
0
def main(args):
    # parse args
    args = parse_args(args)

    # prepare data
    if args['prep_data']:
        print('\n>> Preparing Data\n')
        for i in range(1,7):
            print(' TASK#{}\n'.format(i))
            prepare_data(args, task_id=i)
        sys.exit()

    # ELSE
    # read data and metadata from pickled files
    with open(P_DATA_DIR + str(args['task_id']) + '.metadata.pkl', 'rb') as f:
        metadata = pkl.load(f)
    with open(P_DATA_DIR + str(args['task_id']) + '.data.pkl', 'rb') as f:
        data_ = pkl.load(f)

    # read content of data and metadata
    candidates = data_['candidates']
    candid2idx, idx2candid = metadata['candid2idx'], metadata['idx2candid']

    # get train/test/val data
    train, test, val = data_['train'], data_['test'], data_['val']

    # gather more information from metadata
    sentence_size = metadata['sentence_size']
    w2idx = metadata['w2idx']
    idx2w = metadata['idx2w']
    memory_size = metadata['memory_size']
    vocab_size = metadata['vocab_size']
    n_cand = metadata['n_cand']
    candidate_sentence_size = metadata['candidate_sentence_size']

    # vectorize candidates
    candidates_vec = data_utils.vectorize_candidates(candidates, w2idx, candidate_sentence_size)

    ###
    # create model
    #model = model['memn2n']( # why?
    model = memn2n.MemN2NDialog(
                batch_size= BATCH_SIZE,
                vocab_size= vocab_size, 
                candidates_size= n_cand, 
                sentence_size= sentence_size, 
                embedding_size= 20, 
                candidates_vec= candidates_vec, 
                hops= 3
                )
    # gather data in batches
    train, val, test, batches = data_utils.get_batches(train, val, test, metadata, batch_size=BATCH_SIZE)

    if args['train']:
        # training starts here
        epochs = args['epochs']
        eval_interval = args['eval_interval']
        #
        # training and evaluation loop
        print('\n>> Training started!\n')
        # write log to file
        log_handle = open('log/' + args['log_file'], 'w')
        cost_total = 0.
        #best_validation_accuracy = 0.
        for i in range(epochs+1):

            for start, end in batches:
                s = train['s'][start:end]
                q = train['q'][start:end]
                a = train['a'][start:end]
                cost_total += model.batch_fit(s, q, a)
            
            if i%eval_interval == 0 and i:
                train_preds = batch_predict(model, train['s'], train['q'], len(train['s']), batch_size=BATCH_SIZE)
                val_preds = batch_predict(model, val['s'], val['q'], len(val['s']), batch_size=BATCH_SIZE)
                train_acc = metrics.accuracy_score(np.array(train_preds), train['a'])
                val_acc = metrics.accuracy_score(val_preds, val['a'])
                print('Epoch[{}] : <ACCURACY>\n\ttraining : {} \n\tvalidation : {}'.
                     format(i, train_acc, val_acc))
                log_handle.write('{} {} {} {}\n'.format(i, train_acc, val_acc, 
                    cost_total/(eval_interval*len(batches))))
                cost_total = 0. # empty cost
                #
                # save the best model, to disk
                #if val_acc > best_validation_accuracy:
                #best_validation_accuracy = val_acc
                model.saver.save(model._sess, CKPT_DIR + '{}/memn2n_model.ckpt'.format(args['task_id']), 
                        global_step=i)

            model.saver.save(model._sess, CKPT_DIR + '{}/memn2n_model.ckpt_final'.format(args['task_id']), 
                             global_step=i)        
        # close file
        log_handle.close()

    else: # inference
        ###

        print("launching interative dialog session")
        #print("model's session variables" + model._sess)
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR + '{}'.format(args['task_id']))
        if ckpt and ckpt.model_checkpoint_path:
            print('\n>> restoring checkpoint from', ckpt.model_checkpoint_path)
            model.saver.restore(model._sess, ckpt.model_checkpoint_path)

        ds = interactive_session(task_id=args['task_id'])
        ds.interactive_dialog(task_id=args['task_id'])
        
        '''