Beispiel #1
0
    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)
        metric1 = dm.get_teacher_forcing_metric()

        while True:
            incoming = self.get_next_batch(dm,
                                           key,
                                           restart=False,
                                           needhash=True)
            if incoming is None:
                break
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
                gen_prob = nn.functional.log_softmax(incoming.gen.w, -1)
            data = incoming.data
            data.resp = incoming.data.resp.detach().cpu().numpy().transpose(
                1, 0)
            data.resp_length = incoming.data.resp_length
            data.gen_prob = gen_prob.detach().cpu().numpy().transpose(1, 0, 2)
            metric1.forward(data)
        res = metric1.close()

        dm.restart(key, args.batch_size, shuffle=False)
        metric2 = dm.get_inference_metric()
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.resp = incoming.data.resp.detach().cpu().numpy().transpose(
                1, 0)
            data.post = incoming.data.post.detach().cpu().numpy().transpose(
                1, 0)
            data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0)
            metric2.forward(data)
        res.update(metric2.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, bytes):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)
Beispiel #2
0
    def forward(self, incoming):
        incoming.result = Storage()
        incoming.result.word_loss = None
        incoming.state = state = Storage()
        incoming.statistic = statistic = Storage()
        statistic.batch_num = incoming.data.post.shape[0]
        statistic.sen_num = 0
        statistic.sen_loss = []

        state.last = incoming.data.post.shape[1]

        for i in range(state.last):
            state.num = i
            self.embLayer.forward(incoming)
            self.postEncoder.forward(incoming)
            self.wikiEncoder.forward(incoming)
            if not self.args.disentangle:
                self.connectLayer.forward(incoming)
            else:
                self.connectLayer.forward_disentangle(incoming)
            self.genNetwork.forward(incoming)

        incoming.result.loss = incoming.result.word_loss + incoming.result.atten_loss

        if torch.isnan(incoming.result.loss).detach().cpu().numpy() > 0:
            logging.info("Nan detected")
            logging.info(incoming.result)
            raise FloatingPointError("Nan detected")
Beispiel #3
0
    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update(
            {key: get_mean(result_arr, key)
             for key in result_arr[0]})
        detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss)
        return detail_arr
Beispiel #4
0
    def detail_forward(self, incoming):
        inp = Storage()
        batch_size = inp.batch_size = incoming.data.batch_size
        inp.init_h = incoming.conn.init_h
        inp.post = incoming.hidden.h
        inp.post_length = incoming.data.post_length
        inp.embLayer = incoming.resp.embLayer
        inp.dm = self.param.volatile.dm
        inp.max_sent_length = self.args.max_sent_length

        incoming.gen = gen = Storage()
        self.freerun(inp, gen)

        dm = self.param.volatile.dm
        w_o = gen.w_o.detach().cpu().numpy()
        incoming.result.resp_str = resp_str = \
          [" ".join(dm.convert_ids_to_tokens(w_o[:, i].tolist())) for i in range(batch_size)]
        incoming.result.golden_str = golden_str = \
          [" ".join(dm.convert_ids_to_tokens(incoming.data.resp[:, i].detach().cpu().numpy().tolist()))\
          for i in range(batch_size)]
        incoming.result.post_str = post_str = \
          [" ".join(dm.convert_ids_to_tokens(incoming.data.post[:, i].detach().cpu().numpy().tolist()))\
          for i in range(batch_size)]
        incoming.result.show_str = "\n".join(["post: " + a + "\n" + "resp: " + b + "\n" + \
          "golden: " + c + "\n" \
          for a, b, c in zip(post_str, resp_str, golden_str)])
Beispiel #5
0
 def _preprocess_batch(self, data):
     incoming = Storage()
     incoming.data = data = Storage(data)
     data.batch_size = data.sent.shape[0]
     data.sent = cuda(torch.LongTensor(data.sent.transpose(
         1, 0)))  # length * batch_size
     return incoming
Beispiel #6
0
def _preprocess_batch(data):
    incoming = Storage()
    incoming.data = data = Storage(data)
    data.batch_size = data.post.shape[0]
    data.post = cuda(torch.LongTensor(data.post.transpose(1, 0)))
    data.resp = cuda(torch.LongTensor(data.resp.transpose(1, 0)))
    return incoming
Beispiel #7
0
    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update({'loss':get_mean(result_arr, 'loss'), \
         'accuracy':get_accuracy(result_arr, label_key='label', prediction_key='prediction')})
        return detail_arr
Beispiel #8
0
 def detail_forward(self, incoming):
     i = incoming.state.num
     incoming.post = Storage()
     incoming.post.embedding = self.embLayer(LongTensor(incoming.data.post))
     incoming.resp = Storage()
     incoming.wiki = Storage()
     incoming.wiki.embedding = self.embLayer(incoming.data.wiki[:, i])
     incoming.resp.embLayer = self.embLayer
Beispiel #9
0
 def _preprocess_batch(self, data):
     incoming = Storage()
     incoming.data = data = Storage(data)
     data.batch_size = data.sent.shape[0]
     data.sent = cuda(torch.LongTensor(data.sent))  # length * batch_size
     data.sent_attnmask = zeros(*data.sent.shape)
     for i, length in enumerate(data.sent_length):
         data.sent_attnmask[i, :length] = 1
     return incoming
Beispiel #10
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = LanguageGeneration
    data_arg = Storage()
    data_arg.file_id = args.dataid
    data_arg.tokenizer = args.tokenizer
    data_arg.max_sent_length = args.max_sent_length
    data_arg.convert_to_lower_letter = args.convert_to_lower_letter
    data_arg.min_frequent_vocab_times = args.min_frequent_vocab_times
    data_arg.min_rare_vocab_times = args.min_rare_vocab_times
    wordvec_class = GeneralWordVector

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load_matrix(embedding_size, dm.frequent_vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = TransformerLM(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        json.dump(test_res, open("./result.json", "w"))
    elif args.mode == "load":
        return model
    else:
        raise ValueError("Unknown mode")
Beispiel #11
0
 def _preprocess_batch(self, data):
     incoming = Storage()
     incoming.data = data = Storage(data)
     incoming.data.batch_size = data.post.shape[0]
     #incoming.data.post = cuda(torch.LongTensor(data.post)) # length * batch_size
     incoming.data.resp = cuda(torch.LongTensor(
         data.resp))  # length * batch_size
     incoming.data.atten = cuda(torch.LongTensor(data.atten))
     incoming.data.wiki = cuda(torch.LongTensor(data.wiki))
     return incoming
Beispiel #12
0
	def forward(self, incoming):
		'''
		inp: data
		output: post
		'''
		incoming.post = Storage()
		incoming.post.embedding = self.embLayer(incoming.data.post)
		incoming.resp = Storage()
		incoming.resp.embedding = self.embLayer(incoming.data.resp)
		incoming.resp.embLayer = self.embLayer
Beispiel #13
0
    def detail_forward(self, incoming):
        incoming.acc = Storage()
        incoming.acc.pred = []
        incoming.acc.label = []
        incoming.acc.prob = []

        incoming.result = Storage()
        incoming.state = state = Storage()
        incoming.statistic = statistic = Storage()
        statistic.batch_num = incoming.data.post.shape[0]
        statistic.sen_loss = []
        statistic.sen_num = 0

        state.last = incoming.data.post.shape[1]

        # post, resp: [batch_size, turn_length, sent_length]

        def pad_post(posts):
            '''
            :param posts: list, [batch, turn_length, sent_length]
            '''
            post_length = np.array(list(map(len, posts)), dtype=int)
            post = np.zeros((len(post_length), np.max(post_length)), dtype=int)
            for j in range(len(post_length)):
                post[j, :len(posts[j])] = posts[j]
            return post, post_length

        dm = self.param.volatile.dm
        ori_post = incoming.data.post.tolist(
        )  # [batch_size, turn_length, sent_length]
        ori_post = [[(dm.trim(post) + [dm.eos_id])[:dm._max_sent_length]
                     for post in posts] for posts in ori_post]
        new_post = [each[0] for each in ori_post]
        for i in range(state.last):
            state.num = i
            incoming.data.post, incoming.data.post_length = pad_post(new_post)

            self.embLayer.detail_forward(incoming)
            self.postEncoder.detail_forward(incoming)
            self.wikiEncoder.forward(incoming)
            if not self.args.disentangle:
                self.connectLayer.detail_forward(incoming)
            else:
                self.connectLayer.detail_forward_disentangle(incoming)
            self.genNetwork.detail_forward(incoming)

            if i < state.last - 1:
                gen_resp = incoming.state.w_o_all[-1].transpose(
                    0, 1).cpu().tolist()  # [batch, sent_length]
                new_post = []
                for j, gr in enumerate(gen_resp):
                    new_post.append(
                        (ori_post[j][i] + ([dm.go_id] + dm.trim(gr) +
                                           [dm.eos_id])[:dm._max_sent_length] +
                         ori_post[j][i + 1])[-dm._max_context_length:])
Beispiel #14
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = SingleTurnDialog.load_class(args.dataset)
    data_arg = Storage()
    data_arg.file_id = args.datapath
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load(embedding_size, dm.vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        for key, val in test_res.items():
            if isinstance(val, bytes):
                test_res[key] = str(val)
        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Beispiel #15
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = LanguageGeneration
    data_arg = Storage()
    data_arg.file_id = args.dataid
    data_arg.max_sent_length = args.max_sent_length
    data_arg.convert_to_lower_letter = args.convert_to_lower_letter
    data_arg.pretrained = args.pretrained
    data_arg.tokenizer = args.pretrained_model

    def load_dataset(data_arg):
        tokenizer = PretrainedTokenizer(
            GPT2Tokenizer.from_pretrained(data_arg.tokenizer))
        new_arg = Storage(data_arg.copy())
        new_arg.tokenizer = tokenizer
        dm = data_class(**new_arg)
        return dm

    if args.cache:
        dm = try_cache(load_dataset, (data_arg, ), args.cache_dir,
                       data_class.__name__)
    else:
        dm = load_dataset(data_arg)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = GPT2LM(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Beispiel #16
0
    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric1 = dm.get_teacher_forcing_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval teacher-forcing")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            incoming.args.sampling_proba = 1.
            with torch.no_grad():
                self.net.forward(incoming)
                gen_log_prob = nn.functional.log_softmax(
                    incoming.gen.w_pro, -1)
            data = incoming.data
            data.resp_allvocabs = LongTensor(incoming.data.resp_allvocabs)
            data.resp_length = incoming.data.resp_length
            data.gen_log_prob = gen_log_prob.transpose(1, 0)
            metric1.forward(data)
        res = metric1.close()

        metric2 = dm.get_inference_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval free-run")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0)
            metric2.forward(data)
        res.update(metric2.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with codecs.open(filename, 'w', encoding='utf8') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, str):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)
        return {
            key: val
            for key, val in res.items() if isinstance(val, (str, int, float))
        }
Beispiel #17
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = SingleTurnDialog.load_class(args.dataset)
    data_arg = Storage()
    data_arg.file_id = args.datapath + "#OpenSubtitles"
    data_arg.tokenizer = PretrainedTokenizer(
        BertTokenizer.from_pretrained(args.bert_vocab))
    data_arg.pretrained = "bert"
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load_matrix(embedding_size, dm.frequent_vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        model.test_process()
    else:
        raise ValueError("Unknown mode")
	def _preprocess_batch(self, data):
		incoming = Storage()
		incoming.data = data = Storage(data)
		# print(data)
		data.batch_size = data.post.shape[0]
		data.post = cuda(torch.LongTensor(data.post.transpose(1, 0))) # length * batch_size
		data.resp = cuda(torch.LongTensor(data.resp.transpose(1, 0))) # length * batch_size
		# data.post_bert = cuda(torch.LongTensor(data.post_bert.transpose(1, 0))) # length * batch_size
		# data.resp_bert = cuda(torch.LongTensor(data.resp_bert.transpose(1, 0))) # length * batch_size


		return incoming
Beispiel #19
0
	def forward(self, incoming):
		inp = Storage()
		inp.resp_length = incoming.data.resp_length
		inp.embedding = incoming.resp.embedding
		inp.init_h = incoming.conn.init_h

		incoming.gen = gen = Storage()
		self.teacherForcing(inp, gen)

		w_o_f = flattenSequence(gen.w, incoming.data.resp_length-1)
		data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length-1)
		incoming.result.word_loss = self.lossCE(w_o_f, data_f)
		incoming.result.perplexity = torch.exp(incoming.result.word_loss)
    def forward(self, incoming):
        inp = Storage()
        inp.length = incoming.data.sent_length
        inp.embedding = incoming.sent.embedding

        incoming.gen = gen = Storage()
        self.teacherForcing(inp, gen)

        w_o_f = flattenSequence(gen.w, incoming.data.sent_length - 1)
        data_f = flattenSequence(
            incoming.data.sent.transpose(0, 1)[1:],
            incoming.data.sent_length - 1)
        incoming.result.word_loss = self.lossCE(w_o_f, data_f)
        incoming.result.perplexity = torch.exp(incoming.result.word_loss)
Beispiel #21
0
 def structuring(self, name, result, nodes=[]):
     if name in self.raw_data:
         nodes.append(name)
         result.name = name
         if len(self.raw_data) > len(nodes):
             result.sub = Storage()
             self.structuring(self.raw_data[name], result.sub, nodes)
         else:
             result.sub = Storage(name=self.raw_data[name])
     else:
         result.name = name
         result.args = Storage()
         for dname in self.raw_data:
             if dname not in nodes:
                 result.args[dname] = self.raw_data[dname]
Beispiel #22
0
	def forward(self, incoming):
		incoming.hidden = hidden = Storage()
		with torch.no_grad():
			h, _ = self.bert_exclude(incoming.data.post_bert)
		hidden.h = h[-1] # [length, batch, hidden]
		hidden.h_n = self.drop(hidden.h[0])
		hidden.h = self.drop(hidden.h)
Beispiel #23
0
 def forward(self, incoming):
     '''
     inp: data
     output: post
     '''
     i = incoming.state.num
     incoming.post = Storage()
     incoming.post.embedding = self.drop(
         self.embLayer(LongTensor(incoming.data.post[:, i])))
     incoming.resp = Storage()
     incoming.resp.embedding = self.drop(
         self.embLayer(incoming.data.resp[:, i]))
     incoming.wiki = Storage()
     incoming.wiki.embedding = self.drop(
         self.embLayer(incoming.data.wiki[:, i]))
     incoming.resp.embLayer = self.embLayer
Beispiel #24
0
	def detail_forward(self, incoming):
		incoming.result = Storage()

		self.embLayer.forward(incoming)
		self.postEncoder.forward(incoming)
		self.connectLayer.forward(incoming)
		self.genNetwork.detail_forward(incoming)
Beispiel #25
0
def test_process(opt):
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model
    opt.batch_size = opt.b

    device = torch.device('cuda' if opt.cuda else 'cpu')

    data_class = SingleTurnDialog.load_class('OpenSubtitles')
    data_arg = Storage()
    data_arg.file_id = opt.datapath
    data_arg.min_vocab_times = 20

    def load_dataset(data_arg, wvpath, embedding_size):
        dm = data_class(**data_arg)
        return dm

    opt.n_position = 100
    dm = load_dataset(data_arg, None, opt.n_position)

    opt.n_src_vocab = dm.valid_vocab_len
    opt.n_trg_vocab = dm.valid_vocab_len
    opt.n_vocab_size = dm.valid_vocab_len
    opt.src_pad_idx = 0
    opt.trg_pad_idx = 0
    opt.pad_idx = 0

    model = transformer_model(opt, device).to(device)

    if (opt.restore != None):
        checkpoint = torch.load(opt.restore)
        model.load_state_dict(checkpoint['net'])

    dl = cotk.dataloader.OpenSubtitles(
        opt.datapath, min_vocab_times=data_arg.min_vocab_times)
    test(model, dm, device, opt, dl)
Beispiel #26
0
    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric = dm.get_accuracy_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval accuracy")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
            data = incoming.data
            data.prediction = imcoming.result.prediction
            data.label = imcoming.data.label
            metric.forward(data)
        res = metric.close()

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, bytes):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            f.flush()
        logging.info("result output to %s.", filename)
Beispiel #27
0
    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            accuracy = np.mean(
                (incoming.result.label == incoming.result.prediction
                 ).float().detach().cpu().numpy())
            detail_arr = storage_to_list(incoming.result)
            detail_arr.update({'accuracy_on_batch': accuracy})
            self.trainSummary(self.now_batch, detail_arr)
            logging.info("batch %d : classification loss=%f, batch accuracy=%f", \
             self.now_batch, loss.detach().cpu().numpy(), accuracy)

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()
Beispiel #28
0
    def train(self, batch_num, total_step_counter):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()
            incoming.args.sampling_proba = 1. - \
               inverse_sigmoid_decay(args.decay_factor, total_step_counter)

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            self.trainSummary(self.now_batch, storage_to_list(incoming.result))
            logging.info("batch %d : gen loss=%f", self.now_batch,
                         loss.detach().cpu().numpy())

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()

            total_step_counter += 1

        return total_step_counter
Beispiel #29
0
 def load_dataset(data_arg):
     tokenizer = PretrainedTokenizer(
         GPT2Tokenizer.from_pretrained(data_arg.tokenizer))
     new_arg = Storage(data_arg.copy())
     new_arg.tokenizer = tokenizer
     dm = data_class(**new_arg)
     return dm
    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()
            incoming.now_epoch = self.now_epoch

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            self.trainSummary(self.now_batch, storage_to_list(incoming.result))
            logging.info("batch %d : gen loss=%f", self.now_batch,
                         loss.detach().cpu().numpy())

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()