def train_one_epoch(self):
        sum_loss = 0.0
        mrr = 0.0

        # train process
        batches = batch_iter(self.L, self.batch_size, 0, self.lookup, 'f', 'g')
        batch_id = 0
        for batch in batches:
            pos, neg = batch
            if not len(pos['f']) == len(pos['g']) and not len(neg['f']) == len(
                    neg['g']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos['f'])
            feed_dict = {
                self.pos_inputs['f']: self.X[pos['f'], :],
                self.pos_inputs['g']: self.Y[pos['g'], :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            batch_id += 1

        # valid process
        valid_size = 0
        if self.valid:
            valid = valid_iter(self.L, self.valid_sample_size, self.lookup,
                               'f', 'g')
            if not len(valid['f']) == len(valid['g']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                return
            valid_size = len(valid['f'])
            feed_dict = {
                self.valid_inputs['f']: self.X[valid['f'], :],
                self.valid_inputs['g']: self.Y[valid['g'], :]
            }
            valid_dist = self.sess.run(self.dot_dist, feed_dict)

            mrr = .0
            for i in range(valid_size):
                fst_dist = valid_dist[i][0]
                pos = 1
                for k in range(1, len(valid_dist[i])):
                    if fst_dist >= valid_dist[i][k]:
                        pos += 1
                mrr += 1. / pos
            self.logger.info(
                'Epoch={}, sum of loss={!s}, mrr in validation={}'.format(
                    self.cur_epoch, sum_loss / (batch_id + 1e-8),
                    mrr / (valid_size + 1e-8)))
        else:
            self.logger.info('Epoch={}, sum of loss={!s}'.format(
                self.cur_epoch, sum_loss / batch_id))
        self.cur_epoch += 1

        # print(batch_id,valid_size)
        return sum_loss / (batch_id + 1e-8), mrr / (valid_size + 1e-8)
예제 #2
0
파일: pale_mlp.py 프로젝트: Allen517/dcnh
    def train_one_epoch(self):
        sum_loss = 0.0

        # train process
        # with tf.device(self.device):
        batches = batch_iter(self.L, self.batch_size, 0\
                , self.lookup_f, self.lookup_g, 'f', 'g')
        batch_id = 0
        for batch in batches:
            pos_f, pos_g, neg_f, neg_g = batch
            if not len(pos_f) == len(pos_g):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos_f)
            feed_dict = {
                self.pos_f_inputs: self.X[pos_f, :],
                self.pos_g_inputs: self.Y[pos_g, :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            # self.logger.info('Finish processing batch {} and cur_loss={}'
            #                        .format(batch_id, cur_loss))
            batch_id += 1
        # valid process
        valid_f, valid_g = valid_iter(self.L, self.valid_sample_size,
                                      self.lookup_f, self.lookup_g, 'f', 'g')
        # print valid_f,valid_g
        if not len(valid_f) == len(valid_g):
            self.logger.info(
                'The input label file goes wrong as the file format.')
            return
        valid_size = len(valid_f)
        feed_dict = {
            self.valid_f_inputs: self.X[valid_f, :],
            self.valid_g_inputs: self.Y[valid_g, :]
        }
        valid_dist = self.sess.run(self.dot_dist, feed_dict)
        # valid_dist = self.sess.run(self.hamming_dist,feed_dict)
        mrr = .0
        for i in range(valid_size):
            fst_dist = valid_dist[i][0]
            pos = 1
            for k in range(1, len(valid_dist[i])):
                if fst_dist >= valid_dist[i][k]:
                    pos += 1
            # print pos
            # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
            # print valid_dist[i]
            mrr += 1. / pos
        self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
            self.cur_epoch, sum_loss / batch_id, mrr / valid_size))
        # print 'mrr:',mrr/valid_size
        # self.logger.info('Epoch={}, sum of loss={!s}, valid_loss={}'
        #                     .format(self.cur_epoch, sum_loss/batch_id, valid_loss))
        self.cur_epoch += 1
예제 #3
0
    def train_one_epoch(self):
        sum_loss = 0.0
        mrr = 0.0

        # train process
        # print 'start training...'
        batches = batch_iter(self.L, self.batch_size, self.neg_ratio\
                                        , self.lookup, 'src', 'end')

        batch_id = 0
        for batch in batches:
            # training the process from source network to end network
            pos, neg = batch
            if not len(pos['src']) == len(pos['end']) and not len(
                    neg['src']) == len(neg['end']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos['src'])
            feed_dict = {
                self.inputs_pos['src']: self.F[pos['src'], :],
                self.inputs_pos['end']: self.G[pos['end'], :],
                self.inputs_neg['src']: self.F[neg['src'], :],
                self.inputs_neg['end']: self.G[neg['end'], :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            batch_id += 1

        if self.valid:
            # valid process
            valid = valid_iter(self.L, self.valid_sample_size, self.lookup,
                               'src', 'end')
            # print valid_f,valid_g
            if not len(valid['src']) == len(valid['end']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                return
            valid_size = len(valid['src'])
            feed_dict = {
                self.inputs_val['src']: self.F[valid['src'], :],
                self.inputs_val['end']: self.G[valid['end'], :],
            }
            # valid_dist = self.sess.run(self.dot_dist,feed_dict)
            valid_dist = self.sess.run(self.hamming_dist, feed_dict)
            mrr = .0
            for i in range(valid_size):
                fst_dist = valid_dist[i][0]
                pos = 1
                for k in range(1, len(valid_dist[i])):
                    if fst_dist >= valid_dist[i][k]:
                        pos += 1
                # print pos
                # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
                # print valid_dist[i]
                mrr += 1. / pos
            self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
                self.cur_epoch, sum_loss / batch_id / 2, mrr / valid_size))
        else:
            self.logger.info('Epoch={}, sum of loss={!s}'.format(
                self.cur_epoch, sum_loss / batch_id / 2))

        self.cur_epoch += 1

        # print(sum_loss/(batch_id+1e-8), mrr/(valid_size+1e-8))
        return sum_loss / (batch_id + 1e-8), mrr / (valid_size + 1e-8)
예제 #4
0
    def train_one_epoch(self):
        sum_loss = 0.0

        # train process
        batches_f2g = list(batch_iter(self.L, self.batch_size, self.neg_ratio\
                , self.lookup_f, self.lookup_g, 'f', 'g'))
        batches_g2f = list(batch_iter(self.L, self.batch_size, self.neg_ratio\
                , self.lookup_g, self.lookup_f, 'g', 'f'))
        n_batches = min(len(batches_f2g), len(batches_g2f))
        batch_id = 0
        for i in range(n_batches):
            # training the process from network f to network g
            pos_src_f2g, pos_obj_f2g, neg_src_f2g, neg_obj_f2g = batches_f2g[i]
            if not len(pos_src_f2g) == len(pos_obj_f2g) and not len(
                    neg_src_f2g) == len(neg_obj_f2g):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size_f2g = len(pos_src_f2g)
            feed_dict = {
                self.pos_src_inputs: self.F[pos_src_f2g, :],
                self.pos_obj_inputs: self.G[pos_obj_f2g, :],
                self.neg_src_inputs: self.F[neg_src_f2g, :],
                self.neg_obj_inputs: self.G[neg_obj_f2g, :],
                self.cur_batch_size: batch_size_f2g
            }
            _, cur_loss_f2g = self.sess.run([self.train_op_f2g, self.loss_f2g],
                                            feed_dict)

            sum_loss += cur_loss_f2g

            # training the process from network g to network f
            pos_src_g2f, pos_obj_g2f, neg_src_g2f, neg_obj_g2f = batches_g2f[i]
            if not len(pos_src_g2f) == len(pos_obj_g2f) and not len(
                    neg_src_g2f) == len(neg_obj_g2f):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size_g2f = len(pos_src_g2f)
            feed_dict = {
                self.pos_src_inputs: self.G[pos_src_g2f, :],
                self.pos_obj_inputs: self.F[pos_obj_g2f, :],
                self.neg_src_inputs: self.G[neg_src_g2f, :],
                self.neg_obj_inputs: self.F[neg_obj_g2f, :],
                self.cur_batch_size: batch_size_g2f
            }
            _, cur_loss_g2f = self.sess.run([self.train_op_g2f, self.loss_g2f],
                                            feed_dict)

            sum_loss += cur_loss_g2f

            batch_id += 1
            break

        # valid process
        valid_f, valid_g = valid_iter(self.L, self.valid_sample_size,
                                      self.lookup_f, self.lookup_g, 'f', 'g')
        # print valid_f,valid_g
        if not len(valid_f) == len(valid_g):
            self.logger.info(
                'The input label file goes wrong as the file format.')
            return
        valid_size = len(valid_f)
        feed_dict = {
            self.valid_f_inputs: self.F[valid_f, :],
            self.valid_g_inputs: self.G[valid_g, :],
        }
        # valid_dist = self.sess.run(self.dot_dist,feed_dict)
        valid_dist = self.sess.run(self.hamming_dist, feed_dict)
        mrr = .0
        for i in range(valid_size):
            fst_dist = valid_dist[i][0]
            pos = 1
            for k in range(1, len(valid_dist[i])):
                if fst_dist >= valid_dist[i][k]:
                    pos += 1
            # print pos
            # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
            # print valid_dist[i]
            mrr += 1. / pos
        self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
            self.cur_epoch, sum_loss / batch_id / 2, mrr / valid_size))
        # print 'mrr:',mrr/valid_size
        # self.logger.info('Epoch={}, sum of loss={!s}, valid_loss={}'
        #                     .format(self.cur_epoch, sum_loss/batch_id, valid_loss))
        self.cur_epoch += 1