def _run(FLAGS, model_cls):
    logger = logging.getLogger('Trainer_%s' % model_cls.__name__)
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler('%s.log' % model_cls.__name__)
    file_handler.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('[%(asctime)s] ## %(message)s')
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    hparams = tf.contrib.training.HParams(**COMMON_HPARAMS.values())
    hparams.set_hparam('batch_size', FLAGS.bs)
    hparams.set_hparam('n_steps', FLAGS.stp)
    hparams.set_hparam('n_dims', FLAGS.dims)
    hparams.set_hparam('n_info_dims', FLAGS.info_dims)
    hparams.set_hparam('n_att_dims', FLAGS.att_dims)
    hparams.set_hparam('max_epochs', FLAGS.epochs)
    hparams.set_hparam('checkpoint', FLAGS.ckpt)
    hparams.set_hparam('n_heads', FLAGS.heads)
    hparams.set_hparam('n_selfatt_dims', FLAGS.selfatt_dims)

    assert hparams.n_dims == hparams.n_info_dims + hparams.n_att_dims, "`n_dims` should be equal to the sum of `n_info_dims` and `n_att_dims`"
    assert hparams.n_dims == hparams.n_heads * hparams.n_selfatt_dims, "`n_dims` should be equal to the product of `n_heads` and `n_selfatt_dims`"

    name_size = 'SZ%d-STP%d' % (FLAGS.sz, FLAGS.stp)
    config_size = Config(size=FLAGS.sz, max_steps=FLAGS.stp)

    for name_std, config_std in CONFIG_STDS.iteritems():
        for name_drop, config_drop in CONFIG_DROPS.iteritems():
            for name_direction, config_direction in CONFIG_DIRECTIONS.iteritems(
            ):
                config = Config()
                config.add('base', 'base', CONFIG_BASE)
                config.add('size', name_size, config_size)
                config.add('direction', name_direction, config_direction)
                config.add('drop', name_drop, config_drop)
                config.add('std', name_std, config_std)
                gridworld = GridWorld(name=config.get_name(),
                                      **config.get_kwargs())

                for seed in GRIDWORLD_SEEDS:
                    data_dir = '%s-SEED%d' % (config.get_name(), seed)
                    gridworld.load(data_dir,
                                   seed=seed,
                                   splitting_seed=SPLITTING_SEED)

                    dataset_name = config.get_name()
                    for shuffling_seed in SHUFFLING_SEEDS:
                        dataset = Dataset(dataset_name,
                                          os.path.join(BASE_DIR, data_dir),
                                          shuffling_seed=shuffling_seed)
                        model = model_cls(dataset,
                                          hparams,
                                          gridworld,
                                          seed=MODEL_SEED)
                        Trainer(model, logger)()
示例#2
0
def main(args):
    env = GridWorld.load(args.env)

    dev = torch.device(
        "cuda" if args.use_cuda and torch.cuda.is_available() else "cpu")
    q_net = build_MLP(2, *args.hidden_dims, 4)
    target_net = build_MLP(2, *args.hidden_dims, 4)
    target_net.load_state_dict(q_net.state_dict())
    q_net.to(dev)
    target_net.to(dev)

    optim = torch.optim.SGD(q_net.parameters(), lr=args.base_lr)
    if args.lr_decay is not None:
        lr_sched = torch.optim.lr_scheduler.StepLR(optim, args.lr_step,
                                                   args.lr_decay)
    epsilon = args.base_epsilon

    memory = ReplayMemory(maxlen=args.mem_size)

    avg_cumul = None
    avg_success = None
    avg_loss = None
    AVG_R = 0.05
    stats = []
    try:
        with tqdm.trange(args.max_iter) as progress:
            for it in progress:
                trajectory, cumul, success = sample_trajectory(
                    env, lambda z: epsilon_greedy(z, q_net, epsilon),
                    args.max_t)
                memory.extend(trajectory)

                loss = 0
                for b, batch in enumerate(
                        sample_batch(memory, args.batch_size, args.batch_count,
                                     dev)):
                    loss += update_weights(q_net, target_net, optim,
                                           *batch[:-1], args.discount)
                if b > 0:
                    loss /= b

                avg_cumul = cumul if avg_cumul is None else (
                    1 - AVG_R) * avg_cumul + AVG_R * cumul
                avg_success = success if avg_success is None else (
                    1 - AVG_R) * avg_success + AVG_R * success
                avg_loss = loss if avg_loss is None else (
                    1 - AVG_R) * avg_loss + AVG_R * loss
                lr = optim.param_groups[0]["lr"]
                progress.set_postfix(cumul=avg_cumul,
                                     success=avg_success,
                                     loss=avg_loss,
                                     lr=lr,
                                     eps=epsilon)
                stats.append(
                    (it, avg_cumul, avg_success, avg_loss, lr, epsilon))

                if it % args.freeze_period == args.freeze_period - 1:
                    target_net.load_state_dict(q_net.state_dict())
                if args.lr_decay is not None:
                    lr_sched.step()
                if args.eps_decay is None:
                    epsilon = (args.base_epsilon - args.min_epsilon) * (
                        1 - it / args.max_iter) + args.min_epsilon
                elif it % args.eps_step == args.eps_step - 1:
                    epsilon = max(epsilon * args.eps_decay, args.min_epsilon)
    except KeyboardInterrupt:
        pass

    os.makedirs(args.output_dir, exist_ok=True)
    with open(os.path.join(args.output_dir, "training_args.json"), 'w') as f:
        json.dump(vars(args), f, indent=4)
    torch.save(q_net.state_dict(),
               os.path.join(args.output_dir, "trained_mlp_{}.pkl".format(it)))
    with open(os.path.join(args.output_dir, "training_stats.csv"), 'w') as f:
        for it_stat in stats:
            f.write(', '.join(str(s) for s in it_stat))
            f.write('\n')
示例#3
0
#!/usr/bin/env python3

from gridworld import GridWorld
import sys

if __name__ == "__main__":
    g = GridWorld.load(sys.argv[1])

    z = g.reset()
    done = False
    trajectory = []
    cumul = 0
    while not done:
        print(g)
        a = {
            'w': GridWorld.Direction.NORTH,
            'a': GridWorld.Direction.WEST,
            's': GridWorld.Direction.SOUTH,
            'd': GridWorld.Direction.EAST
        }.get(input("z = {} > ".format(z)))
        nxt, r, done = g.step(a)
        cumul += r
        trajectory.append((z, a, r, nxt, done))
        z = nxt
    print("Cumul = {}".format(cumul))
    print('\n'.join(str(step) for step in trajectory))