Ejemplo n.º 1
0
def train(args):
    step_based_schedule = {
        100: 2,
        200: 3,
        300: 4,
        400: 2,
        500: 3,
        600: 1,
    }
    ds = build_dataset(args)
    model, loss, opt = build_ops(args)
    need_sync = True
    total_samples = int(MNIST_DATA_SIZE * args.num_epochs)
    trained_samples = tf.Variable(0)
    global_step = tf.Variable(0)
    for local_step, (images, labels) in enumerate(ds):
        global_step.assign_add(1)
        trained_samples.assign_add(current_cluster_size() * args.batch_size)
        loss_value = training_step(model, loss, opt, images, labels)
        if need_sync:
            sync_offsets([global_step, trained_samples])
            sync_model(model, opt)
            need_sync = False
        step = int(global_step)
        print('step: %d loss: %f' % (step, loss_value))
        if step in step_based_schedule:
            new_size = step_based_schedule[step]
            need_sync = resize_cluster(new_size)
            if detached():
                break

        if trained_samples >= total_samples:
            break
Ejemplo n.º 2
0
def main():
    # step -> new_size
    fake_schedule = {
        10: 2,
        20: 3,
        40: 4,
        50: 1,
    }
    args = parse_args()
    gs = tf.train.get_or_create_global_step()
    sync_step_op = tf.assign(gs, all_reduce(gs, op='max'))
    inc_gs = tf.assign_add(gs, 1)
    new_size = tf.placeholder(dtype=tf.uint32)
    resize_op = resize(new_size)
    train_op = build_fake_train_op(args.use_nccl)
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        need_sync = True
        while True:
            if need_sync:
                sess.run(sync_step_op)
                need_sync = False

            step = sess.run(gs)

            # BEGIN train
            vs = sess.run(train_op)
            print('step %d, result: %d' % (step, vs[0].sum()))
            # END train

            if step in fake_schedule:
                changed = sess.run(resize_op,
                                   feed_dict={new_size: fake_schedule[step]})
                if changed:
                    need_sync = True
                    if detached():
                        break
                else:
                    print('cluster not changed')
                assert changed

            next_gs = sess.run(inc_gs)
            print('finished %s' % (next_gs - 1))
            if next_gs >= args.max_step:
                break

    print('stopped')
Ejemplo n.º 3
0
    def after_run(self, run_context, run_values):
        sess = run_context.session
        bs = self.get_batch_size(sess)
        trained_samples = sess.run(self._trained_samples)
        trained_samples += bs * current_cluster_size()
        self._set_trained_samples(sess, trained_samples)
        self._trained_epochs = int(trained_samples / self._epoch_size)

        for policy in reversed(self._policies):
            policy.after_step(sess)

        if self._trained_epochs > self._last_trained_epochs:
            for policy in reversed(self._policies):
                policy.after_epoch(sess)

        if trained_samples >= self._total_samples:
            # print('%s' % 'request_stop ...')
            run_context.request_stop()

        if detached():
            run_context.request_stop()
Ejemplo n.º 4
0
def test_detached():
    from kungfu.python import detached
    assert (not detached())