def multi_gpu(config, train_data):
    logger = logging.getLogger(config.model_name + '_train')
    con = tf.ConfigProto(allow_soft_placement=True)
    con.gpu_options.allow_growth = True
    with tf.Session(config=con) as sess:
        with tf.device('/cpu:0'):
            learning_rate = config.lr
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            logger.info('build model...')
            logger.info('build model on gpu tower...')
            tower_y, tower_losses, tower_grads, tower_output = [], [], [], []
            for gpu_id in range(config.n_gpu):
                with tf.device('/gpu:%d' % gpu_id):
                    logger.info('GPU:%d...' % gpu_id)
                    with tf.name_scope('tower_%d' % gpu_id):
                        with tf.variable_scope('gpu_variables', reuse=tf.AUTO_REUSE):
                            x, y = train_data.get_next()
                            x = tf.sparse_tensor_to_dense(x)
                            x = tf.reshape(x, [config.num_utt_per_class*config.num_classes_per_batch, -1, config.feature_dims])
                            y = tf.one_hot(y, depth=config.n_speaker)
                            y = tf.cast(y, dtype=tf.float32)
                            model = XVector(config)
                            output = model.inference(x)
                            tower_output.append(output)
                            loss = model.loss(y, output)
                            tower_losses.append(loss)
                            grads = opt.compute_gradients(loss)
                            tower_grads.append(grads)
                        logger.info('build model on gpu tower done.')
            logger.info('reduce model on cpu...')
            aver_loss_op = tf.reduce_mean(tower_losses)
            apply_gradient_op = opt.apply_gradients(ops.average_gradients(tower_grads))
            tf.summary.scalar('loss', aver_loss_op)
            #all_output = tf.reshape(tf.stack(tower_output, 0), [-1, 512])
            logger.info('reduce model on cpu done.')
            logger.info('run train op...')
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter('log', sess.graph)
            for epoch in range(config.max_step):
                start_time = time.time()
                avg_loss, log_flag = 0.0, 0
                logger.info('Epoch:%d, lr:%.4f, total_batch=%d' % (epoch, config.lr, config.batch_nums_per_epoch))
                for batch_idx in range(config.batch_nums_per_epoch):
                    _, _loss, summary_str = sess.run([apply_gradient_op, aver_loss_op,  summary_op])
                    avg_loss += _loss
                    log_flag += 1
                    if log_flag % 100 == 0 and log_flag != 0:
                        log_flag = 0
                        duration = time.time() - start_time
                        start_time = time.time()
                        logger.info('At %d batch, present batch loss is %.4f, %.2f batches/sec'%(batch_idx, _loss, 100.0*config.n_gpu/duration))
                    summary_writer.add_summary(summary_str, epoch*config.batch_nums_per_epoch+batch_idx)
                avg_loss /= config.batch_nums_per_epoch
                logger.info('Train average loss:%.4f' % (avg_loss))
                abs_save_path = os.path.abspath(os.path.join(config.save_path, config.model_name + ".ckpt"))
                saver.save(sess=sess, save_path=abs_save_path)
            logger.info('training done.')
Beispiel #2
0
def multi_gpu(config, train_data, test=None, enroll=None):
    tf.reset_default_graph()
    logger = logging.getLogger(config.model_name)
    con = tf.ConfigProto(allow_soft_placement=True)
    con.gpu_options.allow_growth = True
    with tf.Session(config=con) as sess:
        with tf.device('/cpu:0'):
            learning_rate = config.lr
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            logger.info('build model...')
            logger.info('build model on gpu tower...')
            for gpu_id in range(config.n_gpu):
                with tf.device('/gpu:%d' % gpu_id):
                    logger.info('GPU:%d...' % gpu_id)
                    with tf.name_scope('tower_%d' % gpu_id):
                        with tf.variable_scope('gpu_variables', reuse=tf.AUTO_REUSE):
                            x, y = train_data.get_next()
                            model = DeepSpeaker(config, out_channel=[64, 128, 256, 512])
                            output = model.inference(x)
                            loss = model.loss(output, y)
                            grads = opt.compute_gradients(loss)
                            ops.tower_to_collection(tower_y=y, tower_losses=loss, tower_grads=grads, tower_output=output)
                        logger.info('build model on gpu tower done.')
            logger.info('reduce model on cpu...')
            aver_loss_op = tf.reduce_mean(tf.get_collection('tower_losses'))
            apply_gradient_op = opt.apply_gradients(ops.average_gradients(tf.get_collection('tower_grads')))
            all_y = tf.reshape(tf.stack(tf.get_collection('tower_y'), 0), [-1, 1])
            all_output = tf.reshape(tf.stack(tf.get_collection('tower_output'), 0), [-1, 400])
            vectors = dict()
            logger.info('reduce model on cpu done.')
            logger.info('run train op...')
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            
            for epoch in range(config.max_step):
                start_time = time.time()
                avg_loss, log_flag = 0.0, 0
                logger.info('Epoch:%d, lr:%.4f, total_batch=%d' % (epoch, config.lr, config.batch_nums_per_epoch))
                
                for batch_idx in range(config.batch_nums_per_epoch):
                    _, _loss, batch_out = sess.run([apply_gradient_op, aver_loss_op, all_output])
                    avg_loss += _loss
                    log_flag += 1
                    if log_flag % 100 == 0 and log_flag != 0:
                        log_flag = 0
                        duration = time.time() - start_time
                        start_time = time.time()
                        logger.info('At %d batch, present batch loss is %.4f, %.2f batches/sec'%(batch_idx, _loss, 100.0/duration))

                avg_loss /= config.batch_nums_per_epoch
                logger.info('Train average loss:%.4f' % (avg_loss))
                
                
                
                abs_save_path = os.path.abspath(os.path.join(config.save_path, config.model_name + ".ckpt"))
                saver.save(sess=sess, save_path=abs_save_path)

            logger.info('training done.')
    def train(self, train_data, valid=None):
        """Interface to train model.

        :param train_data: `tf.data.dataset`
        :param valid: dict, defaults to None. contain enroll and test data,
                            like {'t_x:0': [...], 'e_x:0': [...], 'e_y:0': ...}
        """
        logger = logging.getLogger('train')
        tf_config = tf.ConfigProto(allow_soft_placement=True)
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        opt = tf.train.AdamOptimizer(learning_rate=self.config.lr)
        logger.info('Build model on %s tower...' %
                    ('cpu' if self.config.n_gpu == 0 else 'gpu'))
        tower_y, tower_losses, tower_grads, tower_output = [], [], [], []
        for gpu_id in range(self.config.n_gpu):
            with tf.device('/gpu:%d' % gpu_id):
                x, y = train_data.get_next()
                x = tf.reshape(x, [-1, 251, 64])
                y = tf.reshape(y, [-1, 1])
                output, vector = self.inference(x)
                tower_output.append(output)
                losses = self.loss(output, y)
                tower_losses.append(losses)
                grads = ops.clip_grad(opt.compute_gradients(losses), 3.0)
                grads = [(0.01 * i, j) if (j.name == 'loss/loss_b:0'
                                           or j.name == 'loss/loss_w:0') else
                         (i, j) for i, j in grads]
                tower_grads.append(grads)
        # handle batch loss
        aver_loss_op = tf.reduce_mean(tower_losses)
        apply_gradient_op = opt.apply_gradients(
            ops.average_gradients(tower_grads))
        tf.summary.scalar('loss', aver_loss_op)

        # init
        emb = self.init_validation()
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(
            os.path.join(self.config.save_path, 'graph'), sess.graph)
        log_flag = 0

        for epoch in range(self.config.max_step):
            logger.info(
                'Epoch:%d, lr:%.4f, total_batch=%d' %
                (epoch, self.config.lr, self.config.batch_nums_per_epoch))
            avg_loss = 0.0
            start_time = time.time()
            for batch_idx in range(self.config.batch_nums_per_epoch):
                _, _loss, summary_str = sess.run(
                    [apply_gradient_op, aver_loss_op, summary_op])
                avg_loss += _loss
                log_flag += 1
                if log_flag % 100 == 0 and log_flag != 0:
                    duration = time.time() - start_time
                    start_time = time.time()
                    logger.info(
                        'At %d batch, present batch loss is %.4f, %.2f batches/sec'
                        %
                        (batch_idx, _loss, 100 * self.config.n_gpu / duration))
                if log_flag % 600 == 0 and log_flag != 0:
                    test_x, test_y, enroll_x, enroll_y = valid['t_x'], valid[
                        't_y'], valid['e_x'], valid['e_y']
                    acc, _ = self._validation(emb,
                                              test_x,
                                              test_y,
                                              enroll_x,
                                              enroll_y,
                                              sess,
                                              step=epoch)
                    logger.info('At %d epoch after %d batch, acc is %.6f' %
                                (epoch, batch_idx, acc))
                summary_writer.add_summary(
                    summary_str,
                    epoch * self.config.batch_nums_per_epoch + batch_idx)
            avg_loss /= self.config.batch_nums_per_epoch
            logger.info('Train average loss:%.4f' % avg_loss)
            abs_save_path = os.path.abspath(
                os.path.join(self.config.save_path, 'model',
                             self.config.model_name + ".ckpt"))
            saver.save(sess=sess, save_path=abs_save_path)
        logger.info('training done.')