def test_dataset_iter_ms(): GlobalComm.CHECK_ENVS = False init("hccl") GlobalComm.CHECK_ENVS = True context.set_context(enable_loop_sink=False) dataset = get_dataset(32) DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
def test_dataset_iter_normal(): dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False) count = 0 for _ in range(2): for _ in dataset_helper: count += 1 dataset.reset() assert count == 6
def test_dataset_iter_ge(): init("hccl") dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) count = 0 for _ in range(2): for _ in dataset_helper: count += 1 assert count == 2
def test_dataset_iter_ge(): GlobalComm.CHECK_ENVS = False init("hccl") GlobalComm.CHECK_ENVS = True dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) count = 0 for _ in range(2): for _ in dataset_helper: count += 1 assert count == 2
def test_dataset_iter_ms_loop_sink(): init("hccl") context.set_context(enable_loop_sink=True) dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) count = 0 for _ in range(2): for inputs in dataset_helper: count += 1 assert inputs == tuple() assert count == 2
def train_process(self, epoch, train_dataset, mini_steps=None): """ Training process. The data would be passed to network directly. """ dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) for i in range(epoch): step = 0 for k, next_element in enumerate(dataset_helper): loss = self._train_forward_backward(*next_element) if (k + 1) % mini_steps == 0: step += 1 print("epoch:", i + 1, "step:", step, "loss is ", loss) self._train_optim() self._train_clear() train_dataset.reset() save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt")
def test_dataset_helper_sink_size_negative(): dataset = get_dataset(32) with pytest.raises(ValueError): DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2)
def test_dataset_helper_sink_size_float(): dataset = get_dataset(32) with pytest.raises(TypeError): DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0)
def test_dataset_helper_dataset_sink_mode_int(): dataset = get_dataset(32) with pytest.raises(TypeError): DatasetHelper(dataset, dataset_sink_mode=1)
def test_dataset_iter_ms(): init("hccl") context.set_context(enable_loop_sink=False) dataset = get_dataset(32) DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
if __name__ == "__main__": num_data, batch_size, repeat_size = 1600, 16, 1 lr, momentum = 0.005, 0.9 network = LinearNet() net_loss = nn.loss.MSELoss() net_opt = nn.Momentum(network.trainable_params(), lr, momentum) net = WithLossCell(network, net_loss) net = TrainOneStepCell(net, net_opt) ds_train = create_dataset(num_data, batch_size=batch_size, repeat_size=repeat_size) dataset_helper = DatasetHelper(ds_train, dataset_sink_mode=False, sink_size=100, epoch_num=10) # dataset_sink_mode is not supported in CPU device # dataset_helper = DatasetHelper(ds_train, dataset_sink_mode=True, sink_size=100, epoch_num=10) # net = connect_network_with_dataset(net, dataset_helper) network.set_train() print("============== Starting Training ==============") epoch = 2 # a customized training loop for step in range(epoch): for inputs in dataset_helper: output = net(*inputs) print("epoch: {0}/{1}, losses: {2}".format(step + 1, epoch,