예제 #1
0
def predict(logger, args):
    """do inference on the test dataset """
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
        logger.info('vocab size is {} and embed dim is {}'.format(
            vocab.size(), vocab.embed_dim))
    brc_data = BRCDataset(args.max_p_num,
                          args.max_p_len,
                          args.max_q_len,
                          dev_files=args.testset)
    logger.info('Converting text into ids...')
    # brc_data.convert_to_ids(vocab)
    brc_data.set_vocab(vocab)
    logger.info('Initialize the model...')

    # covert dataset
    args.devset = args.testset

    # build model
    main_program = fluid.Program()
    startup_prog = fluid.Program()
    with fluid.program_guard(main_program, startup_prog):
        with fluid.unique_name.guard():
            avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
                args.hidden_size, vocab, args)
            # initialize parameters
            if not args.use_gpu:
                place = fluid.CPUPlace()
                dev_count = int(
                    os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
            else:
                place = fluid.CUDAPlace(0)
                dev_count = fluid.core.get_cuda_device_count()

            exe = Executor(place)
            if args.load_dir:
                logger.info('load from {}'.format(args.load_dir))
                fluid.io.load_persistables(exe,
                                           args.load_dir,
                                           main_program=main_program)
            else:
                logger.error('No model file to load ...')
                return

            inference_program = main_program.clone(for_test=True)
            eval_loss, bleu_rouge = validation(inference_program, avg_cost,
                                               s_probs, e_probs, match,
                                               feed_order, place, dev_count,
                                               vocab, brc_data, logger, args)

            logger.info('test eval loss {}'.format(eval_loss))
            logger.info('test eval result: {}'.format(bleu_rouge))
예제 #2
0
def train(logger, args):
    """train a model"""
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        if six.PY2:
            vocab = pickle.load(fin)
        else:
            vocab = pickle.load(fin, encoding='bytes')
        logger.info('vocab size is {} and embed dim is {}'.format(
            vocab.size(), vocab.embed_dim))
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.trainset, args.devset)
    logger.info('Converting text into ids...')
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')

    if not args.use_gpu:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    # build model
    main_program = fluid.Program()
    startup_prog = fluid.Program()
    if args.enable_ce:
        main_program.random_seed = args.random_seed
        startup_prog.random_seed = args.random_seed
    with fluid.program_guard(main_program, startup_prog):
        with fluid.unique_name.guard():
            avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
                args.hidden_size, vocab, args)
            # clone from default main program and use it as the validation program
            inference_program = main_program.clone(for_test=True)

            # build optimizer
            if args.optim == 'sgd':
                optimizer = fluid.optimizer.SGD(
                    learning_rate=args.learning_rate)
            elif args.optim == 'adam':
                optimizer = fluid.optimizer.Adam(
                    learning_rate=args.learning_rate)
            elif args.optim == 'rprop':
                optimizer = fluid.optimizer.RMSPropOptimizer(
                    learning_rate=args.learning_rate)
            else:
                logger.error('Unsupported optimizer: {}'.format(args.optim))
                exit(-1)
            if args.weight_decay > 0.0:
                obj_func = avg_cost + args.weight_decay * l2_loss(main_program)
                optimizer.minimize(obj_func)
            else:
                obj_func = avg_cost
                optimizer.minimize(obj_func)

            # initialize parameters
            place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
            exe = Executor(place)
            if args.load_dir:
                logger.info('load from {}'.format(args.load_dir))
                fluid.io.load_persistables(exe,
                                           args.load_dir,
                                           main_program=main_program)
            else:
                exe.run(startup_prog)
                embedding_para = fluid.global_scope().find_var(
                    'embedding_para').get_tensor()
                embedding_para.set(vocab.embeddings.astype(np.float32), place)

            # prepare data
            feed_list = [
                main_program.global_block().var(var_name)
                for var_name in feed_order
            ]
            feeder = fluid.DataFeeder(feed_list, place)

            logger.info('Training the model...')
            parallel_executor = fluid.ParallelExecutor(
                main_program=main_program,
                use_cuda=bool(args.use_gpu),
                loss_name=avg_cost.name)
            print_para(main_program, parallel_executor, logger, args)

            for pass_id in range(1, args.pass_num + 1):
                pass_start_time = time.time()
                pad_id = vocab.get_id(vocab.pad_token)
                if args.enable_ce:
                    train_reader = lambda: brc_data.gen_mini_batches(
                        'train', args.batch_size, pad_id, shuffle=False)
                else:
                    train_reader = lambda: brc_data.gen_mini_batches(
                        'train', args.batch_size, pad_id, shuffle=True)
                train_reader = read_multiple(train_reader, dev_count)
                log_every_n_batch, n_batch_loss = args.log_interval, 0
                total_num, total_loss = 0, 0
                for batch_id, batch_list in enumerate(train_reader(), 1):
                    feed_data = batch_reader(batch_list, args)
                    fetch_outs = parallel_executor.run(
                        feed=list(feeder.feed_parallel(feed_data, dev_count)),
                        fetch_list=[obj_func.name],
                        return_numpy=False)
                    cost_train = np.array(fetch_outs[0]).mean()
                    total_num += args.batch_size * dev_count
                    n_batch_loss += cost_train
                    total_loss += cost_train * args.batch_size * dev_count

                    if args.enable_ce and batch_id >= 100:
                        break
                    if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
                        print_para(main_program, parallel_executor, logger,
                                   args)
                        logger.info(
                            'Average loss from batch {} to {} is {}'.format(
                                batch_id - log_every_n_batch + 1, batch_id,
                                "%.10f" % (n_batch_loss / log_every_n_batch)))
                        n_batch_loss = 0
                    if args.dev_interval > 0 and batch_id % args.dev_interval == 0:
                        if brc_data.dev_set is not None:
                            eval_loss, bleu_rouge = validation(
                                inference_program, avg_cost, s_probs, e_probs,
                                match, feed_order, place, dev_count, vocab,
                                brc_data, logger, args)
                            logger.info('Dev eval loss {}'.format(eval_loss))
                            logger.info(
                                'Dev eval result: {}'.format(bleu_rouge))
                pass_end_time = time.time()
                time_consumed = pass_end_time - pass_start_time
                logger.info('epoch: {0}, epoch_time_cost: {1:.2f}'.format(
                    pass_id, time_consumed))
                logger.info(
                    'Evaluating the model after epoch {}'.format(pass_id))
                if brc_data.dev_set is not None:
                    eval_loss, bleu_rouge = validation(inference_program,
                                                       avg_cost, s_probs,
                                                       e_probs, match,
                                                       feed_order, place,
                                                       dev_count, vocab,
                                                       brc_data, logger, args)
                    logger.info('Dev eval loss {}'.format(eval_loss))
                    logger.info('Dev eval result: {}'.format(bleu_rouge))
                else:
                    logger.warning(
                        'No dev set is loaded for evaluation in the dataset!')

                logger.info('Average train loss for epoch {} is {}'.format(
                    pass_id, "%.10f" % (1.0 * total_loss / total_num)))

                if pass_id % args.save_interval == 0:
                    model_path = os.path.join(args.save_dir, str(pass_id))
                    if not os.path.isdir(model_path):
                        os.makedirs(model_path)

                    fluid.io.save_persistables(executor=exe,
                                               dirname=model_path,
                                               main_program=main_program)
                if args.enable_ce:  # For CE
                    print("kpis\ttrain_cost_card%d\t%f" %
                          (dev_count, total_loss / total_num))
                    if brc_data.dev_set is not None:
                        print("kpis\ttest_cost_card%d\t%f" %
                              (dev_count, eval_loss))
                    print("kpis\ttrain_duration_card%d\t%f" %
                          (dev_count, time_consumed))