コード例 #1
0
def parallel_train(training_dataset):
    import horovod.tensorflow as hvd

    hvd.init()  # Horovod

    ds = training_dataset.shuffle(buffer_size=4096)
    ds = ds.shard(num_shards=hvd.size(), index=hvd.rank())
    ds = ds.repeat(n_epoch)
    ds = ds.map(_map_fn, num_parallel_calls=4)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=1)

    iterator = ds.make_one_shot_iterator()
    one_element = iterator.get_next()
    net, total_loss, log_tensors = make_model(*one_element,
                                              is_train=True,
                                              reuse=False)
    x_ = net.img  # net input
    last_conf = net.last_conf  # net output
    last_paf = net.last_paf  # net output
    confs_ = net.confs  # GT
    pafs_ = net.pafs  # GT
    mask = net.m1  # mask1, GT
    # net.m2 = m2                 # mask2, GT
    stage_losses = net.stage_losses
    l2_loss = net.l2_loss

    global_step = tf.Variable(1, trainable=False)
    # scaled_lr = lr_init * hvd.size()  # Horovod: scale the learning rate linearly
    scaled_lr = lr_init  # Linear scaling rule is not working in openpose training.
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(scaled_lr, trainable=False)

    opt = tf.train.MomentumOptimizer(lr_v, 0.9)
    opt = hvd.DistributedOptimizer(opt)  # Horovod
    train_op = opt.minimize(total_loss, global_step=global_step)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    config.gpu_options.allow_growth = True  # Horovod
    config.gpu_options.visible_device_list = str(hvd.local_rank())  # Horovod

    # Add variable initializer.
    init = tf.global_variables_initializer()

    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    bcast = hvd.broadcast_global_variables(0)  # Horovod

    # Horovod: adjust number of steps based on number of GPUs.
    global n_step, lr_decay_every_step
    n_step = n_step // hvd.size() + 1  # Horovod
    lr_decay_every_step = lr_decay_every_step // hvd.size() + 1  # Horovod

    # Start training
    with tf.Session(config=config) as sess:
        init.run()
        bcast.run()  # Horovod
        print('Worker{}: Initialized'.format(hvd.rank()))
        print(
            'Worker{}: Start - n_step: {} batch_size: {} lr_init: {} lr_decay_every_step: {}'
            .format(hvd.rank(), n_step, batch_size, lr_init,
                    lr_decay_every_step))

        # restore pre-trained weights
        try:
            # tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=os.path.join(
                                                  model_path, 'pose.npz'))
        except:
            print("no pre-trained model")

        # train until the end
        while True:
            step = sess.run(global_step)
            if step == n_step:
                break

            tic = time.time()
            if step != 0 and (step % lr_decay_every_step == 0):
                new_lr_decay = lr_decay_factor**(step // lr_decay_every_step)
                sess.run(tf.assign(lr_v, scaled_lr * new_lr_decay))

            [_, _loss, _stage_losses, _l2, conf_result, paf_result] = \
                sess.run([train_op, total_loss, stage_losses, l2_loss, last_conf, last_paf])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Worker{}: Total Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'
                .format(hvd.rank(), step, n_step, _loss, lr, _l2,
                        time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Worker{}:', hvd.rank(), 'Network#', ix, 'For Branch',
                      ix % 2 + 1, 'Loss:', ll)

            # save intermediate results and model
            if hvd.rank() == 0:  # Horovod
                if (step != 0) and (step % save_interval == 0):
                    # save some results
                    [
                        img_out, confs_ground, pafs_ground, conf_result,
                        paf_result, mask_out
                    ] = sess.run(
                        [x_, confs_, pafs_, last_conf, last_paf, mask])
                    draw_results(img_out, confs_ground, conf_result,
                                 pafs_ground, paf_result, mask_out,
                                 'train_%d_' % step)

                    # save model
                    # tl.files.save_npz(
                    #    net.all_params, os.path.join(model_path, 'pose' + str(step) + '.npz'), sess=sess)
                    # tl.files.save_npz(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
                    tl.files.save_npz_dict(net.all_params,
                                           os.path.join(
                                               model_path,
                                               'pose' + str(step) + '.npz'),
                                           sess=sess)
                    tl.files.save_npz_dict(net.all_params,
                                           os.path.join(
                                               model_path, 'pose.npz'),
                                           sess=sess)
コード例 #2
0
def single_train(training_dataset):
    ds = training_dataset.shuffle(
        buffer_size=4096)  # shuffle before loading images
    ds = ds.repeat(n_epoch)
    ds = ds.map(_map_fn, num_parallel_calls=multiprocessing.cpu_count() //
                2)  # decouple the heavy map_fn
    ds = ds.batch(batch_size)  # TODO: consider using tf.contrib.map_and_batch
    ds = ds.prefetch(2)
    iterator = ds.make_one_shot_iterator()
    one_element = iterator.get_next()
    net, total_loss, log_tensors = make_model(*one_element,
                                              is_train=True,
                                              reuse=False)
    x_ = net.img  # net input
    last_conf = net.last_conf  # net output
    last_paf = net.last_paf  # net output
    confs_ = net.confs  # GT
    pafs_ = net.pafs  # GT
    mask = net.m1  # mask1, GT
    # net.m2 = m2                 # mask2, GT
    stage_losses = net.stage_losses
    l2_loss = net.l2_loss

    global_step = tf.Variable(1, trainable=False)
    print(
        'Start - n_step: {} batch_size: {} lr_init: {} lr_decay_every_step: {}'
        .format(n_step, batch_size, lr_init, lr_decay_every_step))
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)

    opt = tf.train.MomentumOptimizer(lr_v, 0.9)
    train_op = opt.minimize(total_loss, global_step=global_step)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    # start training
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        # restore pre-trained weights
        try:
            # tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=os.path.join(
                                                  model_path, 'pose.npz'))
        except:
            print("no pre-trained model")

        # train until the end
        sess.run(tf.assign(lr_v, lr_init))
        while True:
            tic = time.time()
            step = sess.run(global_step)
            if step != 0 and (step % lr_decay_every_step == 0):
                new_lr_decay = lr_decay_factor**(step // lr_decay_every_step)
                sess.run(tf.assign(lr_v, lr_init * new_lr_decay))

            [_, _loss, _stage_losses, _l2, conf_result, paf_result] = \
                sess.run([train_op, total_loss, stage_losses, l2_loss, last_conf, last_paf])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Total Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'
                .format(step, n_step, _loss, lr, _l2,
                        time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Network#', ix, 'For Branch', ix % 2 + 1, 'Loss:', ll)

            # save intermediate results and model
            if (step != 0) and (step % save_interval == 0):
                # save some results
                [
                    img_out, confs_ground, pafs_ground, conf_result,
                    paf_result, mask_out
                ] = sess.run([x_, confs_, pafs_, last_conf, last_paf, mask])
                draw_results(img_out, confs_ground, conf_result, pafs_ground,
                             paf_result, mask_out, 'train_%d_' % step)
                # save model
                # tl.files.save_npz(
                #    net.all_params, os.path.join(model_path, 'pose' + str(step) + '.npz'), sess=sess)
                # tl.files.save_npz(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
                tl.files.save_npz_dict(net.all_params,
                                       os.path.join(
                                           model_path,
                                           'pose' + str(step) + '.npz'),
                                       sess=sess)
                tl.files.save_npz_dict(net.all_params,
                                       os.path.join(model_path, 'pose.npz'),
                                       sess=sess)
            if step == n_step:  # training finished
                break
コード例 #3
0
ファイル: train_parallel.py プロジェクト: Aki57/openpose-plus
def parallel_train(training_dataset):
    global_step = tf.Variable(1, trainable=False)
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    opt = tf.train.MomentumOptimizer(lr_v, 0.9)

    tower_grads = []
    for i, gpu_id in enumerate(node_num):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                ds = training_dataset.shard(num_shards=node_num, index=i)
                ds = ds.repeat(n_epoch)
                ds = ds.shuffle(buffer_size=4096)
                ds = ds.map(_map_fn, num_parallel_calls=4)
                ds = ds.batch(batch_size)
                ds = ds.prefetch(buffer_size=1)

                iterator = ds.make_one_shot_iterator()
                one_element = iterator.get_next()
                net, total_loss, log_tensors = make_model(*one_element, is_train=True, reuse=False)
                x_ = net.img  # net input
                last_conf = net.last_conf  # net output
                last_paf = net.last_paf  # net output
                confs_ = net.confs  # GT
                pafs_ = net.pafs  # GT
                mask = net.m1  # mask1, GT
                # net.m2 = m2                 # mask2, GT
                stage_losses = net.stage_losses
                l2_loss = net.l2_loss

                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)

    def average_gradients(tower_grads):
        average_grads = []
        for grad_and_vars in zip(*tower_grads):
            grads = []
            for g, _ in grad_and_vars:
                expanded_g = tf.expand_dims(g, 0)
                grads.append(expanded_g)

            grad = tf.concat(grads, 0)
            grad = tf.reduce_mean(grad, 0)

            v = grad_and_vars[0][1]
            grad_and_var = (grad, v)
            average_grads.append(grad_and_var)

        return average_grads

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()

    train_op = opt.minimize(base_loss, global_step=global_step)
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)

    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = ",".join(list(range(node_num)))  # Horovod

    # Add variable initializer.
    init = tf.global_variables_initializer()

    # Start training
    with tf.Session(config=config) as sess:
        init.run()
        print('Worker{}: Initialized'.format(hvd.rank()))
        print('Worker{}: Start - n_step: {} batch_size: {} lr_init: {} lr_decay_interval: {}'.format(
            hvd.rank(), n_step, batch_size, lr_init, lr_decay_interval))

        # restore pre-trained weights
        try:
            # tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
            tl.files.load_and_assign_npz_dict(sess=sess, name=os.path.join(model_path, 'pose.npz'))
        except:
            print("no pre-trained model")

        # train until the end
        while True:
            step = sess.run(global_step)
            if step == n_step:
                break

            tic = time.time()
            if step != 0 and (step % lr_decay_interval == 0):
                new_lr_decay = lr_decay_factor**(step // lr_decay_interval)
                sess.run(tf.assign(lr_v, lr_init * new_lr_decay))

            [_, _loss, _stage_losses, _l2, conf_result, paf_result] = \
                sess.run([train_op, total_loss, stage_losses, l2_loss, last_conf, last_paf])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Worker{}: Total Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'.format(
                    hvd.rank(), step, n_step, _loss, lr, _l2,
                    time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Worker{}:', hvd.rank(), 'Network#', ix, 'For Branch', ix % 2 + 1, 'Loss:', ll)

            # save intermediate results and model
            if hvd.rank() == 0:  # Horovod
                if (step != 0) and (step % save_interval == 0):
                    # save some results
                    [img_out, confs_ground, pafs_ground, conf_result, paf_result,
                     mask_out] = sess.run([x_, confs_, pafs_, last_conf, last_paf, mask])
                    draw_results(img_out, confs_ground, conf_result, pafs_ground, paf_result, mask_out,
                                 'train_%d_' % step)

                    # save model
                    tl.files.save_npz_dict(
                        net.all_params, os.path.join(model_path, 'pose' + str(step) + '.npz'), sess=sess)
                    tl.files.save_npz_dict(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
コード例 #4
0
def parallel_train(training_dataset, kungfu_option):
    from kungfu import current_cluster_size, current_rank
    from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer, SynchronousAveragingOptimizer, PairAveragingOptimizer

    ds = training_dataset.shuffle(buffer_size=4096)
    ds = ds.shard(num_shards=current_cluster_size(), index=current_rank())
    ds = ds.repeat(n_epoch)
    ds = ds.map(_map_fn, num_parallel_calls=4)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=1)

    iterator = ds.make_one_shot_iterator()
    one_element = iterator.get_next()
    net, total_loss, log_tensors = make_model(*one_element,
                                              is_train=True,
                                              reuse=False)
    x_ = net.img  # net input
    last_conf = net.last_conf  # net output
    last_paf = net.last_paf  # net output
    confs_ = net.confs  # GT
    pafs_ = net.pafs  # GT
    mask = net.m1  # mask1, GT
    # net.m2 = m2                 # mask2, GT
    stage_losses = net.stage_losses
    l2_loss = net.l2_loss

    global_step = tf.Variable(1, trainable=False)
    # scaled_lr = lr_init * current_cluster_size()  # Horovod: scale the learning rate linearly
    scaled_lr = lr_init  # Linear scaling rule is not working in openpose training.
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(scaled_lr, trainable=False)

    opt = tf.train.MomentumOptimizer(lr_v, 0.9)

    # KungFu
    if kungfu_option == 'sync-sgd':
        opt = SynchronousSGDOptimizer(opt)
    elif kungfu_option == 'async-sgd':
        opt = PairAveragingOptimizer(opt)
    elif kungfu_option == 'sma':
        opt = SynchronousAveragingOptimizer(opt)
    else:
        raise RuntimeError('Unknown distributed training optimizer.')

    train_op = opt.minimize(total_loss, global_step=global_step)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    # Add variable initializer.
    init = tf.global_variables_initializer()

    # KungFu
    from kungfu.tensorflow.initializer import BroadcastGlobalVariablesOp
    bcast = BroadcastGlobalVariablesOp()

    global n_step, lr_decay_every_step
    n_step = n_step // current_cluster_size() + 1  # KungFu
    lr_decay_every_step = lr_decay_every_step // current_cluster_size(
    ) + 1  # KungFu

    # Start training
    with tf.Session(config=config) as sess:
        init.run()
        bcast.run()  # KungFu
        print('Worker{}: Initialized'.format(current_rank()))
        print(
            'Worker{}: Start - n_step: {} batch_size: {} lr_init: {} lr_decay_every_step: {}'
            .format(current_rank(), n_step, batch_size, lr_init,
                    lr_decay_every_step))

        # restore pre-trained weights
        try:
            # tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=os.path.join(
                                                  model_path, 'pose.npz'))
        except:
            print("no pre-trained model")

        # train until the end
        while True:
            step = sess.run(global_step)
            if step == n_step:
                break

            tic = time.time()
            if step != 0 and (step % lr_decay_every_step == 0):
                new_lr_decay = lr_decay_factor**(step // lr_decay_every_step)
                sess.run(tf.assign(lr_v, scaled_lr * new_lr_decay))

            [_, _loss, _stage_losses, _l2, conf_result, paf_result] = \
                sess.run([train_op, total_loss, stage_losses, l2_loss, last_conf, last_paf])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Worker{}: Total Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'
                .format(current_rank(), step, n_step, _loss, lr, _l2,
                        time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Worker{}:', current_rank(), 'Network#', ix,
                      'For Branch', ix % 2 + 1, 'Loss:', ll)

            # save intermediate results and model
            if current_rank() == 0:  # KungFu
                if (step != 0) and (step % save_interval == 0):
                    # save some results
                    [
                        img_out, confs_ground, pafs_ground, conf_result,
                        paf_result, mask_out
                    ] = sess.run(
                        [x_, confs_, pafs_, last_conf, last_paf, mask])
                    draw_results(img_out, confs_ground, conf_result,
                                 pafs_ground, paf_result, mask_out,
                                 'train_%d_' % step)

                    # save model
                    # tl.files.save_npz(
                    #    net.all_params, os.path.join(model_path, 'pose' + str(step) + '.npz'), sess=sess)
                    # tl.files.save_npz(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
                    tl.files.save_npz_dict(net.all_params,
                                           os.path.join(
                                               model_path,
                                               'pose' + str(step) + '.npz'),
                                           sess=sess)
                    tl.files.save_npz_dict(net.all_params,
                                           os.path.join(
                                               model_path, 'pose.npz'),
                                           sess=sess)
コード例 #5
0
def train(training_dataset, epoch, n_step):
    ds = training_dataset.shuffle(
        buffer_size=4096)  # shuffle before loading images
    ds = ds.map(_map_fn, num_parallel_calls=multiprocessing.cpu_count() //
                2)  # decouple the heavy map_fn
    ds = ds.batch(batch_size)
    ds = ds.prefetch(2)
    iterator = ds.make_one_shot_iterator()
    one_element = iterator.get_next()
    base_net, base_loss = make_model(*one_element, is_train=True, reuse=False)
    x_2d_ = base_net.input  # base_net input
    last_conf = base_net.last_conf  # base_net output
    last_paf = base_net.last_paf  # base_net output
    confs_ = base_net.confs  # GT
    pafs_ = base_net.pafs  # GT
    stage_losses = base_net.stage_losses
    l2_loss = base_net.l2_loss

    new_lr_decay = lr_decay_factor**((epoch - 1) * n_step // lr_decay_interval)
    print(
        'Start - epoch: {} n_step: {} batch_size: {} lr_init: {} lr_decay_interval: {}'
        .format(epoch, n_step, batch_size, lr_init * new_lr_decay,
                lr_decay_interval))

    lr_v = tf.Variable(lr_init * new_lr_decay,
                       trainable=False,
                       name='learning_rate')
    global_step = tf.Variable(1, trainable=False)
    train_op = tf.train.MomentumOptimizer(lr_v, 0.9).minimize(
        base_loss, global_step=global_step)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    # start training
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        # restore pre-trained weights
        try:
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=os.path.join(
                                                  model_path,
                                                  'openposenet.npz'))
        except:
            print("no pre-trained model")

        # train until the end
        sess.run(tf.assign(lr_v, lr_init))
        while True:
            tic = time.time()
            step = sess.run(global_step)
            if step != 0 and (((epoch - 1) * n_step + step) % lr_decay_interval
                              == 0):
                new_lr_decay = lr_decay_factor**((
                    (epoch - 1) * n_step + step) // lr_decay_interval)
                sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
                print('lr decay to {}'.format(lr_init * new_lr_decay))

            [_, _loss, _stage_losses,
             _l2] = sess.run([train_op, base_loss, stage_losses, l2_loss])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Training Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'
                .format(step, n_step, _loss, lr, _l2,
                        time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Network#', ix, 'For Branch', ix % 2 + 1, 'Loss:', ll)

            # save intermediate results and model
            if (step != 0) and (step % save_interval == 0):
                # save some results
                [img_out, confs_ground, pafs_ground, conf_result, paf_result
                 ] = sess.run([x_2d_, confs_, pafs_, last_conf, last_paf])
                draw_results(img_out[:, :, :, :3], confs_ground, conf_result,
                             pafs_ground, paf_result, None, 'train_%d_' % step)
                # save model
                tl.files.save_npz_dict(base_net.all_params,
                                       os.path.join(
                                           model_path,
                                           'openposenet-' + str(epoch) + '-' +
                                           str(step) + '.npz'),
                                       sess=sess)
                tl.files.save_npz_dict(base_net.all_params,
                                       os.path.join(model_path,
                                                    'openposenet.npz'),
                                       sess=sess)
            # training finished
            if step == n_step:
                tl.files.save_npz_dict(
                    base_net.all_params,
                    os.path.join(model_path,
                                 'openposenet-' + str(epoch) + '.npz'),
                    sess=sess)
                tl.files.save_npz_dict(base_net.all_params,
                                       os.path.join(model_path,
                                                    'openposenet.npz'),
                                       sess=sess)
                break
コード例 #6
0
ファイル: train.py プロジェクト: kaplansinan/openpose-plus
def single_train(training_dataset):
    ds = training_dataset.shuffle(buffer_size=config.TRAIN.shuffel_buffer_size
                                  )  # shuffle before loading images
    ds = ds.repeat(n_epoch)
    ds = ds.map(_map_fn,
                num_parallel_calls=max(1,
                                       multiprocessing.cpu_count() -
                                       1))  # decouple the heavy map_fn
    ds = ds.batch(config.TRAIN.batch_size
                  )  # TODO: consider using tf.contrib.map_and_batch
    ds = ds.prefetch(buffer_size=3)
    iterator = ds.make_one_shot_iterator()
    one_element = iterator.get_next()
    net, total_loss, log_tensors = make_model(
        *one_element,
        is_train=True,
        train_bn=config.TRAIN.train_batch_norm,
        reuse=False)
    x_ = net.img  # net input
    last_conf = net.last_conf  # net output
    last_paf = net.last_paf  # net output
    confs_ = net.confs  # GT
    pafs_ = net.pafs  # GT
    mask = net.m1  # mask1, GT
    # net.m2 = m2                 # mask2, GT
    stage_losses = net.stage_losses
    l2_loss = net.l2_loss
    temp_tot_loss, temp_l2_loss = (float("inf"), float("inf"))

    global_step = tf.Variable(1, trainable=False)
    print(
        'Start - n_step: {} batch_size: {} lr_init: {} lr_decay_every_step: {}'
        .format(n_step, config.TRAIN.batch_size, lr_init, lr_decay_every_step))
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)

    opt = tf.train.MomentumOptimizer(lr_v, 0.9)
    train_op = opt.minimize(total_loss, global_step=global_step)
    tfconfig = tf.ConfigProto(allow_soft_placement=True,
                              log_device_placement=False)

    # start training
    with tf.Session(config=tfconfig) as sess:
        sess.run(tf.global_variables_initializer())

        if tensorboard:

            if not os.path.exists('summaries'):
                os.mkdir('summaries')

            summ_writer = tf.summary.FileWriter(
                os.path.join('summaries',
                             'run' + str(datetime.datetime.now())), sess.graph)

            tf.summary.scalar('total_loss', total_loss)
            tf.summary.scalar('l2_loss', l2_loss)

            for ix, ll in enumerate(stage_losses):
                tf.summary.scalar('stage{}_loss'.format(ix), ll)
            merge = tf.summary.merge_all()

        # restore pre-trained weights
        try:
            # tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
            tl.files.load_and_assign_npz_dict(sess=sess,
                                              name=os.path.join(
                                                  model_path,
                                                  config.MODEL.model_file))
        except:
            print("no pre-trained model")

        if config.MODEL.initial_weights:
            # restore pre-trained mobilnet weights
            try:
                tl.files.load_and_assign_npz_dict(
                    sess=sess,
                    name=os.path.join(model_path,
                                      config.MODEL.initial_weights_file))
            except:
                print("no pre-trained mobilnet model")

        # train until the end
        sess.run(tf.assign(lr_v, lr_init))
        while True:
            tic = time.time()
            step = sess.run(global_step)
            if step != 0 and (step % lr_decay_every_step == 0):
                new_lr_decay = lr_decay_factor**(step // lr_decay_every_step)
                sess.run(tf.assign(lr_v, lr_init * new_lr_decay))

            # TODO Test images
            #print("save test image")
            #[img_out, confs_ground, pafs_ground, conf_result, paf_result,
            #     mask_out] = sess.run([x_, confs_, pafs_, last_conf, last_paf, mask])
            #draw_results(img_out, confs_ground, conf_result, pafs_ground, paf_result, mask_out, 'train_test_%d_' % step)

            [_, _loss, _stage_losses, _l2, conf_result, paf_result] = \
                sess.run([train_op, total_loss, stage_losses, l2_loss, last_conf, last_paf])

            # tstring = time.strftime('%d-%m %H:%M:%S', time.localtime(time.time()))
            lr = sess.run(lr_v)
            print(
                'Total Loss at iteration {} / {} is: {} Learning rate {:10e} l2_loss {:10e} Took: {}s'
                .format(step, n_step, _loss, lr, _l2,
                        time.time() - tic))
            for ix, ll in enumerate(_stage_losses):
                print('Network#', ix, 'For Branch', ix % 2 + 1, 'Loss:', ll)

            if temp_tot_loss > _loss:
                temp_tot_loss = _loss
                # save some results
                [
                    img_out, confs_ground, pafs_ground, conf_result,
                    paf_result, mask_out
                ] = sess.run([x_, confs_, pafs_, last_conf, last_paf, mask])
                draw_results(img_out, confs_ground, conf_result, pafs_ground,
                             paf_result, mask_out,
                             'train_best_{}_{}_'.format(step, _loss))

                tl.files.save_npz_dict(
                    net.all_params,
                    os.path.join(model_path,
                                 'pose_best_{}_{}_'.format(step, _loss)),
                    sess=sess)

            if tensorboard:
                summ = sess.run(merge)
                summ_writer.add_summary(summ, step)

            # save intermediate results and model
            if (step != 0) and (step % save_interval == 0):
                # save some results
                [
                    img_out, confs_ground, pafs_ground, conf_result,
                    paf_result, mask_out
                ] = sess.run([x_, confs_, pafs_, last_conf, last_paf, mask])
                draw_results(img_out, confs_ground, conf_result, pafs_ground,
                             paf_result, mask_out, 'train_%d_' % step)

                # save model
                tl.files.save_npz_dict(net.all_params,
                                       os.path.join(
                                           model_path,
                                           'pose' + str(step) + '.npz'),
                                       sess=sess)
                tl.files.save_npz_dict(net.all_params,
                                       os.path.join(model_path,
                                                    config.MODEL.model_file),
                                       sess=sess)
            if step == n_step:  # training finished
                break