コード例 #1
0
ファイル: classifier.py プロジェクト: yyht/daga
def main():
    """Main workflow."""
    args = utils.build_predict_args(argparse.ArgumentParser())
    log_handler = add_file_handler(logger,
                                   Path(args.output_file).with_suffix(".log"))
    utils.init_random(args.seed)

    with torch.cuda.device(args.gpuid):
        flair.device = torch.device(f"cuda:{args.gpuid}")
        predict(args)

    logger.removeHandler(log_handler)
コード例 #2
0
ファイル: eval_wrn.py プロジェクト: yyht/HDGE
def sample_p_0(replay_buffer, bs, y=None):
    if len(replay_buffer) == 0:
        return init_random(bs), []
    buffer_size = len(replay_buffer) if y is None else len(replay_buffer) // args.n_classes
    inds = torch.randint(0, buffer_size, (bs,))
    # if cond, convert inds to class conditional inds
    if y is not None:
        inds = y.cpu() * buffer_size + inds
        assert not args.uncond, "Can't drawn conditional samples without giving me y"
    buffer_samples = replay_buffer[inds]
    random_samples = init_random(bs)
    choose_random = (torch.rand(bs) < args.reinit_freq).float()[:, None, None, None]
    samples = choose_random * random_samples + (1 - choose_random) * buffer_samples
    return samples.to(device), inds
コード例 #3
0
ファイル: train_wrn.py プロジェクト: yyht/HDGE
def get_model_and_buffer(args, sample_q):
    if args.pxycontrast > 0 or args.pxyjointcontrast > 0 or args.pxcontrast > 0:
        model_cls = HYM
        f = model_cls(args.depth,
                      args.width,
                      args.norm,
                      dropout_rate=args.dropout_rate,
                      n_classes=args.n_classes,
                      K=args.contrast_k,
                      T=args.contrast_t)
    else:
        model_cls = F if args.uncond else CCF
        f = model_cls(args.depth,
                      args.width,
                      args.norm,
                      dropout_rate=args.dropout_rate,
                      n_classes=args.n_classes)
    if not args.uncond:
        assert args.buffer_size % args.n_classes == 0, "Buffer size must be divisible by args.n_classes"
    if args.load_path is None:
        replay_buffer = init_random(args, args.buffer_size)
    else:
        print(f"loading model from {args.load_path}")
        ckpt_dict = torch.load(args.load_path)
        f.load_state_dict(ckpt_dict["model_state_dict"])
        replay_buffer = ckpt_dict["replay_buffer"]

    f = f.to(device)
    return f, replay_buffer
コード例 #4
0
ファイル: train_classifier.py プロジェクト: yyht/daga
def main():
    """Main workflow."""
    args = utils.build_train_args(argparse.ArgumentParser())

    log_handler = add_file_handler(logger,
                                   Path(args.model_dir) / "training.log")

    logger.info(f"Args: {pformat(vars(args))}")

    utils.init_random(args.seed)

    with torch.cuda.device(args.gpuid):
        flair.device = torch.device(f"cuda:{args.gpuid}")
        train(args)

    logger.removeHandler(log_handler)
コード例 #5
0
ファイル: train.py プロジェクト: yyht/daga
def main():
    """Main workflow"""
    args = utils.build_args(argparse.ArgumentParser())

    utils.init_logger(args.model_file)

    assert torch.cuda.is_available()
    torch.cuda.set_device(args.gpuid)

    utils.init_random(args.seed)

    utils.set_params(args)
    logger.info("Config:\n%s", pformat(vars(args)))

    fields = utils.build_fields()
    logger.info("Fields: %s", fields.keys())

    logger.info("Load %s", args.train_file)
    train_data = LMDataset(fields, args.train_file, args.sent_length_trunc)
    logger.info("Training sentences: %d", len(train_data))
    logger.info("Load %s", args.valid_file)
    val_data = LMDataset(fields, args.valid_file, args.sent_length_trunc)
    logger.info("Validation sentences: %d", len(val_data))

    fields["sent"].build_vocab(train_data)

    train_iter = utils.build_dataset_iter(train_data, args)
    val_iter = utils.build_dataset_iter(val_data, args, train=False)

    if args.resume and os.path.isfile(args.checkpoint_file):
        logger.info("Resume training")
        logger.info("Load checkpoint %s", args.checkpoint_file)
        checkpoint = torch.load(args.checkpoint_file,
                                map_location=lambda storage, loc: storage)
        es_stats = checkpoint["es_stats"]
        args = utils.set_args(args, checkpoint)
    else:
        checkpoint = None
        es_stats = ESStatistics(args)

    model = utils.build_model(fields, args, checkpoint)
    logger.info("Model:\n%s", model)

    optimizer = utils.build_optimizer(model, args, checkpoint)

    try_train_val(fields, model, optimizer, train_iter, val_iter, es_stats,
                  args)
コード例 #6
0
def main():
    """Main workflow."""
    args = utils.build_train_args(argparse.ArgumentParser())
    if not args.data_columns:
        raise RuntimeError("Specify data_column (e.g., text ner)")

    log_handler = add_file_handler(logger, Path(args.model_dir) / "training.log")

    logger.info(f"Args: {pformat(vars(args))}")

    utils.init_random(args.seed)

    with torch.cuda.device(args.gpuid):
        flair.device = torch.device(f"cuda:{args.gpuid}")
        train(args)

    logger.removeHandler(log_handler)
コード例 #7
0
ファイル: generate.py プロジェクト: yyht/daga
def main():
    """Main workflow"""
    args = utils.build_gen_args(argparse.ArgumentParser())

    utils.init_logger(args.out_file)
    logger.info("Config:\n%s", pformat(vars(args)))

    assert torch.cuda.is_available()
    torch.cuda.set_device(args.gpuid)

    utils.init_random(args.seed)

    logger.info("Load parameters from '%s'", args.model_file)
    params = torch.load(args.model_file, map_location=lambda storage, loc: storage)

    utils.set_params(params["args"])

    fields = utils.load_fields_from_vocab(params["vocab"])
    logger.info("Fields: %s", fields.keys())

    model = utils.build_test_model(fields, params)

    sent_idx = [i for i in range(args.num_sentences)]
    num_batches = math.ceil(float(args.num_sentences) / args.batch_size)
    samples = []
    with torch.no_grad():
        for i in range(num_batches):
            running_batch_size = len(
                sent_idx[i * args.batch_size : (i + 1) * args.batch_size]
            )
            samples.append(
                model.generate(
                    running_batch_size, args.max_sent_length, args.temperature
                )
            )
    samples = torch.cat(samples, 0)
    save(samples, fields, args.out_file)
コード例 #8
0
            pickle.dump(x, fin)
        return 0


# Should be updated with the new way evaluation is performed or deleted
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Entry point of the application.")

    parser.add_argument("--learning-rate", type=float, required=False, default=0.1,
                        help="Learning rate to pass to the optimizer")
    parser.add_argument("--embeddings-path", type=str, required=True, help="Path to the saved embeddings file")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the model file")
    parser.add_argument("--train-file", type=str, required=True,
                        help="Path to the file containing the data to train on")
    parser.add_argument("--dev-file", type=str, required=True,
                        help="Path to the file containing the data to evaluate on")
    parser.add_argument("--batch-size", type=int, required=False, default=64,
                        help="Number of examples to consider per batch")
    parser.add_argument("--num-epochs", type=int, required=False, default=10)
    parser.add_argument("--hidden-size", type=int, required=False, default=500)
    parser.add_argument("--weight-decay", type=float, required=False, default=0.001)

    result = parser.parse_args()
    init_random(1)
    device = torch.device('cpu')
    model = LSTMMultiply(300, result.hidden_size)
    model.load_state_dict(torch.load(result.model_path, map_location=device))
    sg_embeddings = torch.load(result.embeddings_path, map_location=device)
    evaluation = Evaluation(result.train_file, result.dev_file, model, sg_embeddings, 'vocab/vocab_250.json')

    print(evaluation.evaluate())
コード例 #9
0
    model_config = json.load(open(config['model']))
    config = {**config, **model_config}
    result = vars(result)
    print(result)
    print(config)
    # Override based on cli arguments
    for update_param in ['learning_rate', 'batch_size', 'num_epochs', 'number_of_negative_examples', 'save_path', 'which_cuda', 'weight_decay', 'early_stopping', 'random_seed', 'pretrained_model', 'heldout_data', 'flip_right_sentence']:
        if result[update_param] is not None:
            config[update_param] = result[update_param]

    if result['hidden_size'] is not None:
        config['model']['attributes']['hidden_size'] = result['hidden_size']

    if result['model_type'] is not None:
        config['model']['name'] = result['model_type']

    if result['num_models'] is not None:
        config['model']['attributes']['num_models'] = result['num_models']
    print(config)
    print(f"Init random seed with {config['random_seed']}")
    init_random(seed=config['random_seed'])
    if result['pretrained_model'] is not None:
        print(f"Loading pretrained model from {result['pretrained_model']}")
    train_obj = MWETrain(config)
    if result['only_eval']:
        print("Only evaluate the pretrained model")
        train_obj.eval()
    else:
        print("Start training")
        train_obj.train()
コード例 #10
0
ファイル: test.py プロジェクト: yyht/daga
def main():
    """Main workflow"""
    args = utils.build_test_args(argparse.ArgumentParser())

    suff = ".test"
    if args.report_iw_nll:
        if (
            args.num_iw_samples > args.iw_batch_size
            and args.num_iw_samples % args.iw_batch_size != 0
        ):
            raise RuntimeError("Expected num_iw_samples divisible by iw_batch_size")
        suff = ".test.iw" + str(args.num_iw_samples)

    utils.init_logger(args.model_file + suff)
    logger.info("Config:\n%s", pformat(vars(args)))

    assert torch.cuda.is_available()
    torch.cuda.set_device(args.gpuid)

    utils.init_random(args.seed)

    logger.info("Load parameters from '%s'", args.model_file)
    params = torch.load(args.model_file, map_location=lambda storage, loc: storage)

    utils.set_params(params["args"])

    fields = utils.load_fields_from_vocab(params["vocab"])
    logger.info("Fields: %s", fields.keys())

    model = utils.build_test_model(fields, params)
    logger.info("Model:\n%s", model)

    logger.info("Load %s", args.test_file)
    test_data = LMDataset(fields, args.test_file, args.sent_length_trunc)
    logger.info("Test sentences: %d", len(test_data))

    test_iter = utils.OrderedIterator(
        dataset=test_data,
        batch_size=args.batch_size,
        device=params["args"].device,
        train=False,
        shuffle=False,
        repeat=False,
        sort=False,
        sort_within_batch=True,
    )

    if model.encoder is None:
        args.report_iw_nll = False
        logger.info("Force report_iw_nll to False")

    start_time = time.time()
    logger.info("Start testing")
    if args.report_iw_nll:
        if args.num_iw_samples <= args.iw_batch_size:
            n_iw_iter = 1
        else:
            n_iw_iter = args.num_iw_samples // args.iw_batch_size
            args.num_iw_samples = args.iw_batch_size

        test_stats = report_iw_nll(model, test_iter, n_iw_iter, args.num_iw_samples)
        logger.info(
            "Results: test nll %.2f | test ppl %.2f", test_stats.nll(), test_stats.ppl()
        )
    else:
        test_stats = validate(model, test_iter)
        logger.info(
            "Results: test nll %.2f | test kl %.2f | test ppl %.2f",
            test_stats.nll(),
            test_stats.kl(),
            test_stats.ppl(),
        )

    logger.info("End of testing: time %.1f min", (time.time() - start_time) / 60)