def main(argv): # Allow running multiple at once set_gpu_memory(FLAGS.gpumem) # Figure out the log and model directory filenames assert FLAGS.uid != "", "uid cannot be an empty string" model_dir, log_dir = get_directory_names() if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # Write config file about what dataset we're using, sources, target, etc. file_utils.write_config_from_args(log_dir) # Load datasets source_datasets, target_dataset = load_datasets.load_da(FLAGS.dataset, FLAGS.sources, FLAGS.target, test=FLAGS.test) # for x in source_datasets: # print (x) # source_train_iterators = [iter(x.train) for x in source_datasets] # print (len(source_train_iterators)) # for x in source_train_iterators: # a = next(x) # print (a) # data_sources = [next(x) for x in source_train_iterators] # data_sources = [next(x) for x in source_train_iterators] # data_sources = [next(x) for x in source_train_iterators] # Need to know which iteration for learning rate schedule global_step = tf.Variable(0, name="global_step", trainable=False) # Load the method, model, etc. method = methods.get_method(FLAGS.method, source_datasets=source_datasets, target_dataset=target_dataset, model_name=FLAGS.model, global_step=global_step, total_steps=FLAGS.steps, ensemble_size=FLAGS.ensemble, moving_average=FLAGS.moving_average, share_most_weights=FLAGS.share_most_weights) # Check that this method is supposed to be trainable. If not, we're done. # (Basically, we just wanted to write the config file for non-trainable # models.) if not method.trainable: print("Method not trainable. Exiting now.") return # Checkpoints checkpoint = tf.train.Checkpoint( global_step=global_step, **method.checkpoint_variables) checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir) checkpoint_manager.restore_latest() # Metrics has_target_domain = target_dataset is not None metrics = Metrics(log_dir, method, source_datasets, target_dataset, has_target_domain) # Start training # # TODO maybe eventually rewrite this in the more-standard Keras way # See: https://www.tensorflow.org/guide/keras/train_and_evaluate for i in range(int(global_step), FLAGS.steps+1): t = time.time() data_sources, data_target = method.train_step() global_step.assign_add(1) t = time.time() - t if FLAGS.time_training: print(int(global_step), t, sep=",") continue # skip evaluation, checkpointing, etc. when timing if i%1000 == 0: print("step %d took %f seconds"%(int(global_step), t)) sys.stdout.flush() # otherwise waits till the end to flush on Kamiak # Metrics on training/validation data if FLAGS.log_train_steps != 0 and i%FLAGS.log_train_steps == 0: metrics.train(data_sources, data_target, global_step, t) # Evaluate every log_val_steps but also at the last step validation_accuracy_source = None validation_accuracy_target = None if (FLAGS.log_val_steps != 0 and i%FLAGS.log_val_steps == 0) \ or i == FLAGS.steps: validation_accuracy_source, validation_accuracy_target \ = metrics.test(global_step) print(validation_accuracy_source,validation_accuracy_target) # Checkpoints -- Save either if at the right model step or if we found # a new validation accuracy. If this is better than the previous best # model, we need to make a new checkpoint so we can restore from this # step with the best accuracy. if (FLAGS.model_steps != 0 and i%FLAGS.model_steps == 0) \ or validation_accuracy_source is not None: checkpoint_manager.save(int(global_step-1), validation_accuracy_source, validation_accuracy_target) # Plots if FLAGS.log_plots_steps != 0 and i%FLAGS.log_plots_steps == 0: metrics.plots(global_step) # We're done -- used for hyperparameter tuning file_utils.write_finished(log_dir)
def process_model(log_dir, model_dir, config, gpumem, multi_gpu): """ Evaluate a model on the train/test data and compute the results """ setup_gpu_for_process(gpumem, multi_gpu) dataset_name = config["dataset"] method_name = config["method"] model_name = config["model"] sources = config["sources"] target = config["target"] moving_average = config["moving_average"] ensemble_size = config["ensemble"] share_most_weights = config["share_most_weights"] # Load datasets source_datasets, target_dataset = load_datasets.load_da(dataset_name, sources, target, test=FLAGS.test) # Load the method, model, etc. # Note: {global,num}_step are for training, so it doesn't matter what # we set them to here method = methods.get_method(method_name, source_datasets=source_datasets, target_dataset=target_dataset, model_name=model_name, global_step=1, total_steps=1, moving_average=moving_average, ensemble_size=ensemble_size, share_most_weights=share_most_weights) # Load model from checkpoint (if there's anything in the checkpoint) if len(method.checkpoint_variables) > 0: checkpoint = tf.train.Checkpoint(**method.checkpoint_variables) checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir) if FLAGS.selection == "last": checkpoint_manager.restore_latest() max_accuracy_step = checkpoint_manager.latest_step() max_accuracy = None # We don't really care... found = checkpoint_manager.found_last elif FLAGS.selection == "best_source": checkpoint_manager.restore_best_source() max_accuracy_step = checkpoint_manager.best_step_source() max_accuracy = checkpoint_manager.best_validation_source found = checkpoint_manager.found_best_source elif FLAGS.selection == "best_target": checkpoint_manager.restore_best_target() max_accuracy_step = checkpoint_manager.best_step_target() max_accuracy = checkpoint_manager.best_validation_target found = checkpoint_manager.found_best_target else: raise NotImplementedError("unknown --selection argument") else: max_accuracy_step = None max_accuracy = None found = True # Metrics has_target_domain = target_dataset is not None metrics = Metrics(log_dir, method, source_datasets, target_dataset, has_target_domain) # If not found, give up if not found: return log_dir, model_dir, config, {}, None, None # Evaluate on both datasets metrics.train_eval() metrics.test(evaluation=True) # Get results results = metrics.results() return log_dir, model_dir, config, results, max_accuracy_step, max_accuracy
def process_model(log_dir, model_dir, source, target, model_name, method_name, gpumem, multi_gpu): """ Evaluate a model on the train/test data and compute the results """ # We need to do this in the process since otherwise TF can't access cuDNN # for some reason. But, we only need to do this the first time we create the # process. It'll error on any subsequent calls (since the pool re-uses # process). try: set_gpu_memory(FLAGS.gpumem) except RuntimeError: pass # Ignore: "RuntimeError: GPU options must be set at program startup" # Get what GPU to run this on, otherwise it'll default to whatever the # first one is if multi_gpu: # Get all GPUs SLURM gave to us and what process in the pool this is available_gpus = get_gpus() pool_id = get_pool_id() # Pick which one based on pool id gpu = available_gpus[pool_id] # Only let TensorFlow see this GPU. I tried tf.device, but somehow # each process still put some stuff into memory on every GPU. os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) # Load datasets if target != "": source_dataset, target_dataset = load_datasets.load_da(source, target, test=True) assert source_dataset.num_classes == target_dataset.num_classes, \ "Adapting from source to target with different classes not supported" else: raise NotImplementedError("currently don't support only source") source_dataset = load_datasets.load(source, test=True) target_dataset = None # Evaluation datasets if we have the dataset source_dataset_train = source_dataset.train_evaluation target_dataset_train = target_dataset.train_evaluation \ if target_dataset is not None else None source_dataset_test = source_dataset.test_evaluation target_dataset_test = target_dataset.test_evaluation \ if target_dataset is not None else None # Information about domains num_classes = source_dataset.num_classes # Build our model # Note: {global,num}_step are for training, so it doesn't matter what # we set them to here global_step = 1 num_steps = 1 model = DomainAdaptationModel(num_classes, model_name, global_step, num_steps) # Does this method use a target classifier? has_target_classifier = method_name in ["pseudo", "instance"] \ and FLAGS.target_classifier # Load model from checkpoint checkpoint = tf.train.Checkpoint(model=model) checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir, target=has_target_classifier) if FLAGS.last: checkpoint_manager.restore_latest() max_accuracy_step = checkpoint_manager.latest_step() max_accuracy = 0 # We don't really care... else: checkpoint_manager.restore_best(FLAGS.best_target) max_accuracy_step = checkpoint_manager.best_step(FLAGS.best_target) if has_target_classifier and FLAGS.best_target: max_accuracy = checkpoint_manager.best_target_validation else: max_accuracy = checkpoint_manager.best_validation # Print which step we're loading the model for print(log_dir + "," + source + "," + target + "," + method_name + "," + model_name + "," + str(max_accuracy_step) + "," + str(max_accuracy)) # If not found, give up if not checkpoint_manager.found: return source, target, model_name, method_name, \ None, None, None, None, \ None, None, None, None # Metrics have_target_domain = target_dataset is not None metrics = Metrics(log_dir, source_dataset, None, None, have_target_domain, target_classifier=has_target_classifier, enable_compile=False) # Evaluate on both datasets metrics.train(model, source_dataset_train, target_dataset_train, evaluation=True) metrics.test(model, source_dataset_test, target_dataset_test, evaluation=True) # Get results results = metrics.results() s_train = results["accuracy_task/source/training"] s_test = results["accuracy_task/source/validation"] target_s_train = None target_s_test = None target_t_train = None target_t_test = None if has_target_classifier: target_s_train = results["accuracy_target/source/training"] target_s_test = results["accuracy_target/source/validation"] if target_dataset is not None: t_train = results["accuracy_task/target/training"] t_test = results["accuracy_task/target/validation"] if has_target_classifier: target_t_train = results["accuracy_target/target/training"] target_t_test = results["accuracy_target/target/validation"] else: t_train = None t_test = None return log_dir, source, target, model_name, method_name, \ s_train, t_train, s_test, t_test, \ target_s_train, target_s_test, target_t_train, target_t_test
def main(argv): # Allow running multiple at once set_gpu_memory(FLAGS.gpumem) # Figure out the log and model directory filenames model_dir, log_dir = get_directory_names() if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # We adapt for any of the methods other than "none" (no adaptation) adapt = FLAGS.method != "none" # For adaptation, we'll be concatenating together half source and half target # data, so to keep the batch_size about the same, we'll cut it in half train_batch = FLAGS.train_batch if adapt and FLAGS.use_grl: train_batch = train_batch // 2 # Input training data # # Note: "It is worth noting that only the training sets of the small image # datasets were used during training; the test sets used for reporting # scores only." (self-ensembling) -- so, only use *_test for evaluation. # However, for now we'll use 1000 random target test samples for the # validation dataset (as is common). if FLAGS.target != "": source_dataset, target_dataset = load_datasets.load_da( FLAGS.source, FLAGS.target, test=FLAGS.test, train_batch=train_batch) assert source_dataset.num_classes == target_dataset.num_classes, \ "Adapting from source to target with different classes not supported" else: raise NotImplementedError("currently don't support only source") source_dataset = load_datasets.load(FLAGS.source, test=FLAGS.test, train_batch=train_batch) target_dataset = None # Iterator and evaluation datasets if we have the dataset source_iter = iter(source_dataset.train) source_dataset_eval = source_dataset.test_evaluation target_iter = iter(target_dataset.train) \ if target_dataset is not None else None target_dataset_eval = target_dataset.test_evaluation \ if target_dataset is not None else None # Information about domains num_classes = source_dataset.num_classes # Loss functions task_loss = models.make_task_loss(adapt and FLAGS.use_grl) domain_loss = models.make_domain_loss(adapt) weighted_task_loss = models.make_weighted_loss() # We need to know where we are in training for the GRL lambda schedule global_step = tf.Variable(0, name="global_step", trainable=False) # Build our model model = models.DomainAdaptationModel(num_classes, FLAGS.model, global_step, FLAGS.steps, use_grl=FLAGS.use_grl) # Optimizers opt = tf.keras.optimizers.Adam(FLAGS.lr) d_opt = tf.keras.optimizers.Adam(FLAGS.lr * FLAGS.lr_domain_mult) # For GAN-like training (train_step_gan), we'll weight by the GRL schedule # to make it more equivalent to when use_grl=True. grl_schedule = models.DannGrlSchedule(FLAGS.steps) # Target classifier optimizer if target_classifier, otherwise the optimizer # for the task-classifier when running on pseudo-labeled data has_target_classifier = FLAGS.method in ["pseudo", "instance"] t_opt = tf.keras.optimizers.Adam(FLAGS.lr * FLAGS.lr_target_mult) # Checkpoints checkpoint = tf.train.Checkpoint(global_step=global_step, opt=opt, d_opt=d_opt, t_opt=t_opt, model=model) checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir, target=has_target_classifier) checkpoint_manager.restore_latest() # Metrics has_target_domain = target_dataset is not None metrics = Metrics(log_dir, source_dataset, task_loss, domain_loss, has_target_domain, has_target_classifier, enable_compile=FLAGS.compile_metrics) # Start training for i in range(int(global_step), FLAGS.steps + 1): # Get data for this iteration data_a = next(source_iter) data_b = next(target_iter) if target_iter is not None else None t = time.time() step_args = (data_a, data_b, model, opt, d_opt, task_loss, domain_loss) if adapt and FLAGS.use_grl: train_step_grl(*step_args) elif adapt: instance_weights = train_step_gan(*step_args, grl_schedule, global_step) else: train_step_none(*step_args) if FLAGS.method == "pseudo": # We'll ignore the real labels, so just get the data x, _ = data_b # Pseudo-label target data if FLAGS.use_domain_confidence: task_y_pred, weights = pseudo_label_domain(x, model) else: task_y_pred, weights = pseudo_label_task(x, model) # Create new data with same input by pseudo-labels not true labels data_b_pseudo = (x, task_y_pred) # Train target classifier on pseudo-labeled data, weighted # by probability that it's source data (i.e. higher confidence) train_step_target(data_b_pseudo, weights, model, t_opt, weighted_task_loss) elif FLAGS.method == "instance": # Train target classifier on source data, but weighted # by probability that it's target data train_step_target(data_a, instance_weights, model, t_opt, weighted_task_loss) global_step.assign_add(1) t = time.time() - t if i % 10 == 0: logging.info("step %d took %f seconds", int(global_step), t) # Metrics on training/validation data if i % FLAGS.log_train_steps == 0: metrics.train(model, data_a, data_b, global_step, t) validation_accuracy = None if i % FLAGS.log_val_steps == 0: validation_accuracy, target_validation_accuracy = metrics.test( model, source_dataset_eval, target_dataset_eval, global_step) # Checkpoints -- Save either if at the right model step or if we found # a new validation accuracy. If this is better than the previous best # model, we need to make a new checkpoint so we can restore from this # step with the best accuracy. if i % FLAGS.model_steps == 0 or validation_accuracy is not None: checkpoint_manager.save(int(global_step - 1), validation_accuracy, target_validation_accuracy) # We're done -- used for hyperparameter tuning write_finished(log_dir)