Exemplo n.º 1
0
def main():
    with tf.Graph().as_default():
        args = parser.parse_args()  
        batch_size = args.batch_size 
        content_image_filenames = list(absoluteFilePaths(args.content_image_dir))
        style_image_filenames = list(absoluteFilePaths(args.style_image_dir))
        
        content_dataset = tf.data.Dataset.from_tensor_slices(tf.constant(content_image_filenames))
        content_dataset = content_dataset.map(read_image, num_parallel_calls=4)
        content_dataset = content_dataset.map(resize_content_image, num_parallel_calls=4)
        content_dataset = content_dataset.shuffle(1000) 
        content_dataset = content_dataset.batch(batch_size)
        content_dataset.prefetch(1)
        content_iterator = content_dataset.make_one_shot_iterator()
        content_batch = content_iterator.get_next()
        
        style_dataset = tf.data.Dataset.from_tensor_slices(style_image_filenames)
        style_dataset = style_dataset.map(read_image, num_parallel_calls=4)
        style_dataset = style_dataset.map(augment_image, num_parallel_calls=4)
        style_dataset = style_dataset.shuffle(1000) 
        style_dataset = style_dataset.batch(batch_size)
        style_dataset.prefetch(1)
        style_iterator = style_dataset.make_one_shot_iterator()
        style_batch = style_iterator.get_next()

        with slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
            with tf.name_scope("content_endpoints"):
                _, content_endpoints = mobilenet_v2.mobilenet(tf.image.resize_images(content_batch, [224, 224]))
            with tf.name_scope("style_input_endpoints"):
                _, style_input_endpoints = mobilenet_v2.mobilenet(tf.image.resize_images(style_batch, [224, 224]))
            
            style_params = model.style_prediction_network(style_batch,style_input_endpoints["layer_18/output"])
            stylized_image = model.style_transformer_network(content_batch, style_params)

            with tf.name_scope("stylized_image_endpoints"):
                _, stylized_image_endpoints = mobilenet_v2.mobilenet(tf.image.resize_images(stylized_image, [224, 224]))
        loss = losses.total_loss(CONTENT_WEIGHT, content_batch, STYLE_WEIGHT, style_batch, stylized_image, TV_WEIGHT) 
        
        ema = tf.train.ExponentialMovingAverage(0.999)
        vars = ema.variables_to_restore()
        saver = tf.train.Saver(vars)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer()) 
            saver.restore(sess, args.mobile_net)
            loss = sess.run(loss)
Exemplo n.º 2
0
    def train_loop(self):
        with tf.GradientTape() as tape:
            pred = self.model(self.feature_tensor)

            # get the embedding of the nodes
            Z = self.model.getZ()
            self.Z_np = self.model.getZ().numpy()

            self.X2_np = self.model.getX2().numpy()

            # calculate the loss
            loss = total_loss(self.y_actual, pred, self.F_tensor,
                              self.S_tensor, Z, self.gamma, self.eta)

            # get the gradients
            grad = tape.gradient(loss, self.model.trainable_variables)

        # update the weights of the model by using the precendetly calculated gradients
        self.optimizer.apply_gradients(
            zip(grad, self.model.trainable_variables))

        return pred[0], pred[1]
Exemplo n.º 3
0
def train():
    """Training"""
    opt = FLAGS
    
    tf.logging.info("Build CleanNet...")
    batch_size = opt.batch_size_sup + opt.batch_size_unsup
    model = CleanNet(opt.num_ref, opt.img_dim, opt.embed_norm, opt.dropout_rate, opt.weight_decay)

    # phi_s: class embedding (batch_size, embed_size)
    # v_q: query image feature (batch_size, img_dim)
    # phi_q: query embedding (batch_size, embed_size)
    # v_qr: reconstructed query image feature (batch_size, img_dim)
    phi_s, v_q, phi_q, v_qr = model.forward(is_training=True)
    
    # verification labels
    vlabel = tf.placeholder(tf.float32, shape=(None,), name="vlabel")
    
    # verification flags indicating a sample is for supervised(1) or unsupervised(0) training
    vflag = tf.placeholder(tf.float32, shape=(None,), name="vflag")
    
    cos_sim = similarity(phi_s, phi_q)

    acc = accuracy(vlabel[:opt.batch_size_sup], cos_sim[:opt.batch_size_sup], threshold=0.1, scope="train_acc")
    val_acc = accuracy(vlabel, cos_sim, threshold=opt.val_sim_thres, scope="val_acc_at_{}".format(opt.val_sim_thres))
    tf.summary.scalar('train/accuracy', acc)
    
    objective_loss = tf.reduce_mean(total_loss(vlabel, cos_sim, phi_s, v_q, phi_q, v_qr, vflag, opt.neg_weight, beta=0.1, gamma=0.1))
    tf.summary.scalar('train/objective_loss', objective_loss)
    regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    tf.summary.scalar('train/regularization_loss', regularization_loss)
    loss = objective_loss + regularization_loss
    tf.summary.scalar('train/loss', loss)

    lr = tf.train.exponential_decay(opt.learning_rate, model.global_step, opt.lr_update, opt.lr_decay, staircase=True)
    tf.summary.scalar('train/lr', lr)
    merged = tf.summary.merge_all()

    optimizer = tf.train.MomentumOptimizer(lr, opt.momentum)
    train_op = optimizer.minimize(loss, global_step=model.global_step)

    tf.logging.info("Get data batcher...")
    supervised_data = data_provider_factory.get_data_batcher('trainval', 'train', opt)
    val_data = data_provider_factory.get_data_batcher('trainval', 'val', opt)
    if opt.batch_size_unsup > 0:
        unsupervised_data = data_provider_factory.get_data_batcher('trainval', 'unverified', opt)

    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        
        train_summary_writer = tf.summary.FileWriter(opt.log_dir + '/train', sess.graph)
        val_summary_writer = tf.summary.FileWriter(opt.log_dir + '/val')

        cur_step = 0
        best_avg_val_acc = 0.0
        sess.run(init_op)

        # recover from latest checkpoint and run validation if available
        ckpt = tf.train.get_checkpoint_state(opt.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt.model_checkpoint_path)
            saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
            cur_step, avg_val_acc = validation(sess, model, loss, val_acc, vlabel, vflag, opt.val_batch_size, val_data, val_summary_writer)
            best_avg_val_acc = avg_val_acc
            tf.logging.info("Recover model at global step = %d.", cur_step)
        else:
            tf.logging.info("Training from scratch.")

        while cur_step < opt.n_step:
            # data for supervised training
            _, batch_vlabel, batch_q, batch_vflag, batch_ref = supervised_data.get_batch(batch_size=opt.batch_size_sup)

            # data for unsupervised training
            if opt.batch_size_unsup > 0:
                # ubatch_vlabel_u is a dummy zero tensor since unsupervised samples don't have verification labels
                _, ubatch_vlabel_u, ubatch_q, ubatch_vflag, ubatch_ref = unsupervised_data.get_batch(batch_size=opt.batch_size_unsup)

                # concate supervised and unsupervied training data
                batch_vlabel = np.concatenate([batch_vlabel, ubatch_vlabel_u], axis=0)
                batch_q = np.concatenate([batch_q, ubatch_q], axis=0)
                batch_vflag = np.concatenate([batch_vflag, ubatch_vflag], axis=0)
                batch_ref = np.concatenate([batch_ref, ubatch_ref], axis=0)

            _, cur_step, cur_loss, cur_acc, summary = sess.run([train_op, model.global_step, loss, acc, merged], 
                   feed_dict={model.reference: batch_ref, 
                              model.query: batch_q, 
                              vlabel: batch_vlabel,
                              vflag: batch_vflag})

            train_summary_writer.add_summary(summary, cur_step)

            if cur_step % opt.log_interval == 0:
                tf.logging.info('step {}: train/loss = {}, train/acc = {}'.format(cur_step, cur_loss, cur_acc))
            if cur_step % opt.val_interval == 0 and cur_step != 0:
                _, avg_val_acc = validation(sess, model, loss, val_acc, vlabel, vflag, opt.val_batch_size, val_data, val_summary_writer)
                if not os.path.exists(opt.checkpoint_dir):
                    os.mkdir(opt.checkpoint_dir)
                save_path = saver.save(sess, opt.checkpoint_dir)
                print("Model saved in path: %s" % save_path)
                if avg_val_acc > best_avg_val_acc:
                    best_avg_val_acc = avg_val_acc
                    model_path = os.path.join(save_path, "checkpoint")
                    best_model_path = os.path.join(save_path, "best_model_{}".format(cur_step))
                    shutil.copy(model_path, best_model_path)
                    print("Best model saved in path: %s" % best_model_path)
Exemplo n.º 4
0
def train():
    """Training"""
    opt = FLAGS

    tf.logging.info("Build CleanNet...")
    batch_size = opt.batch_size_sup + opt.batch_size_unsup
    model = CleanNet(opt.num_ref, opt.img_dim, opt.embed_norm,
                     opt.dropout_rate, opt.weight_decay)

    # phi_s: class embedding (batch_size, embed_size)
    # v_q: query image feature (batch_size, img_dim)
    # phi_q: query embedding (batch_size, embed_size)
    # v_qr: reconstructed query image feature (batch_size, img_dim)
    phi_s, v_q, phi_q, v_qr = model.forward(is_training=True)

    # verification labels
    vlabel = tf.placeholder(tf.float32, shape=(None, ), name="vlabel")

    # verification flags indicating a sample is for supervised(1) or unsupervised(0) training
    vflag = tf.placeholder(tf.float32, shape=(None, ), name="vflag")

    cos_sim = similarity(phi_s, phi_q)

    acc = accuracy(vlabel[:opt.batch_size_sup],
                   cos_sim[:opt.batch_size_sup],
                   threshold=0.1,
                   scope="train_acc")
    val_acc = accuracy(vlabel,
                       cos_sim,
                       threshold=opt.val_sim_thres,
                       scope="val_acc_at_{}".format(opt.val_sim_thres))
    tf.summary.scalar('train/accuracy', acc)

    objective_loss = tf.reduce_mean(
        total_loss(vlabel,
                   cos_sim,
                   phi_s,
                   v_q,
                   phi_q,
                   v_qr,
                   vflag,
                   opt.neg_weight,
                   beta=0.1,
                   gamma=0.1))
    tf.summary.scalar('train/objective_loss', objective_loss)
    regularization_loss = tf.reduce_sum(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    tf.summary.scalar('train/regularization_loss', regularization_loss)
    loss = objective_loss + regularization_loss
    tf.summary.scalar('train/loss', loss)

    lr = tf.train.exponential_decay(opt.learning_rate,
                                    model.global_step,
                                    opt.lr_update,
                                    opt.lr_decay,
                                    staircase=True)
    tf.summary.scalar('train/lr', lr)
    merged = tf.summary.merge_all()

    optimizer = tf.train.MomentumOptimizer(lr, opt.momentum)
    train_op = optimizer.minimize(loss, global_step=model.global_step)

    tf.logging.info("Get data batcher...")
    supervised_data = data_provider_factory.get_data_batcher(
        'trainval', 'train', opt)
    val_data = data_provider_factory.get_data_batcher('trainval', 'val', opt)
    if opt.batch_size_unsup > 0:
        unsupervised_data = data_provider_factory.get_data_batcher(
            'trainval', 'unverified', opt)

    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:

        train_summary_writer = tf.summary.FileWriter(opt.log_dir + '/train',
                                                     sess.graph)
        val_summary_writer = tf.summary.FileWriter(opt.log_dir + '/val')

        cur_step = 0
        best_avg_val_acc = 0.0
        sess.run(init_op)

        # recover from latest checkpoint and run validation if available
        ckpt = tf.train.get_checkpoint_state(opt.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt.model_checkpoint_path)
            saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
            cur_step, avg_val_acc = validation(sess, model, loss, val_acc,
                                               vlabel, vflag,
                                               opt.val_batch_size, val_data,
                                               val_summary_writer)
            best_avg_val_acc = avg_val_acc
            tf.logging.info("Recover model at global step = %d.", cur_step)
        else:
            tf.logging.info("Training from scratch.")

        while cur_step < opt.n_step:
            # data for supervised training
            _, batch_vlabel, batch_q, batch_vflag, batch_ref = supervised_data.get_batch(
                batch_size=opt.batch_size_sup)

            # data for unsupervised training
            if opt.batch_size_unsup > 0:
                # ubatch_vlabel_u is a dummy zero tensor since unsupervised samples don't have verification labels
                _, ubatch_vlabel_u, ubatch_q, ubatch_vflag, ubatch_ref = unsupervised_data.get_batch(
                    batch_size=opt.batch_size_unsup)

                # concate supervised and unsupervied training data
                batch_vlabel = np.concatenate([batch_vlabel, ubatch_vlabel_u],
                                              axis=0)
                batch_q = np.concatenate([batch_q, ubatch_q], axis=0)
                batch_vflag = np.concatenate([batch_vflag, ubatch_vflag],
                                             axis=0)
                batch_ref = np.concatenate([batch_ref, ubatch_ref], axis=0)

            _, cur_step, cur_loss, cur_acc, summary = sess.run(
                [train_op, model.global_step, loss, acc, merged],
                feed_dict={
                    model.reference: batch_ref,
                    model.query: batch_q,
                    vlabel: batch_vlabel,
                    vflag: batch_vflag
                })

            train_summary_writer.add_summary(summary, cur_step)

            if cur_step % opt.log_interval == 0:
                tf.logging.info(
                    'step {}: train/loss = {}, train/acc = {}'.format(
                        cur_step, cur_loss, cur_acc))
            if cur_step % opt.val_interval == 0 and cur_step != 0:
                _, avg_val_acc = validation(sess, model, loss, val_acc, vlabel,
                                            vflag, opt.val_batch_size,
                                            val_data, val_summary_writer)
                if not os.path.exists(opt.checkpoint_dir):
                    os.mkdir(opt.checkpoint_dir)
                save_path = saver.save(sess, opt.checkpoint_dir)
                print("Model saved in path: %s" % save_path)
                if avg_val_acc > best_avg_val_acc:
                    best_avg_val_acc = avg_val_acc
                    model_path = os.path.join(save_path, "checkpoint")
                    best_model_path = os.path.join(
                        save_path, "best_model_{}".format(cur_step))
                    shutil.copy(model_path, best_model_path)
                    print("Best model saved in path: %s" % best_model_path)
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        ltoday, mtoday, htoday, tomorrow, _, _, _, _, _ = rec.data_inputs(
            FLAGS.train_input_path, FLAGS.train_batch_size, conf.shape_dict,
            30, False, False)
        predictions, _, _, _ = cnn_branches.cnn_with_branch(
            ltoday, mtoday, htoday, conf.HEIGHT * conf.HIGH_WIDTH,
            FLAGS.train_batch_size)
        reality = tf.reshape(tomorrow, predictions.get_shape())
        mse = losses.mse_loss(predictions, reality)
        loss = losses.total_loss(predictions, reality, losses.main_loss)
        train_step = ut.train(loss, global_step,
                              conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
        saver = tf.train.Saver(tf.global_variables())
        summary_op = tf.summary.merge_all()

        init = tf.global_variables_initializer()
        coord = tf.train.Coordinator()
        sess = tf.Session()
        #tf_debug.add_debug_tensor_watch(sess,'l_conv1')
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess,)
        sess.run(init)

        tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        loss_list = []
        mse_list = []
        total_loss_list = []

        for step in xrange(FLAGS.epoch *
                           conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + 1):
            start_time = time.time()
            _, loss_val, mse_loss = sess.run([train_step, loss, mse])
            duration = time.time() - start_time

            assert not np.isnan(loss_val), 'Model diverged with loss = NaN'
            loss_list.append(loss_val)
            mse_list.append(mse_loss)

            if step % conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN == 0:
                num_examples_per_step = FLAGS.train_batch_size
                examples_per_sec = 0  #num_examples_per_step / duration
                sec_per_batch = float(duration)
                average_loss_value = np.mean(loss_list)
                average_mse_value = np.mean(mse_list)
                total_loss_list.append(average_loss_value)
                loss_list.clear()
                format_str = (
                    '%s: epoch %d, loss = %.4f , mse = %.4f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(
                    format_str %
                    (datetime.now(), step /
                     conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN, average_loss_value,
                     average_mse_value, examples_per_sec, sec_per_batch))
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
            if step % (conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 10 + 1) == 0:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        matlab.save_matrix(FLAGS.train_dir + 'cnn_branch_loss.mat',
                           total_loss_list, 'cnn_branch_loss')
Exemplo n.º 6
0
def train():
    """Train CIFAR-10 for a number of steps."""

    with tf.device('/cpu:0'), tf.name_scope('input'):
        global_step = tf.train.get_or_create_global_step()

        print("Loading CIFAR-10 Data")
        cifar10 = data.Cifar10()

        images_placeholder = tf.placeholder(
            cifar10.train_images.dtype,
            (None, params.IMAGE_SIZE, params.IMAGE_SIZE, params.CHANNELS),
            name='images_placeholder')
        labels_placeholder = tf.placeholder(cifar10.train_labels.dtype,
                                            (None, ),
                                            name='labels_placeholder')

        train_data_dict = {
            images_placeholder: cifar10.train_images,
            labels_placeholder: cifar10.train_labels
        }
        test_data_dict = {
            images_placeholder: cifar10.test_images,
            labels_placeholder: cifar10.test_labels
        }

        training_dataset = tf.data.Dataset.from_tensor_slices(
            (images_placeholder, labels_placeholder))
        training_dataset = training_dataset.prefetch(params.SHUFFLE_BUFFER)
        training_dataset = training_dataset.map(
            data.randomization_function, num_parallel_calls=params.NUM_THREADS)
        training_dataset = training_dataset.shuffle(
            buffer_size=params.SHUFFLE_BUFFER)
        training_dataset = training_dataset.batch(params.BATCH_SIZE)
        training_dataset = training_dataset.repeat()
        training_dataset = training_dataset.prefetch(
            params.TRAIN_OUTPUT_BUFFER)

        validation_dataset = tf.data.Dataset.from_tensor_slices(
            (images_placeholder, labels_placeholder))
        validation_dataset = validation_dataset.map(
            data.standardization_function,
            num_parallel_calls=params.NUM_THREADS)
        validation_dataset = validation_dataset.batch(params.BATCH_SIZE)
        validation_dataset = validation_dataset.prefetch(
            params.VALIDATION_OUTPUT_BUFFER)

        iterator = tf.contrib.data.Iterator.from_structure(
            training_dataset.output_types, training_dataset.output_shapes)
        next_element = iterator.get_next()
        training_init_op = iterator.make_initializer(training_dataset)
        validation_init_op = iterator.make_initializer(validation_dataset)

        training_placeholder = tf.placeholder_with_default(
            False, (), name='training_placeholder')

    print("Building TensorFlow Graph")

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = net.inference(next_element[0], training=training_placeholder)

    # Calculate loss.
    total_loss = losses.total_loss(logits, next_element[1])

    with (tf.name_scope('accuracy')):
        correct = tf.nn.in_top_k(logits, next_element[1], 1)
        number_correct = tf.reduce_sum(tf.cast(correct, tf.int32))

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = create_train_op(total_loss, global_step)

    init = tf.global_variables_initializer()

    print("Starting TensorFlow Session")

    saver = tf.train.Saver()
    tf_file_writer = tf.summary.FileWriter(params.TRAIN_DIR,
                                           tf.get_default_graph())
    merged_summary = tf.summary.merge_all()
    csv_file_dir = os.path.join(params.TRAIN_DIR, 'log.csv')
    with tf.Session() as sess, open(csv_file_dir, 'w', newline='') as log_file:
        log_writer = csv.writer(log_file)
        log_writer.writerow(["Step", "Train Error", "Test Error", "Step Time"])

        print("Initializing Global Variables")
        init.run()

        print("Training in Progress")
        while global_step.eval() < params.TRAIN_STEPS:
            # Run a number of training steps set by params.LOG_FREQUENCY
            training_init_op.run(feed_dict=(train_data_dict))
            start_time = time.perf_counter()
            for _ in range(0, params.LOG_FREQUENCY):
                batch_loss, summary_str = sess.run(
                    [train_op, total_loss, merged_summary],
                    feed_dict={training_placeholder: True})[1:]
            end_time = time.perf_counter()
            average_time_per_step = (end_time -
                                     start_time) / params.LOG_FREQUENCY

            # Write a summary of the last training batch for TensorBoard
            tf_file_writer.add_summary(summary_str, global_step.eval())

            # Calculate error rate based on the full train set.
            validation_init_op.run(feed_dict=train_data_dict)
            total_correct = 0
            n_train_validation_steps = math.ceil(params.NUM_TRAIN_EXAMPLES /
                                                 params.BATCH_SIZE)
            for _ in range(0, n_train_validation_steps):
                total_correct += number_correct.eval()
            train_error_rate = 1.0 - total_correct / params.NUM_TRAIN_EXAMPLES

            # Calculate error rate based on the full test set.
            validation_init_op.run(feed_dict=test_data_dict)
            total_correct = 0
            n_test_validation_steps = math.ceil(params.NUM_TEST_EXAMPLES /
                                                params.BATCH_SIZE)
            for _ in range(0, n_test_validation_steps):
                total_correct += number_correct.eval()
            test_error_rate = 1.0 - total_correct / params.NUM_TEST_EXAMPLES

            print("Step:", global_step.eval())
            print("  Train Set Error Rate:", train_error_rate)
            print("  Test Set Error Rate:", test_error_rate)
            print("  Average Training Time per Step:", average_time_per_step)
            log_writer.writerow([
                global_step.eval(), train_error_rate, test_error_rate,
                average_time_per_step
            ])
        saver.save(sess, os.path.join(params.TRAIN_DIR, "model.ckpt"))
    tf_file_writer.close()