예제 #1
0
def validate(val_loader, model, criterion, epoch, start_time):
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    eval_start_time = time.time()

    for i, (input, target) in enumerate(val_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        if args.distributed:
            top1acc, top5acc, loss, batch_total = distributed_predict(
                input, target, model, criterion)
        else:
            with torch.no_grad():
                output = model(input)
                loss = criterion(output, target).data
            batch_total = input.size(0)
            top1acc, top5acc = accuracy(output.data, target, topk=(1, 5))

        # Eval batch done. Logging results
        timer.batch_end()
        losses.update(to_python_float(loss), to_python_float(batch_total))
        top1.update(to_python_float(top1acc), to_python_float(batch_total))
        top5.update(to_python_float(top5acc), to_python_float(batch_total))
        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(val_loader))
        if args.local_rank == 0 and should_print:
            output = (
                f'Test:  [{epoch}][{batch_num}/{len(val_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})')
            log.verbose(output)

    tb.log_eval(top1.avg, top5.avg, time.time() - eval_start_time)
    tb.log('epoch', epoch)

    return top1.avg, top5.avg
예제 #2
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            model.zero_grad()
            loss.backward()
            model_grads_to_master_grads(model_params, master_params)
            for param in master_params:
                param.grad.data = param.grad.data / args.loss_scale
            optimizer.step()
            master_params_to_model_params(model_params, master_params)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()
        corr1, corr5 = correct(output.data, target, topk=(1, 5))
        reduced_loss, batch_total = to_python_float(
            loss.data), to_python_float(input.size(0))
        if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
            metrics = torch.tensor([batch_total, reduced_loss, corr1,
                                    corr5]).float().cuda()
            batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(
                metrics).cpu().numpy()
            reduced_loss = reduced_loss / dist_utils.env_world_size()
        top1acc = to_python_float(corr1) * (100.0 / batch_total)
        top5acc = to_python_float(corr5) * (100.0 / batch_total)

        losses.update(reduced_loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))
        if args.local_rank == 0 and should_print:
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

        tb.update_step_count(batch_total)
예제 #3
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            # zero_grad() and converting fp16/fp32 is handled in optimizer
            loss.backward()
            optimizer.step(wait_for_finish=should_print)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()

        if args.local_rank == 0 and should_print:
            corr1, corr5 = correct(output.data, target, topk=(1, 5))
            reduced_loss, batch_total = to_python_float(
                loss.data), to_python_float(input.size(0))
            if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
                validate_tensor[0] = batch_total
                validate_tensor[1] = reduced_loss
                validate_tensor[2] = corr1
                validate_tensor[3] = corr5
                batch_total, reduced_loss, corr1, corr5 = bps.push_pull(
                    validate_tensor, average=False, name="validation_tensor")
                batch_total = batch_total.cpu().numpy()
                reduced_loss = reduced_loss.cpu().numpy()
                corr1 = corr1.cpu().numpy()
                corr5 = corr5.cpu().numpy()
                reduced_loss = reduced_loss / bps.size()

            top1acc = to_python_float(corr1) * (100.0 / batch_total)
            top5acc = to_python_float(corr5) * (100.0 / batch_total)

            losses.update(reduced_loss, batch_total)
            top1.update(top1acc, batch_total)
            top5.update(top5acc, batch_total)
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

            tb.update_step_count(batch_total)