Esempio n. 1
0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            time.sleep(0.5)
            loss.backward()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            if ftlib.skip_allreduce:
                logging.info("skip allreduce")
                optimizer.step()
                continue
            else:
                res = ftlib.wait_weights_ready(model)
            if res == FTAllReduceStatus.NO_NEED:
                logging.critical(
                    "cannot use average_gradient when there is no need")
                exit(2)
            if res == FTAllReduceStatus.SUCCESS:
                logging.info("average succeed")
                optimizer.step()
            if res == FTAllReduceStatus.ABORT:
                logging.info("average failed, abort")
                continue
        scheduler.step()

    logging.info("terminate!")
Esempio n. 2
0
    epochs = 1
    dl = dummy_dataloader(10)

    # initialize the fault-tolerant library with consensus and framework options
    ftlib = BasicFTLib()
    ftlib.init(consensus='shared_storage', framework='dummy_NCCL')

    for _ in range(epochs):
        for batch in dl:
            dummy_forward()
            dummy_backward()

            if ftlib.skip_allreduce:
                logging.info("skip allreduce")
                dummy_update()
                continue
            else:
                res = ftlib.wait_weights_ready()
            if res == FTAllReduceStatus.NO_NEED:
                logging.critical(
                    "cannot use average_gradient when there is no need")
                exit(2)
            if res == FTAllReduceStatus.SUCCESS:
                logging.info("average succeed")
                dummy_update()
            if res == FTAllReduceStatus.ABORT:
                logging.info("average failed, abort")
                continue

    logging.info("terminate!")