Ejemplo n.º 1
0
def exec_pool(args):
    log_level = logging.DEBUG if args.debug else logging.INFO
    if args.console:
        utils.setup_console_logger(log_level)
    else:
        utils.setup_file_logger(args.data_dir, 'luby', log_level)

    id_keys = ['k', 'n', 'c', 'delta']
    id_val = [str(vars(args)[key]) for key in id_keys]
    saver = utils.Saver(args.data_dir,
                        list(zip(['type'] + id_keys, ['luby'] + id_val)))
    log = logging.getLogger('.'.join(id_val))

    k, n, arr = args.k, args.n, []
    omega = get_soliton(k, args.c, args.delta)

    def callback(cb_args):
        sim_id, num_sym = cb_args
        log.info('sim_id=%d, num_sym=%d' % (sim_id, num_sym))
        arr.append(num_sym)
        saver.add_all({'arr': arr})

    pool = Pool(processes=args.pool)
    results = [
        pool.apply_async(simulate_cw, (
            x,
            omega,
            n,
        ), callback=callback) for x in range(args.count)
    ]
    for r in results:
        r.wait()
    log.info('Finished all!')
Ejemplo n.º 2
0
def test(args):
    model = models[args.channel]
    dec_fac = getattr(model, args.decoder)
    id_keys = ['channel', 'code', 'decoder', 'codeword', 'min_wec'
               ] + dec_fac.id_keys
    id_val = [vars(args)[key] for key in id_keys]
    log = logging.getLogger('.'.join(utils.strl(id_val)))
    code = codes.get_code(args.code)
    code_n = code.get_n()
    x = code.parity_mtx[0] * 0 + args.codeword  # add 1 or 0
    min_wec = args.min_wec
    saver = utils.Saver(args.data_dir,
                        list(zip(['type'] + id_keys, ['simulation'] + id_val)))

    for param in args.params:
        log.info('Starting parameter: %f' % param)

        channel = model.Channel(param)
        decoder = dec_fac(param, code, **vars(args))
        tot, wec, wer, bec, ber = 0, 0, 0., 0, 0.
        start_time = time.time()

        def log_status():
            keys = ['tot', 'wec', 'wer', 'bec', 'ber']
            vals = [int(tot), int(wec), float(wer), int(bec), float(ber)]
            log.info(', '.join(('%s:%s' % (key.upper(), val)
                                for key, val in zip(keys, vals))))
            if hasattr(decoder, 'stats'):
                keys.append('dec'), vals.append(decoder.stats())
            saver.add(param, OrderedDict(zip(keys, vals)))

        while wec < min_wec:
            if args.codeword == -1:
                x = code.cb[np.random.choice(code.cb.shape[0], 1)[0]]
            y = channel.send(x)
            x_hat = decoder.decode(y)
            errors = (~(x == x_hat)).sum()
            wec += errors > 0
            bec += errors
            tot += 1
            wer, ber = wec / tot, bec / (tot * code_n)
            if time.time() - start_time > args.log_freq:
                start_time = time.time()
                log_status()

        log_status()
    log.info('Done!')
Ejemplo n.º 3
0
def main(args):
    log = logging.getLogger()
    errors = ['wer', 'ber']
    file_list = utils.get_data_file_list(args.data_dir)
    matches = dict(((err, []) for err in errors))
    pattern = re.compile('^' + args.prefix + '_[0-9]+$')
    log.info('regex to match: ' + pattern.pattern)

    src_list = []
    for file_name in file_list:
        data = utils.load_json(os.path.join(args.data_dir, file_name))
        if is_valid(data, args, pattern):
            log.info('found match: %s' % file_name)
            src_list.append(data['code'])
            for err in errors: matches[err].append(data[err])

    avg = {}
    for err in errors:
        ll = {}
        for inst in matches[err]:
            for point in inst:
                if point not in ll.keys(): ll[point] = []
                ll[point].append(inst[point])

        for point in ll:
            val = ll[point]
            ll[point] = sum(val) / float(len(val))

        avg[err] = ll

    id_keys = ('type', 'channel', 'prefix', 'decoder')
    id_val = ('stats', args.channel, args.prefix, args.decoder)
    saver = utils.Saver(args.data_dir, list(zip(id_keys, id_val)))
    saver.add_meta('sources', src_list)
    saver.add_all(avg)
    log.info('Done!')
Ejemplo n.º 4
0
def train(train_file,
          test_file=None,
          format='tree',
          embed_file=None,
          n_epoch=20,
          batch_size=20,
          lr=0.001,
          limit=-1,
          l2_lambda=0.0,
          grad_clip=5.0,
          encoder_input=('char', 'postag'),
          model_config=None,
          device=-1,
          save_dir=None,
          seed=None,
          cache_dir='',
          refresh_cache=False,
          bert_model=0,
          bert_dir=''):
    if seed is not None:
        utils.set_random_seed(seed, device)
    logger = logging.getLogger()
    # logger.configure(filename='log.txt', logdir=save_dir)
    assert isinstance(logger, logging.AppLogger)
    if model_config is None:
        model_config = {}
    model_config['bert_model'] = bert_model
    model_config['bert_dir'] = bert_dir

    os.makedirs(save_dir, exist_ok=True)

    read_genia = format == 'genia'
    loader = dataset.DataLoader.build(
        postag_embed_size=model_config.get('postag_embed_size', 50),
        char_embed_size=model_config.get('char_embed_size', 10),
        word_embed_file=embed_file,
        filter_coord=(not read_genia),
        refresh_cache=refresh_cache,
        format=format,
        cache_options=dict(dir=cache_dir, mkdir=True, logger=logger),
        extra_ids=(git.hash(), ))

    use_external_postags = not read_genia
    cont_embed_file_ext = _get_cont_embed_file_ext(encoder_input)
    use_cont_embed = cont_embed_file_ext is not None

    train_dataset = loader.load_with_external_resources(
        train_file,
        train=True,
        bucketing=False,
        size=None if limit < 0 else limit,
        refresh_cache=refresh_cache,
        use_external_postags=use_external_postags,
        use_contextualized_embed=use_cont_embed,
        contextualized_embed_file_ext=cont_embed_file_ext)
    logging.info('{} samples loaded for training'.format(len(train_dataset)))
    test_dataset = None
    if test_file is not None:
        test_dataset = loader.load_with_external_resources(
            test_file,
            train=False,
            bucketing=False,
            size=None if limit < 0 else limit // 10,
            refresh_cache=refresh_cache,
            use_external_postags=use_external_postags,
            use_contextualized_embed=use_cont_embed,
            contextualized_embed_file_ext=cont_embed_file_ext)
        logging.info('{} samples loaded for validation'.format(
            len(test_dataset)))

    builder = models.CoordSolverBuilder(loader,
                                        inputs=encoder_input,
                                        **model_config)
    logger.info("{}".format(builder))
    model = builder.build()
    logger.trace("Model: {}".format(model))
    if device >= 0:
        chainer.cuda.get_device_from_id(device).use()
        model.to_gpu(device)

    if bert_model == 1:
        optimizer = chainer.optimizers.AdamW(alpha=lr)
        optimizer.setup(model)
        # optimizer.add_hook(chainer.optimizer.GradientClipping(1.))
    else:
        optimizer = chainer.optimizers.AdamW(alpha=lr,
                                             beta1=0.9,
                                             beta2=0.999,
                                             eps=1e-08)
        optimizer.setup(model)
        if l2_lambda > 0.0:
            optimizer.add_hook(chainer.optimizer.WeightDecay(l2_lambda))
        if grad_clip > 0.0:
            optimizer.add_hook(chainer.optimizer.GradientClipping(grad_clip))

    def _report(y, t):
        values = {}
        model.compute_accuracy(y, t)
        for k, v in model.result.items():
            if 'loss' in k:
                values[k] = float(chainer.cuda.to_cpu(v.data))
            elif 'accuracy' in k:
                values[k] = v
        training.report(values)

    trainer = training.Trainer(optimizer, model, loss_func=model.compute_loss)
    trainer.configure(utils.training_config)
    trainer.add_listener(
        training.listeners.ProgressBar(lambda n: tqdm(total=n)), priority=200)
    trainer.add_hook(training.BATCH_END,
                     lambda data: _report(data['ys'], data['ts']))
    if test_dataset:
        parser = parsers.build_parser(loader, model)
        evaluator = eval_module.Evaluator(parser,
                                          logger=logging,
                                          report_details=False)
        trainer.add_listener(evaluator)

    if bert_model == 2:
        num_train_steps = 20000 * 5 / 20
        num_warmup_steps = 10000 / 20
        learning_rate = 2e-5
        # learning rate (eta) scheduling in Adam
        lr_decay_init = learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.add_hook(
            training.BATCH_END,
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps),
                optimizer=optimizer))
        trainer.add_hook(
            training.BATCH_END,
            extensions.WarmupShift(  # warmup
                'eta',
                0.,
                num_warmup_steps,
                learning_rate,
                optimizer=optimizer))

    if save_dir is not None:
        accessid = logging.getLogger().accessid
        date = logging.getLogger().accesstime.strftime('%Y%m%d')
        # metric = 'whole' if isinstance(model, models.Teranishi17) else 'inner'
        metric = 'exact'
        trainer.add_listener(
            utils.Saver(
                model,
                basename="{}-{}".format(date, accessid),
                context=dict(App.context, builder=builder),
                directory=save_dir,
                logger=logger,
                save_best=True,
                evaluate=(lambda _: evaluator.get_overall_score(metric))))

    trainer.fit(train_dataset, test_dataset, n_epoch, batch_size)
with tpu_graph.as_default():
    init = tf.global_variables_initializer()
    tpu_init = tf.tpu.initialize_system()
    tpu_shutdown = tf.tpu.shutdown_system()
    tpu_session.run([init, tpu_init])
    tf.keras.backend.set_session(tpu_session)
    encoder.load_weights("model/encoder_weights.h5")
    decoder.load_weights("model/decoder_weights.h5")
with local_graph.as_default():
    init = tf.global_variables_initializer()
    local_session.run(init)

# saver
var_list = get_weights() + optimizer.variables()
with tpu_graph.as_default():
    tpu_saver = utils.Saver(var_list, max_keep=100)
var_list = {v.name: v for v in saveable}
with local_graph.as_default():
    local_saver = tf.train.Saver(var_list, max_to_keep=100)

# restore
check_point_num = 0
path = checkpoint_path + "checkpoint_" + str(check_point_num) + ".pkl"
tpu_saver.restore(tpu_session, path)
path = checkpoint_path + "checkpoint_" + str(check_point_num) + ".ckpt"
local_saver.restore(path)

# lr
with tpu_graph.as_default():
    tpu_session.run(lr.assign(1e-5))