Beispiel #1
0
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar100.distorted_inputs()

        logits = cifar100.inference(images)

        loss = cifar100.loss(logits, labels)

        train_op = cifar100.train(loss, global_step)

        saver = tf.train.Saver(tf.all_variables())

        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

                # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Beispiel #2
0
def train():
    """Train CIFAR-100 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-100.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar100.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar100.inference(images)

        # Calculate loss.
        loss = cifar100.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar100.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    self._last_loss = loss_value
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

            def last_loss(self):
                return self._last_loss

        loghook = _LoggerHook()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss), loghook
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:

            t1 = time.time()
            while not mon_sess.should_stop():
                mon_sess.run(train_op)

            t2 = time.time()
            print('spent %f seconds to train %d step' %
                  (t2 - t1, FLAGS.max_steps))
            print('spent %f seconds to train %d step' %
                  (t2 - t1, FLAGS.max_steps))
            print('last loss value: %.2f ' % loghook.last_loss())
def train():
    """训练cifar100"""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            images, labels = cifar100.destorted_inputs()

        # 建立模型,并获取得到的结果logits,用于与labels求交叉熵
        logits = cifar100.inference(images)

        # 计算损失
        loss = cifar100.loss(logits, labels)

        train_op = cifar100.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """打印损失和运行状态"""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')

                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allocator_type = 'BFC'  # 使用BFC算法
        config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 程序最多只能占用指定gpu50%的显存
        config.gpu_options.allow_growth = True  # 程序按需申请内存
        # MonitoredTrainingSession是一个方便的tensorflow会话初始化/恢复器,
        # 也可用于分布式训练
        with tf.train.MonitoredTrainingSession(
                # 加载保存的训练状态的目录,如为空则设为保存目录
                checkpoint_dir=FLAGS.train_dir,
                # 保存间隔
                save_checkpoint_secs=None,
                save_checkpoint_steps=10000,
                # 可选的SessionRunHook对象列表
                # StopAtStepHook表示停止步数
                # NanTensorHook表示当loss为None时返回异常并停止训练
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                       tf.train.NanTensorHook(loss),
                       _LoggerHook()],
                config=config) as mon_sess:

            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Beispiel #4
0
def train():
    """Train CIFAR-10 for a number of steps."""

    #lanhin
    #Construct the cluster and start the server
    ps_spec = FLAGS.ps_hosts.split(",")
    worker_spec = FLAGS.worker_hosts.split(",")

    # Get the number of workers.
    num_workers = len(worker_spec)

    cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)
    if FLAGS.job_name == "ps":
        server.join()
    # only worker will do train()
    is_chief = False
    if FLAGS.task_index == 0:
        is_chief = True

    #lanhin end

    #with tf.Graph().as_default():

    # Use comment to choose which way of tf.device() you want to use
    #with tf.Graph().as_default(), tf.device(tf.train.replica_device_setter(
    #    worker_device="/job:worker/task:%d" % FLAGS.task_index,
    #    cluster=cluster)):
    with tf.device("job:worker/task:%d" % FLAGS.task_index):
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        #with tf.device('/cpu:0'):
        #images, labels = cifar10.distorted_inputs()

        (x_train, y_train_orl), (x_test, y_test_orl) = dset.cifar100.load_data(
            label_mode='fine')
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train, x_test = normalize(x_train, x_test)

        y_train_orl = y_train_orl.astype('int32')
        y_test_orl = y_test_orl.astype('int32')
        y_train_flt = y_train_orl.ravel()
        y_test_flt = y_test_orl.ravel()

        x = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, 32, 32, 3))
        y = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, ))

        # Build a Graph that computes the logits predictions from the
        # inference model.
        #logits, local_var_list = cifar10.inference(images)
        logits, local_var_list = cifar100.inference(x)

        # Calculate loss.
        #loss = cifar10.loss(logits, labels)
        loss = cifar100.loss(logits, y)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar100.train(loss, global_step)

        # the temp var part, for performance testing
        tmp_var_list = []
        var_index = 0
        for var in local_var_list:
            var_index += 1
            tmp_var_list.append(
                tf.Variable(tf.zeros(var.shape),
                            name="tmp_var" + str(var_index)))

        # the non chief workers get local var init_op here
        if not is_chief:
            init_op = tf.global_variables_initializer()
        else:
            init_op = None

        # start global variables region
        global_var_list = []
        with tf.device("/job:ps/replica:0/task:0/cpu:0"):
            # barrier var
            finished = tf.get_variable("worker_finished", [],
                                       tf.int32,
                                       tf.zeros_initializer(tf.int32),
                                       trainable=False)
            with finished.graph.colocate_with(finished):
                finish_op = finished.assign_add(1, use_locking=True)

            var_index = 0
            for var in local_var_list:
                var_index += 1
                global_var_list.append(
                    tf.Variable(tf.zeros(var.shape),
                                name="glo_var" + str(var_index)))

        def assign_global_vars():  # assign local vars' values to global vars
            return [
                gvar.assign(lvar)
                for (gvar, lvar) in zip(global_var_list, local_var_list)
            ]

        def assign_local_vars():  # assign global vars' values to local vars
            return [
                lvar.assign(gvar)
                for (gvar, lvar) in zip(global_var_list, local_var_list)
            ]

        def assign_tmp_vars():  # assign local vars' values to tmp vars
            return [
                tvar.assign(lvar)
                for (tvar, lvar) in zip(tmp_var_list, local_var_list)
            ]

        def assign_local_vars_from_tmp(
        ):  # assign tmp vars' values to local vars
            return [
                lvar.assign(tvar)
                for (tvar, lvar) in zip(tmp_var_list, local_var_list)
            ]

        def update_before_train(alpha, w, global_w):
            varib = alpha * (w - global_w)
            gvar_op = global_w.assign(global_w + varib)
            return gvar_op, varib

        def update_after_train(w, vab):
            return w.assign(w - vab)

        assign_list_local = assign_local_vars()
        assign_list_global = assign_global_vars()
        assign_list_loc2tmp = assign_tmp_vars()
        assign_list_tmp2loc = assign_local_vars_from_tmp()

        before_op_tuple_list = []
        after_op_tuple_list = []
        vbholder_list = []
        for (gvar, lvar) in zip(global_var_list, local_var_list):
            before_op_tuple_list.append(
                (update_before_train(alpha, lvar, gvar)))
        for var in local_var_list:
            vbholder_list.append(tf.placeholder("float", var.shape))
            after_op_tuple_list.append(
                (update_after_train(var,
                                    vbholder_list[-1]), vbholder_list[-1]))

        # the chief worker get global var init op here
        if is_chief:
            init_op = tf.global_variables_initializer()

        # global variables region end

        #lanhin start
        sv = tf.train.Supervisor(
            is_chief=True,  #is_chief,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            #local_init_op=loc_init_op,
            recovery_wait_secs=1)
        #global_step=global_step)

        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement,
            device_filters=[
                "/job:ps", "/job:worker/task:%d" % FLAGS.task_index
            ])

        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        if is_chief:
            print("Worker %d: Initializing session..." % FLAGS.task_index)
        else:
            print("Worker %d: Waiting for session to be initialized..." %
                  FLAGS.task_index)

        sess = sv.prepare_or_wait_for_session(server.target,
                                              config=sess_config)

        if is_chief:
            sess.run(assign_list_global)
            barrier_finished = sess.run(finish_op)
            print("barrier_finished:", barrier_finished)
        else:
            barrier_finished = sess.run(finish_op)
            print("barrier_finished:", barrier_finished)
        while barrier_finished < num_workers:
            time.sleep(1)
            barrier_finished = sess.run(finished)
        sess.run(assign_list_local)
        print("Worker %d: Session initialization complete." % FLAGS.task_index)
        # lanhin end

        #sess = tf.Session()
        #sess.run(init_op)
        #tf.train.start_queue_runners(sess)
        f = open('tl_dist.json', 'w')
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        time_begin = time.time()
        #    while not mon_sess.should_stop():
        #      mon_sess.run(train_op)
        for step in range(FLAGS.max_steps):
            offset = (step * FLAGS.batch_size) % (EPOCH_SIZE -
                                                  FLAGS.batch_size)
            x_data = x_train[offset:(offset + FLAGS.batch_size), ...]
            y_data_flt = y_train_flt[offset:(offset + FLAGS.batch_size)]

            if step % FLAGS.log_frequency == 0:
                time_step = time.time()
                steps_time = time_step - time_begin
                print("step:", step, " steps time:", steps_time, end='  ')
                sess.run(assign_list_loc2tmp)
                sess.run(assign_list_local)
                predt(sess, x_test, y_test_flt, logits, x, y)
                sess.run(assign_list_tmp2loc)
                time_begin = time.time()
            if step % FLAGS.tau == 0 and step > 0:  # update global weights
                thevarib_list = []
                for i in range(0, len(before_op_tuple_list)):
                    (gvar_op, varib) = before_op_tuple_list[i]
                    _, thevarib = sess.run([gvar_op, varib])
                    thevarib_list.append(thevarib)

                sess.run(train_op, feed_dict={x: x_data, y: y_data_flt})

                for i in range(0, len(after_op_tuple_list)):
                    (lvar_op, thevaribHolder) = after_op_tuple_list[i]
                    sess.run(lvar_op,
                             feed_dict={thevaribHolder: thevarib_list[i]})

            else:
                sess.run(train_op, feed_dict={
                    x: x_data,
                    y: y_data_flt
                })  #, options=run_options, run_metadata=run_metadata)
                #tl = timeline.Timeline(run_metadata.step_stats)
                #ctf = tl.generate_chrome_trace_format()
        #f.write(ctf)
        time_end = time.time()
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)
        f.close()
        sess.run(assign_list_local)
        predt(sess, x_test, y_test_flt, logits, x, y)
Beispiel #5
0
def train():
    """Train CIFAR-100 for a number of steps."""
    output = open('output_data/output_' + str(time.time()) + '.txt', 'w')
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-100.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
          images, labels = cifar100.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        
        logitsA,logitsB = cifar100.inference(images)

        # Calculate loss.
        lossA = cifar100.loss(logitsA, labels)
        lossB = cifar100.loss(logitsB, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_opA = cifar100.train(lossA, global_step)
        train_opB = cifar100.train(lossB, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(lossA)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))
                    print((str(self._step) + '\t' +
                           str(loss_value) + '\n'), file=output)

        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(lossA),
                   tf.train.NanTensorHook(lossB),
                   _LoggerHook()],
            config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            
            file_writer = tf.summary.FileWriter('tb-logs/', mon_sess.graph)

            while not mon_sess.should_stop():
                print("stepA")
                mon_sess.run(train_opA)
                print("stepB")
                mon_sess.run(train_opB)
        output.close()