Example #1
0
 def _poll(self):
     """Poll the completion of the tensor's backward or push-pull from a FIFO event_queue"""
     while True:
         p, handle, ctx = self._event_queue.get()
         if p is None:
             self._logger.debug("poller exits.")
             break
         # Check whether the push-pull is finished. If so, start updating parameters.
         if handle is not None and poll(handle):
             output = synchronize(handle)
             p.grad.set_(self._compression.decompress(output, ctx))
             self._logger.debug("{} {} finished push-pull".format(
                 self._desc, self._parameter_names[id(p)]))
             self._push_pull_delay[p] = self.backward_passes_per_step
             # So far ByteScheduler only supports SGD, Adam and RMSprop optimizers in torch
             if isinstance(self._opt, torch.optim.SGD):
                 self._sgd(p)
             elif isinstance(self._opt, torch.optim.Adam):
                 self._adam(p)
             elif isinstance(self._opt, torch.optim.RMSprop):
                 self._rmsprop(p)
             else:
                 raise ValueError(
                     "Invalid optimizer! ByteScheduler only supports SGD, Adam and RMSprop."
                 )
             self._zero_one_grad(p)
             # notify update completion and parameter is ready for forward propagation
             if p in self._locks:
                 self._locks[p].release()
         else:
             self._event_queue.put((p, handle, ctx))
Example #2
0
def benchmark(tensor, average, name):
    if not args.no_wait and bps.rank() == 0:
        time.sleep(0.01)
    start = time.time()
    handle = push_pull_async_inplace(tensor, average, name)
    while True:
        if poll(handle):
            synchronize(handle)
            break
    end = time.time()
    return (end - start) * 1000
Example #3
0
def benchmark(tensor, average, name):
    if not args.no_wait and hvd.rank() == 0:
        # let other workers submit allreduce request first
        time.sleep(0.01)
    start = time.time()
    # do not use allreduce_() as it polls every 1ms
    handle = push_pull_async_inplace(tensor, average, name)
    while True:
        if poll(handle):
            synchronize(handle)
            break
    end = time.time()
    return (end - start) * 1000
Example #4
0
 def _try_to_synchronize(self, p):
     handle, ctx = self._handles[p]
     if poll(handle):
         output = synchronize(handle)
         self._push_pull_delay[p] = self.backward_passes_per_step
         if self._is_tensor_instance:
             fp16_p = self._fp32_to_fp16_map.get(p.__hash__())
         else:
             fp16_p = self._fp32_to_fp16_map.get(p)
         fp16_p.grad.set_(self._compression.decompress(output, ctx))
         p.grad.data.copy_(fp16_p.grad.data)
         p.grad.data = p.grad.data / (self.loss_scale * size())
         self._step_one_param(p)
         fp16_p.data.copy_(p.data)
         self._handles.pop(p)
         return True
     else:
         return False
Example #5
0
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', bps.size())

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/validation')

    # if args.distributed:
    # log.console('Distributed initializing process group')
    torch.cuda.set_device(bps.local_rank())
    print(f'cuda device set to {bps.local_rank()}')
    log.console("cuda initialized (rank=%d)" % (bps.local_rank()))
    # dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=bps.size())
    log.console("Distributed: success (%d/%d)" % (bps.rank(), bps.size()))

    log.console("Loading model (rank=%d)" % (bps.rank()))
    model = resnet.resnet50(bn0=args.init_bn0).cuda()

    # reuse the validate tensor
    global validate_tensor, dist_validate_tensor
    validate_tensor = torch.tensor([0, 0, 0, 0]).float().cuda()
    dist_validate_tensor = torch.tensor([0, 0, 0, 0, 0]).float().cuda()

    if args.fp16: model = network_to_half(model)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16: model_params, master_params = prep_param_lists(model)
    else: model_params = master_params = model.parameters()

    optim_params, name_list = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    named_param = []
    for p in optim_params:
        tensors = p['params']
        for tensor in tensors:
            named_param.append(tensor)

    # create bps_param (tuple)
    bps_param = []
    for i, tensor in enumerate(named_param):
        name = name_list[i]
        bps_param.append((name, tensor))

    # wrap with byteps optimizer
    optimizer = DistributedOptimizer(
        optimizer,
        named_parameters=bps_param,
        backward_passes_per_step=args.batches_per_pushpull,
        half=True,
        model=model,
        fp16_params=model_params,
        fp32_params=master_params,
        loss_scale=args.loss_scale)

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    num_machines = (bps.size() - 1) // 8 + 1
    assert (num_machines in schedules)
    phases = schedules[num_machines]
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    # BytePS: broadcast parameters & optimizer state.
    broadcast_parameters([(name, p.detach()) for name, p in bps_param],
                         root_rank=0)
    broadcast_optimizer_state(optimizer, root_rank=0)

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Global Barrier: Syncing machines before training')
        tensor = torch.tensor([1.0]).float().cuda()
        barrier_handler = push_pull_async_inplace(tensor,
                                                  average=True,
                                                  name="init.barrier")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break
        # do broadcast for validate tensor
        log.console('Broadcasting validate tensor')
        barrier_handler = push_pull_async_inplace(validate_tensor,
                                                  average=True,
                                                  name="validation_tensor")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break
        barrier_handler = push_pull_async_inplace(
            dist_validate_tensor,
            average=True,
            name="distributed_validation_tensor")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')