def train(conf): logger = util.Logger(conf) if not os.path.exists(conf.checkpoint_dir): os.makedirs(conf.checkpoint_dir) model_name = conf.model_name dataset_name = "ClassificationDataset" collate_name = "FastTextCollator" if model_name == "FastText" \ else "ClassificationCollator" train_data_loader, validate_data_loader, test_data_loader = \ get_data_loader(dataset_name, collate_name, conf) empty_dataset = globals()[dataset_name](conf, [], mode="train") model = get_classification_model(model_name, empty_dataset, conf) loss_fn = globals()["ClassificationLoss"](label_size=len( empty_dataset.label_map), loss_type=conf.train.loss_type) optimizer = get_optimizer(conf, model) evaluator = cEvaluator(conf.eval.dir) trainer = globals()["ClassificationTrainer"](empty_dataset.label_map, logger, evaluator, conf, loss_fn) best_epoch = -1 best_performance = 0 model_file_prefix = conf.checkpoint_dir + "/" + model_name for epoch in range(conf.train.start_epoch, conf.train.start_epoch + conf.train.num_epochs): start_time = time.time() trainer.train(train_data_loader, model, optimizer, "Train", epoch) trainer.eval(train_data_loader, model, optimizer, "Train", epoch) performance = trainer.eval(validate_data_loader, model, optimizer, "Validate", epoch) trainer.eval(test_data_loader, model, optimizer, "test", epoch) if performance > best_performance: # record the best model best_epoch = epoch best_performance = performance save_checkpoint( { 'epoch': epoch, 'model_name': model_name, 'state_dict': model.state_dict(), 'best_performance': best_performance, 'optimizer': optimizer.state_dict(), }, model_file_prefix) time_used = time.time() - start_time logger.info("Epoch %d cost time: %d second" % (epoch, time_used)) # best model on validateion set best_epoch_file_name = model_file_prefix + "_" + str(best_epoch) best_file_name = model_file_prefix + "_best" shutil.copyfile(best_epoch_file_name, best_file_name) load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model, optimizer) trainer.eval(test_data_loader, model, optimizer, "Best test", best_epoch)
def kfold_eval(conf): logger = util.Logger(conf) model_name = conf.model_name dataset_name = "ClassificationDataset" collate_name = "FastTextCollator" if model_name == "FastText" \ else "ClassificationCollator" test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) collate_fn = globals()[collate_name](conf, len(test_dataset.label_map)) test_data_loader = DataLoader( test_dataset, batch_size=conf.eval.batch_size, shuffle=False, num_workers=conf.data.num_worker, collate_fn=collate_fn, pin_memory=True) empty_dataset = globals()[dataset_name](conf, []) model = get_classification_model(model_name, empty_dataset, conf) optimizer = get_optimizer(conf, model) load_checkpoint(conf.eval.model_dir, conf, model, optimizer) model.eval() predict_probs = [] standard_labels = [] evaluator = cEvaluator(conf.eval.dir) for batch in test_data_loader: logits = model(batch) result = torch.sigmoid(logits).cpu().tolist() predict_probs.extend(result) standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) # ============================ EVALUATION API ============================================================================================ y_test, predictions = [], [] print (standard_labels) for i, j in zip(standard_labels, predict_probs): y_test.append(i) predictions.append(j) pred, actual = take_values(predictions, y_test , conf.eval.threshold, conf.eval.top_k ) print(pred) actual=np.array(actual) pred=np.array(pred) evaluation_measures={"Accuracy": accuracy(actual, pred) , "Precision": precision(actual, pred) , "Recall": recall(actual, pred) , "F1 score": f1_scor(actual, pred, ) , "Hamming Loss":hammingLoss(actual, pred), "f-1 Macro":macroF1(actual, pred) , "f-1 Micro":microF1(actual, pred), "averagePrecision":averagePrecision(actual, pred) } return evaluation_measures
def eval(conf): logger = util.Logger(conf) model_name = conf.model_name dataset_name = "ClassificationDataset" collate_name = "FastTextCollator" if model_name == "FastText" \ else "ClassificationCollator" test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) collate_fn = globals()[collate_name](conf, len(test_dataset.label_map)) test_data_loader = DataLoader(test_dataset, batch_size=conf.eval.batch_size, shuffle=False, num_workers=conf.data.num_worker, collate_fn=collate_fn, pin_memory=True) empty_dataset = globals()[dataset_name](conf, []) model = get_classification_model(model_name, empty_dataset, conf) optimizer = get_optimizer(conf, model.parameters()) load_checkpoint(conf.eval.model_dir, conf, model, optimizer) model.eval() is_multi = False if conf.task_info.label_type == ClassificationType.MULTI_LABEL: is_multi = True predict_probs = [] standard_labels = [] total_loss = 0. evaluator = cEvaluator(conf.eval.dir) for batch in test_data_loader: logits = model(batch) if not is_multi: result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist() else: result = torch.sigmoid(logits).cpu().tolist() predict_probs.extend(result) standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) total_loss = total_loss / len(predict_probs) (_, precision_list, recall_list, fscore_list, right_list, predict_list, standard_list) = \ evaluator.evaluate( predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map, threshold=conf.eval.threshold, top_k=conf.eval.top_k, is_flat=conf.eval.is_flat, is_multi=is_multi) logger.warn( "Performance is precision: %f, " "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % (precision_list[0][cEvaluator.MICRO_AVERAGE], recall_list[0][cEvaluator.MICRO_AVERAGE], fscore_list[0][cEvaluator.MICRO_AVERAGE], right_list[0][cEvaluator.MICRO_AVERAGE], predict_list[0][cEvaluator.MICRO_AVERAGE], standard_list[0][cEvaluator.MICRO_AVERAGE])) evaluator.save()
def train(conf): model_name = conf.model_name logger = util.Logger(conf) if conf.task_info.weak_pretrain: logger.info("Batch Size: " + str(conf.train.batch_size) + " Pretrain Num Epoch: " + str(conf.train.pretrain_num_epochs)) else: logger.info("Batch Size: " + str(conf.train.batch_size)) if conf.task_info.weak_pretrain and conf.task_info.weak_data_augmentation: model_teacher = get_classification_model(model_name, empty_dataset, conf) if conf.model_name != "BERT": optimizer_teacher = get_optimizer(conf, model_teacher) else: optimizer_teacher = AdamW(model_teacher.parameters(), lr=5e-2, eps=1e-2) # optimizer_teacher: optimizer for teacher model model_target = get_classification_model(model_name, empty_dataset, conf) loss_fn = globals()["ClassificationLoss"](label_size=len( empty_dataset.label_map), loss_type=conf.train.loss_type) if conf.task_info.weak_pretrain: if conf.model_name != "BERT": optimizer_weak = get_optimizer(conf, model_target) else: optimizer_weak = AdamW(model_target.parameters(), lr=5e-2, eps=1e-2) # optimizer_weak: optimizer for target model pretraining stage if conf.model_name != "BERT": optimizer_target = get_optimizer(conf, model_target) else: optimizer_target = AdamW(model_target.parameters(), lr=5e-2, eps=1e-2) # optimizer_target: optimizer for target model fine-tuning stage evaluator = cEvaluator(conf.eval.dir) trainer_target = globals()["ClassificationTrainer"]( empty_dataset.label_map, logger, evaluator, conf, loss_fn) # trainer_target: trainer for target model on fine-tuning stage if conf.task_info.weak_pretrain: trainer_weak = globals()["ClassificationTrainer"]( empty_dataset.label_map, logger, evaluator, conf, loss_fn) # trainer_weak: trainer for target model on pretraining stage if conf.task_info.weak_data_augmentation: trainer_teacher = globals()["ClassificationTrainer"]( empty_dataset.label_map, logger, evaluator, conf, loss_fn) # trainer_teacher: trainer for teacher model if conf.task_info.weak_data_augmentation: best_epoch = -1 best_performance = 0 model_file_prefix = conf.checkpoint_dir + "/" + model_name + "_teacher" logger.info("Training Teacher Model on Labeled Data") for epoch in range(conf.train.start_epoch, conf.train.start_epoch + conf.train.num_epochs): start_time = time.time() trainer_teacher.train(train_data_loader, model_teacher, optimizer_teacher, "Train", epoch) trainer_teacher.eval(train_data_loader, model_teacher, optimizer_teacher, "Train", epoch) performance = trainer_teacher.eval(validate_data_loader, model_teacher, optimizer_teacher, "Validate", epoch) trainer_teacher.eval(test_data_loader, model_teacher, optimizer_teacher, "Test", epoch) if performance > best_performance: # record the best model best_epoch = epoch best_performance = performance temp_model = model_teacher save_checkpoint( { 'epoch': epoch, 'model_name': model_name, 'state_dict': model_teacher.state_dict(), 'best_performance': best_performance, 'optimizer': optimizer_teacher.state_dict(), }, model_file_prefix) time_used = time.time() - start_time logger.info("Epoch %d cost time: %d second" % (epoch, time_used)) best_epoch = -1 best_performance = 0 if conf.task_info.weak_pretrain: if conf.task_info.weak_data_augmentation: unlabeled_data_train_data_loader = select_unlabeled_data( temp_model, unlabeled_train_data_loader, len(trainer_weak.label_map), conf) logger.info("Pretraining on Weak Supervision Data") for epoch in range( conf.train.start_epoch, conf.train.start_epoch + conf.train.pretrain_num_epochs): start_time = time.time() trainer_weak.train(unlabeled_train_data_loader, model_target, optimizer_weak, "Train", epoch) trainer_weak.eval(unlabeled_train_data_loader, model_target, optimizer_weak, "Train", epoch) performance = trainer_weak.eval(validate_data_loader, model_target, optimizer_weak, "Validate", epoch) trainer_weak.eval(test_data_loader, model_target, optimizer_weak, "Test", epoch) if performance > best_performance: # record the best model temp_model = model_target time_used = time.time() - start_time logger.info("Epoch %d cost time: %d second" % (epoch, time_used)) model_target = temp_model logger.info("Fine-tuning on Labeled Data") best_epoch = -1 best_performance = 0 if conf.task_info.weak_pretrain: if conf.task_info.weak_data_augmentation: model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-Augmentation-" + conf.task_info.Augmentation_Method + "-Pretrain" + str( conf.train.pretrain_num_epochs) + "-Batch" + str( conf.train.batch_size) else: model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-WeakSupervision-" + "-Pretrain" + str( conf.train.pretrain_num_epochs) + "-Batch" + str( conf.train.batch_size) else: model_file_prefix = conf.checkpoint_dir + "/" + model_name + "-Batch" + str( conf.train.batch_size) for epoch in range(conf.train.start_epoch, conf.train.start_epoch + conf.train.num_epochs): start_time = time.time() trainer_target.train(train_data_loader, model_target, optimizer_target, "Train", epoch) trainer_target.eval(train_data_loader, model_target, optimizer_target, "Train", epoch) performance = trainer_target.eval(validate_data_loader, model_target, optimizer_target, "Validate", epoch) trainer_target.eval(test_data_loader, model_target, optimizer_target, "Test", epoch) if performance > best_performance: # record the best model best_epoch = epoch best_performance = performance temp_model = model_target save_checkpoint( { 'epoch': epoch, 'model_name': model_name, 'state_dict': model_target.state_dict(), 'best_performance': best_performance, 'optimizer': optimizer_target.state_dict(), }, model_file_prefix) time_used = time.time() - start_time logger.info("Epoch %d cost time: %d second" % (epoch, time_used)) logger.info("The Best Performance on Validation Data and Test Data") #best_epoch_file_name = model_file_prefix + "_" + str(best_epoch) #best_file_name = model_file_prefix + "_best" #shutil.copyfile(best_epoch_file_name, best_file_name) #load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model, # optimizer) model = temp_model trainer_target.eval(train_data_loader, model, optimizer_target, "Best Train", best_epoch) trainer_target.eval(validate_data_loader, model, optimizer_target, "Best Validate", best_epoch) trainer_target.eval(test_data_loader, model, optimizer_target, "Best Test", best_epoch)
def eval(conf): logger = util.Logger(conf) model_name = conf.model_name dataset_name = "ClassificationDataset" collate_name = "FastTextCollator" if model_name == "FastText" \ else "ClassificationCollator" test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) collate_fn = globals()[collate_name](conf, len(test_dataset.label_map)) test_data_loader = DataLoader( test_dataset, batch_size=conf.eval.batch_size, shuffle=False, num_workers=conf.data.num_worker, collate_fn=collate_fn, pin_memory=True) empty_dataset = globals()[dataset_name](conf, []) model = get_classification_model(model_name, empty_dataset, conf) optimizer = get_optimizer(conf, model) load_checkpoint(conf.eval.model_dir, conf, model, optimizer) model.eval() is_multi = False if conf.task_info.label_type == ClassificationType.MULTI_LABEL: is_multi = True predict_probs = [] standard_labels = [] evaluator = cEvaluator(conf.eval.dir) for batch in test_data_loader: with torch.no_grad(): logits = model(batch) if not is_multi: result = torch.nn.functional.softmax(logits, dim=1) else: result = torch.sigmoid(logits) result = result.detach().cpu().tolist() predict_probs.extend(result) standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) if conf.eval.is_flat: (_, precision_list, recall_list, fscore_list, right_list, predict_list, standard_list, pak_dict, rak_dict, rpak_dict, ndcgak_dict) = \ evaluator.evaluate( predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map, threshold=conf.eval.threshold, top_k=conf.eval.top_k, is_flat=conf.eval.is_flat, is_multi=is_multi, debug_file_name=conf.eval.debug_file_name, is_label_split=conf.data.generate_label_group, label_split_json_file=os.path.join(conf.data.dict_dir, "{}.json".format(ClassificationDataset.DOC_LABEL_GROUP)), instance_remove=conf.eval.instance_remove ) sup_message = "" for i in range(1, conf.eval.top_k + 1): for group in pak_dict[i]: sup_message += "Precision at {} of {} group: {}, ".format(i, group, pak_dict[i][group]) sup_message += "Recall at {} of {} group: {}, ".format(i, group, rak_dict[i][group]) sup_message += "R-Precision at {} of {} group: {}, ".format(i, group, rpak_dict[i][group]) sup_message += "nDCG at {} of {} group: {}, ".format(i, group, ndcgak_dict[i][group]) message = "Performance is precision: {}, recall: {}, fscore: {}, " + \ "macro-fscore: {}, right: {}, predict: {}, standard: {}, " logger.warn(message.format( precision_list[0][cEvaluator.MICRO_AVERAGE], recall_list[0][cEvaluator.MICRO_AVERAGE], fscore_list[0][cEvaluator.MICRO_AVERAGE], fscore_list[0][cEvaluator.MACRO_AVERAGE], right_list[0][cEvaluator.MICRO_AVERAGE], predict_list[0][cEvaluator.MICRO_AVERAGE], standard_list[0][cEvaluator.MICRO_AVERAGE]) + sup_message) else: (_, precision_list, recall_list, fscore_list, right_list, predict_list, standard_list) = \ evaluator.evaluate( predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map, threshold=conf.eval.threshold, top_k=conf.eval.top_k, is_flat=conf.eval.is_flat, is_multi=is_multi, is_label_split=conf.data.generate_label_group, label_split_json_file=os.path.join(conf.data.dict_dir, "{}.json".format(ClassificationDataset.DOC_LABEL_GROUP)) ) logger.warn( "Performance is precision: %f, " "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % ( precision_list[0][cEvaluator.MICRO_AVERAGE], recall_list[0][cEvaluator.MICRO_AVERAGE], fscore_list[0][cEvaluator.MICRO_AVERAGE], right_list[0][cEvaluator.MICRO_AVERAGE], predict_list[0][cEvaluator.MICRO_AVERAGE], standard_list[0][cEvaluator.MICRO_AVERAGE])) evaluator.save()
def model_fn(features, labels, mode, params=None): """Build model and optimizer.""" is_training = mode == tf.estimator.ModeKeys.TRAIN # Check training mode. if FLAGS.train_mode == 'pretrain': num_transforms = 1 if FLAGS.use_td_loss: num_transforms += 1 if FLAGS.use_bu_loss: num_transforms += 1 if FLAGS.fine_tune_after_block > -1: raise ValueError('Does not support layer freezing during pretraining,' 'should set fine_tune_after_block<=-1 for safety.') elif FLAGS.train_mode == 'finetune': num_transforms = 1 else: raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode)) # Split channels, and optionally apply extra batched augmentation. features_list = tf.split( features, num_or_size_splits=num_transforms, axis=-1) if FLAGS.use_td_loss: target_images = features_list[-1] features_list = features_list[:-1] # transforms thetas_list = tf.split( labels['thetas'], num_or_size_splits=num_transforms, axis=-1) if FLAGS.train_mode == 'pretrain': # Fix for fine-tuning/eval thetas = tf.concat(thetas_list[:-1], 0) else: target_images = features_list if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain': features_list, sigmas = data_util.batch_random_blur( features_list, FLAGS.image_size, FLAGS.image_size) if FLAGS.use_td_loss: sigmas = tf.concat(sigmas, 0) thetas = tf.concat([thetas, sigmas[:,None]], 1) else: if FLAGS.use_td_loss: sigmas = tf.zeros_like(thetas[:,0]) thetas = tf.concat([thetas, sigmas[:,None]], 1) # thetas = tf.zeros([target_images.get_shape().as_list()[0], 11]) features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c) # Base network forward pass. with tf.variable_scope('base_model'): if FLAGS.train_mode == 'finetune': if FLAGS.fine_tune_after_block >= 4: # Finetune just supervised (linear) head will not update BN stats. model_train_mode = False else: if FLAGS.use_td_loss: viz_features = features features = (features, thetas) else: viz_features = features # Pretrain or finetune anything else will update BN stats. model_train_mode = is_training outputs = model(features, is_training=model_train_mode) # Add head and loss. if FLAGS.train_mode == 'pretrain': tpu_context = params['context'] if 'context' in params else None if FLAGS.use_td_loss and isinstance(outputs, tuple): hiddens, reconstruction, metric_hidden_r, metric_hidden_t = outputs else: hiddens = outputs reconstruction = features if FLAGS.use_td_loss: with tf.name_scope('td_loss'): if FLAGS.td_loss=='attractive': td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_loss( reconstruction=metric_hidden_r, target=metric_hidden_t, temperature=FLAGS.temperature, tpu_context=tpu_context if is_training else None) logits_td_con = tf.zeros([params['batch_size'], params['batch_size']]) labels_td_con = tf.zeros([params['batch_size'], params['batch_size']]) elif FLAGS.td_loss=='attractive_repulsive': td_loss, logits_td_con, labels_td_con = obj_lib.td_attractive_repulsive_loss( reconstruction=metric_hidden_r, target=metric_hidden_t, temperature=FLAGS.temperature, tpu_context=tpu_context if is_training else None) else: raise NotImplementedError("Error at TD loss {}".format(FLAGS.td_loss)) else: # No TD loss logits_td_con = tf.zeros([params['batch_size'], params['batch_size']]) labels_td_con = tf.zeros([params['batch_size'], params['batch_size']]) td_loss = 0. hiddens_proj = model_util.projection_head(hiddens, is_training) if FLAGS.use_bu_loss: with tf.name_scope('bu_loss'): if FLAGS.bu_loss=='attractive': bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_loss( hiddens_proj, temperature=FLAGS.temperature, hidden_norm=FLAGS.hidden_norm) logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']]) labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']]) elif FLAGS.bu_loss=='attractive_repulsive': bu_loss, logits_bu_con, labels_bu_con = obj_lib.attractive_repulsive_loss( hiddens_proj, hidden_norm=FLAGS.hidden_norm, temperature=FLAGS.temperature, tpu_context=tpu_context if is_training else None) else: raise NotImplementedError('Unknown loss') else: # No BU loss logits_bu_con = tf.zeros([params['batch_size'], params['batch_size']]) labels_bu_con = tf.zeros([params['batch_size'], params['batch_size']]) bu_loss = 0. logits_sup = tf.zeros([params['batch_size'], num_classes]) else: # contrast_loss = tf.zeros([]) td_loss = tf.zeros([]) bu_loss = tf.zeros([]) logits_td_con = tf.zeros([params['batch_size'], 10]) labels_td_con = tf.zeros([params['batch_size'], 10]) logits_bu_con = tf.zeros([params['batch_size'], 10]) labels_bu_con = tf.zeros([params['batch_size'], 10]) hiddens = outputs hiddens = model_util.projection_head(hiddens, is_training) logits_sup = model_util.supervised_head( hiddens, num_classes, is_training) sup_loss = obj_lib.supervised_loss( labels=labels['labels'], logits=logits_sup, weights=labels['mask']) # Add weight decay to loss, for non-LARS optimizers. model_util.add_weight_decay(adjust_per_optimizer=True) # reg_loss = tf.losses.get_regularization_losses() if FLAGS.train_mode == 'pretrain': print(bu_loss) print(td_loss) loss = tf.add_n([td_loss * FLAGS.td_loss_weight, bu_loss * FLAGS.bu_loss_weight] + tf.losses.get_regularization_losses()) else: loss = tf.add_n([sup_loss] + tf.losses.get_regularization_losses()) # loss = tf.losses.get_total_loss() if FLAGS.train_mode == 'pretrain': variables_to_train = tf.trainable_variables() else: collection_prefix = 'trainable_variables_inblock_' variables_to_train = [] for j in range(FLAGS.fine_tune_after_block + 1, 6): variables_to_train += tf.get_collection(collection_prefix + str(j)) assert variables_to_train, 'variables_to_train shouldn\'t be empty!' tf.logging.info('===============Variables to train (begin)===============') tf.logging.info(variables_to_train) tf.logging.info('================Variables to train (end)================') learning_rate = model_util.learning_rate_schedule( FLAGS.learning_rate, num_train_examples) if is_training: if FLAGS.train_summary_steps > 0: # Compute stats for the summary. prob_bu_con = tf.nn.softmax(logits_bu_con) entropy_bu_con = - tf.reduce_mean( tf.reduce_sum(prob_bu_con * tf.math.log(prob_bu_con + 1e-8), -1)) prob_td_con = tf.nn.softmax(logits_td_con) entropy_td_con = - tf.reduce_mean( tf.reduce_sum(prob_td_con * tf.math.log(prob_td_con + 1e-8), -1)) contrast_bu_acc = tf.equal( tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1)) contrast_bu_acc = tf.reduce_mean(tf.cast(contrast_bu_acc, tf.float32)) contrast_td_acc = tf.equal( tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1)) contrast_td_acc = tf.reduce_mean(tf.cast(contrast_td_acc, tf.float32)) label_acc = tf.equal( tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1)) label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32)) def host_call_fn(gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im): gs = gs[0] with tf2.summary.create_file_writer( FLAGS.model_dir, max_queue=FLAGS.checkpoint_steps).as_default(): with tf2.summary.record_if(True): tf2.summary.scalar( 'total_loss', g_l[0], step=gs) tf2.summary.scalar( 'train_bottomup_loss', bu_l[0], step=gs) tf2.summary.scalar( 'train_topdown_loss', td_l[0], step=gs) tf2.summary.scalar( 'train_bottomup_acc', c_bu_a[0], step=gs) tf2.summary.scalar( 'train_topdown_acc', c_td_a[0], step=gs) tf2.summary.scalar( 'train_label_accuracy', l_a[0], step=gs) tf2.summary.scalar( 'contrast_bu_entropy', c_e_bu[0], step=gs) tf2.summary.scalar( 'contrast_td_entropy', c_e_td[0], step=gs) tf2.summary.scalar( 'learning_rate', lr[0], step=gs) # print("Images") # print(target_images) # print("Features") # print(viz_features) # print("Reconstruction") # print(reconstruction) tf2.summary.image( 'Images', tar_im[0], step=gs) tf2.summary.image( 'Transformed images', viz_f[0], step=gs) tf2.summary.image( 'Reconstructed images', rec_im[0], step=gs) return tf.summary.all_v2_summary_ops() n_images = 4 if isinstance(target_images, list): target_images = target_images[0] image_shape = target_images.get_shape().as_list() tar_im = tf.reshape(tf.cast(target_images[:n_images], tf.float32), [1, n_images] + image_shape[1:]) viz_f = tf.reshape(tf.cast(viz_features[:n_images], tf.float32), [1, n_images] + image_shape[1:]) rec_im = tf.reshape(tf.cast(reconstruction[:n_images], tf.float32), [1, n_images] + image_shape[1:]) gs = tf.reshape(tf.train.get_global_step(), [1]) g_l = tf.reshape(loss, [1]) bu_l = tf.reshape(bu_loss, [1]) td_l = tf.reshape(td_loss, [1]) c_bu_a = tf.reshape(contrast_bu_acc, [1]) c_td_a = tf.reshape(contrast_td_acc, [1]) l_a = tf.reshape(label_acc, [1]) c_e_bu = tf.reshape(entropy_bu_con, [1]) c_e_td = tf.reshape(entropy_td_con, [1]) lr = tf.reshape(learning_rate, [1]) host_call = (host_call_fn, [gs, g_l, bu_l, td_l, c_bu_a, c_td_a, l_a, c_e_bu, c_e_td, lr, tar_im, viz_f, rec_im]) else: host_call=None optimizer = model_util.get_optimizer(learning_rate) control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # if FLAGS.train_summary_steps > 0: # control_deps.extend(tf.summary.all_v2_summary_ops()) with tf.control_dependencies(control_deps): train_op = optimizer.minimize( loss, global_step=tf.train.get_or_create_global_step(), var_list=variables_to_train) if FLAGS.checkpoint: def scaffold_fn(): """Scaffold function to restore non-logits vars from checkpoint.""" tf.logging.info('*'*180) tf.logging.info('Initializing from checkpoint %s'%FLAGS.checkpoint) tf.logging.info('*'*180) tf.train.init_from_checkpoint( FLAGS.checkpoint, {v.op.name: v.op.name for v in tf.global_variables(FLAGS.variable_schema)}) if FLAGS.zero_init_logits_layer: # Init op that initializes output layer parameters to zeros. output_layer_parameters = [ var for var in tf.trainable_variables() if var.name.startswith( 'head_supervised')] tf.logging.info('Initializing output layer parameters %s to zero', [x.op.name for x in output_layer_parameters]) with tf.control_dependencies([tf.global_variables_initializer()]): init_op = tf.group([ tf.assign(x, tf.zeros_like(x)) for x in output_layer_parameters]) return tf.train.Scaffold(init_op=init_op) else: return tf.train.Scaffold() else: scaffold_fn = None return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn, host_call=host_call ) else: def metric_fn(logits_sup, labels_sup, logits_bu_con, labels_bu_con, logits_td_con, labels_td_con, mask, **kws): """Inner metric function.""" metrics = {k: tf.metrics.mean(v, weights=mask) for k, v in kws.items()} metrics['label_top_1_accuracy'] = tf.metrics.accuracy( tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1), weights=mask) metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k( tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask) metrics['bottomup_top_1_accuracy'] = tf.metrics.accuracy( tf.argmax(labels_bu_con, 1), tf.argmax(logits_bu_con, axis=1), weights=mask) # metrics['bottomup_top_5_accuracy'] = tf.metrics.recall_at_k( # tf.argmax(labels_bu_con, 1), logits_bu_con, k=5, weights=mask) metrics['topdown_top_1_accuracy'] = tf.metrics.accuracy( tf.argmax(labels_td_con, 1), tf.argmax(logits_td_con, axis=1), weights=mask) # metrics['topdown_top_5_accuracy'] = tf.metrics.recall_at_k( # tf.argmax(labels_td_con, 1), logits_td_con, k=5, weights=mask) return metrics metrics = { 'logits_sup': logits_sup, 'labels_sup': labels['labels'], 'logits_bu_con': logits_bu_con, 'logits_td_con': logits_td_con, 'labels_bu_con': labels_bu_con, 'labels_td_con': labels_td_con, 'mask': labels['mask'], 'td_loss': tf.fill((params['batch_size'],), bu_loss), 'bu_loss': tf.fill((params['batch_size'],), td_loss), 'regularization_loss': tf.fill((params['batch_size'],), tf.losses.get_regularization_loss()), } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(metric_fn, metrics), host_call=None, scaffold_fn=None)