예제 #1
0
    def training_fn():
        import time
        import torch
        import horovod.torch as hvd
        from horovod.ray import ray_logger

        hvd.init()

        model = torch.nn.Sequential(torch.nn.Linear(2, 2))
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        ray_logger.log({"started": True, "pid": os.getpid()})

        @hvd.elastic.run
        def train(state):
            for state.epoch in range(state.epoch, iterations):
                ray_logger.log({"training": True, "pid": os.getpid()})
                time.sleep(0.1)
                state.commit()  # triggers scale-up, scale-down
            ray_logger.log({"finished": True, "pid": os.getpid()})

        state = hvd.elastic.TorchState(model,
                                       optimizer,
                                       batch=0,
                                       epoch=0,
                                       commits=0,
                                       rendezvous=0)
        train(state)
        return True
예제 #2
0
def train(state, train_loader):
    epoch = state.epoch
    batch_offset = state.batch

    state.model.train()
    state.train_sampler.set_epoch(epoch)
    train_loss = Metric('train_loss')
    train_accuracy = Metric('train_accuracy')

    for batch_idx, (data, target) in enumerate(train_loader):
        # Elastic Horovod: update the current batch index this epoch
        # and commit / check for host updates. Do not check hosts when
        # we commit as it would be redundant.
        state.batch = batch_offset + batch_idx
        if args.batches_per_commit > 0 and \
                state.batch % args.batches_per_commit == 0:
            state.commit()
        elif args.batches_per_host_check > 0 and \
                state.batch % args.batches_per_host_check == 0:
            state.check_host_updates()

        if args.cuda:
            data, target = data.cuda(), target.cuda()
        state.optimizer.zero_grad()

        output = state.model(data)
        train_accuracy.update(accuracy(output, target))

        loss = F.cross_entropy(output, target)
        train_loss.update(loss)
        loss.backward()
        state.optimizer.step()
        # Only log from the 0th rank worker.
        if hvd.rank() == 0:
            ray_logger.log({
                "tqdm_mode": 'train',
                "train/loss": train_loss.avg.item(),
                "train/accuracy": 100. * train_accuracy.avg.item(),
                "total": len(train_loader),
                "epoch": epoch,
                "world_size": hvd.size()
            })
예제 #3
0
 def train(state):
     for state.epoch in range(state.epoch, iterations):
         ray_logger.log({"training": True, "pid": os.getpid()})
         time.sleep(0.1)
         state.commit()  # triggers scale-up, scale-down
     ray_logger.log({"finished": True, "pid": os.getpid()})