示例#1
0
def distributed_predict(input, target, model, criterion, cnt):
    # Allows distributed prediction on uneven batches. Test set isn't always large enough for every GPU to get a batch
    batch_size = input.size(0)
    output = loss = corr1 = corr5 = valid_batches = 0

    if batch_size:
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target).data
        # measure accuracy and record loss
        valid_batches = 1
        corr1, corr5 = correct(output.data, target, topk=(1, 5))

    dist_validate_tensor[0] = batch_size
    dist_validate_tensor[1] = valid_batches
    dist_validate_tensor[2] = loss
    dist_validate_tensor[3] = corr1
    dist_validate_tensor[4] = corr5
    batch_total, valid_batches, reduced_loss, corr1, corr5 = bps.push_pull(
        dist_validate_tensor,
        average=False,
        name="distributed_validation_tensor")
    reduced_loss = reduced_loss / valid_batches

    top1 = corr1 * (100.0 / batch_total)
    top5 = corr5 * (100.0 / batch_total)
    return top1, top5, reduced_loss, batch_total
def metric_average(val, name):
    tensor = torch.tensor(val)
    if args.cuda:
        tensor = tensor.cuda()
    avg_tensor = bps.push_pull(tensor, name=name)
    return avg_tensor.item()
示例#3
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            # zero_grad() and converting fp16/fp32 is handled in optimizer
            loss.backward()
            optimizer.step(wait_for_finish=should_print)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()

        if args.local_rank == 0 and should_print:
            corr1, corr5 = correct(output.data, target, topk=(1, 5))
            reduced_loss, batch_total = to_python_float(
                loss.data), to_python_float(input.size(0))
            if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
                validate_tensor[0] = batch_total
                validate_tensor[1] = reduced_loss
                validate_tensor[2] = corr1
                validate_tensor[3] = corr5
                batch_total, reduced_loss, corr1, corr5 = bps.push_pull(
                    validate_tensor, average=False, name="validation_tensor")
                batch_total = batch_total.cpu().numpy()
                reduced_loss = reduced_loss.cpu().numpy()
                corr1 = corr1.cpu().numpy()
                corr5 = corr5.cpu().numpy()
                reduced_loss = reduced_loss / bps.size()

            top1acc = to_python_float(corr1) * (100.0 / batch_total)
            top5acc = to_python_float(corr5) * (100.0 / batch_total)

            losses.update(reduced_loss, batch_total)
            top1.update(top1acc, batch_total)
            top5.update(top5acc, batch_total)
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

            tb.update_step_count(batch_total)