def test_convnet(ctx): hint = util.divup(64, sqrt(ctx.num_workers)) images = expr.eager( expr.ones((N_IMGS, ) + IMG_SIZE, tile_hint=(N_IMGS, N_COLORS, hint, hint))) w1 = expr.eager( expr.ones((N_FILTERS, N_COLORS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv1 = stencil.stencil(images, w1, 2) pool1 = stencil.maxpool(conv1) w2 = expr.eager( expr.ones((N_FILTERS, N_FILTERS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv2 = stencil.stencil(pool1, w2, 2) pool2 = stencil.maxpool(conv2) w3 = expr.eager( expr.ones((N_FILTERS, N_FILTERS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv3 = stencil.stencil(pool2, w3, 2) pool3 = stencil.maxpool(conv3) util.log_info(pool3.shape)
def _(): conv1 = stencil.stencil(images, w1, 2) pool1 = stencil.maxpool(conv1) conv2 = stencil.stencil(pool1, w2, 2) pool2 = stencil.maxpool(conv2) conv3 = stencil.stencil(pool2, w3, 2) pool3 = stencil.maxpool(conv3) expr.force(pool3)
def _(): conv1 = stencil.stencil(images, w1, 2) pool1 = stencil.maxpool(conv1) conv2 = stencil.stencil(pool1, w2, 2) pool2 = stencil.maxpool(conv2) conv3 = stencil.stencil(pool2, w3, 2) pool3 = stencil.maxpool(conv3) pool3.evaluate()
def test_convnet(ctx): hint = util.divup(64, sqrt(ctx.num_workers)) images = expr.eager(expr.ones((N_IMGS,) + IMG_SIZE, tile_hint=(N_IMGS, N_COLORS, hint, hint))) w1 = expr.eager(expr.ones((N_FILTERS, N_COLORS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv1 = stencil.stencil(images, w1, 2) pool1 = stencil.maxpool(conv1) w2 = expr.eager(expr.ones((N_FILTERS, N_FILTERS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv2 = stencil.stencil(pool1, w2, 2) pool2 = stencil.maxpool(conv2) w3 = expr.eager(expr.ones((N_FILTERS, N_FILTERS) + FILTER_SIZE, tile_hint=ONE_TILE)) conv3 = stencil.stencil(pool2, w3, 2) pool3 = stencil.maxpool(conv3) util.log_info(pool3.shape)