예제 #1
0
def main():
    args = get_args()
    args.epoch = 2
    args.batch_size = 32
    datasets, ontology, vocab, E = load_dataset()
    ptrs, seed_ptrs, num_turns = datasets["dev"].get_turn_ptrs(
        1, 10, sample_mode="singlepass")
    for model in [GLAD(args, ontology, vocab), GCE(args, ontology, vocab)]:
        if model.optimizer is None:
            model.set_optimizer()
        iteration = 0
        for epoch in range(args.epochs):
            logging.info('starting epoch {}'.format(epoch))
            for batch, batch_labels in datasets["dev"].batch(
                    batch_size=args.batch_size, ptrs=ptrs, shuffle=True):
                iteration += 1
                model.zero_grad()
                loss, scores = model.forward(batch,
                                             batch_labels,
                                             training=True)
                loss.backward()
                model.optimizer.step()
                if iteration > 2:
                    iteration = 0
                    break
        dev_out = model.run_eval(datasets["dev"], args)
예제 #2
0
def main(cmd=None, stdout=True):
    """Run finetuning experiment for fixed seed."""

    # Initialize system
    args = get_args(cmd)
    assert torch.cuda.is_available()
    torch.cuda.set_device(args.device)

    # Initialize logging
    model_id = ("Finetune {}, seed size {}, epochs {}, labels {}, "
                "batch size {}, lr {}").format(args.model, args.seed_size,
                                               args.epochs, args.label_budget,
                                               args.batch_size, args.lr)
    logging.basicConfig(filename="{}/{}.txt".format(args.dout, model_id),
                        format='%(asctime)s %(levelname)-8s %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO)
    if stdout:
        logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logger = Experiment(comet_ml_key, project_name="ActiveDialogue")
    logger.set_name(model_id)
    logger.log_parameters(vars(args))

    # Select model and environment
    if args.model == "glad":
        model_arch = GLAD
    elif args.model == "gce":
        model_arch = GCE
    env = DSTEnv(load_dataset, model_arch, args)

    # Load seed if need-be
    if not env.load('seed'):
        raise ValueError("No loaded seed.")

    # Initialize evaluation
    best_metrics = env.metrics(True)
    for k, v in best_metrics.items():
        logger.log_metric(k, v, step=0)
    logging.info("Initial metrics: %s", best_metrics)

    # Finetune
    env.label_all()
    for epoch in range(1, args.epochs + 1):
        logging.info('Starting fit epoch %d.', epoch)
        env.fit()
        metrics = env.metrics(True)
        logging.info("Epoch metrics: %s", metrics)
        for k, v in metrics.items():
            logger.log_metric(k, v, step=epoch)
        if best_metrics is None or metrics[args.stop] > best_metrics[
                args.stop]:
            logging.info("Saving best!")
            best_metrics = metrics
예제 #3
0
def main(cmd=None, stdout=True):
    args = get_args(cmd, stdout)

    model_id = "seed_{}_strat_{}_noise_fn_{}_noise_fp_{}_num_passes_{}_seed_size_{}_model_{}_batch_size_{}_gamma_{}_label_budget_{}_epochs_{}".format(
        args.seed, args.strategy, args.noise_fn, args.noise_fp, args.num_passes, args.seed_size, args.model, args.batch_size, args.gamma, args.label_budget, args.epochs)

    logging.basicConfig(
        filename="{}/{}.txt".format(args.dout, model_id),
        format='%(asctime)s %(levelname)-8s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    logger = Experiment(comet_ml_key, project_name="ActiveDialogue")
    logger.log_parameters(vars(args))

    if args.model == "glad":
        model_arch = GLAD
    elif args.model == "gce":
        model_arch = GCE

    env = PartialEnv(load_dataset, model_arch, args)
    if args.seed_size:
        with logger.train():
            if not env.load('seed'):
                logging.info("No loaded seed. Training now.")
                env.seed_fit(args.seed_epochs, prefix="seed")
                logging.info("Seed completed.")
            else:
                logging.info("Loaded seed.")
                if args.force_seed:
                    logging.info("Training seed regardless.")
                    env.seed_fit(args.seed_epochs, prefix="seed")
        env.load('seed')

    use_strategy = False
    if args.strategy == "entropy":
        use_strategy = True
        strategy = partial_entropy
    elif args.strategy == "bald":
        use_strategy = True
        strategy = partial_bald

    if use_strategy:
        if args.threshold_strategy == "fixed":
            strategy = FixedThresholdStrategy(strategy, args, True)
        elif args.threshold_strategy == "variable":
            strategy = VariableThresholdStrategy(strategy, args, True)
        elif args.threshold_strategy == "randomvariable":
            strategy = StochasticVariableThresholdStrategy(
                strategy, args, True)

    ended = False
    i = 0

    initial_metrics = env.metrics(True)
    logger.log_current_epoch(i)
    logging.info("Initial metrics: {}".format(initial_metrics))
    for k, v in initial_metrics.items():
        logger.log_metric(k, v)

    with logger.train():
        while not ended:
            i += 1

            # Observe environment state
            logger.log_current_epoch(i)

            if env.can_label:
                # Obtain label request from strategy
                obs, preds = env.observe(20 if args.strategy ==
                                         "bald" else 1)
                if args.strategy != "bald":
                    preds = preds[0]
                if args.strategy == "aggressive":
                    label_request = aggressive(preds)
                elif args.strategy == "random":
                    label_request = random(preds)
                elif args.strategy == "passive":
                    label_request = passive(preds)
                elif use_strategy:
                    label_request = strategy.observe(preds)
                else:
                    raise ValueError()

                # Label solicitation
                labeled = env.label(label_request)
                if use_strategy:
                    strategy.update(
                        sum([
                            np.sum(s.flatten())
                            for s in label_request.values()
                        ]),
                        sum([
                            np.sum(np.ones_like(s).flatten())
                            for s in label_request.values()
                        ]))
            else:
                break

            # Environment stepping
            ended = env.step()
            # Fit every al_batch of items
            best = env.fit(prefix=model_id, reset_model=True)
            for k, v in best.items():
                logger.log_metric(k, v)
            env.load(prefix=model_id)

    # Final fit
    final_metrics = env.fit(epochs=args.final_epochs,
                            prefix="final_fit_" + model_id,
                            reset_model=True)
    for k, v in final_metrics.items():
        logger.log_metric("Final " + k, v)
        logging.info("Final " + k + ": " + str(v))
    logging.info("Run finished.")
예제 #4
0
def main():
    args = get_args()

    args.seed_size = 1
    args.label_budget = 2
    args.num_passes = 1
    args.al_batch = 2
    args.fit_items = 3
    args.batch_size = 2
    args.comp_batch_size = 2
    args.eval_period = 1
    args.recency_bias = 0
    args.seed = 911
    args.seed_epochs = 3
    args.epochs = 1

    env = PartialEnv(load_dataset, GLAD, args)
    strategy = FixedThresholdStrategy(partial_entropy, args, True)
    logging.info("Seed indices")
    logging.info(env._support_ptrs)
    logging.info("Stream indices")
    logging.info(env._ptrs)
    logging.info("\n")
    ended = False
    i = 0
    while not ended:
        i += 1
        logging.info("Environment observation now.")
        raw_obs, preds = env.observe(1)
        pred = preds[0]
        logging.info("Current idx", env._current_idx)
        logging.info("Current ptrs", env.current_ptrs)
        logging.info("Raw observation:")
        logging.info([d.to_dict() for d in raw_obs])
        logging.info("pred:")
        logging.info(pred)
        true_labels = env.leak_labels()
        logging.info("True labels:")
        logging.info(true_labels)

        logging.info("\n")
        requested_label = strategy.observe(pred)
        logging.info("Requested label: ", requested_label)
        logging.info("Environment label request now.")

        if env.can_label:
            label_success = env.label(requested_label)
            logging.info("Label success: ", label_success)
            logging.info("Support ptrs: ", env._support_ptrs)

        logging.info("\n")
        logging.info("Environment stepping now.")
        ended = env.step()
        logging.info("Ended: ", ended)

        logging.info("\n")
        logging.info("Fitting environment now.")
        env.fit()
        logging.info("Reporting metrics now.")
        for k, v in env.metrics(i % args.eval_period == 0).items():
            logging.info("\t{}: {}".format(k, v))

        logging.info("\n")
        logging.info("\n")
        logging.info("\n")
예제 #5
0
def main():
    args = get_args()
    logger = Experiment(comet_ml_key, project_name="ActiveDialogue")
    logger.log_parameters(vars(args))

    if args.model == "glad":
        model_arch = GLAD
    elif args.model == "gce":
        model_arch = GCE

    env = BagEnv(load_dataset, model_arch, args, logger)
    if args.seed_size:
        with logger.train():
            if not env.load_seed():
                logging.debug("No loaded seed. Training now.")
                env.seed_fit(args.seed_epochs, prefix="seed")
                logging.debug("Seed completed.")
            else:
                logging.debug("Loaded seed.")
                if args.force_seed:
                    logging.debug("Training seed regardless.")
                    env.seed_fit(args.seed_epochs, prefix="seed")
        env.load_seed()
        logging.debug("Current seed metrics: {}".format(env.metrics(True)))

    use_strategy = False
    if args.strategy == "lc":
        use_strategy = True
        strategy = lc_singlet
    elif args.strategy == "bald":
        use_strategy = True
        strategy = bald_singlet

    if use_strategy:
        if args.threshold_strategy == "fixed":
            strategy = FixedThresholdStrategy(strategy, args)
        elif args.threshold_strategy == "variable":
            strategy = VariableThresholdStrategy(strategy, args)
        elif args.threshold_strategy == "randomvariable":
            strategy = StochasticVariableThresholdStrategy(strategy, args)

    ended = False
    i = 0
    while not ended:
        i += 1

        # Observe environment state
        logger.log_current_epoch(i)

        for j in range(args.label_timeout):
            if env.can_label:
                # Obtain label request from strategy
                obs, preds = env.observe()
                if args.strategy == "epsiloncheat":
                    label_request = epsilon_cheat(preds, env.leak_labels())
                elif args.strategy == "randomsinglets":
                    label_request = random_singlets(preds)
                elif args.strategy == "passive":
                    label_request = passive(preds)
                elif use_strategy:
                    label_request = strategy.observe(preds)
                else:
                    raise ValueError()

                # Label solicitation
                labeled = env.label(label_request)
                if use_strategy:
                    strategy.update(
                        np.sum(label_request.flatten()),
                        np.sum(np.ones_like(label_request.flatten())))

        # Environment stepping
        ended = env.step()
        # Fit every al_batch of items
        env.fit()

    logging.debug("Final fit: ", env.seed_fit(100, "final_fit", True))