コード例 #1
0
        batch_input = batch[:-1]

        # forward
        cls_predict = model.forward(*batch_input)
        batch_acc, batch_eq_num = evaluate_acc(cls_predict, cls_truth)

        batch_num = cls_truth.shape[0]
        eq_num += batch_eq_num
        all_num += batch_num

    acc = eq_num / all_num
    return acc


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--in', dest='in_infix', type=str, default='default', help='input path infix')
    parser.add_argument('--out', type=str, default='default', help='output path infix')
    parser.add_argument('--slot', type=str, default='directed_by', help='output path infix')
    parser.add_argument('--train', action='store_true', default=False, help='enable train step')
    parser.add_argument('--test', action='store_true', default=False, help='enable test step')
    parser.add_argument('--gpuid', type=int, default=None, help='gpuid')
    args = parser.parse_args()

    in_infix = args.in_infix + '/' + args.slot
    out_infix = args.out + '/' + args.slot

    init_logging(out_infix=out_infix)
    main('config/game_config.yaml', in_infix=in_infix, out_infix=out_infix, slot=args.slot,
         is_train=args.train, is_test=args.test, gpuid=args.gpuid)
コード例 #2
0
ファイル: classification.py プロジェクト: laddie132/LW-PT
                        type=str,
                        default='config.yaml',
                        help='config path')
    parser.add_argument('-in',
                        dest='in_infix',
                        type=str,
                        default='default',
                        help='input data_path infix')
    parser.add_argument('-out',
                        type=str,
                        default='default',
                        help='output data_path infix')
    parser.add_argument('-train',
                        action='store_true',
                        default=False,
                        help='enable train step')
    parser.add_argument('-test',
                        action='store_true',
                        default=False,
                        help='enable test step')
    parser.add_argument('-gpuid', type=int, default=None, help='gpuid')
    args = parser.parse_args()

    init_logging(out_infix=args.out)
    main(args.config,
         args.in_infix,
         args.out,
         is_train=args.train,
         is_test=args.test,
         gpuid=args.gpuid)
コード例 #3
0
ファイル: model_tranfer.py プロジェクト: laddie132/LW-PT
__author__ = "Han"
__email__ = "*****@*****.**"

import os
import sys
sys.path.append(os.getcwd())

import argparse
import torch
import logging
from collections import OrderedDict
from models import *
from utils.config import init_logging, read_config

init_logging()
logger = logging.getLogger(__name__)


def transform(pre_model_path, tar_model_path, cur_model):
    pre_weight = torch.load(pre_model_path,
                            map_location=lambda storage, loc: storage)
    pre_keys = pre_weight.keys()
    pre_value = pre_weight.values()

    cur_weight = cur_model.state_dict()
    del cur_weight['model.embedding_layer.weight']
    cur_keys = cur_weight.keys()

    assert len(pre_keys) == len(cur_keys)
    logging.info('pre-keys: ' + str(pre_keys))