def train(args): step_based_schedule = { 100: 2, 200: 3, 300: 4, 400: 2, 500: 3, 600: 1, } ds = build_dataset(args) model, loss, opt = build_ops(args) need_sync = True total_samples = int(MNIST_DATA_SIZE * args.num_epochs) trained_samples = tf.Variable(0) global_step = tf.Variable(0) for local_step, (images, labels) in enumerate(ds): global_step.assign_add(1) trained_samples.assign_add(current_cluster_size() * args.batch_size) loss_value = training_step(model, loss, opt, images, labels) if need_sync: sync_offsets([global_step, trained_samples]) sync_model(model, opt) need_sync = False step = int(global_step) print('step: %d loss: %f' % (step, loss_value)) if step in step_based_schedule: new_size = step_based_schedule[step] need_sync = resize_cluster(new_size) if detached(): break if trained_samples >= total_samples: break
def main(): # step -> new_size fake_schedule = { 10: 2, 20: 3, 40: 4, 50: 1, } args = parse_args() gs = tf.train.get_or_create_global_step() sync_step_op = tf.assign(gs, all_reduce(gs, op='max')) inc_gs = tf.assign_add(gs, 1) new_size = tf.placeholder(dtype=tf.uint32) resize_op = resize(new_size) train_op = build_fake_train_op(args.use_nccl) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) need_sync = True while True: if need_sync: sess.run(sync_step_op) need_sync = False step = sess.run(gs) # BEGIN train vs = sess.run(train_op) print('step %d, result: %d' % (step, vs[0].sum())) # END train if step in fake_schedule: changed = sess.run(resize_op, feed_dict={new_size: fake_schedule[step]}) if changed: need_sync = True if detached(): break else: print('cluster not changed') assert changed next_gs = sess.run(inc_gs) print('finished %s' % (next_gs - 1)) if next_gs >= args.max_step: break print('stopped')
def after_run(self, run_context, run_values): sess = run_context.session bs = self.get_batch_size(sess) trained_samples = sess.run(self._trained_samples) trained_samples += bs * current_cluster_size() self._set_trained_samples(sess, trained_samples) self._trained_epochs = int(trained_samples / self._epoch_size) for policy in reversed(self._policies): policy.after_step(sess) if self._trained_epochs > self._last_trained_epochs: for policy in reversed(self._policies): policy.after_epoch(sess) if trained_samples >= self._total_samples: # print('%s' % 'request_stop ...') run_context.request_stop() if detached(): run_context.request_stop()
def test_detached(): from kungfu.python import detached assert (not detached())