示例#1
0
def main(args):

    #network = importlib.import_module(args.model_def)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)

    train_set = utils.get_dataset(args.data_dir)
    #train_set = facenet.dataset_from_list2(args.data_dir,'dataset/casia_maxpy_mtcnnpy_182',error_classes=[],drop_key='AsianStarCropBig_YES')
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)

    dataset_size = len(image_list)
    single_batch_size = args.people_per_batch * args.images_per_person
    indices = range(dataset_size)
    np.random.shuffle(indices)

    def _sample_people_softmax(x):
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]

        softmax_ind += true_num_batch

        return (np.array(sample_paths), np.array(sample_labels,
                                                 dtype=np.int32))

    def _sample_people(x):
        '''We sample people based on tf.data, where we can use transform and prefetch.

        '''

        image_paths, num_per_class = sample_people(
            train_set, args.people_per_batch * (args.num_gpus - 1),
            args.images_per_person)
        labels = []
        for i in range(len(num_per_class)):
            labels.extend([i] * num_per_class[i])
        return (np.array(image_paths), np.array(labels, dtype=np.int32))

    def _parse_function(filename, label):
        file_contents = tf.read_file(filename)
        image = tf.image.decode_image(file_contents, channels=3)
        #image = tf.image.decode_jpeg(file_contents, channels=3)
        print(image.shape)

        if args.random_crop:
            print('use random crop')
            image = tf.random_crop(image,
                                   [args.image_size, args.image_size, 3])
        else:
            print('Not use random crop')
            #image.set_shape((args.image_size, args.image_size, 3))
            image.set_shape((None, None, 3))
            image = tf.image.resize_images(image,
                                           size=(args.image_height,
                                                 args.image_width))
            #print(image.shape)
        if args.random_flip:
            image = tf.image.random_flip_left_right(image)

        #pylint: disable=no-member
        #image.set_shape((args.image_size, args.image_size, 3))
        image.set_shape((args.image_height, args.image_width, 3))
        if debug:
            image = tf.cast(image, tf.float32)
        else:
            image = tf.image.per_image_standardization(image)
        return image, label

    #train_set = facenet.dataset_from_list(args.data_dir,'dataset/ms_mp',keys=['MultiPics'])
    #train_set = facenet.dataset_from_list(args.data_dir,'dataset/ms_mp')
    gpus = [0, 1]
    #gpus = [0]

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf_data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs * 100)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.string, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.map(_parse_function,
                                                  num_threads=8,
                                                  output_buffer_size=2000)
            softmax_dataset = softmax_dataset.batch(args.num_gpus *
                                                    single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (args.num_gpus * single_batch_size, args.image_height,
                 args.image_width, 3))
            softmax_next_element[1].set_shape(args.num_gpus *
                                              single_batch_size)
            batch_image_split = tf.split(softmax_next_element[0],
                                         args.num_gpus)
            batch_label_split = tf.split(softmax_next_element[1],
                                         args.num_gpus)

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        tower_losses = []
        tower_cross = []
        tower_dist = []
        tower_th = []
        for i in range(args.num_gpus):
            with tf.device("/gpu:" + str(i)):
                with tf.name_scope("tower_" + str(i)) as scope:
                    with slim.arg_scope([slim.model_variable, slim.variable],
                                        device="/cpu:0"):
                        with tf.variable_scope(
                                tf.get_variable_scope()) as var_scope:
                            reuse = False if i == 0 else True
                            #with slim.arg_scope(resnet_v2.resnet_arg_scope(args.weight_decay)):
                            #prelogits, end_points = resnet_v2.resnet_v2_50(batch_image_split[i],is_training=True,
                            #        output_stride=16,num_classes=args.embedding_size,reuse=reuse)
                            #prelogits, end_points = network.inference(batch_image_split[i], args.keep_probability,
                            #    phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
                            #    weight_decay=args.weight_decay, reuse=reuse)
                            if args.network == 'slim_sphere':
                                prelogits = network.infer(batch_image_split[i])
                            elif args.network == 'densenet':
                                with slim.arg_scope(
                                        densenet.densenet_arg_scope(
                                            args.weight_decay)):
                                    #prelogits, endpoints = densenet.densenet_small(batch_image_split[i],num_classes=args.embedding_size,is_training=True,reuse=reuse)
                                    prelogits, endpoints = densenet.densenet_small_middle(
                                        batch_image_split[i],
                                        num_classes=args.embedding_size,
                                        is_training=True,
                                        reuse=reuse)
                                    prelogits = tf.squeeze(prelogits,
                                                           axis=[1, 2])

                            #prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn')
                            if args.loss_type == 'softmax':
                                cross_entropy_mean = utils.softmax_loss(
                                    prelogits, batch_label_split[i],
                                    len(train_set), args.weight_decay, reuse)
                                regularization_losses = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                tower_cross.append(cross_entropy_mean)
                                #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses)
                                loss = cross_entropy_mean + tf.add_n(
                                    regularization_losses)
                                tower_dist.append(0)
                                tower_cross.append(cross_entropy_mean)
                                tower_th.append(0)
                                tower_losses.append(loss)
                            elif args.loss_type == 'scatter' or args.loss_type == 'coco':
                                label_reshape = tf.reshape(
                                    batch_label_split[i], [single_batch_size])
                                label_reshape = tf.cast(
                                    label_reshape, tf.int64)
                                if args.loss_type == 'scatter':
                                    scatter_loss, _ = utils.weight_scatter_speed(
                                        prelogits,
                                        label_reshape,
                                        len(train_set),
                                        reuse,
                                        weight=args.weight,
                                        scale=args.scale)
                                else:
                                    scatter_loss, _ = utils.coco_loss(
                                        prelogits,
                                        label_reshape,
                                        len(train_set),
                                        reuse,
                                        alpha=args.alpha,
                                        scale=args.scale)
                                regularization_losses = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                loss = scatter_loss[
                                    'loss_total'] + args.weight_decay * tf.add_n(
                                        regularization_losses)
                                tower_dist.append(scatter_loss['loss_dist'])
                                tower_cross.append(0)
                                tower_th.append(scatter_loss['loss_th'])

                                tower_losses.append(loss)

                            #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
                            tf.get_variable_scope().reuse_variables()
        total_loss = tf.reduce_mean(tower_losses)
        total_cross = tf.reduce_mean(tower_cross)
        total_dist = tf.reduce_mean(tower_dist)
        total_th = tf.reduce_mean(tower_th)
        losses = {}
        losses['total_loss'] = total_loss
        losses['total_cross'] = total_cross
        losses['total_dist'] = total_dist
        losses['total_th'] = total_th
        debug_info = {}
        debug_info['logits'] = prelogits
        #debug_info['end_points'] = end_points
        debug_info['batch_image_split'] = batch_image_split
        debug_info['batch_label_split'] = batch_label_split
        #debug_info['endpoints'] = endpoints

        grads = opt.compute_gradients(total_loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        check_nan = tf.add_check_numerics_ops()
        debug_info['check_nan'] = check_nan

        #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            #pdb.set_trace()

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch,
                                image_batch_gather, enqueue_op,
                                batch_size_placeholder, image_batch_split,
                                image_paths_split, num_per_class_split,
                                image_paths_placeholder,
                                image_paths_split_placeholder,
                                labels_placeholder, labels_batch,
                                num_per_class_placeholder,
                                num_per_class_split_placeholder, len(gpus))
                # Train for one epoch
                train(args, sess, epoch, len(gpus), debug_info,
                      learning_rate_placeholder, phase_train_placeholder,
                      global_step, losses, train_op, summary_op,
                      summary_writer, args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)

                # Evaluate on LFW
    return model_dir
示例#2
0
def main(args):

    #network = importlib.import_module(args.model_def)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)

    train_set = utils.get_dataset(args.data_dir)
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    print('total images: {}'.format(len(image_list)))
    label_list = np.array(label_list, dtype=np.int32)

    dataset_size = len(image_list)
    data_reader = DataGenerator(image_list, label_list, args.batch_size)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        images_placeholder = tf.placeholder(tf.float32, [None, 112, 96, 3],
                                            name='images_placeholder')
        labels_placeholder = tf.placeholder(tf.int32, [None],
                                            name='labels_placeholder')

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

        if args.network == 'sphere_network':
            prelogits = network.infer(images_placeholder)
        else:
            raise Exception('Not supported network: {}'.format(args.loss_type))

        if args.loss_type == 'softmax':
            cross_entropy_mean = utils.softmax_loss(prelogits,
                                                    labels_placeholder,
                                                    len(train_set),
                                                    args.weight_decay, False)
            regularization_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses)
            loss = cross_entropy_mean + args.weight_decay * tf.add_n(
                regularization_losses)
            #loss = cross_entropy_mean
        else:
            raise Exception('Not supported loss type: {}'.format(
                args.loss_type))

        #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
        losses = {}
        losses['total_loss'] = loss
        losses['softmax_loss'] = cross_entropy_mean
        debug_info = {}
        debug_info['prelogits'] = prelogits

        grads = opt.compute_gradients(loss, tf.trainable_variables())
        train_op = opt.apply_gradients(grads, global_step=global_step)

        #save_vars = [var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name]
        save_vars = tf.global_variables()

        #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        with sess.as_default():
            #pdb.set_trace()

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size

                # Train for one epoch
                train(args, sess, epoch, images_placeholder,
                      labels_placeholder, data_reader, debug,
                      learning_rate_placeholder, global_step, losses, train_op,
                      args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                model_dir = args.models_base_dir
                checkpoint_path = os.path.join(model_dir,
                                               'model-%s.ckpt' % 'softmax')
                saver.save(sess,
                           checkpoint_path,
                           global_step=step,
                           write_meta_graph=False)

                # Evaluate on LFW
    return model_dir
示例#3
0
def main(args):

    src_path, _ = os.path.split(os.path.realpath(__file__))

    # Create result directory
    res_name = utils.gettime()
    res_dir = os.path.join(src_path, 'results', res_name)
    os.makedirs(res_dir, exist_ok=True)

    log_filename = os.path.join(res_dir, 'log.h5')
    model_filename = os.path.join(res_dir, res_name)

    # Store some git revision info in a text file in the log directory
    utils.store_revision_info(src_path, res_dir, ' '.join(sys.argv))

    # Store parameters in an HDF5 file
    utils.store_hdf(os.path.join(res_dir, 'parameters.h5'), vars(args))

    # Copy learning rate schedule file to result directory
    learning_rate_schedule = utils.copy_learning_rate_schedule_file(
        args.learning_rate_schedule, res_dir)

    with tf.Session() as sess:

        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)

        filelist = ['train_%03d.pkl' % i for i in range(200)]
        dataset = create_dataset(filelist,
                                 args.data_dir,
                                 buffer_size=20000,
                                 batch_size=args.batch_size,
                                 total_seq_length=args.nrof_init_time_steps +
                                 args.seq_length)

        # Create an iterator over the dataset
        iterator = dataset.make_one_shot_iterator()
        obs, action = iterator.get_next()

        is_pdt_ph = tf.placeholder(tf.bool, [None, args.seq_length])
        is_pdt = create_transition_type_matrix(args.batch_size,
                                               args.seq_length,
                                               args.training_scheme)

        with tf.variable_scope('env_model'):
            env_model = EnvModel(is_pdt_ph,
                                 obs,
                                 action,
                                 1,
                                 model_type=args.model_type,
                                 nrof_time_steps=args.seq_length,
                                 nrof_free_nats=args.nrof_free_nats)

        reg_loss = tf.reduce_mean(env_model.regularization_loss)
        rec_loss = tf.reduce_mean(env_model.reconstruction_loss)
        loss = reg_loss + rec_loss

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate_ph = tf.placeholder(tf.float32, ())
        train_op = tf.train.AdamOptimizer(learning_rate_ph).minimize(
            loss, global_step=global_step)

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())

        stat = {
            'loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'rec_loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'reg_loss': np.zeros((args.max_nrof_steps, ), np.float32),
            'learning_rate': np.zeros((args.max_nrof_steps, ), np.float32),
        }

        try:
            print('Started training')
            rec_loss_tot, reg_loss_tot, loss_tot = (0.0, 0.0, 0.0)
            lr = None
            t = time.time()
            for i in range(1, args.max_nrof_steps + 1):
                if not lr or i % 100 == 0:
                    lr = utils.get_learning_rate_from_file(
                        learning_rate_schedule, i)
                    if lr < 0:
                        break
                stat['learning_rate'][i - 1] = lr
                _, rec_loss_, reg_loss_, loss_ = sess.run(
                    [train_op, rec_loss, reg_loss, loss],
                    feed_dict={
                        is_pdt_ph: is_pdt,
                        learning_rate_ph: lr
                    })
                stat['loss'][i - 1], stat['rec_loss'][i - 1], stat['reg_loss'][
                    i - 1] = loss_, rec_loss_, reg_loss_
                rec_loss_tot += rec_loss_
                reg_loss_tot += reg_loss_
                loss_tot += loss_
                if i % 10 == 0:
                    print(
                        'step: %-5d  time: %-12.3f  lr: %-12.6f  rec_loss: %-12.1f  reg_loss: %-12.1f  loss: %-12.1f'
                        % (i, time.time() - t, lr, rec_loss_tot / 10,
                           reg_loss_tot / 10, loss_tot / 10))
                    rec_loss_tot, reg_loss_tot, loss_tot = (0.0, 0.0, 0.0)
                    t = time.time()
                if i % 5000 == 0 and i > 0:
                    saver.save(sess, model_filename, i)
                if i % 100 == 0:
                    utils.store_hdf(log_filename, stat)

        except tf.errors.OutOfRangeError:
            pass

        print("Saving model...")
        saver.save(sess, model_filename, i)

        print('Done!')
示例#4
0
def main(args):
    network = importlib.import_module(args.model_def)
    image_size = (args.image_size, args.image_size)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    stat_file_name = os.path.join(log_dir, 'stat.h5')

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    random.seed(args.seed)
    dataset = utils.get_dataset(args.data_dir)
    # print(dataset[1].image_paths)
    if args.filter_filename:
        dataset = filter_dataset(dataset, os.path.expanduser(args.filter_filename),
                                 args.filter_percentile, args.filter_min_nrof_images_per_class)

    if args.validation_set_split_ratio > 0.0:
        train_set, val_set = utils.split_dataset(dataset, args.validation_set_split_ratio,
                                                   args.min_nrof_val_images_per_class, 'SPLIT_IMAGES')
    else:
        train_set, val_set = dataset, []

    nrof_classes = len(train_set)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    pretrained_model = None
    if args.pretrained_model:
        pretrained_model = os.path.expanduser(args.pretrained_model)
        print('Pre-trained model: %s' % pretrained_model)

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get a list of image paths and their labels
        image_list, label_list = utils.get_image_paths_and_labels(train_set)
        assert len(image_list) > 0, 'The training set should not be empty'

        val_image_list, val_label_list = utils.get_image_paths_and_labels(val_set)

        learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')
        batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 1), name='image_paths')
        labels_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='labels')
        control_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='control')

        image_batch_plh = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='image_batch_p')
        label_batch_plh = tf.placeholder(tf.int32, name='label_batch_p')

        print('Number of classes in training set: %d' % nrof_classes)
        print('Number of examples in training set: %d' % len(image_list))

        print('Number of classes in validation set: %d' % len(val_set))
        print('Number of examples in validation set: %d' % len(val_image_list))

        print('Building training graph')

        # Build the inference graph
        # prelogits, _ = efficientnet_builder.build_model_base(image_batch_plh, 'efficientnet-b2', training=True)
        prelogits, _ = network.inference(image_batch_plh, args.keep_probability, image_size,
                                       phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
                                         weight_decay=args.weight_decay)
        logits = slim.fully_connected(prelogits, len(train_set), activation_fn=None,
                                      weights_initializer=slim.initializers.xavier_initializer(),
                                      weights_regularizer=slim.l2_regularizer(args.weight_decay),
                                      scope='Logits', reuse=False)

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        # Norm for the prelogits
        eps = 1e-4
        prelogits_norm = tf.reduce_mean(tf.norm(tf.abs(prelogits) + eps, ord=args.prelogits_norm_p, axis=1))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_norm * args.prelogits_norm_loss_factor)

        # Add center loss
        prelogits_center_loss, _ = utils.center_loss(prelogits, label_batch_plh, args.center_loss_alfa, nrof_classes)
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_center_loss * args.center_loss_factor)

        learning_rate = tf.train.exponential_decay(learning_rate_placeholder, global_step,
                                                   args.learning_rate_decay_epochs * args.epoch_size,
                                                   args.learning_rate_decay_factor, staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        # Calculate the average cross entropy loss across the batch
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_batch_plh, logits=logits, name='cross_entropy_per_example')
        cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
        tf.add_to_collection('losses', cross_entropy_mean)

        correct_prediction = tf.cast(tf.equal(tf.argmax(logits, 1), tf.cast(label_batch_plh, tf.int64)), tf.float32)
        accuracy = tf.reduce_mean(correct_prediction, name='accuracy')

        # Calculate the total losses
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')

        # Separate facenet variables from smaug's ones
        facenet_global_vars = tf.global_variables()

        # Build a Graph that trains the model with one batch of examples and updates the model parameters
        train_op = utils.train(total_loss, global_step, args.optimizer,
                                 learning_rate, args.moving_average_decay, facenet_global_vars, args.log_histograms)

        # Create a saver
        facenet_saver_vars = tf.trainable_variables()
        facenet_saver_vars.append(global_step)
        saver = tf.train.Saver(facenet_saver_vars, max_to_keep=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Create session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        # Create normal pipeline
        dataset_train = LabeledImageData(image_list, label_list, sess, batch_size=args.batch_size, shuffle=True,
                                         use_flip=True, use_black_patches=True, use_crop=True)
        dataset_val = LabeledImageDataRaw(val_image_list, val_label_list, sess, batch_size=args.val_batch_size,
                                          shuffle=False)

        # Start running operations on the Graph. Change to tf.compat in newer versions of tf.
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        with sess.as_default():
            if pretrained_model:
                print('Restoring pretrained model: %s' % pretrained_model)
                ckpt_dir_or_file = tf.train.latest_checkpoint(pretrained_model)
                saver.restore(sess, ckpt_dir_or_file)

            # Training and validation loop
            print('Running training')
            nrof_steps = args.max_nrof_epochs * args.epoch_size
            nrof_val_samples = int(math.ceil(
                args.max_nrof_epochs / args.validate_every_n_epochs))  # Validate every validate_every_n_epochs as well as in the last epoch
            stat = {
                'loss': np.zeros((nrof_steps,), np.float32),
                'center_loss': np.zeros((nrof_steps,), np.float32),
                'reg_loss': np.zeros((nrof_steps,), np.float32),
                'xent_loss': np.zeros((nrof_steps,), np.float32),
                'prelogits_norm': np.zeros((nrof_steps,), np.float32),
                'accuracy': np.zeros((nrof_steps,), np.float32),
                'val_loss': np.zeros((nrof_val_samples,), np.float32),
                'val_xent_loss': np.zeros((nrof_val_samples,), np.float32),
                'val_accuracy': np.zeros((nrof_val_samples,), np.float32),
                'lfw_accuracy': np.zeros((args.max_nrof_epochs,), np.float32),
                'lfw_valrate': np.zeros((args.max_nrof_epochs,), np.float32),
                'learning_rate': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_train': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_validate': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_evaluate': np.zeros((args.max_nrof_epochs,), np.float32),
                'prelogits_hist': np.zeros((args.max_nrof_epochs, 1000), np.float32),
                'smaug_alpha_loss': np.zeros((nrof_steps,), np.float32),
                'smaug_total_loss': np.zeros((nrof_steps,), np.float32)
            }
            global_step_ = sess.run(global_step)
            start_epoch = 1 + global_step_ // args.epoch_size
            batch_number = global_step_ % args.epoch_size
            biggest_acc = 0.0
            for epoch in range(start_epoch, args.max_nrof_epochs + 1):
                step = sess.run(global_step, feed_dict=None)
                # Train for one epoch
                t = time.time()
                cont = train(args, sess, epoch, batch_number,
                             learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder,
                             image_batch_plh, label_batch_plh, global_step,
                             total_loss, train_op, summary_op, summary_writer, regularization_losses,
                             args.learning_rate_schedule_file,
                             stat, cross_entropy_mean, accuracy, learning_rate, prelogits, prelogits_center_loss,
                             prelogits_norm, args.prelogits_hist_max, dataset_train, )
                stat['time_train'][epoch - 1] = time.time() - t
                print("------------------Accuracy-----------------" + str(stat['val_accuracy']))
                if not cont:
                    break

                t = time.time()
                if len(val_image_list) > 0 and ((epoch - 1) % args.validate_every_n_epochs == args.validate_every_n_epochs - 1 or epoch == args.max_nrof_epochs):
                    validate(args, sess, epoch, val_label_list, phase_train_placeholder, batch_size_placeholder,
                             stat, total_loss, cross_entropy_mean, accuracy, args.validate_every_n_epochs,
                             image_batch_plh, label_batch_plh, dataset_val)
                stat['time_validate'][epoch - 1] = time.time() - t

                cur_val_acc = get_val_acc(epoch, stat, args.validate_every_n_epochs)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, epoch, args.save_every,
                                             cur_val_acc, biggest_acc, args.save_best)

                biggest_acc = update_biggest_acc(biggest_acc, cur_val_acc)

                print('Saving statistics')
                with h5py.File(stat_file_name, 'w') as f:
                    for key, value in stat.items():
                        f.create_dataset(key, data=value)

    return model_dir
示例#5
0
def main(args):

    src_path,_ = os.path.split(os.path.realpath(__file__))

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    res_dir = os.path.join(os.path.expanduser(args.output_base_dir), subdir)
    if not os.path.isdir(res_dir):  # Create the log directory if it doesn't exist
        os.makedirs(res_dir)
        
    # Store some git revision info in a text file in the log directory
    utils.store_revision_info(src_path, res_dir, ' '.join(sys.argv))
    
    # Store parameters in an HDF5 file
    utils.store_hdf(os.path.join(res_dir, 'parameters.h5'), vars(args))
    
    # Create statistics object
    stat_filename = os.path.join(res_dir, 'stat.h5')
    stat = utils.Stat(stat_filename)

    with tf.Graph().as_default():
        tf.compat.v1.random.set_random_seed(args.seed)
        np.random.seed(args.seed)
    
        ###########################################
        """             Load Data               """
        ###########################################
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        nrof_train_batches = int(np.ceil(x_train.shape[0] / args.batch_size))
        nrof_test_batches = int(np.ceil(x_test.shape[0] / args.batch_size))
        input_dims = (x_train.shape[1], x_train.shape[2], 1)
        train_iterator = create_dataset(x_train, y_train, args.batch_size)
        test_iterator = create_dataset(x_test, y_test, args.batch_size)
        xtrain, ytrain = train_iterator.get_next() #@UnusedVariable
        xtest, ytest = test_iterator.get_next() #@UnusedVariable
    
        ###########################################
        """        Build Model Graphs           """
        ###########################################
        with tf.compat.v1.variable_scope("vae"):
    
            warmup_temp = tf.compat.v1.placeholder(tf.float32, shape=(), name="warmup_temp")
    
            if args.model_type=='VAE':
                m = VAE(input_dims, args.learning_rate, warmup_temp, to_list(args.nrof_stochastic_units), to_list(args.nrof_mlp_units))
            elif args.model_type=='LVAE':
                m = LVAE(input_dims, args.learning_rate, warmup_temp, to_list(args.nrof_stochastic_units), to_list(args.nrof_mlp_units))
            else:
                raise ValueError('Invalid model type')
            print('Building train graph...')
            train_op, train_o, train_dbg = m.build_graph(xtrain, is_training=True)

            print('Building evaluation graph...')
            _, eval_o, eval_dbg = m.build_graph(xtest, is_training=False) #@UnusedVariable

        init_op = tf.compat.v1.global_variables_initializer()
    
        sess  = tf.compat.v1.InteractiveSession()
        sess.run(init_op)
        sess.run(train_iterator.initializer)
        sess.run(test_iterator.initializer)
    
        print('... start training')
        for epoch in range(1, args.nrof_epochs+1):
    
            # Get warm-up temperature
            temp = get_warmup_temp(epoch, args.nrof_warmup_epochs)
    
            o_list = []
            start_time = time.time()
            for _ in range(nrof_train_batches):
                feed_dict = {warmup_temp: temp}
                o, dbg, _ = sess.run([train_o, train_dbg, train_op], feed_dict=feed_dict) #@UnusedVariable
                o_list += [ flatten(o) ]
                
            o_mean = mean(o_list)
            stat.add(add_prefix('train_', o_mean))
                
                #if is_nan_or_inf(dbg.values()) or is_nan_or_inf(o.values()):
                #    xxx = 1 #@UnusedVariable
    
            print(' epoch: %5d  time: %6.3f   temp: %10.3f  elbo: %10.3f   log p(x): %10.3f   log p(z): %8.3f | %8.3f  log q(z): %8.3f | %8.3f  KL(q(z|x)||p(z)): %8.3f | %8.3f' % \
                  (epoch, time.time()-start_time, temp, o_mean['elbo'], o_mean['log_px'], o_mean['log_pz_0'], o_mean['log_pz_1'], o_mean['log_qz_0'], o_mean['log_qz_1'], o_mean['kl_0'], o_mean['kl_1'] ))
            
            # Evaluate every n epochs
            if epoch % args.eval_every_n_epochs == 0:
                o_list = []
                start_time = time.time()
                for _ in range(nrof_test_batches):
                    feed_dict = {warmup_temp: 1.0}
                    o, dbg = sess.run([eval_o, eval_dbg], feed_dict=feed_dict) #@UnusedVariable
                    o_list += [ flatten(o) ]
        
                o_mean = mean(o_list)
                stat.add(add_prefix('eval_', o_mean))
                if args.display_eval:
                    print('*epoch: %5d  time: %6.3f   temp: %10.3f  elbo: %10.3f   log p(x): %10.3f   log p(z): %8.3f | %8.3f  log q(z): %8.3f | %8.3f  KL(q(z|x)||p(z)): %8.3f | %8.3f' % \
                          (epoch, time.time()-start_time, 1.0, o_mean['elbo'], o_mean['log_px'], o_mean['log_pz_0'], o_mean['log_pz_1'], o_mean['log_qz_0'], o_mean['log_qz_1'], o_mean['kl_0'], o_mean['kl_1'] ))

            # Store statistics
            stat.store()
示例#6
0
def main_train(args):

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)
    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    train_set = utils.dataset_from_list(
        args.train_data_dir, args.train_list_dir)  # class objects in a list

    #----------------------class definition-------------------------------------
    '''
    class ImageClass():
    "Stores the paths to images for a given class"
    def __init__(self, name, image_paths):
        self.name = name
        self.image_paths = image_paths
  
    def __str__(self):
        return self.name + ', ' + str(len(self.image_paths)) + ' images'
  
    def __len__(self):
        return len(self.image_paths)
    '''

    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    print('total images: ', len(image_list))  # label is in the form scalar.
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)
    dataset_size = len(image_list)
    single_batch_size = args.class_per_batch * args.images_per_class
    indices = list(range(dataset_size))
    np.random.shuffle(indices)

    def _sample_people_softmax(x):  # loading the images in batches.
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_images = []

        for item in sample_paths:
            sample_images.append(np.load(str(item)))
            #print(item)
        #print(type(sample_paths[0]))
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]
        softmax_ind += true_num_batch
        return (np.expand_dims(np.array(sample_images, dtype=np.float32),
                               axis=4), np.array(sample_labels,
                                                 dtype=np.int32))

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf.data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.float32, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.batch(single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (single_batch_size, args.image_height, args.image_width,
                 args.image_width, 1))
            softmax_next_element[1].set_shape(single_batch_size)
            batch_image_split = softmax_next_element[0]
            # batch_image_split = tf.expand_dims(batch_image_split, axis = 4)
            batch_label_split = softmax_next_element[1]

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'SGD':
            opt = tf.train.GradientDescentOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        elif args.optimizer == 'ADAM':
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=0.9,
                                         beta2=0.999,
                                         epsilon=0.1)
        else:
            raise Exception("Not supported optimizer: {}".format(
                args.optimizer))

        losses = {}
        with slim.arg_scope([slim.model_variable, slim.variable],
                            device="/cpu:0"):
            with tf.variable_scope(tf.get_variable_scope()) as var_scope:
                reuse = False

                if args.network == 'sphere_network':

                    prelogits = network.infer(batch_image_split,
                                              args.embedding_size)
                else:
                    raise Exception("Not supported network: {}".format(
                        args.network))

                if args.fc_bn:
                    prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,\
                        updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn')

                if args.loss_type == 'softmax':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = cross_entropy_mean + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the softmax loss')
                    losses['total_loss'] = cross_entropy_mean
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmcl':
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = coco_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the lmcl loss')
                    losses['total_loss'] = coco_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'center':
                    # center loss
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = center_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the center loss')
                    losses['total_loss'] = center_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmccl':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    reg_loss = args.weight_decay * tf.add_n(
                        regularization_losses)
                    loss = coco_loss + reg_loss + args.center_weighting * center_loss + cross_entropy_mean
                    losses[
                        'total_loss_center'] = args.center_weighting * center_loss
                    losses['total_loss_lmcl'] = coco_loss
                    losses['total_loss_softmax'] = cross_entropy_mean
                    losses['total_reg'] = reg_loss

        grads = opt.compute_gradients(loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # used for updating the centers in the center loss.
        if args.loss_type == 'lmccl' or args.loss_type == 'center':
            with tf.control_dependencies([centers_update_op]):
                with tf.control_dependencies(update_ops):
                    train_op = tf.group(apply_gradient_op)
        else:
            with tf.control_dependencies(update_ops):
                train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch, image_batch_gather,\
                     enqueue_op,batch_size_placeholder, image_batch_split,image_paths_split,num_per_class_split,
                            image_paths_placeholder,image_paths_split_placeholder, labels_placeholder, labels_batch,\
                             num_per_class_placeholder,num_per_class_split_placeholder,len(gpus))
                # Train for one epoch
                if args.loss_type == 'lmccl' or args.loss_type == 'center':
                    train_contain_center(args, sess, epoch,
                                         learning_rate_placeholder,
                                         phase_train_placeholder, global_step,
                                         losses, train_op, summary_op,
                                         summary_writer, '', centers_update_op)
                else:
                    train(args, sess, epoch, learning_rate_placeholder,
                          phase_train_placeholder, global_step, losses,
                          train_op, summary_op, summary_writer, '')
                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
    return model_dir