예제 #1
0
    for i in range(n_iter):
        for xmb, mmb, ymb in iter_data(*shuffle(trX,
                                                trM,
                                                trYt,
                                                random_state=np.random),
                                       n_batch=n_batch_train,
                                       truncate=True,
                                       verbose=True):
            cost, _ = sess.run([clf_loss, train], {
                X_train: xmb,
                M_train: mmb,
                Y_train: ymb
            })
            n_updates += 1
            if n_updates in [1000, 2000, 4000, 8000, 16000, 32000
                             ] and n_epochs == 0:
                log()
        n_epochs += 1
        log()
    if submit:
        sess.run([
            p.assign(ip) for p, ip in zip(
                params,
                joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))
        ])
        predict()
        if analysis:
            rocstories_analysis(data_dir,
                                os.path.join(submission_dir, 'ROCStories.tsv'),
                                os.path.join(log_dir, 'rocstories.jsonl'))
예제 #2
0
파일: train.py 프로젝트: souravsingh/models
    n_updates = 0
    n_epochs = 0
    if dataset != 'stsb':
        trYt = trY
    if submit:
        path = os.path.join(save_dir, desc, 'best_params')
        chainer.serializers.save_npz(make_path(path), model)
    best_score = 0
    for i in range(n_iter):
        print("running epoch", i)
        run_epoch()
        n_epochs += 1
        log()
    if submit:
        path = os.path.join(save_dir, desc, 'best_params')
        chainer.serializers.load_npz(make_path(path), model)
        predict()
        if analysis:
            if dataset == 'rocstories':
                rocstories_analysis(
                    data_dir, os.path.join(
                        submission_dir, filenames[dataset]), os.path.join(
                            log_dir, '{}.jsonl'.format(desc)))
            elif dataset == 'sst':
                sst_analysis(
                    data_dir, os.path.join(
                        submission_dir, filenames[dataset]), os.path.join(
                            log_dir, '{}.jsonl'.format(desc)))
            else:
                raise NotImplementedError
예제 #3
0
파일: train.py 프로젝트: GAIMJKP/models-2
        clf_head.to_gpu()

    n_updates = 0
    n_epochs = 0
    if dataset != 'stsb':
        trYt = trY
    if submit:
        path = os.path.join(save_dir, desc, 'best_params')
        chainer.serializers.save_npz(make_path(path), model)
    best_score = 0
    for i in range(n_iter):
        print("running epoch", i)
        run_epoch()
        n_epochs += 1
        log()
    if submit:
        path = os.path.join(save_dir, desc, 'best_params')
        chainer.serializers.load_npz(make_path(path), model)
        predict()
        if analysis:
            if dataset == 'rocstories':
                rocstories_analysis(
                    data_dir, os.path.join(submission_dir, filenames[dataset]),
                    os.path.join(log_dir, '{}.jsonl'.format(desc)))
            elif dataset == 'sst':
                sst_analysis(data_dir,
                             os.path.join(submission_dir, filenames[dataset]),
                             os.path.join(log_dir, '{}.jsonl'.format(desc)))
            else:
                raise NotImplementedError