コード例 #1
0
ファイル: mldock_gnn.py プロジェクト: Matrix-Groups/pharml
def build_optimizer(args,loss_op,num_train_items):
    # Optimizer.
    banner_print("Building optimizer.")
    global_step = tf.train.get_or_create_global_step()
    lr_init = float(RANKS*args.lr_init)
    if args.use_clr:
        import clr
        if RANK == 0:
            print("Using Cyclic LR with initial LR: ", lr_init)
        step_sz = num_train_items/args.batch_size
        max_steps = args.epochs*(num_train_items/args.batch_size)
        lr_decay = tf.train.exponential_decay(lr_init, global_step, max_steps*10, 0.9, staircase=False)
        learning_rate = clr.cyclic_learning_rate(global_step=global_step, learning_rate=lr_decay, max_lr=100*lr_decay,
                                                 step_size=2*step_sz, mode='triangular', gamma=.999)

    else:
        learning_rate = lr_init
        if RANK == 0:
            print("Using constant LR: ", learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate) 
    #optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    #optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
    if args.hvd:
        import horovod.tensorflow as hvd
        compression = hvd.Compression.fp16 if args.use_fp16 else hvd.Compression.none
        
        optimizer = hvd.DistributedOptimizer(optimizer, use_locking=False, compression=compression, op=hvd.Average)
        optimizer._learning_rate = tf.cast(learning_rate,tf.float32)
    else:
        optimizer._learning_rate = tf.cast(learning_rate,tf.float32)
    step_op = optimizer.minimize(loss_op,global_step)
    # Return the optimizer and the step_op
    return optimizer, step_op
コード例 #2
0
 def test_triangular2(self):
     step = 5
     lr = 0.01
     max_lr = 0.1
     step_size = 20.
     cyclic_lr = learning_rate_decay.cyclic_learning_rate(
         step, lr, max_lr, step_size, mode='triangular2')
     expected = self.np_cyclic_learning_rate(step,
                                             lr,
                                             max_lr,
                                             step_size,
                                             mode='triangular2')
     self.assertAllClose(self.evaluate(cyclic_lr), expected, 1e-6)
コード例 #3
0
ファイル: opt.py プロジェクト: zhoufeng-coder/MGADAE
    def __init__(self, model, preds, labels, lr, num_u, num_v,
                 association_nam):
        norm = num_u * num_v / float((num_u * num_v - association_nam) * 2)
        preds_sub = preds
        labels_sub = labels
        pos_weight = float(num_u * num_v - association_nam) / (association_nam)

        global_step = tf.Variable(0, trainable=False)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=cyclic_learning_rate(global_step=global_step,
                                               learning_rate=lr * 0.1,
                                               max_lr=lr,
                                               mode='exp_range',
                                               gamma=.995))

        self.cost = norm * tf.reduce_mean(
            tf.nn.weighted_cross_entropy_with_logits(
                logits=preds_sub, targets=labels_sub, pos_weight=pos_weight))

        self.opt_op = self.optimizer.minimize(
            self.cost,
            global_step=global_step,
        )
        self.grads_vars = self.optimizer.compute_gradients(self.cost)
コード例 #4
0
Theta1 = tf.Variable(tf.random_uniform([21, 15], -1, 1), name="Theta1")
Theta2 = tf.Variable(tf.random_uniform([15, 15], -1, 1), name="Theta2")
Theta3 = tf.Variable(tf.random_uniform([15, 3], -1, 1), name="Theta3")

Bias1 = tf.Variable(tf.zeros([15]), name="Bias1")
Bias2 = tf.Variable(tf.zeros([15]), name="Bias2")
Bias3 = tf.Variable(tf.zeros([3]), name="Bias3")

A2 = tf.sigmoid(tf.matmul(x_, Theta1) + Bias1)
A3 = tf.sigmoid(tf.matmul(A2, Theta2) + Bias2)
A4 = tf.sigmoid(tf.matmul(A3, Theta3) + Bias3)

cost = tf.reduce_mean(((y_ * tf.log(A4)) + ((1 - y_) * tf.log(1.0 - A4))) * -1)

gs = tf.train.create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=clr.cyclic_learning_rate(
    global_step=gs, mode='triangular2')).minimize(cost)

input_X = pd.read_excel(r'C:\Users\dwije\Downloads\data_normalized.xlsx',
                        sheet_name='upinput_norm')
input_Y = pd.read_excel(r'C:\Users\dwije\Downloads\data_normalized.xlsx',
                        sheet_name='upout')  #Loading the output of size mx1
test_input = pd.read_excel(r'C:\Users\dwije\Downloads\data_normalized.xlsx',
                           sheet_name='test_norm')
test_out = pd.read_excel(r'C:\Users\dwije\Downloads\data_normalized.xlsx',
                         sheet_name='testout2').to_numpy().flatten()
print(test_out.shape)

init = tf.global_variables_initializer()
sess = tf.Session()

writer = tf.summary.FileWriter("./logs/xor_logs", sess.graph)
コード例 #5
0
ファイル: experiment.py プロジェクト: DerekChia/mltagger
def run_experiment(config_path):
    config = parse_config("config", config_path)
    temp_model_path = config_path + ".model"
    if "random_seed" in config:
        random.seed(config["random_seed"])
        numpy.random.seed(config["random_seed"])

    # To print everything in config - not needed for now
    # for key, val in config.items():
    #     print(str(key) + ": " + str(val))

    data_train, data_dev, data_test = None, None, None
    if config["path_train"] != None and len(config["path_train"]) > 0:
        data_train = read_input_files(config["path_train"], config["max_train_sent_length"])
    if config["path_dev"] != None and len(config["path_dev"]) > 0:
        data_dev = read_input_files(config["path_dev"])
    if config["path_test"] != None and len(config["path_test"]) > 0:
        data_test = []
        for path_test in config["path_test"].strip().split(":"):
            data_test += read_input_files(path_test)
    
    model = MLTModel(config)
    model.build_vocabs(data_train, data_dev, data_test, config["preload_vectors"])
    model.construct_network()
    model.initialize_session()
    if config["preload_vectors"] != None:
        model.preload_word_embeddings(config["preload_vectors"])

    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter("output", model.session.graph)

    print("parameter_count: " + str(model.get_parameter_count()))
    print("parameter_count_without_word_embeddings: " + str(model.get_parameter_count_without_word_embeddings()))

    if data_train != None:
        model_selector = config["model_selector"].split(":")[0]
        model_selector_type = config["model_selector"].split(":")[1]
        best_selector_value = 0.0
        best_epoch = -1
        # learningrate = config["learningrate"]

        sess = tf.Session()

        for epoch in range(config["epochs"]):
            print("EPOCH: " + str(epoch))

            learningrate = sess.run(clr.cyclic_learning_rate(epoch, learning_rate=0.8, max_lr=1.2, mode='triangular2'))
            
            print("current_learningrate: " + str(learningrate))
            
            random.shuffle(data_train)

            results_train = process_sentences(writer, merged_summary, epoch, data_train, model, is_training=True, learningrate=learningrate, config=config, name="train")

            if data_dev != None:
                results_dev = process_sentences(epoch, data_dev, model, is_training=False, learningrate=0.0, config=config, name="dev")

                if math.isnan(results_dev["dev_cost_sum"]) or math.isinf(results_dev["dev_cost_sum"]):
                    raise ValueError("Cost is NaN or Inf. Exiting.")

                if (epoch == 0 or (model_selector_type == "high" and results_dev[model_selector] > best_selector_value) 
                               or (model_selector_type == "low" and results_dev[model_selector] < best_selector_value)):
                    best_epoch = epoch
                    best_selector_value = results_dev[model_selector]
                    model.saver.save(model.session, temp_model_path, latest_filename=os.path.basename(temp_model_path)+".checkpoint")
                print("best_epoch: " + str(best_epoch))

                if config["stop_if_no_improvement_for_epochs"] > 0 and (epoch - best_epoch) >= config["stop_if_no_improvement_for_epochs"]:
                    break

                if (epoch - best_epoch) > 3:
                    learningrate *= config["learningrate_decay"]

            while config["garbage_collection"] == True and gc.collect() > 0:
                pass

        if data_dev != None and best_epoch >= 0:
            # loading the best model so far
            model.saver.restore(model.session, temp_model_path)
            os.remove(temp_model_path+".checkpoint")
            os.remove(temp_model_path+".data-00000-of-00001")
            os.remove(temp_model_path+".index")
            os.remove(temp_model_path+".meta")

    if config["save"] is not None and len(config["save"]) > 0:
        model.save(config["save"])

    if config["path_test"] is not None:
        i = 0
        for path_test in config["path_test"].strip().split(":"):
            data_test = read_input_files(path_test)
            results_test = process_sentences(epoch, data_test, model, is_training=False, learningrate=0.0, config=config, name="test"+str(i))
            i += 1
    
    writer.close()
def train(input_train, output_train, strain_train, input_validate, output_validate, strain_validate):
      print
      if GENERATE_DATA:
            pass
            # print("Generating Input Data")
            # input_train = np.zeros((NUM_SAMPLES_TRAIN, NUM_STRANDS, NUM_MARKERS, NUM_AMINO_ACIDS))
            # for strain_counter in range(NUM_SAMPLES_TRAIN):
            #       for strand_counter in range(NUM_STRANDS):
            #             for snp_counter in range(NUM_MARKERS):
            #                   rand_aa = random.randint(0, 19)
            #                   input_train[strain_counter][strand_counter][snp_counter][rand_aa] = 1.0
            #
            # input_validate = np.zeros((NUM_SAMPLES_VALIDATE, NUM_STRANDS, NUM_MARKERS, NUM_AMINO_ACIDS))
            # for strain_counter in range(NUM_SAMPLES_VALIDATE):
            #       for strand_counter in range(NUM_STRANDS):
            #             for snp_counter in range(NUM_MARKERS):
            #                   rand_aa = random.randint(0, 19)
            #                   input_validate[strain_counter][strand_counter][snp_counter][rand_aa] = 1.0
            #
            # print("Generating Output Data")
            # output_train = np.random.rand(NUM_SAMPLES_TRAIN, NUM_TRAITS) * 5
            # output_validate = np.random.rand(NUM_SAMPLES_VALIDATE, NUM_TRAITS) * 5

      epoch = tf.Variable(0, trainable=False, dtype=tf.int32)
      epoch_add_op = epoch.assign(epoch + 1)
      loaded_epoch = tf.placeholder(dtype=tf.int32)
      epoch_load_op = epoch.assign(loaded_epoch)


    #  learning_rate = tf.Variable(LEARN_RATE, trainable=False, dtype=tf.float32)
      learning_rate = cyclic_learning_rate(global_step=epoch, mode='triangular2')


      keep_prob_dense = tf.placeholder(dtype=tf.float32)
      L2_param = tf.constant(L2_WEIGHT, dtype=tf.float32)

      #input_variable = tf.placeholder(dtype=tf.float32, shape=[None, 2, NUM_MARKERS, 20])

      input_variable = tf.placeholder(dtype=tf.float32, shape=[None, NUM_PCA])
      output_variable = tf.placeholder(dtype=tf.float32, shape=[None, NUM_TRAITS])
      output_prediction_variable = model_residual(input_variable, keep_prob_dense)
     # cost = cost_function_sse(output_prediction_variable, output_variable)

      scaling = tf.placeholder(dtype=tf.float32, shape=[None, NUM_TRAITS])
      cost = cost_function_mse(output_prediction_variable, output_variable, scaling)
      total_parameters = 0
      L2_cost = 0
      counter = 0
      for variable in tf.trainable_variables():
            shape = variable.get_shape()
            L2_cost += tf.nn.l2_loss(variable)
            variable_parameters = 1
            for dim in shape:
                  variable_parameters *= dim.value
            total_parameters += variable_parameters
            counter += 1
      print "Num Variables = ", counter, "\tTotal parameters = ", total_parameters
      print

      L2_cost = tf.multiply(L2_param, L2_cost)
      cost = tf.add(cost, L2_cost)
      optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=1.0e-8)
      grads_and_vars = optimizer.compute_gradients(cost)
      if CLIP > 0:
            grads_and_vars = [(tf.clip_by_value(grad, -1 * CLIP, CLIP), var) for grad, var in grads_and_vars]
      if NOISE > 0:
            grads_and_vars = [(i[0] + tf.random_normal(shape=tf.shape(i[0]), mean=0, stddev=NOISE), i[1]) for i in
                              grads_and_vars]
      train_op = optimizer.apply_gradients(grads_and_vars)
      saver = tf.train.Saver()
      print("Running Session")
      with tf.Session() as sess:
            if LOAD_TRAINING_MODEL:
                  print("Reading model parameters from %s" % PATH_SAVE_TRAIN)
                  saver.restore(sess, tf.train.latest_checkpoint(PATH_SAVE_TRAIN))
                  tag = PATH_SAVE + "loss_summary.txt"
                  loss_file = open(tag)
                  previous_best = np.inf
                  counter = 1
                  for line in loss_file:
                        line = line.split()
                        lv = float(line[3])
                        if lv < previous_best:
                              previous_best = lv
                        counter += 1
                  loss_file.close()
                  sess.run(epoch_load_op, feed_dict={loaded_epoch: counter})
                  print "current epoch =", epoch.eval(session=sess)
                  print "best validation error =", previous_best
            else:
                  sess.run(tf.global_variables_initializer())
                  previous_best = np.inf
            best_accuracy_total = 0.0
            if TRAIT_TARGET == "CANNABINOIDS" or TRAIT_TARGET == "TERPENES" or TRAIT_TARGET == "ALL":
                  scale_train = np.array([[SCALE] + [1]*(NUM_TRAITS-1), ] * BATCH_SIZE)
                  scale_val = np.array([[SCALE] + [1]*(NUM_TRAITS-1), ] * len(input_validate))

            elif TRAIT_TARGET == "SPECIFIC":
                  scale_train = np.array([[SCALE] , ] * BATCH_SIZE)
                  scale_val = np.array([[SCALE] , ] * len(input_validate))

            while True:
                  sess.run(epoch_add_op)
                  loss_train = 0
                  loss_train_l2 = 0

                  shuffle_in_unison_3(input_train, output_train, strain_train)
                  for batch_num in xrange(0, len(input_train), BATCH_SIZE):
                        if batch_num + BATCH_SIZE > len(input_train):
                              break
                        input_train_batch = input_train[batch_num:batch_num + BATCH_SIZE]
                        output_train_batch = output_train[batch_num:batch_num + BATCH_SIZE]

                        _, loss, loss_l2, output_pred_train = sess.run(
                              [train_op, cost, L2_cost, output_prediction_variable],
                              feed_dict={output_variable: output_train_batch,
                                         input_variable: input_train_batch,
                                         scaling: scale_train,
                                         keep_prob_dense: KEEP_PROB_DENSE})
                        loss_train += loss
                        loss_train_l2 += loss_l2
                  loss_train -= loss_train_l2

                  loss_validate, loss_validate_l2, output_pred_validate = sess.run(
                        [cost, L2_cost, output_prediction_variable],
                        feed_dict={output_variable: output_validate,
                                   input_variable: input_validate,
                                   scaling: scale_val,
                                   keep_prob_dense: 1.0})
                  loss_validate -= loss_validate_l2
                  print epoch.eval(session=sess), "\t", loss_train, "\t", loss_validate
                  tag = PATH_SAVE_TRAIN + "best_train_loss_model"
                  saver.save(sess, tag)
                  if loss_validate < previous_best:
                        total_percent_real = np.sum(output_validate, axis=1)
                        for i in range(len(total_percent_real)):
                              if total_percent_real[i] == 0.0:
                                    total_percent_real[i] = 0.00001

                        total_percent_pred = np.sum(output_pred_validate, axis=1)
                        diff_total = abs(total_percent_real - total_percent_pred)
                        accuracy_total = (1.0 - diff_total / total_percent_real) * 100
                        tot = sum(accuracy_total)
                        accuracy_mean_total = tot / len(accuracy_total)
                        if accuracy_mean_total > best_accuracy_total:
                              best_accuracy_total = accuracy_mean_total
                        previous_best = loss_validate
                        tag = PATH_SAVE_VAL + "best_val_loss_model"
                        saver.save(sess, tag)
                  if epoch.eval(session=sess) % 25 == 0:
                        # diff = abs(output_validate - output_pred_validate)
                        # accuracy = (1.0 - diff / output_validate) * 100
                        # print accuracy
                        # exit()
                        total_percent_real = np.sum(output_validate, axis=1)
                        total_percent_pred = np.sum(output_pred_validate, axis=1)
                        diff_total = abs(total_percent_real - total_percent_pred)
                        accuracy_total = (1.0 - diff_total / total_percent_real) * 100
                        tot = sum(accuracy_total)
                        accuracy_mean_total = tot / len(accuracy_total)
                        if accuracy_mean_total > best_accuracy_total:
                              best_accuracy_total = accuracy_mean_total
                        print "\t\t", best_accuracy_total, "\t", accuracy_mean_total
                  if epoch.eval(session=sess) % EPOCH_CHECK == 0:
                        tag = PATH_SAVE + "loss_summary.txt"
                        string = str(loss_train) + " " + str(loss_validate) + "\n"
                        with open(tag, "a") as myfile:
                              myfile.write(string)
                  if epoch.eval(session=sess) == EPOCH_LIMIT:
                        break
            string = str(TRAIT_TARGET) + "\t" + str(BATCH_SIZE) + "\t" + str(NUM_LAYERS_DENSE) + "\t" + str(
                  NUM_NODES) + "\t" + str(LEARN_RATE) + "\t" + str(L2_WEIGHT) + "\t" + str(
                  epoch.eval(session=sess)) + "\t" + str(TEN_FOLD_VALIDATION_COUNTER) + "\t" + str(
                  best_accuracy) + "\t" + str(DENSE_BATCH_NORM_FLAG) + "\t" + str(KEEP_PROB_DENSE) + "\n"
            with open(PATH_SAVE_FINAL, "a") as myfile:
                  myfile.write(string)
def model_fn(features, labels, mode, params):

    with tf.variable_scope('model', reuse=tf.AUTO_REUSE):

        # Model Definition (resnet_3D)
        model_output_ops = resnet_3d(
            inputs=features[
                'x'],  # Input: Concatenated Patches (dimensions=128cube,channel=61-63; defined by 'patch_size')
            num_res_units=2,
            num_classes=NUM_CLASSES,
            filters=(16, 32, 64, 128, 256),
            strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
            mode=mode,
            activation=tf.nn.relu6,
            kernel_initializer=tf.initializers.variance_scaling(
                distribution='uniform'),
            bias_initializer=tf.zeros_initializer(),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))

        # Prediction Mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=model_output_ops,
                export_outputs={
                    'out': tf.estimator.export.PredictOutput(model_output_ops)
                })

        # Loss Function
        one_hot_labels = tf.reshape(tf.one_hot(labels['y'], depth=NUM_CLASSES),
                                    [-1, NUM_CLASSES])
        loss = tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=model_output_ops['logits'])

        global_step = tf.train.get_global_step()

        # Learning Rate
        if (LR_MODE == 'eLR'):
            # Exponential Learning Rate Decay [ learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) ]
            learning_rate = tf.train.exponential_decay(
                eLR_INITIAL,
                global_step,
                decay_steps=eLRDECAY_STEPS,
                decay_rate=eLRDECAY_RATE,
                staircase=True)
        elif (LR_MODE == 'CLR'):
            # Cyclic Learning Rate
            # >> cycle = floor( 1 + global_step / ( 2 * step_size ) )
            # >>     x = abs( global_step / step_size - 2 * cycle + 1 )
            # >>   clr = learning_rate + ( max_lr - learning_rate ) * max( 0 , 1 - x )
            learning_rate = cyclic_learning_rate(global_step=global_step,
                                                 learning_rate=CLR_MINLR,
                                                 max_lr=CLR_MAXLR,
                                                 step_size=CLR_STEPSIZE,
                                                 gamma=CLR_GAMMA,
                                                 mode=CLR_MODE)

        # Optimizer
        if params["opt"] == 'adam':
            optimiser = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               epsilon=1e-5)
            optimiser = tf.contrib.estimator.TowerOptimizer(optimiser)
        elif params["opt"] == 'momentum':
            optimiser = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9)
            optimiser = tf.contrib.estimator.TowerOptimizer(optimiser)
        elif params["opt"] == 'rmsprop':
            optimiser = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                                  momentum=0.9)
            optimiser = tf.contrib.estimator.TowerOptimizer(optimiser)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimiser.minimize(loss, global_step=global_step)

        # Custom Image Summaries (TensorBoard)
        my_image_summaries = {}
        my_image_summaries['CT_Patch'] = features['x'][0, 32, :, :, 0]

        expected_output_size = [1, PATCH, PATCH, 1]  # [B, W, H, C]
        [
            tf.summary.image(name, tf.reshape(image, expected_output_size))
            for name, image in my_image_summaries.items()
        ]

        # Track Metrics
        acc = tf.metrics.accuracy
        prec = tf.metrics.precision
        auc = tf.metrics.auc
        eval_metric_ops = {
            "accuracy": acc(labels['y'], model_output_ops['y_']),
            "precision": prec(labels['y'], model_output_ops['y_']),
            "auc": prec(labels['y'], model_output_ops['y_'])
        }

        # Return EstimatorSpec Object
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=model_output_ops,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
コード例 #8
0
ファイル: test_train_lv.py プロジェクト: ahsanshahenshah/lba
def train_test(FLAGS):

    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, train_images_label = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, train_images_label = dataset_tools.get_data(
            'unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))
        #TODO: make smaple slections per classs :done
        #unsup_by_label = semisup.sample_by_label(train_images_unlabeled, train_images_label,
        #                                       FLAGS.unsup_samples/num_labels+num_labels, num_labels,
        #                                       seed)

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            #print(t_sup_images.shape)
            #with tf.Session() as sess: print (t_sup_images.eval().shape)
            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs1, _):
                    inputs = tf.cast(inputs1, tf.float32)
                    inputs = tf.image.adjust_brightness(
                        inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                    inputs = tf.image.random_contrast(inputs, 0.3, 1)
                    # inputs = tf.image.per_image_standardization(inputs)
                    inputs = tf.image.random_hue(inputs, 0.05)
                    inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                    def f1():
                        return tf.abs(inputs)  #annotations

                    def f2():
                        return tf.abs(inputs1)

                    return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5),
                                   f1, f2)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.unsup_samples != 0:
                t_unsup_emb = model.image_to_embedding(t_unsup_images)
                visit_weight_envelope_steps = (
                    FLAGS.walker_weight_envelope_steps
                    if FLAGS.visit_weight_envelope_steps == -1 else
                    FLAGS.visit_weight_envelope_steps)
                visit_weight_envelope_delay = (
                    FLAGS.walker_weight_envelope_delay
                    if FLAGS.visit_weight_envelope_delay == -1 else
                    FLAGS.visit_weight_envelope_delay)
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            if FLAGS.learning_rate_type is None:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)
            elif FLAGS.learning_rate_type == 'exp2':
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.maximum_learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='exp_range',
                                         gamma=0.9999),
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994))

            else:
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            final_loss = slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
                #session_wrapper=tf_debug.LocalCLIDebugWrapperSession
            )

            print(final_loss)
    return final_loss