예제 #1
0
파일: test_pos.py 프로젝트: philip30/chainn
    def test_read_train(self):
        train = ["I_NNP am_VBZ Philip_NNP", "I_NNP am_VBZ student_NN"]
        X, Y, data = load_pos_train_data(train)
       
        data = list(data)
        # Check Vocabulary
        x_exp, y_exp = Vocabulary(), Vocabulary(unk=False)
        x_exp["I"], x_exp["am"]
        y_exp["NNP"], y_exp["VBZ"], y_exp["NNP"], y_exp["NN"]

        self.assertVocEqual(X, x_exp)
        self.assertVocEqual(Y, y_exp)
        
        # Check data
        word_exp = [\
                [x_exp["I"], x_exp["am"], x_exp.unk_id()],\
                [x_exp["I"], x_exp["am"], x_exp.unk_id()]\
        ]

        label_exp = [\
                [y_exp["NNP"], y_exp["VBZ"], y_exp["NNP"]],\
                [y_exp["NNP"], y_exp["VBZ"], y_exp["NN"]]\
        ]

        data_exp = [(x,y) for x, y in zip(word_exp, label_exp)]

        self.assertEqual(data, data_exp)
예제 #2
0
parser.add_argument("--init_model", type=str, help="Init the training weights with saved model.")
parser.add_argument("--model",type=str,choices=["lstm"], default="lstm", help="Type of model being trained.")
parser.add_argument("--unk_cut", type=int, default=1, help="Threshold for words in corpora to be treated as unknown.")
parser.add_argument("--dropout", type=positive_decimal, default=0.2, help="Dropout ratio for LSTM.")
parser.add_argument("--seed", type=int, default=0, help="Seed for RNG. 0 for totally random seed.")
args = parser.parse_args()

if args.use_cpu:
    args.gpu = -1

""" Training """
trainer   = ParallelTrainer(args.seed, args.gpu)

# data
UF.trace("Loading corpus + dictionary")
X, Y, data = load_pos_train_data(sys.stdin, cut_threshold=args.unk_cut)
data       = list(batch_generator(data, (X, Y), args.batch))
UF.trace("INPUT size:", len(X))
UF.trace("LABEL size:", len(Y))
UF.trace("Data loaded.")

""" Setup model """
UF.trace("Setting up classifier")
opt   = optimizers.Adam()
model = ParallelTextClassifier(args, X, Y, opt, args.gpu, activation=F.relu, collect_output=args.verbose)

""" Training Callback """
def onEpochStart(epoch):
    UF.trace("Starting Epoch", epoch+1)

def report(output, src, trg, trained, epoch):