Ejemplo n.º 1
0
def broadcast_parameters(params, root_rank):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `model.state_dict()`,
    `model.named_parameters()`, or `model.parameters()`.
    Arguments:
        params: One of the following:
            - list of parameters to broadcast
            - dict of parameters to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run synchronous broadcasts.
    for name, p in params:
        # Broadcast is implemented as push + pull in BytePS
        # To make it a real broadcast, we set the non-root tensors all 0.
        if rank() != root_rank:
            p.fill_(0)
        # Remember to disable averaging because we are doing broadcast
        handle = byteps_push_pull(p, average=False, name="Parameter." + name)
        synchronize(handle)
Ejemplo n.º 2
0
def benchmark(tensor, average, name):
    if not args.no_wait and bps.rank() == 0:
        time.sleep(0.01)
    start = time.time()
    handle = push_pull_async_inplace(tensor, average, name)
    while True:
        if poll(handle):
            synchronize(handle)
            break
    end = time.time()
    return (end - start) * 1000
Ejemplo n.º 3
0
def benchmark(tensor, average, name):
    if not args.no_wait and hvd.rank() == 0:
        # let other workers submit allreduce request first
        time.sleep(0.01)
    start = time.time()
    # do not use allreduce_() as it polls every 1ms
    handle = push_pull_async_inplace(tensor, average, name)
    while True:
        if poll(handle):
            synchronize(handle)
            break
    end = time.time()
    return (end - start) * 1000
Ejemplo n.º 4
0
 def _poll(self):
     """Poll the completion of the tensor's backward or push-pull from a FIFO event_queue"""
     while True:
         p, handle, ctx = self._event_queue.get()
         if p is None:
             self._logger.debug("poller exits.")
             break
         # Check whether the push-pull is finished. If so, start updating parameters.
         if handle is not None and poll(handle):
             output = synchronize(handle)
             p.grad.set_(self._compression.decompress(output, ctx))
             self._logger.debug("{} {} finished push-pull".format(
                 self._desc, self._parameter_names[id(p)]))
             self._push_pull_delay[p] = self.backward_passes_per_step
             # So far ByteScheduler only supports SGD, Adam and RMSprop optimizers in torch
             if isinstance(self._opt, torch.optim.SGD):
                 self._sgd(p)
             elif isinstance(self._opt, torch.optim.Adam):
                 self._adam(p)
             elif isinstance(self._opt, torch.optim.RMSprop):
                 self._rmsprop(p)
             else:
                 raise ValueError(
                     "Invalid optimizer! ByteScheduler only supports SGD, Adam and RMSprop."
                 )
             self._zero_one_grad(p)
             # notify update completion and parameter is ready for forward propagation
             if p in self._locks:
                 self._locks[p].release()
         else:
             self._event_queue.put((p, handle, ctx))
    def synchronize(self):
        #missing_p = self._requires_update - set(self._handles.keys())
        #for p in missing_p:
        #    handle, ctx = self._push_pull_grad_async(p)
        #    self._handles[p] = (handle, ctx)

        #for p, value in self._handles.items():
        #    handle, ctx = value
        #    if handle is None:
        #        handle, ctx = self._push_pull_grad_async(p)
        #        self._handles[p] = (handle, ctx)
        self.set_backward_passes_per_step(self.backward_passes_per_step)
        for d_p, (handle, ctx) in self._handles.items():
            if handle is None:
                continue
            output = synchronize(handle)
            if not self._enable_async:
                d_p.set_(self._compression.decompress(output, ctx))
        if self._tensor_fusion_threshold > 0: 
            for merged_p, (handle, ctx) in self._handles.items():
                if handle is None:
                    continue
                new_name = self._merged_parameter_names.get(merged_p)
                tensors = self._pull_from_buffer(new_name, merged_p)
                for n in tensors:
                    p = self._named_parameters.get(n)
                    p.grad.set_(tensors[n].data)
        self._handles.clear()
Ejemplo n.º 6
0
    def synchronize(self):
        # missing_p = self._requires_update - set(self._handles.keys())
        # for p in missing_p:
        #     handle, ctx = self._push_pull_grad_async(p)
        #     self._handles[p] = (handle, ctx)

        # for p, value in self._handles.items():
        #     handle, ctx = value
        #     if handle is None:
        #         handle, ctx = self._push_pull_grad_async(p)
        #         self._handles[p] = (handle, ctx)
        # for p, (handle, _) in self._handles.items():
        #     output = synchronize(handle)
        #     self._push_pull_delay[p] = self.backward_passes_per_step
        #     if not self._enable_async:
        #         p.grad.set_(self._compression.decompress(output, ctx))
        # handle, ctx = self._push_pull_grad_async(self.whole_gradient)
        print("=======================================")
        print("Length of self.whole_gradient: %d" % len(self.whole_gradient))
        tensor_compressed, ctx = self._compression.compress(self.whole_gradient)
        handle = byteps_push_pull(tensor_compressed, average=True, name="Whole.Gradient")
        output = synchronize(handle)
        output = self._compression.decompress(output, ctx)
        # for param_group in self.param_groups:
        #     for p in param_group['params']:
        #         if p.requires_grad:
        for name, p in self._named_parameters.items():
            d_p = self._pull_from_buffer(name).view(p.data.shape)
            p.grad.set_(d_p)
        self.whole_gradient = torch.tensor([]).cuda()
        self._merged_parameter_offsets = {}
        self._merged_parameter_index = {}
        self._handles.clear()
Ejemplo n.º 7
0
    def synchronize(self):
        missing_p = self._requires_update - set(self._handles.keys())
        for p in missing_p:
            handle, ctx = self._push_pull_grad_async(p)
            self._handles[p] = (handle, ctx)

        for p, value in self._handles.items():
            handle, ctx = value
            if handle is None:
                handle, ctx = self._push_pull_grad_async(p)
                self._handles[p] = (handle, ctx)
        for p, (handle, _) in self._handles.items():
            output = synchronize(handle)
            self._push_pull_delay[p] = self.backward_passes_per_step
            p.grad.set_(self._compression.decompress(output, ctx))
        self._handles.clear()
Ejemplo n.º 8
0
    def synchronize(self):
        missing_p = self._requires_update - set(self._handles.keys())
        for p in missing_p:
            handle, ctx, grad_count = self._push_pull_grad_group_sync(p, self._num_grads)
            self._handles[p] = (handle, ctx)

        for p, value in self._handles.items():
            handle, ctx = value
            if handle is None:
                handle, ctx, grad_count = self._push_pull_grad_group_sync(p)
                self._handles[p] = (handle, ctx)
        for p, (handle, _) in self._handles.items():
            output = synchronize(handle)
            if not self._enable_async:
                p.grad.set_(self._compression.decompress(output, ctx))
        self._handles.clear()
Ejemplo n.º 9
0
 def _try_to_synchronize(self, p):
     handle, ctx = self._handles[p]
     if poll(handle):
         output = synchronize(handle)
         self._push_pull_delay[p] = self.backward_passes_per_step
         if self._is_tensor_instance:
             fp16_p = self._fp32_to_fp16_map.get(p.__hash__())
         else:
             fp16_p = self._fp32_to_fp16_map.get(p)
         fp16_p.grad.set_(self._compression.decompress(output, ctx))
         p.grad.data.copy_(fp16_p.grad.data)
         p.grad.data = p.grad.data / (self.loss_scale * size())
         self._step_one_param(p)
         fp16_p.data.copy_(p.data)
         self._handles.pop(p)
         return True
     else:
         return False
Ejemplo n.º 10
0
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', bps.size())

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/validation')

    # if args.distributed:
    # log.console('Distributed initializing process group')
    torch.cuda.set_device(bps.local_rank())
    print(f'cuda device set to {bps.local_rank()}')
    log.console("cuda initialized (rank=%d)" % (bps.local_rank()))
    # dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=bps.size())
    log.console("Distributed: success (%d/%d)" % (bps.rank(), bps.size()))

    log.console("Loading model (rank=%d)" % (bps.rank()))
    model = resnet.resnet50(bn0=args.init_bn0).cuda()

    # reuse the validate tensor
    global validate_tensor, dist_validate_tensor
    validate_tensor = torch.tensor([0, 0, 0, 0]).float().cuda()
    dist_validate_tensor = torch.tensor([0, 0, 0, 0, 0]).float().cuda()

    if args.fp16: model = network_to_half(model)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16: model_params, master_params = prep_param_lists(model)
    else: model_params = master_params = model.parameters()

    optim_params, name_list = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    named_param = []
    for p in optim_params:
        tensors = p['params']
        for tensor in tensors:
            named_param.append(tensor)

    # create bps_param (tuple)
    bps_param = []
    for i, tensor in enumerate(named_param):
        name = name_list[i]
        bps_param.append((name, tensor))

    # wrap with byteps optimizer
    optimizer = DistributedOptimizer(
        optimizer,
        named_parameters=bps_param,
        backward_passes_per_step=args.batches_per_pushpull,
        half=True,
        model=model,
        fp16_params=model_params,
        fp32_params=master_params,
        loss_scale=args.loss_scale)

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    num_machines = (bps.size() - 1) // 8 + 1
    assert (num_machines in schedules)
    phases = schedules[num_machines]
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    # BytePS: broadcast parameters & optimizer state.
    broadcast_parameters([(name, p.detach()) for name, p in bps_param],
                         root_rank=0)
    broadcast_optimizer_state(optimizer, root_rank=0)

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Global Barrier: Syncing machines before training')
        tensor = torch.tensor([1.0]).float().cuda()
        barrier_handler = push_pull_async_inplace(tensor,
                                                  average=True,
                                                  name="init.barrier")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break
        # do broadcast for validate tensor
        log.console('Broadcasting validate tensor')
        barrier_handler = push_pull_async_inplace(validate_tensor,
                                                  average=True,
                                                  name="validation_tensor")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break
        barrier_handler = push_pull_async_inplace(
            dist_validate_tensor,
            average=True,
            name="distributed_validation_tensor")
        while True:
            if poll(barrier_handler):
                synchronize(barrier_handler)
                break

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')