コード例 #1
0
    def fetcher_loop(self, data_queue, data_buffer, pin_memory=False):
        while True:
            idx, batch = data_queue.get()
            t = time.time()
            if idx is None:
                return
            if pin_memory:
                batch = ([d.as_in_context(mx.cpu_pinned()) for d in batch[0]],
                         [d.as_in_context(mx.cpu_pinned()) for d in batch[1]])
#            else:
#                batch = ([d.as_in_context(mx.cpu()) for d in batch[0]], [d.as_in_context(mx.cpu()) for d in batch[1]])
            data_buffer[idx] = batch
コード例 #2
0
ファイル: train_mask_rcnn.py プロジェクト: ygest/gluon-cv
def _stage_data(i, data, ctx_list, pinned_data_stage):
    def _get_chunk(data, storage):
        s = storage.reshape(shape=(storage.size,))
        s = s[:data.size]
        s = s.reshape(shape=data.shape)
        data.copyto(s)
        return s

    if ctx_list[0].device_type == "cpu":
        return data
    if i not in pinned_data_stage:
        pinned_data_stage[i] = [d.as_in_context(mx.cpu_pinned()) for d in data]
        return pinned_data_stage[i]

    storage = pinned_data_stage[i]

    for j in range(len(storage)):
        if data[j].size > storage[j].size:
            storage[j] = data[j].as_in_context(mx.cpu_pinned())

    return [_get_chunk(d, s) for d, s in zip(data, storage)]
コード例 #3
0
def is_pinned(input):
    return input.context == mx.cpu_pinned()
コード例 #4
0
def train(hyperparameters, hosts, num_gpus, **kwargs):
    try:
        _ = mx.nd.array([1], ctx=mx.gpu(0))
        ctx = [mx.gpu(i) for i in range(num_gpus)]
        print("using GPU")
        DTYPE = "float16"
        host_ctx = mx.cpu_pinned(0)
    except mx.MXNetError:
        ctx = [mx.cpu()]
        print("using CPU")
        DTYPE = "float32"
        host_ctx = mx.cpu(0)

    model_dir = os.environ.get("SM_CHANNEL_MODEL")
    if model_dir:
        print("using prebuild model")
        shutil.unpack_archive("%s/model.tar.gz" % (model_dir), model_dir)
        with open('%s/hyperparameters.json' % (model_dir), 'r') as fp:
            saved_hyperparameters = json.load(fp)

        net = model(
            depth=int(saved_hyperparameters.get("depth", 2)),
            width=int(saved_hyperparameters.get("width", 3)),
        )
        try:
            print("trying to load float16")
            net.cast("float16")
            net.collect_params().load("%s/model-0000.params" % (model_dir),
                                      ctx)
        except Exception as e:
            print(e)
            print("trying to load float32")
            net.cast("float32")
            net.collect_params().load("%s/model-0000.params" % (model_dir),
                                      ctx)
        net.cast(DTYPE)
    else:
        print("building model from scratch")
        net = model(
            depth=int(hyperparameters.get("depth", 2)),
            width=int(hyperparameters.get("width", 3)),
        )
        net.cast(DTYPE)
    net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
    net.hybridize()
    print(net)

    dice = DiceLoss()
    dice.cast(DTYPE)
    dice.hybridize()

    trainer = gluon.Trainer(
        net.collect_params_layers(2) if model_dir else net.collect_params(),
        'adam', {
            "multi_precision": (DTYPE == 'float16'),
            'learning_rate': float(hyperparameters.get("learning_rate", .001))
        })
    train_iter, test_iter = get_data(int(hyperparameters.get("batch_size", 8)),
                                     DTYPE, host_ctx)

    Loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)

    best = float("inf")
    warm_up = int(hyperparameters.get("warm_up", 30))
    patience = int(hyperparameters.get("patience", 10))
    wait = 0

    for e in range(hyperparameters.get("epochs", 11)):
        print("Epoch %s" % (e))
        val_loss = 0
        st = time.time()
        training_count = 0
        testing_count = 0
        training_loss = 0

        for batch in train_iter:
            batch_size = batch.data[0].shape[0]
            training_count += batch_size
            data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx)
            label = gluon.utils.split_and_load(
                batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx)
            mask = gluon.utils.split_and_load(
                batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx)

            with autograd.record():
                output = [net(x) for x in data]
                losses = [
                    -dice(x, y, z) for x, y, z in zip(output, label, mask)
                ]
            for loss in losses:
                loss.backward()
            trainer.step(batch_size)
            training_loss += sum(loss.sum().asscalar() for loss in losses)

        for batch in test_iter:
            batch_size = batch.data[0].shape[0]
            testing_count += batch_size

            data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx)
            label = gluon.utils.split_and_load(
                batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx)
            mask = gluon.utils.split_and_load(
                batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx)

            output = [net(x) for x in data]
            losses = [-dice(x, y, z) for x, y, z in zip(output, label, mask)]

            val_loss += sum(loss.sum().asscalar() for loss in losses)

        et = time.time()
        print("Hyperparameters: %s;" % (hyperparameters))
        print("Training loss: %s;" % (-training_loss / training_count))
        print("Testing loss: %s;" % (-val_loss / (testing_count)))
        print("Throughput=%2.2f;" % ((training_count + testing_count) /
                                     (et - st)))
        print("Validation Loss=%2.2f;" % val_loss)
        print("Best=%2.2f;" % best)

        if e >= warm_up:
            if val_loss < best:
                print("best model: %s;" % (-val_loss / (testing_count)))
                save(net, hyperparameters)
                best = val_loss
                wait = 0
            else:
                wait += 1
        if wait > patience:
            print("stoping early")
            break
        train_iter.reset()
        test_iter.reset()