Ejemplo n.º 1
0
def test_batches(session):
    session.execute(
        'INSERT INTO "batches" VALUES ("batch1", "sku1", 100, null)')
    session.execute(
        'INSERT INTO "batches" VALUES ("batch2", "sku2", 200, "2011-04-11")')

    expected = [
        model.Batch("batch1", "sku1", 100, eta=None),
        model.Batch("batch2", "sku2", 200, eta=date(2011, 4, 11)),
    ]

    assert session.query(model.Batch).all() == expected
Ejemplo n.º 2
0
def test_retrieving_batches(session):
    session.execute(
        'INSERT INTO batches (reference, sku, _purchased_quantity, eta)'
        ' VALUES ("batch1", "sku1", 100, null)')
    session.execute(
        'INSERT INTO batches (reference, sku, _purchased_quantity, eta)'
        ' VALUES ("batch2", "sku2", 200, "2011-04-11")')
    expected = [
        model.Batch("batch1", "sku1", 100, eta=None),
        model.Batch("batch2", "sku2", 200, eta=date(2011, 4, 11)),
    ]

    assert session.query(model.Batch).all() == expected
    session.close()
    session = None
Ejemplo n.º 3
0
def test_returns_allocation():
    line = model.OrderLine("o1", "COMPLICATED-LAMP", 10)
    batch = model.Batch("b1", "COMPLICATED-LAMP", 100, eta=None)
    repo = FakeRepository([batch])

    result = services.allocate(line, repo, FakeSession())
    assert result == "b1"
Ejemplo n.º 4
0
def test_commits():
    line = model.OrderLine('o1', 'OMINOUS-MIRROR', 10)
    batch = model.Batch('b1', 'OMINOUS-MIRROR', 100, eta=None)
    repo = FakeRepository([batch])
    session = FakeSession()

    services.allocate(line, repo, session)
    assert session.committed is True
Ejemplo n.º 5
0
def test_error_for_invalid_sku():
    line = model.OrderLine("o1", "NONEXISTENTSKU", 10)
    batch = model.Batch("b1", "AREALSKU", 100, eta=None)
    repo = FakeRepository([batch])

    with pytest.raises(services.InvalidSku,
                       match="Invalid sku NONEXISTENTSKU"):
        services.allocate(line, repo, FakeSession())
Ejemplo n.º 6
0
def test_saving_allocations(session):
    batch = model.Batch("batch1", "sku1", 100, eta=None)
    line = model.OrderLine("order1", "sku1", 10)
    batch.allocate(line)
    session.add(batch)
    session.commit()
    rows = list(session.execute('SELECT orderline_id, batch_id FROM "allocations"'))
    assert rows == [(batch.id, line.id)]
Ejemplo n.º 7
0
def test_saving_batches(session):
    batch = model.Batch("batch1", "sku1", 100, eta=None)
    session.add(batch)
    session.commit()
    rows = session.execute(
        'SELECT reference, sku, _purchased_quantity, eta FROM "batches"'
    )
    assert list(rows) == [("batch1", "sku1", 100, None)]
Ejemplo n.º 8
0
def test_retrieving_batches(session):
    session.execute(
        "INSERT INTO batches (reference, sku, _purchased_quantity, manufacture_date)"
        'VALUES("batch1", "sku1",100,null)'

    )
    session.execute(
        "INSERT INTO batches(reference, sku,_purchased_quantity, manufacture_date)"
        'VALUES ("batch2", "sku2", 200, "2021-04-11")'

    )
    expected = [
        model.Batch("batch1", "sku1", 100, manufacture_date=None),
        model.Batch("batch2", "sku2", 200, manufacture_date=date(2021, 4, 4)),

    ]
    assert session.query(model.Batch).all() == expected
Ejemplo n.º 9
0
def test_repository_can_save_a_batch(session: Session):
    batch = model.Batch("batch1", "RUSTY-SOAPDISH", 100, eta=None)

    repo = repository.SqlAlchemyRepository(session)
    repo.add(batch)
    session.commit()
    rows = session.execute(
        'SELECT reference, sku, _purchased_quantity, eta FROM "batches"')
    assert list(rows) == [("batch1", "RUSTY-SOAPDISH", 100, None)]
Ejemplo n.º 10
0
def test_saving_batches(session):
    batch = model.Batch('batch1', 'sku1', 100, eta=None)
    session.add(batch)
    session.commit()
    rows = list(
        session.execute(
            'SELECT reference, sku, _purchased_quantity, eta FROM "batches"'))
    assert rows == [('batch1', 'sku1', 100, None)]
    session.close()
    session = None
def test_repository_can_save_a_batch(session):
    batch = model.Batch("batch1", "RUSTY-SOAPDISH", 100, eta=None)

    repo = repository.SqlAlchemyRepository(session)
    repo.add(batch)  # repo.add() is the method under test here.
    session.commit(
    )  # keep the .commit() outside of the repository and make it the responsibility of the caller.

    rows = list(
        session.execute(
            'SELECT reference, sku, _purchased_quantity, eta FROM "batches"'  # use the raw SQL to verify that the right data has been saved.
        ))
    assert rows == [("batch1", "RUSTY-SOAPDISH", 100, None)]
Ejemplo n.º 12
0
def test_repository_can_retrieve_a_batch_with_allocations(session):
    orderline_id = insert_order_line(session)
    batch1_id = insert_batch(session, "batch1")
    insert_batch(session, "batch2")
    insert_allocation(session, orderline_id, batch1_id)

    repo = repository.SqlAlchemyRepository(session)
    retrieved = repo.get("batch1")

    expected = model.Batch("batch1", "GENERIC-SOFA", 100, eta=None)
    assert retrieved == expected
    assert retrieved.sku == expected.sku
    assert retrieved._purchased_quantity == expected._purchased_quantity
    assert retrieved._allocations == {
        model.OrderLine("order1", "GENERIC-SOFA", 12),
    }
def test_repository_can_retrieve_a_batch_with_allocations(session):
    orderline_id = insert_order_line(session)
    batch1_id = insert_batch(session, "batch1")
    insert_batch(session, "batch2")
    insert_allocation(session, orderline_id, batch1_id)

    repo = repository.SqlAlchemyRepository(session)
    retrieved = repo.get("batch1")

    expected = model.Batch("batch1", "GENERIC-SOFA", 100, eta=None)
    # checks that the types match, and that the reference is the same
    assert retrieved == expected  # Batch.__eq__ only compares reference
    # explicitly check on its major attributes, including ._allocations, which is a Python set of OrderLine value objects.
    assert retrieved.sku == expected.sku
    assert retrieved._purchased_quantity == expected._purchased_quantity
    assert retrieved._allocations == {
        model.OrderLine("order1", "GENERIC-SOFA", 12),
    }
Ejemplo n.º 14
0
def main(_):
    # total_x, total_y, x_dim, y_dim
    ckpt_path = os.path.join(FLAGS.ckpt_dir, FLAGS.name)
    (train_x, train_y), (test_x, test_y) = preprocess.create_dataset()

    batch = model.Batch(train_x, train_y, FLAGS.epoch)

    print('start session')
    with tf.Session() as sess:
        predicator = model.Predicator(matrix_shape=[9, 8],
                                      num_time=7,
                                      out_time=7,
                                      kernels=[[5, 5], [5, 5], [5, 5], [5, 5],
                                               [5, 5]],
                                      depths=[256, 128, 128, 64, 32],
                                      learning_rate=FLAGS.learning_rate,
                                      beta1=FLAGS.beta1)

        train_path = os.path.join(FLAGS.summary_dir, FLAGS.name, 'train')
        test_path = os.path.join(FLAGS.summary_dir, FLAGS.name, 'test')

        train_writer = tf.summary.FileWriter(train_path, sess.graph)
        test_writer = tf.summary.FileWriter(test_path, sess.graph)

        print('start training')
        sess.run(tf.global_variables_initializer())
        for i in range(FLAGS.epoch):
            for n in range(batch.iter_per_epoch):
                batch_x, batch_y = batch()
                predicator.train(sess, batch_x, batch_y)

            print(i, 'th epoch')
            summary = predicator.inference(sess, predicator.summary, batch_x,
                                           batch_y)
            train_writer.add_summary(summary, global_step=i)

            summary = predicator.inference(sess, predicator.summary, test_x,
                                           test_y)
            test_writer.add_summary(summary, global_step=i)

            if (i + 1) % FLAGS.ckpt_interval == 0:
                predicator.dump(sess, ckpt_path, i)
def main(_):
    # total_x, total_y, x_dim, y_dim
    ckpt_path = os.path.join(FLAGS.ckpt_dir, FLAGS.name)

    batch = model.Batch(total_x, total_y, 128)

    with tf.Session() as sess:
        basic_model = model.BasicModel(x_dim, y_dim, FLAGS.learning_rate, FLAGS.beta1)
        writer = tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, FLAGS.name), sess.graph)

        sess.run(tf.global_variables_initializer())
        for i in range(FLAGS.epoch):
            for n in range(batch.iter_per_epoch):
                batch_x, batch_y = batch()
                basic_model.train(sess, batch_x, batch_y)

                summary = basic_model.inference(sess, basic_model.summary, batch_x, batch_y)
                writer.add_summary(summary)

            if (i + 1) % FLAGS.ckpt_interval == 0:
                basic_model.dump(sess, ckpt_path)