def main(): from mindspore import nn # dataset if os.path.exists(args.data_path): dataset = ems.create_mnist_dataset(args.data_path, is_train=(not parser.do_eval), batch_size=args.batch_size) else: dataset = ems.create_mocker_dataset() ems.print_ds_info(dataset) # network network = ems.LeNet5() # loss & opt loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") optimizer = nn.Momentum(network.trainable_params(), 0.01, 0.9) # train ems.Trainer('lenet', parser) \ .set_dataset(dataset) \ .set_network(network) \ .set_loss_fn(loss_fn) \ .set_optimizer(optimizer) \ .set_callbacks([ems.TimeMonitor()]) \ .run()
def test_03_ds_print_func(): ds = create_dataset() print(">>> print_ds_info:") print_ds_info(ds) print(">>> print_ds_performance:") print_ds_performance(ds) print(">>> print_ds_data:") print_ds_data(ds)
self.samples.append([ np.random.rand(3, 4, 5).astype(np.float32), np.random.randint(10, size=()).astype(np.int32) ]) def __len__(self): return len(self.samples) # define dataset ds = de.GeneratorDataset(source=BaseDataset(), column_names=['image', 'label'], num_shards=num_shards, shard_id=shard_id) # map ops ds = ds.map(input_columns=["image"], operations=lambda img: img, num_parallel_workers=8) # batch & repeat ds = ds.batch(batch_size=batch_size, drop_remainder=False) ds = ds.repeat(count=1) return ds if __name__ == '__main__': dataset = create_dataset() ems.print_ds_info(dataset) ems.print_ds_data(dataset)