예제 #1
0
def train(logbook, net, device, loss_fn, opt, train_l):
    """Run one epoch of the training experiment."""
    logbook.meter.reset()
    bar = FillingSquaresBar('Training \t', max=len(train_l))
    controllers = indiv.Controller.getControllers(net)
        
    for i_batch, data in enumerate(train_l):
        
        # load data onto device
        inputs, gt_labels = data
        inputs            = inputs.to(device)
        gt_labels         = gt_labels.to(device)
        
        # forprop
        pr_outs           = net(inputs)
        loss              = loss_fn(pr_outs, gt_labels)
        
        # update statistics
        logbook.meter.update(pr_outs, gt_labels, loss.item(), track_metric=logbook.track_metric)
        bar.suffix = 'Total: {total:} | ETA: {eta:} | Epoch: {epoch:4d} | ({batch:5d}/{num_batches:5d})'.format(
                total=bar.elapsed_td,
                eta=bar.eta_td,
                epoch=logbook.i_epoch,
                batch=i_batch + 1,
                num_batches=len(train_l))
        bar.suffix = bar.suffix + logbook.meter.bar()
        bar.next()
        
        # backprop
        opt.zero_grad()
        loss.backward()
        opt.step()
        for ctrl in controllers: 
            ctrl.step_postOptimStep()
        
    bar.finish()
    stats = {
        'train_loss':   logbook.meter.avg_loss,
        'train_metric': logbook.meter.avg_metric
    }
    for k, v in stats.items():
        if v:
            logbook.writer.add_scalar(k, v, global_step=logbook.i_epoch)
    logbook.writer.add_scalar('learning_rate', opt.param_groups[0]['lr'], global_step=logbook.i_epoch)
    return stats
예제 #2
0
def test(logbook, net, device, loss_fn, test_l, valid=False, prefix=None):
    """Run a validation epoch."""
    logbook.meter.reset()
    bar_title = 'Validation \t' if valid else 'Test \t'
    bar       = FillingSquaresBar(bar_title, max=len(test_l))
    with torch.no_grad():
        for i_batch, data in enumerate(test_l):
            
            # load data onto device
            inputs, gt_labels     = data
            inputs                = inputs.to(device)
            gt_labels             = gt_labels.to(device)
            
            # forprop
            tensor_stats, pr_outs = net.forward_with_tensor_stats(inputs)
            loss                  = loss_fn(pr_outs, gt_labels)
            
            # update statistics
            logbook.meter.update(pr_outs, gt_labels, loss.item(), track_metric=True)
            bar.suffix = 'Total: {total:} | ETA: {eta:} | Epoch: {epoch:4d} | ({batch:5d}/{num_batches:5d})'.format(
                total=bar.elapsed_td,
                eta=bar.eta_td,
                epoch=logbook.i_epoch,
                batch=i_batch + 1,
                num_batches=len(test_l))
            bar.suffix = bar.suffix + logbook.meter.bar()
            bar.next()
    bar.finish()

    if prefix == None: 
        prefix = 'valid' if valid else 'test'
    stats = {
        prefix+'_loss':   logbook.meter.avg_loss,
        prefix+'_metric': logbook.meter.avg_metric
    }
    if valid:
        for k, v in stats.items():
            if v:
                logbook.writer.add_scalar(k, v, global_step=logbook.i_epoch)
        for name, tensor in tensor_stats:
            logbook.writer.add_histogram(name, tensor, global_step=logbook.i_epoch)
    return stats