Example #1
0
def train(args, sess, epoch, learning_rate_placeholder,
          phase_train_placeholder, global_step, loss, train_op, summary_op,
          summary_writer, learning_rate_schedule_file):
    batch_number = 0

    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    while batch_number < args.epoch_size:
        start_time = time.time()

        print('Running forward pass on sampled images: ', end='')
        feed_dict = {
            learning_rate_placeholder: lr,
            phase_train_placeholder: True
        }
        start_time = time.time()
        total_err, reg_err, _, step = sess.run(
            [loss['total_loss'], loss['total_reg'], train_op, global_step],
            feed_dict=feed_dict)
        duration = time.time() - start_time
        print(
            'Epoch: [%d][%d/%d]\tTime %.3f\tTotal Loss %2.3f\tReg Loss %2.3f, lr %2.5f'
            % (epoch, batch_number + 1, args.epoch_size, duration, total_err,
               reg_err, lr))

        batch_number += 1
    return step
Example #2
0
def train_online(args, sess, epoch, learning_rate_placeholder,
                 phase_train_placeholder, global_step, loss, train_op,
                 summary_op, summary_writer, learning_rate_schedule_file):
    batch_number = 0

    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    while batch_number < args.epoch_size:
        # Sample people randomly from the dataset
        start_time = time.time()

        print('Running forward pass on sampled images: ', end='')
        feed_dict = {
            learning_rate_placeholder: lr,
            phase_train_placeholder: True
        }
        start_time = time.time()
        triplet_err, total_err, _, step = sess.run(
            [loss['triplet_loss'], loss['total_loss'], train_op, global_step],
            feed_dict=feed_dict)
        duration = time.time() - start_time
        print(
            'Epoch: [%d][%d/%d]\tTime %.3f\tTriplet Loss %2.3f Total Loss %2.3f lr %2.5f'
            % (epoch, batch_number + 1, args.epoch_size, duration, triplet_err,
               total_err, lr))
        #ctf = tl.generate_chrome_trace_format()
        batch_number += 1
    return step
Example #3
0
def train(args, sess, epoch, num_gpus, debug_info, learning_rate_placeholder,
          phase_train_placeholder, global_step, loss, train_op, summary_op,
          summary_writer, learning_rate_schedule_file):
    batch_number = 0

    #run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    #run_metadata = tf.RunMetadata()
    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    while batch_number < args.epoch_size:
        # Sample people randomly from the dataset
        #image_paths, num_per_class = sample_people(dataset, args.people_per_batch, args.images_per_person)
        start_time = time.time()
        #pdb.set_trace()

        print('Running forward pass on sampled images: ', end='')
        feed_dict = {
            learning_rate_placeholder: lr,
            phase_train_placeholder: True
        }
        start_time = time.time()
        #triplet_err,total_err, _, step, emb, lab = sess.run([loss['triplet_loss'],loss['total_loss'], train_op, global_step, embeddings, labels_batch], feed_dict=feed_dict,options=run_options,run_metadata=run_metadata)
        #triplet_err,total_err, _, step = sess.run([loss['triplet_loss'],loss['total_loss'], train_op, global_step ], feed_dict=feed_dict,options=run_options,run_metadata=run_metadata)
        #if batch_number > 18:
        #    pdb.set_trace()
        #_,total_err, softmax_err,dist_err,th_err, _, step = sess.run([debug_info['check_nan'],loss['total_loss'], loss['total_cross'], loss['total_dist'], loss['total_th'], train_op, global_step ], feed_dict=feed_dict)
        total_err, softmax_err, dist_err, th_err, _, step = sess.run(
            [
                loss['total_loss'], loss['total_cross'], loss['total_dist'],
                loss['total_th'], train_op, global_step
            ],
            feed_dict=feed_dict)
        duration = time.time() - start_time
        print(
            'Epoch: [%d][%d/%d]\tTime %.3f\tTotal Loss %2.3f\tSoftmax Loss %2.3f, Dist Loss %2.3f, Th Loss %2.3f,  lr %2.5f'
            % (epoch, batch_number + 1, args.epoch_size, duration, total_err,
               softmax_err, dist_err, th_err, lr))
        # Add validation loss and accuracy to summary
        summary = tf.Summary()
        #pylint: disable=maybe-no-member
        summary.value.add(tag='time/selection', simple_value=duration)
        summary_writer.add_summary(summary, step)
        #summary_writer.add_run_metadata(run_metadata,'step%d' %(batch_number+1))

        #tf.contrib.tfprof.model_analyzer.print_model_analysis(
        #        tf.get_default_graph(),
        #        run_meta=run_metadata,
        #        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

        #tl = timeline.Timeline(run_metadata.step_stats)
        #ctf = tl.generate_chrome_trace_format()
        batch_number += 1
        #with open('prefetch_cpu_var_2_{}.json'.format(batch_number),'w') as f:
        #    f.write(ctf)
    return step
Example #4
0
def train(args, sess, epoch, images_placeholder, labels_placeholder,
          data_reader, run_metadata, run_options, learning_rate_placeholder,
          global_step, loss, train_op, learning_rate_schedule_file):
    batch_number = 0

    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    while batch_number < args.epoch_size:
        start_time = time.time()

        #print('Running forward pass on sampled images: ', end='')
        start_time = time.time()
        images, labels = data_reader.next_batch(args.image_height,
                                                args.image_width)
        print("Load {} images cost time: {}".format(images.shape[0],
                                                    time.time() - start_time))
        #if epoch < 10:
        #    continue
        #if batch_number > 49:
        #    pdb.set_trace()
        #    print(labels)
        #print(images.shape,labels)
        feed_dict = {
            learning_rate_placeholder: lr,
            images_placeholder: images,
            labels_placeholder: labels
        }
        #feed_dict = {learning_rate_placeholder: lr}
        start_time = time.time()
        #total_err, softmax_err, _, step = sess.run([loss['total_loss'], loss['softmax_loss'], train_op, global_step ], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
        _ = sess.run(train_op,
                     feed_dict=feed_dict,
                     options=run_options,
                     run_metadata=run_metadata)
        tl = timeline.Timeline(run_metadata.step_stats)
        ctf = tl.generate_chrome_trace_format()
        with open('json_placeholder/tl-{}.json'.format(batch_number),
                  'w') as wd:
            wd.write(ctf)
        duration = time.time() - start_time
        #print('Epoch: [%d][%d/%d]\tTime %.3f\tTotal Loss %2.3f\tSoftmax Loss %2.3f, lr %2.5f' %
        #          (epoch, batch_number+1, args.epoch_size, duration, total_err, softmax_err, lr))

        batch_number += 1
    return batch_number
Example #5
0
def train(args, sess, epoch, images_placeholder, labels_placeholder,
          data_reader, debug, learning_rate_placeholder, global_step, loss,
          train_op, learning_rate_schedule_file):
    batch_number = 0

    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    while batch_number < args.epoch_size:
        start_time = time.time()

        print('Running forward pass on sampled images: ', end='')
        images, labels = data_reader.next_batch(args.image_height,
                                                args.image_width)
        #if epoch < 10:
        #    continue
        #if batch_number > 49:
        #    pdb.set_trace()
        #    print(labels)
        #print(images.shape,labels)
        feed_dict = {
            learning_rate_placeholder: lr,
            images_placeholder: images,
            labels_placeholder: labels
        }
        start_time = time.time()
        total_err, softmax_err, _, step = sess.run(
            [loss['total_loss'], loss['softmax_loss'], train_op, global_step],
            feed_dict=feed_dict)
        duration = time.time() - start_time
        print(
            'Epoch: [%d][%d/%d]\tTime %.3f\tTotal Loss %2.3f\tSoftmax Loss %2.3f, lr %2.5f'
            % (epoch, batch_number + 1, args.epoch_size, duration, total_err,
               softmax_err, lr))

        batch_number += 1
    return step
Example #6
0
def main(args):

    src_path, _ = os.path.split(os.path.realpath(__file__))

    # Create result directory
    res_name = utils.gettime()
    res_dir = os.path.join(src_path, 'results', res_name)
    os.makedirs(res_dir, exist_ok=True)

    log_filename = os.path.join(res_dir, 'log.h5')
    model_filename = os.path.join(res_dir, res_name)

    # Store some git revision info in a text file in the log directory
    utils.store_revision_info(src_path, res_dir, ' '.join(sys.argv))

    # Store parameters in an HDF5 file
    utils.store_hdf(os.path.join(res_dir, 'parameters.h5'), vars(args))

    # Copy learning rate schedule file to result directory
    learning_rate_schedule = utils.copy_learning_rate_schedule_file(
        args.learning_rate_schedule, res_dir)

    with tf.Session() as sess:

        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)

        filelist = ['train_%03d.pkl' % i for i in range(200)]
        dataset = create_dataset(filelist,
                                 args.data_dir,
                                 buffer_size=20000,
                                 batch_size=args.batch_size,
                                 total_seq_length=args.nrof_init_time_steps +
                                 args.seq_length)

        # Create an iterator over the dataset
        iterator = dataset.make_one_shot_iterator()
        obs, action = iterator.get_next()

        is_pdt_ph = tf.placeholder(tf.bool, [None, args.seq_length])
        is_pdt = create_transition_type_matrix(args.batch_size,
                                               args.seq_length,
                                               args.training_scheme)

        with tf.variable_scope('env_model'):
            env_model = EnvModel(is_pdt_ph,
                                 obs,
                                 action,
                                 1,
                                 model_type=args.model_type,
                                 nrof_time_steps=args.seq_length,
                                 nrof_free_nats=args.nrof_free_nats)

        reg_loss = tf.reduce_mean(env_model.regularization_loss)
        rec_loss = tf.reduce_mean(env_model.reconstruction_loss)
        loss = reg_loss + rec_loss

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate_ph = tf.placeholder(tf.float32, ())
        train_op = tf.train.AdamOptimizer(learning_rate_ph).minimize(
            loss, global_step=global_step)

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())

        stat = {
            'loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'rec_loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'reg_loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'learning_rate': np.zeros((args.max_nrof_steps, ), np.float32),
        }

        try:
            print('Started training')
            rec_loss_tot, reg_loss_tot, loss_tot = (0.0, 0.0, 0.0)
            lr = None
            t = time.time()
            for i in range(1, args.max_nrof_steps + 1):
                if not lr or i % 100 == 0:
                    lr = utils.get_learning_rate_from_file(
                        learning_rate_schedule, i)
                    if lr < 0:
                        break
                stat['learning_rate'][i - 1] = lr
                _, rec_loss_, reg_loss_, loss_ = sess.run(
                    [train_op, rec_loss, reg_loss, loss],
                    feed_dict={
                        is_pdt_ph: is_pdt,
                        learning_rate_ph: lr
                    })
                stat['loss'][i - 1], stat['rec_loss'][i - 1], stat['reg_loss'][
                    i - 1] = loss_, rec_loss_, reg_loss_
                rec_loss_tot += rec_loss_
                reg_loss_tot += reg_loss_
                loss_tot += loss_
                if i % 10 == 0:
                    print(
                        'step: %-5d  time: %-12.3f  lr: %-12.6f  rec_loss: %-12.1f  reg_loss: %-12.1f  loss: %-12.1f'
                        % (i, time.time() - t, lr, rec_loss_tot / 10,
                           reg_loss_tot / 10, loss_tot / 10))
                    rec_loss_tot, reg_loss_tot, loss_tot = (0.0, 0.0, 0.0)
                    t = time.time()
                if i % 5000 == 0 and i > 0:
                    saver.save(sess, model_filename, i)
                if i % 100 == 0:
                    utils.store_hdf(log_filename, stat)

        except tf.errors.OutOfRangeError:
            pass

        print("Saving model...")
        saver.save(sess, model_filename, i)

        print('Done!')
Example #7
0
def train(args, sess, epoch, batch_number,
          learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder,
          image_batch_plh, label_batch_plh, step,
          loss, train_op, summary_op, summary_writer, reg_losses, learning_rate_schedule_file,
          stat, cross_entropy_mean, accuracy,
          learning_rate, prelogits, prelogits_center_loss, prelogits_norm,
          prelogits_hist_max, dataset_train,):

    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file, epoch)

    if lr <= 0:
        return False

    # Training loop
    train_time = 0
    while batch_number < args.epoch_size:
        start_time = time.time()

        # Process a batch of data for facenet
        image_batch, label_batch = dataset_train.batch()

        # Compute loss and perform training step for facenet
        feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder: True,
                     batch_size_placeholder: args.batch_size,
                     label_batch_plh: label_batch, image_batch_plh: image_batch}
        # image_summary = tf.summary.image('image_batch', image_batch)

        tensor_list = [loss, train_op, step, reg_losses, prelogits, cross_entropy_mean, learning_rate, prelogits_norm,
                       accuracy, prelogits_center_loss]

        if batch_number % 2 == 0:
            loss_, _, step_, reg_losses_, prelogits_, cross_entropy_mean_, \
            lr_, prelogits_norm_, accuracy_, center_loss_, summary_str = \
                sess.run(tensor_list + [summary_op], feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, global_step=step_)
            # summary_writer.add_summary(image_summary_, global_step=step_)
        else:
            loss_, _, step_, reg_losses_, prelogits_, cross_entropy_mean_, lr_, prelogits_norm_, accuracy_, center_loss_ = sess.run(
                tensor_list, feed_dict=feed_dict)

        duration = time.time() - start_time
        stat['loss'][step_ - 1] = loss_
        stat['center_loss'][step_ - 1] = center_loss_
        stat['reg_loss'][step_ - 1] = np.sum(reg_losses_)
        stat['xent_loss'][step_ - 1] = cross_entropy_mean_
        stat['prelogits_norm'][step_ - 1] = prelogits_norm_
        stat['learning_rate'][epoch - 1] = lr_
        stat['accuracy'][step_ - 1] = accuracy_
        stat['prelogits_hist'][epoch - 1, :] += \
        np.histogram(np.minimum(np.abs(prelogits_), prelogits_hist_max), bins=1000, range=(0.0, prelogits_hist_max))[0]

        print('Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tXent %2.3f\tRegLoss %2.3f\tAccuracy %2.3f'
              '\tLr %2.5f\tCl %2.3f' %
              (epoch, batch_number + 1, args.epoch_size, duration, loss_, cross_entropy_mean_, np.sum(reg_losses_),
               accuracy_, lr_, center_loss_))
        batch_number += 1
        train_time += duration
    # Add validation loss and accuracy to summary
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='time/total', simple_value=train_time)
    summary_writer.add_summary(summary, global_step=step_)
    return True
Example #8
0
def train_simi_online(args, sess, epoch, num_gpus, embeddings_gather,
                      batch_label, images, batch_image_split,
                      learning_rate_placeholder, learning_rate,
                      phase_train_placeholder, global_step, pos_d, neg_d,
                      triplet_handle, loss, train_op, summary_op,
                      summary_writer, learning_rate_schedule_file):
    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = utils.get_learning_rate_from_file(learning_rate_schedule_file,
                                               epoch)
    batch_number = 0
    while batch_number < args.epoch_size:
        # Sample people randomly from the dataset
        embeddings_list = []
        labels_list = []
        images_list = []
        f_time = time.time()
        for i in range(args.scale):
            embeddings_np, labels_np, images_np = sess.run(
                [embeddings_gather, batch_label, images],
                feed_dict={
                    phase_train_placeholder: False,
                    learning_rate_placeholder: lr
                })
            embeddings_list.append(embeddings_np)
            labels_list.append(labels_np)
            images_list.append(images_np)
        embeddings_all = np.vstack(embeddings_list)
        labels_all = np.hstack(labels_list)
        images_all = np.vstack(images_list)
        print('forward time: {}'.format(time.time() - f_time))
        f_time = time.time()
        triplet_pairs = sess.run(triplet_handle['pair'],
                                 feed_dict={
                                     triplet_handle['embeddings']:
                                     embeddings_all,
                                     triplet_handle['labels']: labels_all
                                 })
        print('tf op select triplet time: {}'.format(time.time() - f_time))
        triplet_images_size = len(triplet_pairs)
        if args.show_triplet:
            show_images = (images_all * 128. + 127.5) / 255.
            save_dir = 'rm/{}_{}'.format(epoch, batch_number)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            for i in range(triplet_images_size // 3):
                start_i = i * 3
                image_a = show_images[triplet_pairs[start_i]]
                image_p = show_images[triplet_pairs[start_i + 1]]
                image_n = show_images[triplet_pairs[start_i + 2]]
                image_apn = np.concatenate([image_a, image_p, image_n], axis=1)
                to_name = '{}/{}.jpg'.format(save_dir, i)
                misc.imsave(to_name, image_apn)
        total_batch_size = args.num_gpus * args.people_per_batch * args.images_per_person
        nrof_batches = int(
            math.ceil(1.0 * (triplet_images_size // (args.num_gpus * 3)) *
                      args.num_gpus * 3 / total_batch_size))
        if nrof_batches == 0:
            print('continue forward')
            continue
        for i in range(nrof_batches):
            start_index = i * total_batch_size
            end_index = min((i + 1) * total_batch_size,
                            (triplet_images_size //
                             (args.num_gpus * 3)) * args.num_gpus * 3)
            #select_triplet_pairs = triplet_pairs[:total_batch_size] if triplet_images_size >= total_batch_size else triplet_pairs[:(triplet_images_size//(args.num_gpus*3))*args.num_gpus*3]
            select_triplet_pairs = triplet_pairs[start_index:end_index]
            select_images = images_all[select_triplet_pairs]
            #print('triplet pairs: {}/{}'.format(len(select_triplet_pairs)//3,triplet_images_size//3))
            print('triplet pairs: {}/{}'.format(end_index // 3,
                                                triplet_images_size // 3))

            start_time = time.time()
            print('Running forward pass on sampled images: ', end='')
            feed_dict = {
                phase_train_placeholder: False,
                images: select_images,
                learning_rate_placeholder: lr
            }
            start_time = time.time()
            triplet_err, total_err, _, step, lr, _, pos_np, neg_np = sess.run(
                [
                    loss['triplet_loss'], loss['total_loss'], train_op,
                    global_step, learning_rate, summary_op, pos_d, neg_d
                ],
                feed_dict=feed_dict)
            duration = time.time() - start_time
            print(
                'Epoch: [%d][%d/%d]\tTime %.3f\tTriplet Loss %2.3f Total Loss %2.3f lr %2.5f, pos_d %2.5f, neg_d %2.5f'
                % (epoch, batch_number + 1, args.epoch_size, duration,
                   triplet_err, total_err, lr, pos_np, neg_np))
        # Add validation loss and accuracy to summary
        summary = tf.Summary()
        #pylint: disable=maybe-no-member
        summary.value.add(tag='time/selection', simple_value=duration)
        summary.value.add(tag='loss/triploss', simple_value=triplet_err)
        summary.value.add(tag='loss/total', simple_value=total_err)
        summary.value.add(tag='learning_rate/lr', simple_value=lr)
        summary_writer.add_summary(summary, step)

        batch_number += 1
        #with open('prefetch_cpu_var_2_{}.json'.format(batch_number),'w') as f:
        #    f.write(ctf)
    return step