def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.logging_config:
        print('Setting logging configuration: ', FLAGS.logging_config)
        config.fileConfig(FLAGS.logging_config)

    # Set random seed.
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    ############################################################################
    #                               DATA                                       #
    ############################################################################
    # Load data.
    data = load_data_planetoid(name=FLAGS.dataset_name,
                               path=FLAGS.data_path,
                               row_normalize=FLAGS.row_normalize)

    # Potentially add noisy edges. This can be used to asses the robustness of
    # GAM to noisy edges. See `Robustness` section of our paper.
    if FLAGS.target_ratio_correct:
        data = add_noisy_edges(data, FLAGS.target_ratio_correct)

    ############################################################################
    #                            PREPARE OUTPUTS                               #
    ############################################################################
    # Put together parameters to create a model name.
    model_name = FLAGS.model_cls
    model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
    model_name += '-' + FLAGS.model_agr
    model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
    model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
    model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
    model_name += (
        '-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
        (FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
         FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls))
    model_name += (('-wdecayCls_%.4f' %
                    FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '')
    model_name += (('-wdecayAgr_%.4f' %
                    FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '')
    model_name += '-LL_%s_LU_%s_UU_%s' % (str(
        FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(
            FLAGS.reg_weight_uu))
    model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
    model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
    model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
    model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
    model_name += '-VAT' if FLAGS.reg_weight_vat > 0 else ''
    model_name += 'ENT' if FLAGS.reg_weight_vat > 0 and FLAGS.use_ent_min else ''
    model_name += '-transd' if not FLAGS.inductive else ''
    model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
    model_name += '-graph' if FLAGS.use_graph else '-noGraph'
    model_name += '-rowNorm' if FLAGS.row_normalize else ''
    model_name += '-seed_' + str(FLAGS.seed)
    model_name += FLAGS.experiment_suffix
    logging.info('Model name: %s', model_name)

    # Create directories for model checkpoints, summaries, and
    # self-labeled data backup.
    summary_dir = os.path.join(FLAGS.output_dir, 'summaries',
                               FLAGS.dataset_name, model_name)
    checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints',
                                   FLAGS.dataset_name, model_name)
    data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints',
                            FLAGS.dataset_name, model_name)
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    ############################################################################
    #                            MODEL SETUP                                   #
    ############################################################################
    # Create classification model.
    model_cls = get_model_cls(model_name=FLAGS.model_cls,
                              data=data,
                              dataset_name=FLAGS.dataset_name,
                              hidden=FLAGS.hidden_cls)

    # Create agreement model.
    model_agr = get_model_agr(
        model_name=FLAGS.model_agr,
        dataset_name=FLAGS.dataset_name,
        hidden_aggreg=FLAGS.hidden_aggreg,
        aggregation_agr_inputs=FLAGS.aggregation_agr_inputs,
        hidden=FLAGS.hidden_agr)

    # Train.
    trainer = TrainerCotraining(
        model_cls=model_cls,
        model_agr=model_agr,
        max_num_iter_cotrain=FLAGS.max_num_iter_cotrain,
        min_num_iter_cls=FLAGS.min_num_iter_cls,
        max_num_iter_cls=FLAGS.max_num_iter_cls,
        num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls,
        min_num_iter_agr=FLAGS.min_num_iter_agr,
        max_num_iter_agr=FLAGS.max_num_iter_agr,
        num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr,
        num_samples_to_label=FLAGS.num_samples_to_label,
        min_confidence_new_label=FLAGS.min_confidence_new_label,
        keep_label_proportions=FLAGS.keep_label_proportions,
        num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
        optimizer=tf.train.AdamOptimizer,
        gradient_clip=FLAGS.gradient_clip,
        batch_size_agr=FLAGS.batch_size_agr,
        batch_size_cls=FLAGS.batch_size_cls,
        learning_rate_cls=FLAGS.learning_rate_cls,
        learning_rate_agr=FLAGS.learning_rate_agr,
        enable_summaries=True,
        enable_summaries_per_model=True,
        summary_dir=summary_dir,
        summary_step_cls=FLAGS.summary_step_cls,
        summary_step_agr=FLAGS.summary_step_agr,
        logging_step_cls=FLAGS.logging_step_cls,
        logging_step_agr=FLAGS.logging_step_agr,
        eval_step_cls=FLAGS.eval_step_cls,
        eval_step_agr=FLAGS.eval_step_agr,
        checkpoints_dir=checkpoints_dir,
        checkpoints_step=1,
        data_dir=data_dir,
        abs_loss_chg_tol=1e-10,
        rel_loss_chg_tol=1e-7,
        loss_chg_iter_below_tol=30,
        use_perfect_agr=FLAGS.use_perfect_agreement,
        use_perfect_cls=FLAGS.use_perfect_classifier,
        warm_start_cls=FLAGS.warm_start_cls,
        warm_start_agr=FLAGS.warm_start_agr,
        ratio_valid_agr=FLAGS.ratio_valid_agr,
        max_samples_valid_agr=FLAGS.max_samples_valid_agr,
        weight_decay_cls=FLAGS.weight_decay_cls,
        weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls,
        weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr,
        weight_decay_agr=FLAGS.weight_decay_agr,
        reg_weight_ll=FLAGS.reg_weight_ll,
        reg_weight_lu=FLAGS.reg_weight_lu,
        reg_weight_uu=FLAGS.reg_weight_uu,
        num_pairs_reg=FLAGS.num_pairs_reg,
        reg_weight_vat=FLAGS.reg_weight_vat,
        use_ent_min=FLAGS.use_ent_min,
        penalize_neg_agr=FLAGS.penalize_neg_agr,
        use_l2_cls=FLAGS.use_l2_cls,
        first_iter_original=FLAGS.first_iter_original,
        inductive=FLAGS.inductive,
        seed=FLAGS.seed,
        eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr,
        num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr,
        lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
        lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
        lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
        lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
        load_from_checkpoint=FLAGS.load_from_checkpoint,
        use_graph=FLAGS.use_graph,
        always_agree=FLAGS.always_agree,
        add_negative_edges_agr=FLAGS.add_negative_edges_agr)

    ############################################################################
    #                            TRAIN                                         #
    ############################################################################
    trainer.train(data)
Ejemplo n.º 2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.logging_config:
        print('Setting logging configuration: ', FLAGS.logging_config)
        config.fileConfig(FLAGS.logging_config)

    # Set random seed.
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    ############################################################################
    #                               DATA                                       #
    ############################################################################
    # Potentially create a folder where to save the preprocessed data.
    if not os.path.exists(FLAGS.data_output_dir):
        os.makedirs(FLAGS.data_output_dir)

    # Load and potentially preprocess data.
    if FLAGS.load_preprocessed:
        logging.info('Loading preprocessed data...')
        path = os.path.join(FLAGS.data_output_dir,
                            FLAGS.filename_preprocessed_data)
        data = Dataset.load_from_pickle(path)
    else:
        data = load_data()
        if FLAGS.save_preprocessed:
            assert FLAGS.output_dir
            path = os.path.join(FLAGS.data_output_dir,
                                FLAGS.filename_preprocessed_data)
            data.save_to_pickle(path)
            logging.info('Preprocessed data saved to %s.', path)

    ############################################################################
    #                            PREPARE OUTPUTS                               #
    ############################################################################
    # Put together parameters to create a model name.
    model_name = FLAGS.model_cls
    model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
    model_name += '-' + FLAGS.model_agr
    model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
    model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
    model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
    model_name += (
        '-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
        (FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
         FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls))
    model_name += (('-wdecayCls_%.4f' %
                    FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '')
    model_name += (('-wdecayAgr_%.4f' %
                    FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '')
    model_name += '-LL_%s_LU_%s_UU_%s' % (str(
        FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(
            FLAGS.reg_weight_uu))
    model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
    model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
    model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
    model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
    model_name += '-transd' if not FLAGS.inductive else ''
    model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
    model_name += '-seed_' + str(FLAGS.seed)
    model_name += FLAGS.experiment_suffix
    logging.info('Model name: %s', model_name)

    # Create directories for model checkpoints, summaries, and
    # self-labeled data backup.
    summary_dir = os.path.join(FLAGS.output_dir, 'summaries',
                               FLAGS.dataset_name, model_name)
    checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints',
                                   FLAGS.dataset_name, model_name)
    data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints',
                            FLAGS.dataset_name, model_name)
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    ############################################################################
    #                            MODEL SETUP                                   #
    ############################################################################
    # Select the model based on the provided FLAGS.
    model_cls = get_model_cls(model_name=FLAGS.model_cls,
                              data=data,
                              dataset_name=FLAGS.dataset_name,
                              hidden=FLAGS.hidden_cls)

    # Create agreement model.
    model_agr = get_model_agr(
        model_name=FLAGS.model_agr,
        dataset_name=FLAGS.dataset_name,
        hidden_aggreg=FLAGS.hidden_aggreg,
        aggregation_agr_inputs=FLAGS.aggregation_agr_inputs,
        hidden=FLAGS.hidden_agr)

    # Train.
    trainer = TrainerCotraining(
        model_cls=model_cls,
        model_agr=model_agr,
        max_num_iter_cotrain=FLAGS.max_num_iter_cotrain,
        min_num_iter_cls=FLAGS.min_num_iter_cls,
        max_num_iter_cls=FLAGS.max_num_iter_cls,
        num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls,
        min_num_iter_agr=FLAGS.min_num_iter_agr,
        max_num_iter_agr=FLAGS.max_num_iter_agr,
        num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr,
        num_samples_to_label=FLAGS.num_samples_to_label,
        min_confidence_new_label=FLAGS.min_confidence_new_label,
        keep_label_proportions=FLAGS.keep_label_proportions,
        num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
        optimizer=tf.train.AdamOptimizer,
        gradient_clip=FLAGS.gradient_clip,
        batch_size_agr=FLAGS.batch_size_agr,
        batch_size_cls=FLAGS.batch_size_cls,
        learning_rate_cls=FLAGS.learning_rate_cls,
        learning_rate_agr=FLAGS.learning_rate_agr,
        enable_summaries=True,
        enable_summaries_per_model=True,
        summary_dir=summary_dir,
        summary_step_cls=FLAGS.summary_step_cls,
        summary_step_agr=FLAGS.summary_step_agr,
        logging_step_cls=FLAGS.logging_step_cls,
        logging_step_agr=FLAGS.logging_step_agr,
        eval_step_cls=FLAGS.eval_step_cls,
        eval_step_agr=FLAGS.eval_step_agr,
        checkpoints_dir=checkpoints_dir,
        checkpoints_step=1,
        data_dir=data_dir,
        abs_loss_chg_tol=1e-10,
        rel_loss_chg_tol=1e-7,
        loss_chg_iter_below_tol=30,
        use_perfect_agr=FLAGS.use_perfect_agreement,
        use_perfect_cls=FLAGS.use_perfect_classifier,
        warm_start_cls=FLAGS.warm_start_cls,
        warm_start_agr=FLAGS.warm_start_agr,
        ratio_valid_agr=FLAGS.ratio_valid_agr,
        max_samples_valid_agr=FLAGS.max_samples_valid_agr,
        weight_decay_cls=FLAGS.weight_decay_cls,
        weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls,
        weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr,
        weight_decay_agr=FLAGS.weight_decay_agr,
        reg_weight_ll=FLAGS.reg_weight_ll,
        reg_weight_lu=FLAGS.reg_weight_lu,
        reg_weight_uu=FLAGS.reg_weight_uu,
        reg_weight_vat=FLAGS.reg_weight_vat,
        use_ent_min=FLAGS.use_ent_min,
        num_pairs_reg=FLAGS.num_pairs_reg,
        penalize_neg_agr=FLAGS.penalize_neg_agr,
        use_l2_cls=FLAGS.use_l2_cls,
        first_iter_original=FLAGS.first_iter_original,
        inductive=FLAGS.inductive,
        seed=FLAGS.seed,
        eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr,
        num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr,
        lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
        lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
        lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
        lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
        load_from_checkpoint=FLAGS.load_from_checkpoint)

    ############################################################################
    #                            TRAIN                                         #
    ############################################################################
    trainer.train(data)