예제 #1
0
            'num_pipeline_micro_batches': num_pipeline_steps,
            'sliced_schema': pipeline_schema,
            'sliced_axes': sliced_axes,
            'sliced_tensor_names': ['x', 'target', 'output'],
            # Define pipeline stage partition by specifying cut points.
            # 2-stage cut. It's a cut on tensor "12".
            'pipeline_cut_info_string': '12'
        },
        'allreduce_post_accumulation': True
    }
})

trainer = ORTTrainer(model, schema, adam_config, apply_loss, trainer_config)

loss_history = []
for i in range(5):
    l, p = trainer.train_step(x.to(cuda_device), y.to(cuda_device))
    loss_history.append(l)

# Valid ranks are [0, 1, 2, 3].
# [0, 2] forms the 2-stage pipeline in the 1st data parallel group.
# [1, 3] forms the 2-stage pipeline in the 2nd data parallel group.
last_pipeline_stage_ranks = [2, 3]

# The loss values computed at the last pipeline stages. Note that intermediate
# stages may not have valid loss values, so we don't check them.
expected_loss_history = [0.8660, 1.1219, 1.6610, 1.2641, 1.0162]
if rank in last_pipeline_stage_ranks:
    for result, expected in zip(loss_history, expected_loss_history):
        assert torch.allclose(result.cpu(), torch.Tensor([expected], device='cpu'), 1e-03)
model_desc = {
    "inputs": inputs_description,  # (name, shape)
    "outputs": [(model_output.name, [], True)],  # (name, shape, is_loss)
}

options = {'device': {'id': 'cpu'}}

trainer = ORTTrainer(model,
                     model_desc,
                     optim_config=SGDConfig(lr=0.01),
                     loss_fn=None,
                     options=ORTTrainerOptions(options))

start = time.time()
loss = trainer.train_step(inputs_values)
end = time.time()

updated_initializers = trainer._training_session.get_state()

for val in trainer._training_session.get_state():
    updated_initializers[val] = updated_initializers[val].flatten().tolist()

training_output = {
    'loss': loss.numpy().flatten().tolist()[0],  # get loss as a number
    'updated_initializers':
    updated_initializers  # the updated initializers as {'name': [0., 0.231 ...]}
}

sys.stdout.write(json.dumps(training_output))