示例#1
0
def train(conf):
    logger = util.Logger(conf)
    if not os.path.exists(conf.checkpoint_dir):
        os.makedirs(conf.checkpoint_dir)

    model_name = conf.model_name
    dataset_name = "ClassificationDataset"
    collate_name = "FastTextCollator" if model_name == "FastText" \
        else "ClassificationCollator"
    train_data_loader, validate_data_loader, test_data_loader = \
        get_data_loader(dataset_name, collate_name, conf)
    empty_dataset = globals()[dataset_name](conf, [], mode="train")
    model = get_classification_model(model_name, empty_dataset, conf)
    loss_fn = globals()["ClassificationLoss"](label_size=len(
        empty_dataset.label_map),
                                              loss_type=conf.train.loss_type)
    optimizer = get_optimizer(conf, model)
    evaluator = cEvaluator(conf.eval.dir)
    trainer = globals()["ClassificationTrainer"](empty_dataset.label_map,
                                                 logger, evaluator, conf,
                                                 loss_fn)

    best_epoch = -1
    best_performance = 0
    model_file_prefix = conf.checkpoint_dir + "/" + model_name
    for epoch in range(conf.train.start_epoch,
                       conf.train.start_epoch + conf.train.num_epochs):
        start_time = time.time()
        trainer.train(train_data_loader, model, optimizer, "Train", epoch)
        trainer.eval(train_data_loader, model, optimizer, "Train", epoch)
        performance = trainer.eval(validate_data_loader, model, optimizer,
                                   "Validate", epoch)
        trainer.eval(test_data_loader, model, optimizer, "test", epoch)
        if performance > best_performance:  # record the best model
            best_epoch = epoch
            best_performance = performance
        save_checkpoint(
            {
                'epoch': epoch,
                'model_name': model_name,
                'state_dict': model.state_dict(),
                'best_performance': best_performance,
                'optimizer': optimizer.state_dict(),
            }, model_file_prefix)
        time_used = time.time() - start_time
        logger.info("Epoch %d cost time: %d second" % (epoch, time_used))

    # best model on validateion set
    best_epoch_file_name = model_file_prefix + "_" + str(best_epoch)
    best_file_name = model_file_prefix + "_best"
    shutil.copyfile(best_epoch_file_name, best_file_name)

    load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model,
                    optimizer)
    trainer.eval(test_data_loader, model, optimizer, "Best test", best_epoch)
示例#2
0
def kfold_eval(conf):
    logger = util.Logger(conf)
    model_name = conf.model_name
    dataset_name = "ClassificationDataset"
    collate_name = "FastTextCollator" if model_name == "FastText" \
        else "ClassificationCollator"

    test_dataset = globals()[dataset_name](conf, conf.data.test_json_files)
    collate_fn = globals()[collate_name](conf, len(test_dataset.label_map))
    test_data_loader = DataLoader(
        test_dataset, batch_size=conf.eval.batch_size, shuffle=False,
        num_workers=conf.data.num_worker, collate_fn=collate_fn,
        pin_memory=True)

    empty_dataset = globals()[dataset_name](conf, [])
    model = get_classification_model(model_name, empty_dataset, conf)
    optimizer = get_optimizer(conf, model)
    load_checkpoint(conf.eval.model_dir, conf, model, optimizer)
    model.eval()
    predict_probs = []
    standard_labels = []
    evaluator = cEvaluator(conf.eval.dir)
    for batch in test_data_loader:
        logits = model(batch)
        result = torch.sigmoid(logits).cpu().tolist()
        predict_probs.extend(result)
        standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST])

        # ============================ EVALUATION API ============================================================================================
    y_test, predictions = [], []

    print (standard_labels)
    for i, j in zip(standard_labels, predict_probs):
            y_test.append(i)
            predictions.append(j)



    pred, actual = take_values(predictions, y_test , conf.eval.threshold, conf.eval.top_k )
    print(pred)
    actual=np.array(actual)
    pred=np.array(pred)

    evaluation_measures={"Accuracy": accuracy(actual, pred) ,
                             "Precision": precision(actual, pred) ,
                             "Recall": recall(actual, pred) ,
                             "F1 score": f1_scor(actual, pred, ) ,
                             "Hamming Loss":hammingLoss(actual, pred),
                             "f-1 Macro":macroF1(actual, pred) ,
                             "f-1 Micro":microF1(actual, pred),
                             "averagePrecision":averagePrecision(actual, pred)
                             }
    return evaluation_measures
示例#3
0
def eval(conf):
    logger = util.Logger(conf)
    model_name = conf.model_name
    dataset_name = "ClassificationDataset"
    collate_name = "FastTextCollator" if model_name == "FastText" \
        else "ClassificationCollator"

    test_dataset = globals()[dataset_name](conf, conf.data.test_json_files)
    collate_fn = globals()[collate_name](conf, len(test_dataset.label_map))
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=conf.eval.batch_size,
                                  shuffle=False,
                                  num_workers=conf.data.num_worker,
                                  collate_fn=collate_fn,
                                  pin_memory=True)

    empty_dataset = globals()[dataset_name](conf, [])
    model = get_classification_model(model_name, empty_dataset, conf)
    optimizer = get_optimizer(conf, model.parameters())
    load_checkpoint(conf.eval.model_dir, conf, model, optimizer)
    model.eval()
    is_multi = False
    if conf.task_info.label_type == ClassificationType.MULTI_LABEL:
        is_multi = True
    predict_probs = []
    standard_labels = []
    total_loss = 0.
    evaluator = cEvaluator(conf.eval.dir)
    for batch in test_data_loader:
        logits = model(batch)
        if not is_multi:
            result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist()
        else:
            result = torch.sigmoid(logits).cpu().tolist()
        predict_probs.extend(result)
        standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST])
    total_loss = total_loss / len(predict_probs)
    (_, precision_list, recall_list, fscore_list, right_list,
     predict_list, standard_list) = \
        evaluator.evaluate(
            predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map,
            threshold=conf.eval.threshold, top_k=conf.eval.top_k,
            is_flat=conf.eval.is_flat, is_multi=is_multi)
    logger.warn(
        "Performance is precision: %f, "
        "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." %
        (precision_list[0][cEvaluator.MICRO_AVERAGE],
         recall_list[0][cEvaluator.MICRO_AVERAGE],
         fscore_list[0][cEvaluator.MICRO_AVERAGE],
         right_list[0][cEvaluator.MICRO_AVERAGE],
         predict_list[0][cEvaluator.MICRO_AVERAGE],
         standard_list[0][cEvaluator.MICRO_AVERAGE]))
    evaluator.save()
示例#4
0
def train(conf):
    model_name = conf.model_name
    logger = util.Logger(conf)
    if conf.task_info.weak_pretrain:
        logger.info("Batch Size: " + str(conf.train.batch_size) +
                    " Pretrain Num Epoch: " +
                    str(conf.train.pretrain_num_epochs))
    else:
        logger.info("Batch Size: " + str(conf.train.batch_size))

    if conf.task_info.weak_pretrain and conf.task_info.weak_data_augmentation:
        model_teacher = get_classification_model(model_name, empty_dataset,
                                                 conf)
        if conf.model_name != "BERT":
            optimizer_teacher = get_optimizer(conf, model_teacher)
        else:
            optimizer_teacher = AdamW(model_teacher.parameters(),
                                      lr=5e-2,
                                      eps=1e-2)
        # optimizer_teacher: optimizer for teacher model

    model_target = get_classification_model(model_name, empty_dataset, conf)
    loss_fn = globals()["ClassificationLoss"](label_size=len(
        empty_dataset.label_map),
                                              loss_type=conf.train.loss_type)

    if conf.task_info.weak_pretrain:
        if conf.model_name != "BERT":
            optimizer_weak = get_optimizer(conf, model_target)
        else:
            optimizer_weak = AdamW(model_target.parameters(),
                                   lr=5e-2,
                                   eps=1e-2)
        # optimizer_weak: optimizer for target model pretraining stage
    if conf.model_name != "BERT":
        optimizer_target = get_optimizer(conf, model_target)
    else:
        optimizer_target = AdamW(model_target.parameters(), lr=5e-2, eps=1e-2)
    # optimizer_target: optimizer for target model fine-tuning stage
    evaluator = cEvaluator(conf.eval.dir)

    trainer_target = globals()["ClassificationTrainer"](
        empty_dataset.label_map, logger, evaluator, conf, loss_fn)
    # trainer_target: trainer for target model on fine-tuning stage
    if conf.task_info.weak_pretrain:
        trainer_weak = globals()["ClassificationTrainer"](
            empty_dataset.label_map, logger, evaluator, conf, loss_fn)
        # trainer_weak: trainer for target model on pretraining stage
        if conf.task_info.weak_data_augmentation:
            trainer_teacher = globals()["ClassificationTrainer"](
                empty_dataset.label_map, logger, evaluator, conf, loss_fn)
            # trainer_teacher: trainer for teacher model

    if conf.task_info.weak_data_augmentation:
        best_epoch = -1
        best_performance = 0
        model_file_prefix = conf.checkpoint_dir + "/" + model_name + "_teacher"

        logger.info("Training Teacher Model on Labeled Data")
        for epoch in range(conf.train.start_epoch,
                           conf.train.start_epoch + conf.train.num_epochs):
            start_time = time.time()
            trainer_teacher.train(train_data_loader, model_teacher,
                                  optimizer_teacher, "Train", epoch)
            trainer_teacher.eval(train_data_loader, model_teacher,
                                 optimizer_teacher, "Train", epoch)
            performance = trainer_teacher.eval(validate_data_loader,
                                               model_teacher,
                                               optimizer_teacher, "Validate",
                                               epoch)
            trainer_teacher.eval(test_data_loader, model_teacher,
                                 optimizer_teacher, "Test", epoch)

            if performance > best_performance:  # record the best model
                best_epoch = epoch
                best_performance = performance
                temp_model = model_teacher
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'model_name': model_name,
                        'state_dict': model_teacher.state_dict(),
                        'best_performance': best_performance,
                        'optimizer': optimizer_teacher.state_dict(),
                    }, model_file_prefix)

            time_used = time.time() - start_time
            logger.info("Epoch %d cost time: %d second" % (epoch, time_used))
    best_epoch = -1
    best_performance = 0
    if conf.task_info.weak_pretrain:
        if conf.task_info.weak_data_augmentation:
            unlabeled_data_train_data_loader = select_unlabeled_data(
                temp_model, unlabeled_train_data_loader,
                len(trainer_weak.label_map), conf)

        logger.info("Pretraining on Weak Supervision Data")
        for epoch in range(
                conf.train.start_epoch,
                conf.train.start_epoch + conf.train.pretrain_num_epochs):
            start_time = time.time()
            trainer_weak.train(unlabeled_train_data_loader, model_target,
                               optimizer_weak, "Train", epoch)
            trainer_weak.eval(unlabeled_train_data_loader, model_target,
                              optimizer_weak, "Train", epoch)
            performance = trainer_weak.eval(validate_data_loader, model_target,
                                            optimizer_weak, "Validate", epoch)
            trainer_weak.eval(test_data_loader, model_target, optimizer_weak,
                              "Test", epoch)

            if performance > best_performance:  # record the best model
                temp_model = model_target
            time_used = time.time() - start_time
            logger.info("Epoch %d cost time: %d second" % (epoch, time_used))
        model_target = temp_model

    logger.info("Fine-tuning on Labeled Data")

    best_epoch = -1
    best_performance = 0
    if conf.task_info.weak_pretrain:
        if conf.task_info.weak_data_augmentation:
            model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-Augmentation-" + conf.task_info.Augmentation_Method + "-Pretrain" + str(
                conf.train.pretrain_num_epochs) + "-Batch" + str(
                    conf.train.batch_size)
        else:
            model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-WeakSupervision-" + "-Pretrain" + str(
                conf.train.pretrain_num_epochs) + "-Batch" + str(
                    conf.train.batch_size)
    else:
        model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-Batch" + str(
            conf.train.batch_size)
    for epoch in range(conf.train.start_epoch,
                       conf.train.start_epoch + conf.train.num_epochs):
        start_time = time.time()
        trainer_target.train(train_data_loader, model_target, optimizer_target,
                             "Train", epoch)
        trainer_target.eval(train_data_loader, model_target, optimizer_target,
                            "Train", epoch)
        performance = trainer_target.eval(validate_data_loader, model_target,
                                          optimizer_target, "Validate", epoch)
        trainer_target.eval(test_data_loader, model_target, optimizer_target,
                            "Test", epoch)
        if performance > best_performance:  # record the best model
            best_epoch = epoch
            best_performance = performance
            temp_model = model_target
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_name': model_name,
                    'state_dict': model_target.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimizer_target.state_dict(),
                }, model_file_prefix)
        time_used = time.time() - start_time
        logger.info("Epoch %d cost time: %d second" % (epoch, time_used))

    logger.info("The Best Performance on Validation Data and Test Data")
    #best_epoch_file_name = model_file_prefix + "_" + str(best_epoch)
    #best_file_name = model_file_prefix + "_best"
    #shutil.copyfile(best_epoch_file_name, best_file_name)
    #load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model,
    #                optimizer)
    model = temp_model
    trainer_target.eval(train_data_loader, model, optimizer_target,
                        "Best Train", best_epoch)
    trainer_target.eval(validate_data_loader, model, optimizer_target,
                        "Best Validate", best_epoch)
    trainer_target.eval(test_data_loader, model, optimizer_target, "Best Test",
                        best_epoch)
示例#5
0
def eval(conf):
    logger = util.Logger(conf)
    model_name = conf.model_name
    dataset_name = "ClassificationDataset"
    collate_name = "FastTextCollator" if model_name == "FastText" \
        else "ClassificationCollator"

    test_dataset = globals()[dataset_name](conf, conf.data.test_json_files)
    collate_fn = globals()[collate_name](conf, len(test_dataset.label_map))
    test_data_loader = DataLoader(
        test_dataset, batch_size=conf.eval.batch_size, shuffle=False,
        num_workers=conf.data.num_worker, collate_fn=collate_fn,
        pin_memory=True)

    empty_dataset = globals()[dataset_name](conf, [])
    model = get_classification_model(model_name, empty_dataset, conf)
    optimizer = get_optimizer(conf, model)
    load_checkpoint(conf.eval.model_dir, conf, model, optimizer)
    model.eval()
    is_multi = False
    if conf.task_info.label_type == ClassificationType.MULTI_LABEL:
        is_multi = True
    predict_probs = []
    standard_labels = []
    evaluator = cEvaluator(conf.eval.dir)
    for batch in test_data_loader:
        with torch.no_grad():
            logits = model(batch)
        if not is_multi:
            result = torch.nn.functional.softmax(logits, dim=1)
        else:
            result = torch.sigmoid(logits)
        result = result.detach().cpu().tolist()
        predict_probs.extend(result)
        standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST])
    if conf.eval.is_flat:
        (_, precision_list, recall_list, fscore_list, right_list,
         predict_list, standard_list, pak_dict, rak_dict, rpak_dict, ndcgak_dict) = \
            evaluator.evaluate(
                predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map,
                threshold=conf.eval.threshold, top_k=conf.eval.top_k,
                is_flat=conf.eval.is_flat, is_multi=is_multi,
                debug_file_name=conf.eval.debug_file_name,
                is_label_split=conf.data.generate_label_group,
                label_split_json_file=os.path.join(conf.data.dict_dir,
                                                   "{}.json".format(ClassificationDataset.DOC_LABEL_GROUP)),
                instance_remove=conf.eval.instance_remove
            )
        sup_message = ""
        for i in range(1, conf.eval.top_k + 1):
            for group in pak_dict[i]:
                sup_message += "Precision at {} of {} group: {}, ".format(i, group, pak_dict[i][group])
                sup_message += "Recall at {} of {} group: {}, ".format(i, group, rak_dict[i][group])
                sup_message += "R-Precision at {} of {} group: {}, ".format(i, group, rpak_dict[i][group])
                sup_message += "nDCG at {} of {} group: {}, ".format(i, group, ndcgak_dict[i][group])

        message = "Performance is precision: {}, recall: {}, fscore: {}, " + \
                  "macro-fscore: {}, right: {}, predict: {}, standard: {}, "
        logger.warn(message.format(
            precision_list[0][cEvaluator.MICRO_AVERAGE],
            recall_list[0][cEvaluator.MICRO_AVERAGE],
            fscore_list[0][cEvaluator.MICRO_AVERAGE],
            fscore_list[0][cEvaluator.MACRO_AVERAGE],
            right_list[0][cEvaluator.MICRO_AVERAGE],
            predict_list[0][cEvaluator.MICRO_AVERAGE],
            standard_list[0][cEvaluator.MICRO_AVERAGE]) +
            sup_message)
    else:
        (_, precision_list, recall_list, fscore_list, right_list,
         predict_list, standard_list) = \
            evaluator.evaluate(
                predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map,
                threshold=conf.eval.threshold, top_k=conf.eval.top_k,
                is_flat=conf.eval.is_flat, is_multi=is_multi,
                is_label_split=conf.data.generate_label_group,
                label_split_json_file=os.path.join(conf.data.dict_dir,
                                                   "{}.json".format(ClassificationDataset.DOC_LABEL_GROUP))
            )
        logger.warn(
            "Performance is precision: %f, "
            "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % (
                precision_list[0][cEvaluator.MICRO_AVERAGE],
                recall_list[0][cEvaluator.MICRO_AVERAGE],
                fscore_list[0][cEvaluator.MICRO_AVERAGE],
                right_list[0][cEvaluator.MICRO_AVERAGE],
                predict_list[0][cEvaluator.MICRO_AVERAGE],
                standard_list[0][cEvaluator.MICRO_AVERAGE]))
    evaluator.save()
示例#6
0
  def model_fn(features, labels, mode, params=None):
    """Build model and optimizer."""
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Check training mode.
    if FLAGS.train_mode == 'pretrain':
      num_transforms = 1
      if FLAGS.use_td_loss:
        num_transforms += 1
      if FLAGS.use_bu_loss:
        num_transforms += 1

      if FLAGS.fine_tune_after_block > -1:
        raise ValueError('Does not support layer freezing during pretraining,'
                         'should set fine_tune_after_block<=-1 for safety.')
    elif FLAGS.train_mode == 'finetune':
      num_transforms = 1
    else:
      raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))
    
    # Split channels, and optionally apply extra batched augmentation.
    features_list = tf.split(
        features, num_or_size_splits=num_transforms, axis=-1)
    
    if FLAGS.use_td_loss:
      target_images = features_list[-1]
      features_list = features_list[:-1]
      # transforms
      thetas_list = tf.split(
        labels['thetas'], num_or_size_splits=num_transforms, axis=-1)
      if FLAGS.train_mode == 'pretrain':  # Fix for fine-tuning/eval
        thetas = tf.concat(thetas_list[:-1], 0)
    else:
      target_images = features_list
    

    if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
      features_list, sigmas = data_util.batch_random_blur(
          features_list, FLAGS.image_size, FLAGS.image_size)
      if FLAGS.use_td_loss: 
        sigmas = tf.concat(sigmas, 0)
        thetas = tf.concat([thetas, sigmas[:,None]], 1) 
    else:
      if FLAGS.use_td_loss:
        sigmas = tf.zeros_like(thetas[:,0])
        thetas = tf.concat([thetas, sigmas[:,None]], 1) 
        # thetas = tf.zeros([target_images.get_shape().as_list()[0], 11]) 

    features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)
    
    # Base network forward pass.
    with tf.variable_scope('base_model'):
      if FLAGS.train_mode == 'finetune':
        if FLAGS.fine_tune_after_block >= 4:
          # Finetune just supervised (linear) head will not update BN stats.
          model_train_mode = False
      else:
        if FLAGS.use_td_loss:
          viz_features = features
          features = (features, thetas)
        else:
          viz_features = features

        # Pretrain or finetune anything else will update BN stats.
        model_train_mode = is_training

      outputs = model(features, is_training=model_train_mode)
      
    # Add head and loss.
    if FLAGS.train_mode == 'pretrain':
      tpu_context = params['context'] if 'context' in params else None
      
      if FLAGS.use_td_loss and isinstance(outputs, tuple):
        hiddens, reconstruction, metric_hidden_r, metric_hidden_t = outputs
      else:
        hiddens = outputs
        reconstruction = features

      if FLAGS.use_td_loss:
        with tf.name_scope('td_loss'):
          if FLAGS.td_loss=='attractive':
            td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_loss(
              reconstruction=metric_hidden_r,
              target=metric_hidden_t,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)
            logits_td_con = tf.zeros([params['batch_size'], params['batch_size']])
            labels_td_con = tf.zeros([params['batch_size'], params['batch_size']])
          elif FLAGS.td_loss=='attractive_repulsive':
            td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_repulsive_loss(
              reconstruction=metric_hidden_r,
              target=metric_hidden_t,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)
          else:
            raise NotImplementedError("Error at TD loss {}".format(FLAGS.td_loss))
      else:
        # No TD loss
        logits_td_con = tf.zeros([params['batch_size'], params['batch_size']])
        labels_td_con = tf.zeros([params['batch_size'], params['batch_size']])
        td_loss = 0.
      hiddens_proj = model_util.projection_head(hiddens, is_training)

      if FLAGS.use_bu_loss:
        with tf.name_scope('bu_loss'):
          if FLAGS.bu_loss=='attractive':
            bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_loss(
              hiddens_proj,
              temperature=FLAGS.temperature,
              hidden_norm=FLAGS.hidden_norm)
            logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
            labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']])

          elif FLAGS.bu_loss=='attractive_repulsive':
            bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_repulsive_loss(
              hiddens_proj,
              hidden_norm=FLAGS.hidden_norm,
              temperature=FLAGS.temperature,
              tpu_context=tpu_context if is_training else None)  
          else:
            raise NotImplementedError('Unknown loss')
      else:
        # No BU loss
        logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
        labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']])
        bu_loss = 0.
      logits_sup = tf.zeros([params['batch_size'], num_classes])

    else:
      # contrast_loss = tf.zeros([])
      td_loss = tf.zeros([])
      bu_loss = tf.zeros([])
      logits_td_con = tf.zeros([params['batch_size'], 10])
      labels_td_con = tf.zeros([params['batch_size'], 10])
      logits_bu_con = tf.zeros([params['batch_size'], 10])
      labels_bu_con = tf.zeros([params['batch_size'], 10])
      hiddens = outputs
      hiddens = model_util.projection_head(hiddens, is_training)
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      sup_loss = obj_lib.supervised_loss(
          labels=labels['labels'],
          logits=logits_sup,
          weights=labels['mask'])

    # Add weight decay to loss, for non-LARS optimizers.
    model_util.add_weight_decay(adjust_per_optimizer=True)
    
    # reg_loss = tf.losses.get_regularization_losses()

    
    if FLAGS.train_mode == 'pretrain':
      print(bu_loss)
      print(td_loss)
      loss =  tf.add_n([td_loss * FLAGS.td_loss_weight, bu_loss * FLAGS.bu_loss_weight] + tf.losses.get_regularization_losses())
    else:
      loss =  tf.add_n([sup_loss] + tf.losses.get_regularization_losses())
           
    # loss = tf.losses.get_total_loss()

    if FLAGS.train_mode == 'pretrain':
      variables_to_train = tf.trainable_variables()
    else:
      collection_prefix = 'trainable_variables_inblock_'
      variables_to_train = []
      for j in range(FLAGS.fine_tune_after_block + 1, 6):
        variables_to_train += tf.get_collection(collection_prefix + str(j))
      assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

    tf.logging.info('===============Variables to train (begin)===============')
    tf.logging.info(variables_to_train)
    tf.logging.info('================Variables to train (end)================')

    learning_rate = model_util.learning_rate_schedule(
        FLAGS.learning_rate, num_train_examples)

    if is_training:
      
      if FLAGS.train_summary_steps > 0:
        # Compute stats for the summary.
        prob_bu_con = tf.nn.softmax(logits_bu_con)
        entropy_bu_con = - tf.reduce_mean(
            tf.reduce_sum(prob_bu_con * tf.math.log(prob_bu_con + 1e-8), -1))
        prob_td_con = tf.nn.softmax(logits_td_con)
        entropy_td_con = - tf.reduce_mean(
            tf.reduce_sum(prob_td_con * tf.math.log(prob_td_con + 1e-8), -1))

        contrast_bu_acc = tf.equal(
            tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1))
        contrast_bu_acc = tf.reduce_mean(tf.cast(contrast_bu_acc, tf.float32))
        contrast_td_acc = tf.equal(
            tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1))
        contrast_td_acc = tf.reduce_mean(tf.cast(contrast_td_acc, tf.float32))
        
        label_acc = tf.equal(
            tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
        label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
        

        def host_call_fn(gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im):
          gs = gs[0]
          with tf2.summary.create_file_writer(
              FLAGS.model_dir,
              max_queue=FLAGS.checkpoint_steps).as_default():
            with tf2.summary.record_if(True):
              tf2.summary.scalar(
                  'total_loss',
                  g_l[0],
                  step=gs)
                  
              tf2.summary.scalar(
                  'train_bottomup_loss',
                  bu_l[0],
                  step=gs)

              tf2.summary.scalar(
                  'train_topdown_loss',
                  td_l[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'train_bottomup_acc',
                  c_bu_a[0],
                  step=gs)
              tf2.summary.scalar(
                  'train_topdown_acc',
                  c_td_a[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'train_label_accuracy',
                  l_a[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'contrast_bu_entropy',
                  c_e_bu[0],
                  step=gs)
              tf2.summary.scalar(
                  'contrast_td_entropy',
                  c_e_td[0],
                  step=gs)
              
              tf2.summary.scalar(
                  'learning_rate', lr[0],
                  step=gs)

              # print("Images")
              # print(target_images)
              # print("Features")
              # print(viz_features)
              # print("Reconstruction")
              # print(reconstruction)
              tf2.summary.image(
                  'Images',
                  tar_im[0],
                  step=gs)
              tf2.summary.image(
                  'Transformed images',
                  viz_f[0],
                  step=gs)
              tf2.summary.image(
                  'Reconstructed images',
                  rec_im[0],
                  step=gs)

            return tf.summary.all_v2_summary_ops()


        n_images = 4
        if isinstance(target_images, list):
          target_images = target_images[0]
        image_shape = target_images.get_shape().as_list()

        tar_im = tf.reshape(tf.cast(target_images[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        viz_f = tf.reshape(tf.cast(viz_features[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        rec_im = tf.reshape(tf.cast(reconstruction[:n_images], tf.float32), [1, n_images] + image_shape[1:])
        
        gs = tf.reshape(tf.train.get_global_step(), [1])
        
        g_l = tf.reshape(loss, [1])

        bu_l = tf.reshape(bu_loss, [1])
        td_l = tf.reshape(td_loss, [1])

        c_bu_a = tf.reshape(contrast_bu_acc, [1])
        c_td_a = tf.reshape(contrast_td_acc, [1])
        
        l_a = tf.reshape(label_acc, [1])
        c_e_bu = tf.reshape(entropy_bu_con, [1])
        c_e_td = tf.reshape(entropy_td_con, [1])
        
        lr = tf.reshape(learning_rate, [1])
        
        host_call = (host_call_fn, [gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im])
        
      else:
        host_call=None

      optimizer = model_util.get_optimizer(learning_rate)
      control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      # if FLAGS.train_summary_steps > 0:
      #   control_deps.extend(tf.summary.all_v2_summary_ops())
      with tf.control_dependencies(control_deps):
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step(),
            var_list=variables_to_train)
      
      
      if FLAGS.checkpoint:
        def scaffold_fn():
          """Scaffold function to restore non-logits vars from checkpoint."""
          tf.logging.info('*'*180)
          tf.logging.info('Initializing from checkpoint %s'%FLAGS.checkpoint)
          tf.logging.info('*'*180)

          tf.train.init_from_checkpoint(
              FLAGS.checkpoint,
              {v.op.name: v.op.name
               for v in tf.global_variables(FLAGS.variable_schema)})

          if FLAGS.zero_init_logits_layer:
            # Init op that initializes output layer parameters to zeros.
            output_layer_parameters = [
                var for var in tf.trainable_variables() if var.name.startswith(
                    'head_supervised')]
            tf.logging.info('Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
            with tf.control_dependencies([tf.global_variables_initializer()]):
              init_op = tf.group([
                  tf.assign(x, tf.zeros_like(x))
                  for x in output_layer_parameters])
            return tf.train.Scaffold(init_op=init_op)
          else:
            return tf.train.Scaffold()
      else:
        scaffold_fn = None

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, 
          train_op=train_op, 
          loss=loss, 
          scaffold_fn=scaffold_fn, 
          host_call=host_call
          )

    else:

      def metric_fn(logits_sup, labels_sup, logits_bu_con, labels_bu_con, 
                    logits_td_con, labels_td_con, mask,
                    **kws):
        """Inner metric function."""
        metrics = {k: tf.metrics.mean(v, weights=mask)
                   for k, v in kws.items()}
        metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)
        metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
        
        metrics['bottomup_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1),
            weights=mask)
        # metrics['bottomup_top_5_accuracy'] = tf.metrics.recall_at_k(
        #     tf.argmax(labels_bu_con, 1), logits_bu_con, k=5, weights=mask)

        metrics['topdown_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1),
            weights=mask)
        # metrics['topdown_top_5_accuracy'] = tf.metrics.recall_at_k(
        #     tf.argmax(labels_td_con, 1), logits_td_con, k=5, weights=mask)
        return metrics

      metrics = {
          'logits_sup': logits_sup,
          'labels_sup': labels['labels'],
          'logits_bu_con': logits_bu_con,
          'logits_td_con': logits_td_con,
          'labels_bu_con': labels_bu_con,
          'labels_td_con': labels_td_con,
          'mask': labels['mask'],
          'td_loss': tf.fill((params['batch_size'],), bu_loss),
          'bu_loss': tf.fill((params['batch_size'],), td_loss),
          'regularization_loss': tf.fill((params['batch_size'],),
                                         tf.losses.get_regularization_loss()),
      }

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, metrics),
          host_call=None,
          scaffold_fn=None)