def load_data(): """Loads data.""" if FLAGS.data_source == 'tensorflow_datasets': return load_data_tf_datasets(FLAGS.dataset_name, FLAGS.target_num_train_per_class, FLAGS.target_num_val, FLAGS.seed) elif FLAGS.data_source == 'realistic_ssl': return load_data_realistic_ssl(FLAGS.dataset_name, FLAGS.filename_preprocessed_data, FLAGS.label_map_path) elif FLAGS.data_source == 'planetoid': return load_data_planetoid(FLAGS.dataset_name, FLAGS.preprocessed_data_dir, row_normalize=False) raise ValueError('Unsupported dataset source name: %s' % FLAGS.data_source)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.logging_config: print('Setting logging configuration: ', FLAGS.logging_config) config.fileConfig(FLAGS.logging_config) # Set random seed. np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) ############################################################################ # DATA # ############################################################################ # Load data. data = load_data_planetoid(name=FLAGS.dataset_name, path=FLAGS.data_path, row_normalize=FLAGS.row_normalize) # Potentially add noisy edges. This can be used to asses the robustness of # GAM to noisy edges. See `Robustness` section of our paper. if FLAGS.target_ratio_correct: data = add_noisy_edges(data, FLAGS.target_ratio_correct) ############################################################################ # PREPARE OUTPUTS # ############################################################################ # Put together parameters to create a model name. model_name = FLAGS.model_cls model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else '' model_name += '-' + FLAGS.model_agr model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else '' model_name += '-aggr_' + FLAGS.aggregation_agr_inputs model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else '' model_name += ( '-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' % (FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label, FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls)) model_name += (('-wdecayCls_%.4f' % FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '') model_name += (('-wdecayAgr_%.4f' % FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '') model_name += '-LL_%s_LU_%s_UU_%s' % (str( FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str( FLAGS.reg_weight_uu)) model_name += '-perfAgr' if FLAGS.use_perfect_agreement else '' model_name += '-perfCls' if FLAGS.use_perfect_classifier else '' model_name += '-keepProp' if FLAGS.keep_label_proportions else '' model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else '' model_name += '-VAT' if FLAGS.reg_weight_vat > 0 else '' model_name += 'ENT' if FLAGS.reg_weight_vat > 0 and FLAGS.use_ent_min else '' model_name += '-transd' if not FLAGS.inductive else '' model_name += '-L2' if FLAGS.use_l2_cls else '-CE' model_name += '-graph' if FLAGS.use_graph else '-noGraph' model_name += '-rowNorm' if FLAGS.row_normalize else '' model_name += '-seed_' + str(FLAGS.seed) model_name += FLAGS.experiment_suffix logging.info('Model name: %s', model_name) # Create directories for model checkpoints, summaries, and # self-labeled data backup. summary_dir = os.path.join(FLAGS.output_dir, 'summaries', FLAGS.dataset_name, model_name) checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints', FLAGS.dataset_name, model_name) data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints', FLAGS.dataset_name, model_name) if not os.path.exists(checkpoints_dir): os.makedirs(checkpoints_dir) if not os.path.exists(data_dir): os.makedirs(data_dir) ############################################################################ # MODEL SETUP # ############################################################################ # Create classification model. model_cls = get_model_cls(model_name=FLAGS.model_cls, data=data, dataset_name=FLAGS.dataset_name, hidden=FLAGS.hidden_cls) # Create agreement model. model_agr = get_model_agr( model_name=FLAGS.model_agr, dataset_name=FLAGS.dataset_name, hidden_aggreg=FLAGS.hidden_aggreg, aggregation_agr_inputs=FLAGS.aggregation_agr_inputs, hidden=FLAGS.hidden_agr) # Train. trainer = TrainerCotraining( model_cls=model_cls, model_agr=model_agr, max_num_iter_cotrain=FLAGS.max_num_iter_cotrain, min_num_iter_cls=FLAGS.min_num_iter_cls, max_num_iter_cls=FLAGS.max_num_iter_cls, num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls, min_num_iter_agr=FLAGS.min_num_iter_agr, max_num_iter_agr=FLAGS.max_num_iter_agr, num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr, num_samples_to_label=FLAGS.num_samples_to_label, min_confidence_new_label=FLAGS.min_confidence_new_label, keep_label_proportions=FLAGS.keep_label_proportions, num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr, optimizer=tf.train.AdamOptimizer, gradient_clip=FLAGS.gradient_clip, batch_size_agr=FLAGS.batch_size_agr, batch_size_cls=FLAGS.batch_size_cls, learning_rate_cls=FLAGS.learning_rate_cls, learning_rate_agr=FLAGS.learning_rate_agr, enable_summaries=True, enable_summaries_per_model=True, summary_dir=summary_dir, summary_step_cls=FLAGS.summary_step_cls, summary_step_agr=FLAGS.summary_step_agr, logging_step_cls=FLAGS.logging_step_cls, logging_step_agr=FLAGS.logging_step_agr, eval_step_cls=FLAGS.eval_step_cls, eval_step_agr=FLAGS.eval_step_agr, checkpoints_dir=checkpoints_dir, checkpoints_step=1, data_dir=data_dir, abs_loss_chg_tol=1e-10, rel_loss_chg_tol=1e-7, loss_chg_iter_below_tol=30, use_perfect_agr=FLAGS.use_perfect_agreement, use_perfect_cls=FLAGS.use_perfect_classifier, warm_start_cls=FLAGS.warm_start_cls, warm_start_agr=FLAGS.warm_start_agr, ratio_valid_agr=FLAGS.ratio_valid_agr, max_samples_valid_agr=FLAGS.max_samples_valid_agr, weight_decay_cls=FLAGS.weight_decay_cls, weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls, weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr, weight_decay_agr=FLAGS.weight_decay_agr, reg_weight_ll=FLAGS.reg_weight_ll, reg_weight_lu=FLAGS.reg_weight_lu, reg_weight_uu=FLAGS.reg_weight_uu, num_pairs_reg=FLAGS.num_pairs_reg, reg_weight_vat=FLAGS.reg_weight_vat, use_ent_min=FLAGS.use_ent_min, penalize_neg_agr=FLAGS.penalize_neg_agr, use_l2_cls=FLAGS.use_l2_cls, first_iter_original=FLAGS.first_iter_original, inductive=FLAGS.inductive, seed=FLAGS.seed, eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr, num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr, lr_decay_rate_cls=FLAGS.lr_decay_rate_cls, lr_decay_steps_cls=FLAGS.lr_decay_steps_cls, lr_decay_rate_agr=FLAGS.lr_decay_rate_agr, lr_decay_steps_agr=FLAGS.lr_decay_steps_agr, load_from_checkpoint=FLAGS.load_from_checkpoint, use_graph=FLAGS.use_graph, always_agree=FLAGS.always_agree, add_negative_edges_agr=FLAGS.add_negative_edges_agr) ############################################################################ # TRAIN # ############################################################################ trainer.train(data)