Esempio n. 1
0
 def forward(self, x):
     if not self.use_APS:
         return float_quantize(x, 4, 3)
     else:
         shift_factor = 7 - torch.log2(
             torch.abs(x).max()).ceil().detach().cpu().numpy()
         return float_quantize(x * (2**shift_factor), 4, 3)
Esempio n. 2
0
def train_epoch(model, batches, optimizer, lrs, stats, dist_, warm_up_iter,
                param_exp, param_man, loss_scale):
    global global_step
    model.train(True)
    for param in model.parameters():
        param.data.copy_(float_quantize(param.data, param_exp, param_man))
    for lr, batch in zip(lrs, batches):

        collect(stats, model(batch), dist_)
        update(model, set_opt_params(optimizer, {'lr': lr}, warm_up_iter),
               dist_, loss_scale)
        global_step += 1
        for param in model.parameters():
            param.data.copy_(float_quantize(param.data, param_exp, param_man))
    return stats
Esempio n. 3
0
    def forward(self, x):
        if not self.use_APS:
            return float_quantize(x, 4, 3)
        else:
            if self.shift_factor and not self.fix_shift_factor:
                if self.shift_factor == 7 - torch.log2(
                        torch.abs(x).max()).ceil().detach().cpu().numpy():
                    # if shift factor is equal with last iteration,
                    # we believe the distribution is constant, so we
                    # just keep former shift factor, instead of calculating
                    # new one
                    self.same_cnt += 1
                    # 15 iter/epoch for 4K batch size
                    if self.same_cnt == 30:
                        self.fix_shift_factor = True
                else:
                    self.same_cnt = 0

            if not self.fix_shift_factor:
                self.shift_factor = 7 - torch.log2(
                    torch.abs(x).max()).ceil().detach().cpu().numpy()
            return float_quantize(x * (2**self.shift_factor), 4, 3)
Esempio n. 4
0
def train(epoch):
    model.train()
    train_sampler.set_epoch(epoch)

    with tqdm(total=len(train_loader),
              desc='Train Epoch     #{}'.format(epoch),
              disable=not verbose) as t:
        train_loss = Metric('train_loss')
        train_accuracy = Metric('train_accuracy')
        for batch_idx, (data, target) in enumerate(train_loader):
            curr_lr = adjust_learning_rate(epoch, batch_idx)

            if args.cuda:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            grad_buffer = []
            for param_g in model.parameters():
                grad_buffer.append([])
            # Split data into sub-batches of size batch_size
            for i in range(0, len(data), args.batch_size):
                data_batch = data[i:i + args.batch_size]
                target_batch = target[i:i + args.batch_size]
                output = model(data_batch)
                train_accuracy.update(accuracy(output, target_batch))
                loss = F.cross_entropy(
                    output, target_batch) / world_size
                reduced_loss = loss.data.clone()
                dist.all_reduce(reduced_loss)
                train_loss.update(float(reduced_loss.item()))
                # Average gradients among sub-batches
                loss.div_(math.ceil(float(len(data)) / args.batch_size))
                loss.backward()
                for idx, param in enumerate(model.parameters()):
                    if param.grad is not None:
                        grad_buffer[idx].append(
                            param.grad.detach().clone().data)
                model.zero_grad()
            for idx, param in enumerate(model.parameters()):
                if param.grad is not None:
                    # APS
                    # find maximum exponent
                    max_exp = -100
                    for val in grad_buffer[idx]:
                        t_exp = torch.log2(
                            torch.abs(val * args.emulate_node).max()).ceil().detach().cpu().numpy()
                        if t_exp > max_exp:
                            max_exp = t_exp
                    upper_bound = 2**(args.grad_exp - 1) - 1
                    shift_factor = upper_bound - max_exp
                    if max_exp == -100 or not args.use_APS:
                        shift_factor = 0
                    for grad in grad_buffer[idx]:
                        grad.data.copy_(float_quantize(
                            grad * (2**shift_factor), args.grad_exp, args.grad_man))
                    # as we use a single node to emulate multi-node, we should
                    # first accumulate gradients within a single node and then
                    # communicate them in the distributed system
                    res = torch.zeros_like(grad_buffer[idx][0])
                    for val in grad_buffer[idx]:
                        res = float_quantize(
                            res + val, args.grad_exp, args.grad_man)
                    param.grad.data.copy_(res.data / (2**shift_factor))
            sum_gradients(model, use_APS=args.use_APS,
                          grad_exp=args.grad_exp, grad_man=args.grad_man)

            # Gradient is applied across all ranks
            optimizer.step()

            t.set_postfix({'lr': curr_lr,
                           'loss': train_loss.avg,
                           'accuracy': 100. * train_accuracy.avg})
            t.update(1)
Esempio n. 5
0
def train(train_loader, val_loader, model, criterion, optimizer, lr_scheduler,
          start_iter, tb_logger):

    global args, rank, world_size, best_prec1, emulate_node
    global grad_exp, grad_man, param_exp, param_man

    batch_time = AverageMeter(args.print_freq)
    data_time = AverageMeter(args.print_freq)
    losses = AverageMeter(args.print_freq)

    model.train()

    end = time.time()
    curr_step = start_iter
    emulate_step = 0

    momentum_buffer = []
    for master_p in master_params:
        momentum_buffer.append(torch.zeros_like(master_p))
    grad_buffer = []
    for param_g in model.parameters():
        grad_buffer.append([])

    for i, (input, target) in enumerate(train_loader):
        emulate_step += 1
        if emulate_step == emulate_node:
            curr_step += 1
        if curr_step > args.max_iter:
            break

        current_lr = adjust_learning_rate(optimizer, curr_step)

        target = target.cuda()
        input_var = input.cuda()

        data_time.update(time.time() - end)

        output = model(input_var, rank)
        loss = criterion(output, target) / (world_size * emulate_node)
        reduced_loss = loss.data.clone()
        if args.dist:
            dist.all_reduce(reduced_loss)
        losses.update(float(reduced_loss.item()))
        model.zero_grad()
        loss.backward()
        for idx, param in enumerate(model.parameters()):
            if param.grad is not None:
                grad_buffer[idx].append(param.grad.detach().clone().data)
        model.zero_grad()

        if emulate_node == emulate_step:
            emulate_step = 0
            # reduce all gradients with low precision
            for idx, param in enumerate(model.parameters()):
                if param.grad is not None:
                    if emulate_node == 1:
                        param.grad.data.copy_(grad_buffer[idx][0])
                        continue
                    # find maximum exponent
                    max_exp = -100
                    for val in grad_buffer[idx]:
                        t_exp = torch.log2(
                            torch.abs(val * args.emulate_node).max()).ceil(
                            ).detach().cpu().numpy()
                        if t_exp > max_exp:
                            max_exp = t_exp
                    upper_bound = 2**(args.grad_exp - 1) - 1
                    shift_factor = upper_bound - max_exp
                    if max_exp == -100 or not args.use_APS:
                        shift_factor = 0
                    for grad in grad_buffer[idx]:
                        grad.data.copy_(
                            float_quantize(grad * (2**shift_factor),
                                           args.grad_exp, args.grad_man))
                    # as we use a single node to emulate multi-node, we should
                    # first accumulate gradients within a single node and then
                    # communicate them in the distributed system
                    res = torch.zeros_like(grad_buffer[idx][0])
                    for val in grad_buffer[idx]:
                        res = float_quantize(res + val, args.grad_exp,
                                             args.grad_man)
                    param.grad.data.copy_(res.data / (2**shift_factor))
            grad_buffer = []
            for param_g in model.parameters():
                grad_buffer.append([])
            if args.dist:
                sum_gradients(model,
                              use_APS=args.use_APS,
                              use_kahan=args.use_kahan,
                              grad_exp=args.grad_exp,
                              grad_man=args.grad_man)
            for model_p, master_p in zip(model_params, master_params):
                if model_p.grad is not None:
                    master_p.backward(model_p.grad.float())

            # update parameters
            if args.use_lars:
                for idx, master_p in enumerate(master_params):
                    if master_p.grad is not None:
                        local_lr = master_p.norm(2) /\
                            (master_p.grad.data.norm(2)
                             + args.weight_decay * master_p.norm(2))
                        lars_coefficient = 0.001
                        local_lr = local_lr * lars_coefficient
                        momentum_buffer[idx] = args.momentum * momentum_buffer[idx].data \
                            + current_lr \
                            * local_lr \
                            * (master_p.grad.data + args.weight_decay * master_p.data)
                        update = momentum_buffer[idx]
                        master_p.data.copy_(master_p - update)
            else:
                optimizer.step()
            for model_p, master_p in zip(model_params, master_params):
                model_p.data.copy_(master_p.data)

            optimizer.zero_grad()

            batch_time.update(time.time() - end)
            end = time.time()

            if (curr_step == 1
                    or curr_step % args.print_freq == 0) and rank == 0:
                if tb_logger:
                    tb_logger.add_scalar('loss_train', losses.avg, curr_step)
                    tb_logger.add_scalar('lr', current_lr, curr_step)
                print('Iter: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'LR {lr:.4f}'.format(curr_step,
                                           args.max_iter,
                                           batch_time=batch_time,
                                           data_time=data_time,
                                           loss=losses,
                                           lr=current_lr))

            if curr_step % args.val_freq == 0 and curr_step != 0:
                val_loss, prec1, prec5 = validate(val_loader, model, criterion)

                if tb_logger:
                    tb_logger.add_scalar('loss_val', val_loss, curr_step)
                    tb_logger.add_scalar('acc1_val', prec1, curr_step)
                    tb_logger.add_scalar('acc5_val', prec5, curr_step)

                if rank == 0:
                    # remember best prec@1 and save checkpoint
                    is_best = prec1 > best_prec1
                    best_prec1 = max(prec1, best_prec1)
                    save_checkpoint(
                        {
                            'step': curr_step,
                            'arch': args.arch,
                            'state_dict': model.state_dict(),
                            'best_prec1': best_prec1,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, args.save_path + '/ckpt_' + str(curr_step))
    del momentum_buffer
    val_loss, prec1, prec5 = validate(val_loader, model, criterion)