def fetcher_loop(self, data_queue, data_buffer, pin_memory=False): while True: idx, batch = data_queue.get() t = time.time() if idx is None: return if pin_memory: batch = ([d.as_in_context(mx.cpu_pinned()) for d in batch[0]], [d.as_in_context(mx.cpu_pinned()) for d in batch[1]]) # else: # batch = ([d.as_in_context(mx.cpu()) for d in batch[0]], [d.as_in_context(mx.cpu()) for d in batch[1]]) data_buffer[idx] = batch
def _stage_data(i, data, ctx_list, pinned_data_stage): def _get_chunk(data, storage): s = storage.reshape(shape=(storage.size,)) s = s[:data.size] s = s.reshape(shape=data.shape) data.copyto(s) return s if ctx_list[0].device_type == "cpu": return data if i not in pinned_data_stage: pinned_data_stage[i] = [d.as_in_context(mx.cpu_pinned()) for d in data] return pinned_data_stage[i] storage = pinned_data_stage[i] for j in range(len(storage)): if data[j].size > storage[j].size: storage[j] = data[j].as_in_context(mx.cpu_pinned()) return [_get_chunk(d, s) for d, s in zip(data, storage)]
def is_pinned(input): return input.context == mx.cpu_pinned()
def train(hyperparameters, hosts, num_gpus, **kwargs): try: _ = mx.nd.array([1], ctx=mx.gpu(0)) ctx = [mx.gpu(i) for i in range(num_gpus)] print("using GPU") DTYPE = "float16" host_ctx = mx.cpu_pinned(0) except mx.MXNetError: ctx = [mx.cpu()] print("using CPU") DTYPE = "float32" host_ctx = mx.cpu(0) model_dir = os.environ.get("SM_CHANNEL_MODEL") if model_dir: print("using prebuild model") shutil.unpack_archive("%s/model.tar.gz" % (model_dir), model_dir) with open('%s/hyperparameters.json' % (model_dir), 'r') as fp: saved_hyperparameters = json.load(fp) net = model( depth=int(saved_hyperparameters.get("depth", 2)), width=int(saved_hyperparameters.get("width", 3)), ) try: print("trying to load float16") net.cast("float16") net.collect_params().load("%s/model-0000.params" % (model_dir), ctx) except Exception as e: print(e) print("trying to load float32") net.cast("float32") net.collect_params().load("%s/model-0000.params" % (model_dir), ctx) net.cast(DTYPE) else: print("building model from scratch") net = model( depth=int(hyperparameters.get("depth", 2)), width=int(hyperparameters.get("width", 3)), ) net.cast(DTYPE) net.collect_params().initialize(mx.init.Xavier(), ctx=ctx) net.hybridize() print(net) dice = DiceLoss() dice.cast(DTYPE) dice.hybridize() trainer = gluon.Trainer( net.collect_params_layers(2) if model_dir else net.collect_params(), 'adam', { "multi_precision": (DTYPE == 'float16'), 'learning_rate': float(hyperparameters.get("learning_rate", .001)) }) train_iter, test_iter = get_data(int(hyperparameters.get("batch_size", 8)), DTYPE, host_ctx) Loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False) best = float("inf") warm_up = int(hyperparameters.get("warm_up", 30)) patience = int(hyperparameters.get("patience", 10)) wait = 0 for e in range(hyperparameters.get("epochs", 11)): print("Epoch %s" % (e)) val_loss = 0 st = time.time() training_count = 0 testing_count = 0 training_loss = 0 for batch in train_iter: batch_size = batch.data[0].shape[0] training_count += batch_size data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx) label = gluon.utils.split_and_load( batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx) mask = gluon.utils.split_and_load( batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx) with autograd.record(): output = [net(x) for x in data] losses = [ -dice(x, y, z) for x, y, z in zip(output, label, mask) ] for loss in losses: loss.backward() trainer.step(batch_size) training_loss += sum(loss.sum().asscalar() for loss in losses) for batch in test_iter: batch_size = batch.data[0].shape[0] testing_count += batch_size data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx) label = gluon.utils.split_and_load( batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx) mask = gluon.utils.split_and_load( batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx) output = [net(x) for x in data] losses = [-dice(x, y, z) for x, y, z in zip(output, label, mask)] val_loss += sum(loss.sum().asscalar() for loss in losses) et = time.time() print("Hyperparameters: %s;" % (hyperparameters)) print("Training loss: %s;" % (-training_loss / training_count)) print("Testing loss: %s;" % (-val_loss / (testing_count))) print("Throughput=%2.2f;" % ((training_count + testing_count) / (et - st))) print("Validation Loss=%2.2f;" % val_loss) print("Best=%2.2f;" % best) if e >= warm_up: if val_loss < best: print("best model: %s;" % (-val_loss / (testing_count))) save(net, hyperparameters) best = val_loss wait = 0 else: wait += 1 if wait > patience: print("stoping early") break train_iter.reset() test_iter.reset()