def main():
    """Create the model and start the training.
  """
    # Read CL arguments and snapshot the arguments into text file.
    args = get_arguments()
    utils.general.snapshot_arg(args)

    # The segmentation network is stride 8 by default.
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Initialize the random seed.
    tf.set_random_seed(args.random_seed)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # current step
    step_ph = tf.placeholder(dtype=tf.float32, shape=())

    # Load the reader.
    with tf.device('/cpu:0'):
        with tf.name_scope('create_inputs'):
            reader = ImageReader(args.data_dir, args.data_list, input_size,
                                 args.random_scale, args.random_mirror,
                                 args.random_crop, args.ignore_label, IMG_MEAN)

            image_batch, label_batch = reader.dequeue(args.batch_size)

    # Allocate data evenly to each gpu.
    images_mgpu = nn_mgpu.split(image_batch, args.num_gpu)
    labels_mgpu = nn_mgpu.split(label_batch, args.num_gpu)

    # Create network and output predictions.
    outputs_mgpu = model(images_mgpu, args.num_classes, args.is_training,
                         args.use_global_status)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables()
        if 'block5' not in v.name or not args.not_restore_classifier
    ]

    # Collect losses from each gpu.
    mean_losses = []
    mean_l2_losses = []
    for outputs, lab in zip(outputs_mgpu, labels_mgpu):
        with tf.device(lab.device):
            # Shrink labels to the size of the network output.
            lab = tf.cast(lab, dtype=tf.float32)
            lab = tf.image.resize_nearest_neighbor(lab,
                                                   innet_size,
                                                   name='label_shrink')
            lab = tf.reshape(lab, [
                -1,
            ])

            # Ignore the location where the label value is larger than args.num_classes.
            not_ignore_pixel = tf.less_equal(lab, args.num_classes - 1)

            # Extract the indices of pixel where the gradients are propogated.
            pixel_inds = tf.squeeze(tf.where(not_ignore_pixel), 1)
            lab_gather = tf.to_int32(tf.gather(lab, pixel_inds))

            # Define softmax loss.
            for i, out in enumerate(outputs):
                # Get mini-batch size on each GPU device.
                n = out.get_shape().as_list()[0]

                # Flatten predictions.
                out = tf.reshape(out, [-1, args.num_classes])
                out_gather = tf.gather(out, pixel_inds)
                loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=out_gather, labels=lab_gather)
                loss = tf.reduce_mean(loss)
                loss *= float(n) / float(args.batch_size)
                mean_losses.append(loss)

            # Define weight regularization loss.
            w = args.weight_decay
            l2_losses = [
                w * tf.nn.l2_loss(v) for v in tf.trainable_variables()
                if 'weights' in v.name
            ]
            l2_loss = tf.add_n(l2_losses) / float(args.num_gpu)
            mean_l2_losses.append(l2_loss)

    # Sum all loss terms.
    mean_seg_loss = tf.add_n(mean_losses)
    mean_l2_loss = tf.add_n(mean_l2_losses)
    reduced_loss = mean_seg_loss + mean_l2_loss

    # Grab variable names which are used for training.
    all_trainable = tf.trainable_variables()
    fc_trainable = [v for v in all_trainable if 'block5' in v.name]  # lr*10
    base_trainable = [v for v in all_trainable
                      if 'block5' not in v.name]  # lr*1

    # Computes gradients per iteration.
    grads = tf.gradients(reduced_loss,
                         base_trainable + fc_trainable,
                         colocate_gradients_with_ops=True)
    grads_base = grads[:len(base_trainable)]
    grads_fc = grads[len(base_trainable):]

    # Define optimisation parameters.
    base_lr = tf.constant(args.learning_rate)
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    opt_base = tf.train.MomentumOptimizer(learning_rate * 1.0, args.momentum)
    opt_fc = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)

    # Define tensorflow operations which apply gradients to update variables.
    train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable))
    train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable))
    train_op = tf.group(train_op_base, train_op_fc)

    # Process for visualisation.
    with tf.device('/cpu:0'):
        # Image summary for input image, ground-truth label and prediction.
        cat_output = tf.concat([o[-1] for o in outputs_mgpu], axis=0)
        output_vis = tf.image.resize_nearest_neighbor(
            cat_output,
            tf.shape(image_batch)[1:3, ])
        output_vis = tf.argmax(output_vis, axis=3)
        output_vis = tf.expand_dims(output_vis, dim=3)
        output_vis = tf.cast(output_vis, dtype=tf.uint8)

        labels_vis = tf.cast(label_batch, dtype=tf.uint8)

        in_summary = tf.py_func(utils.general.inv_preprocess,
                                [image_batch, IMG_MEAN], tf.uint8)
        gt_summary = tf.py_func(utils.general.decode_labels,
                                [labels_vis, args.num_classes], tf.uint8)
        out_summary = tf.py_func(utils.general.decode_labels,
                                 [output_vis, args.num_classes], tf.uint8)
        # Concatenate image summaries in a row.
        total_summary = tf.summary.image(
            'images',
            tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]),
            max_outputs=args.batch_size)

        # Scalar summary for different loss terms.
        seg_loss_summary = tf.summary.scalar('seg_loss', mean_seg_loss)
        total_summary = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                               graph=tf.get_default_graph())

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    # Load variables if the checkpoint is provided.
    if args.restore_from is not None:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    pbar = tqdm(range(args.num_steps))
    for step in pbar:
        start_time = time.time()
        feed_dict = {step_ph: step}

        step_loss = 0
        for it in range(args.iter_size):
            # Update summary periodically.
            if it == args.iter_size - 1 and step % args.update_tb_every == 0:
                sess_outs = [reduced_loss, total_summary, train_op]
                loss_value, summary, _ = sess.run(sess_outs,
                                                  feed_dict=feed_dict)
                summary_writer.add_summary(summary, step)
            else:
                sess_outs = [reduced_loss, train_op]
                loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict)

            step_loss += loss_value

        step_loss /= args.iter_size

        lr = sess.run(learning_rate, feed_dict=feed_dict)

        # Save trained model periodically.
        if step % args.save_pred_every == 0 and step > 0:
            save(saver, sess, args.snapshot_dir, step)

        duration = time.time() - start_time
        desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
        pbar.set_description(desc)

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 2
0
def main():
    """Create the model and start the training."""

    # Read CL arguments and snapshot the arguments into text file.
    args = get_arguments()
    utils.general.snapshot_arg(args)

    # The segmentation network is stride 8 by default.
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Initialize the random seed.
    tf.set_random_seed(args.random_seed)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # current step
    step_ph = tf.placeholder(dtype=tf.float32, shape=())

    # Load the reader.
    with tf.device('/cpu:0'):
        with tf.name_scope('create_inputs'):
            reader = SegSortImageReader(args.data_dir, args.data_list,
                                        input_size, args.random_scale,
                                        args.random_mirror, args.random_crop,
                                        args.ignore_label, IMG_MEAN,
                                        args.num_clusters)

            image_batch, label_batch, cluster_label_batch, loc_feature_batch = (
                reader.dequeue(args.batch_size))

    # Allocate data evenly to each gpu.
    images_mgpu = nn_mgpu.split(image_batch, args.num_gpu - 1)

    # Create network and output predictions.
    outputs_mgpu = model(
        images_mgpu,  #calls pspnet_resnet101
        args.embedding_dim,
        args.is_training,
        args.use_global_status)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'block5' not in v.name or not args.
        not_restore_classifier  #qq: why not block5? apparently doesnt restore fc_trainable
    ]

    # Shrink labels to the size of the network output.
    labels = tf.image.resize_nearest_neighbor(label_batch,
                                              innet_size,
                                              name='label_shrink')
    cluster_labels = tf.image.resize_nearest_neighbor(  #qq: what does cluster labels do, it is not trained yet
        cluster_label_batch, innet_size)
    loc_features = tf.image.resize_nearest_neighbor(loc_feature_batch,
                                                    innet_size)

    # Collect embedding from each gpu.
    with tf.device('/gpu:{:d}'.format(args.num_gpu -
                                      1)):  #qq: last gpu used to compute loss?
        embedding_list = [outputs[0] for outputs in outputs_mgpu]
        embedding = tf.concat(embedding_list,
                              axis=0)  #poc - use embedding list!!

        # Define SegSort loss.
        seg_losses = train_utils.add_segsort_loss(
            embedding, labels, args.embedding_dim, args.ignore_label,
            args.concentration, cluster_labels, args.num_clusters,
            args.kmeans_iterations, args.num_banks, loc_features)

        # Define weight regularization loss.
        w = args.weight_decay
        l2_losses = [
            w * tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if 'weights' in v.name
        ]
        mean_l2_loss = tf.add_n(l2_losses)

        # Sum all loss terms.
        mean_seg_loss = seg_losses
        reduced_loss = mean_seg_loss + mean_l2_loss

    # Grab variable names which are used for training.
    all_trainable = tf.trainable_variables(
    )  #qq: whats diff between fc and base trainable? #seems like block5 is pspnet layer
    fc_trainable = [v for v in all_trainable if 'block5' in v.name]  # lr*10
    base_trainable = [v for v in all_trainable
                      if 'block5' not in v.name]  # lr*1

    # Computes gradients per iteration.
    grads = tf.gradients(reduced_loss,
                         base_trainable + fc_trainable,
                         colocate_gradients_with_ops=True)
    grads_base = grads[:len(base_trainable)]
    grads_fc = grads[len(base_trainable):]

    # Define optimisation parameters.
    base_lr = tf.constant(args.learning_rate)
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    opt_base = tf.train.MomentumOptimizer(learning_rate * 1.0, args.momentum)
    opt_fc = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)

    # Define tensorflow operations which apply gradients to update variables.
    train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable))
    train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable))
    train_op = tf.group(train_op_base, train_op_fc)

    # Process for visualisation.
    with tf.device('/cpu:0'):
        # Image summary for input image, ground-truth label and prediction.
        cat_output = tf.concat([o[-1] for o in outputs_mgpu], axis=0)
        output_vis = tf.image.resize_nearest_neighbor(
            cat_output,
            tf.shape(image_batch)[1:3, ])
        output_vis = tf.argmax(output_vis, axis=3)
        output_vis = tf.expand_dims(output_vis, dim=3)
        output_vis = tf.cast(output_vis, dtype=tf.uint8)

        labels_vis = tf.cast(label_batch, dtype=tf.uint8)

        in_summary = tf.py_func(utils.general.inv_preprocess,
                                [image_batch, IMG_MEAN], tf.uint8)
        gt_summary = tf.py_func(utils.general.decode_labels,
                                [labels_vis, args.num_classes], tf.uint8)
        out_summary = tf.py_func(utils.general.decode_labels,
                                 [output_vis, args.num_classes], tf.uint8)
        # Concatenate image summaries in a row.
        total_summary = tf.summary.image(
            'images',
            tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]),
            max_outputs=args.batch_size)

        # Scalar summary for different loss terms.
        seg_loss_summary = tf.summary.scalar('seg_loss', mean_seg_loss)
        total_summary = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                               graph=tf.get_default_graph())

    # Set up tf session and initialize variables.
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    # Load variables if the checkpoint is provided.
    if args.restore_from is not None:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    pbar = tqdm(range(args.num_steps))
    for step in pbar:
        start_time = time.time()
        feed_dict = {step_ph: step}

        step_loss = 0
        for it in range(args.iter_size):
            # Update summary periodically.
            if it == args.iter_size - 1 and step % args.update_tb_every == 0:
                sess_outs = [reduced_loss, total_summary, train_op]
                loss_value, summary, _ = sess.run(sess_outs,
                                                  feed_dict=feed_dict)
                summary_writer.add_summary(summary, step)
            else:
                sess_outs = [reduced_loss, train_op]
                loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict)

            step_loss += loss_value

        step_loss /= args.iter_size

        lr = sess.run(learning_rate, feed_dict=feed_dict)

        # Save trained model periodically.
        if step % args.save_pred_every == 0 and step > 0:
            save(saver, sess, args.snapshot_dir, step)

        duration = time.time() - start_time
        desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
        pbar.set_description(desc)

    coord.request_stop()
    coord.join(threads)
def main():
    print("IMG_NET")
    """Create the model and start the training."""

    # Read CL arguments and snapshot the arguments into text file.
    args = get_arguments()
    utils.general.snapshot_arg(args)

    # The segmentation network is stride 8 by default.
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Initialize the random seed.
    tf.set_random_seed(args.random_seed)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # current step
    step_ph = tf.placeholder(dtype=tf.float32, shape=())

    reader = ImageNetReader(args.data_dir + "train/", args.batch_size, h, 10,
                            False)

    #num supposed images: 1281167
    #num batches: 20019 * 64 = 1281216 (close enough, batch is overest.)

    #set up input
    image_batch = tf.placeholder(tf.float32, [args.batch_size, w, h, 3])
    labels_batch = tf.placeholder(tf.int32, [args.batch_size])

    # Allocate data evenly to each gpu.
    images_mgpu = nn_mgpu.split(image_batch, args.num_gpu)  #last gpu is good

    # Create network and output predictions.
    outputs_mgpu = model(
        images_mgpu,  #calls pspnet_resnet101
        args.embedding_dim,
        False,
        args.use_global_status)

    # Grab variable names which should be restored from checkpoints.
    # tw: double check
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Collect embedding from each gpu.
    with tf.device('/gpu:{:d}'.format(args.num_gpu -
                                      1)):  #qq: last gpu used to compute loss?
        embedding_list = [outputs[0] for outputs in outputs_mgpu]
        embedding = tf.concat(embedding_list,
                              axis=0)  # [batch]x input/8 x input/8 x[emb_size]

        with tf.variable_scope("imagenet_classify"):
            #tw: replace with max/average pooling
            # 1 max/avg pooling 30x30xemb -> 1x1xemb
            # kernel size = 30, stride = 30, padding = valid

            # to change:
            conv_1 = tf.layers.conv2d(
                inputs=embedding,
                filters=args.embedding_dim,
                kernel_size=5,
                strides=(2, 2),
                padding="same",
                activation=tf.nn.relu)  #[batch]x30x30x[emb_size]
            conv_2 = tf.layers.conv2d(
                inputs=conv_1,
                filters=args.embedding_dim,
                kernel_size=5,
                strides=(2, 2),
                padding="same",
                activation=tf.nn.relu)  #[batch]x15x15x[emb_size]

            y_out = tf.layers.flatten(conv_2)
            y_out = tf.layers.dense(y_out, args.num_classes)  #-inf to inf

        # tw: try weightd decau on y_out later
        # Define weight regularization loss.
        # w = args.weight_decay
        # l2_losses = [w*tf.nn.l2_loss(v) for v in tf.trainable_variables()
        #            if 'weights' in v.name]
        # mean_l2_loss = tf.add_n(l2_losses)

        # Define loss terms.

        #also shouldn't this be unsupervised? lol...
        classify_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=y_out,
                                                    labels=tf.one_hot(
                                                        labels_batch,
                                                        args.num_classes)))
        #tw: check how cpc does linear classifier

        # mean_seg_loss = seg_losses
        # reduced_loss = mean_seg_loss + mean_l2_loss

    interim = tf.cast(
        tf.equal(tf.cast(tf.argmax(y_out, axis=1), tf.int32), labels_batch),
        tf.float32)
    train_acc = tf.reduce_mean(interim)

    # Grab variable names which are used for training.
    # todo: grab the correct variables for the last layer
    imgnet_trainable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         'imagenet_classify')

    # Define optimisation parameters.
    base_lr = tf.constant(args.learning_rate)
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    opt_imgnet = tf.train.MomentumOptimizer(
        learning_rate, args.momentum)  #tw: add imgnet variables
    #tw: check original training script

    #tw: learning rate policies:
    # step size
    # exponential

    # Define tensorflow train op to minimize loss
    train_op = opt_imgnet.minimize(classify_loss)

    # Process for visualisation.
    # with tf.device('/cpu:0'):
    #   # Image summary for input image, ground-truth label and prediction.
    #   cat_output = tf.concat([o[-1] for o in outputs_mgpu], axis=0)
    #   output_vis = tf.image.resize_nearest_neighbor(
    #       cat_output, tf.shape(image_batch)[1:3,])
    #   output_vis = tf.argmax(output_vis, axis=3)
    #   output_vis = tf.expand_dims(output_vis, dim=3)
    #   output_vis = tf.cast(output_vis, dtype=tf.uint8)

    #   labels_vis = tf.cast(label_batch, dtype=tf.uint8)

    #   in_summary = tf.py_func(
    #       utils.general.inv_preprocess,
    #       [image_batch, IMG_MEAN],
    #       tf.uint8)
    #   gt_summary = tf.py_func(
    #       utils.general.decode_labels,
    #       [labels_vis, args.num_classes],
    #       tf.uint8)
    #   out_summary = tf.py_func(
    #       utils.general.decode_labels,
    #       [output_vis, args.num_classes],
    #       tf.uint8)
    #   # Concatenate image summaries in a row.
    #   total_summary = tf.summary.image(
    #       'images',
    #       tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]),
    #       max_outputs=args.batch_size)

    #   # Scalar summary for different loss terms.
    #   seg_loss_summary = tf.summary.scalar(
    #       'seg_loss', mean_seg_loss)
    #   total_summary = tf.summary.merge_all()

    #   summary_writer = tf.summary.FileWriter(
    #       args.snapshot_dir,
    #       graph=tf.get_default_graph())

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    # Iterate over training steps.
    pbar = tqdm(range(args.num_steps))
    for step in pbar:
        start_time = time.time()

        #todo: see if reader dequeue gets emptied
        img_np, labels_truth = reader.dequeue()

        feed_dict = {
            step_ph: step,
            image_batch: img_np,
            labels_batch: labels_truth
        }

        step_loss = 0
        for it in range(args.iter_size):
            # Update summary periodically.
            if it == args.iter_size - 1 and step % args.update_tb_every == 0:
                sess_outs = [classify_loss, train_op]
                loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict)
                # summary_writer.add_summary(summary, step)
            else:
                sess_outs = [classify_loss, train_op, train_acc]
                loss_value, _, train_acc_v = sess.run(sess_outs,
                                                      feed_dict=feed_dict)
                print(train_acc_v)

            step_loss += loss_value

        step_loss /= args.iter_size

        lr = sess.run(learning_rate, feed_dict=feed_dict)

        # Save trained model periodically.
        if step % args.save_pred_every == 0 and step > 0:
            save(saver, sess, args.snapshot_dir, step)

        duration = time.time() - start_time
        desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
        pbar.set_description(desc)

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 4
0
def main():
    print("IMG_NET extracting embeddings")
    """Create the model and start the training."""

    # Read CL arguments and snapshot the arguments into text file.
    args = get_arguments()
    utils.general.snapshot_arg(args)
    global curr_class_name

    # The segmentation network is stride 8 by default.
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Initialize the random seed.
    tf.set_random_seed(args.random_seed)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # current step
    step_ph = tf.placeholder(dtype=tf.float32, shape=())

    reader = ImageNetReader(os.path.join(args.data_dir,
                                         "train"), args.batch_size, h,
                            args.num_loading_workers, False)

    #num batches: 20019 * 64 = 1281216 (close enough, batch is overest.)

    #set up input
    image_batch = tf.placeholder(tf.float32, [args.batch_size, w, h, 3])
    labels_batch = tf.placeholder(tf.int32, [args.batch_size])

    # Allocated data evenly to each gpu, because batch-size is only 1 -- nvm
    images_mgpu = nn_mgpu.split(image_batch, args.num_gpu)  #last gpu is good

    # Create network and output predictions.
    outputs_mgpu = model(
        images_mgpu,  #calls pspnet_resnet101
        args.embedding_dim,
        False,
        args.use_global_status)

    # Grab variable names which should be restored from checkpoints.
    # tw: double check
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Collect embedding from each gpu.
    with tf.device('/gpu:{:d}'.format(args.num_gpu -
                                      1)):  #qq: last gpu used to compute loss?
        embedding_list = [outputs[0] for outputs in outputs_mgpu]
        embedding = tf.concat(embedding_list,
                              axis=0)  # [batch]x input/8 x input/8 x[emb_size]

    #tw: check how cpc does linear classifier

    #tw: check original training script

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        print("Loading restore:", args.restore_from)
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    pbar = tqdm(range(reader.num_batches))
    print(reader.num_batches)
    #num supposed images: 1281167

    train_save_dir = os.path.join(args.save_dir, args.embedding_dim, "train")
    curr_class_name_tmp = curr_class_name
    for step in pbar:
        start_time = time.time()

        img_np, labels_truth = reader.dequeue()

        #Handle last batch case
        last_batch_dim = None
        if img_np.shape[0] < args.batch_size:
            #last batch
            img_np, labels_truth, last_batch_dim = handle_last_batch(
                img_np, labels_truth, args.batch_size)

        timeA = time.time() - start_time
        start_time = time.time()

        emb_list = sess.run(embedding, feed_dict={image_batch: img_np})

        timeB = time.time() - start_time
        start_time = time.time()

        save_numpy_to_dir(train_save_dir, emb_list, labels_truth,
                          reader.get_idx_to_class(), last_batch_dim)

        timeC = time.time() - start_time
        start_time = time.time()

        if curr_class_name_tmp != curr_class_name:
            curr_class_name_tmp = curr_class_name
            print(curr_class_name)

        print(timeA, timeB, timeC)

        duration = time.time() - start_time
        # desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
        # pbar.set_description(desc)

    coord.request_stop()
    coord.join(threads)