Exemple #1
0
def train_network(args, model, reglog, optimizer, loader):
    """
    Train the models on the dataset.
    """
    # running statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()

    # training statistics
    log_top1 = AverageMeter()
    log_loss = AverageMeter()
    end = time.perf_counter()

    if 'pascal' in args.data_path:
        criterion = nn.BCEWithLogitsLoss(reduction='none')
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    for iter_epoch, (inp, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.perf_counter() - end)

        learning_rate_decay(optimizer, len(loader) * args.epoch + iter_epoch, args.lr)

        # start at iter start_iter
        if iter_epoch < args.start_iter:
            continue

        # move to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        if 'pascal' in args.data_path:
            target = target.float()

        # forward
        with torch.no_grad():
            output = model(inp)
        output = reglog(output)

        # compute cross entropy loss
        loss = criterion(output, target)

        if 'pascal' in args.data_path:
            mask = (target == 255)
            loss = torch.sum(loss.masked_fill_(mask, 0)) / target.size(0)

        optimizer.zero_grad()

        # compute the gradients
        loss.backward()

        # step
        optimizer.step()

        # log

        # signal received, relaunch experiment
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            if not args.rank:
                torch.save({
                    'epoch': args.epoch,
                    'start_iter': iter_epoch + 1,
                    'state_dict': reglog.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(args.dump_path, 'checkpoint.pth.tar'))
                trigger_job_requeue(os.path.join(args.dump_path, 'checkpoint.pth.tar'))

        # update stats
        log_loss.update(loss.item(), output.size(0))
        if not 'pascal' in args.data_path:
            prec1 = accuracy(args, output, target)
            log_top1.update(prec1.item(), output.size(0))

        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()

        # verbose
        if iter_epoch % 100 == 0:
            logger.info('Epoch[{0}] - Iter: [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t'
                  .format(args.epoch, iter_epoch, len(loader), batch_time=batch_time,
                   data_time=data_time, loss=log_loss, log_top1=log_top1))

    # end of epoch
    args.start_iter = 0
    args.epoch += 1

    # dump checkpoint
    if not args.rank:
        torch.save({
            'epoch': args.epoch,
            'start_iter': 0,
            'state_dict': reglog.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.dump_path, 'checkpoint.pth.tar'))

    return (args.epoch - 1, args.epoch * len(loader), log_top1.avg, log_loss.avg)
Exemple #2
0
def train_network(args, models, optimizers, dataset):
    """
    Train the models with cluster assignments as targets
    """
    # swith to train mode
    for model in models:
        model.train()

    # uniform sampling over pseudo labels
    sampler = DistUnifTargSampler(
        args.epoch_size,
        dataset.sub_classes,
        args.training_local_world_size,
        args.training_local_rank,
        seed=args.epoch + args.training_local_world_id,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )

    # running statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()

    # training statistics
    log_top1_subclass = AverageMeter()
    log_loss_subclass = AverageMeter()
    log_top1_superclass = AverageMeter()
    log_loss_superclass = AverageMeter()

    log_top1 = AverageMeter()
    log_loss = AverageMeter()
    end = time.perf_counter()

    cel = nn.CrossEntropyLoss().cuda()
    relu = torch.nn.ReLU().cuda()

    for iter_epoch, (inp, target) in enumerate(loader):
        # start at iter start_iter
        if iter_epoch < args.start_iter:
            continue

        # measure data loading time
        data_time.update(time.perf_counter() - end)

        # move input to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True).long()

        # forward on the model
        inp = relu(models[0](inp))

        # forward on sub-class prediction layer
        output = models[-1](inp)
        loss_subclass = cel(output, target)

        # forward on super-class prediction layer
        super_class_output = models[1](inp)
        sc_target = args.training_local_world_id + \
                    0 * torch.cuda.LongTensor(args.batch_size)
        loss_superclass = cel(super_class_output, sc_target)

        loss = loss_subclass + loss_superclass

        # initialize the optimizers
        for optimizer in optimizers:
            optimizer.zero_grad()

        # compute the gradients
        loss.backward()

        # step
        for optimizer in optimizers:
            optimizer.step()

        # log

        # signal received, relaunch experiment
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            save_checkpoint(args, iter_epoch + 1, models, optimizers)
            if not args.rank:
                trigger_job_requeue(
                    os.path.join(args.dump_path, 'checkpoint.pth.tar'))

        # regular checkpoints
        if iter_epoch and iter_epoch % 1000 == 0:
            save_checkpoint(args, iter_epoch + 1, models, optimizers)

        # update stats
        log_loss.update(loss.item(), output.size(0))
        prec1 = accuracy(args, output, target, sc_output=super_class_output)
        log_top1.update(prec1.item(), output.size(0))

        log_loss_superclass.update(loss_superclass.item(), output.size(0))
        prec1 = accuracy(args, super_class_output, sc_target)
        log_top1_superclass.update(prec1.item(), output.size(0))

        log_loss_subclass.update(loss_subclass.item(), output.size(0))
        prec1 = accuracy(args, output, target)
        log_top1_subclass.update(prec1.item(), output.size(0))

        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()

        # verbose
        if iter_epoch % 100 == 0:
            logger.info(
                'Epoch[{0}] - Iter: [{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t'
                'Super-class loss: {sc_loss.val:.3f} ({sc_loss.avg:.3f})\t'
                'Super-class prec: {sc_prec.val:.3f} ({sc_prec.avg:.3f})\t'
                'Intra super-class loss: {los.val:.3f} ({los.avg:.3f})\t'
                'Intra super-class prec: {prec.val:.3f} ({prec.avg:.3f})\t'.
                format(args.epoch,
                       iter_epoch,
                       len(loader),
                       batch_time=batch_time,
                       data_time=data_time,
                       loss=log_loss,
                       log_top1=log_top1,
                       sc_loss=log_loss_superclass,
                       sc_prec=log_top1_superclass,
                       los=log_loss_subclass,
                       prec=log_top1_subclass))

    # end of epoch
    args.start_iter = 0
    args.epoch += 1

    # dump checkpoint
    save_checkpoint(args, 0, models, optimizers)
    if not args.rank:
        if not (args.epoch - 1) % args.checkpoint_freq:
            shutil.copyfile(
                os.path.join(args.dump_path, 'checkpoint.pth.tar'),
                os.path.join(args.dump_checkpoints,
                             'checkpoint' + str(args.epoch - 1) + '.pth.tar'),
            )

    return (
        args.epoch - 1,
        args.epoch * len(loader),
        log_top1.avg,
        log_loss.avg,
        log_top1_superclass.avg,
        log_loss_superclass.avg,
        log_top1_subclass.avg,
        log_loss_subclass.avg,
    )
Exemple #3
0
def train_network(args, model, optimizer, dataset):
    """
    Train the models on the dataset.
    """
    # swith to train mode
    model.train()

    sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )

    # running statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()

    # training statistics
    log_top1 = AverageMeter()
    log_loss = AverageMeter()
    end = time.perf_counter()

    cel = nn.CrossEntropyLoss().cuda()

    for iter_epoch, (inp, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.perf_counter() - end)

        # start at iter start_iter
        if iter_epoch < args.start_iter:
            continue

        # move to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward
        output = model(inp)

        # compute cross entropy loss
        loss = cel(output, target)

        optimizer.zero_grad()

        # compute the gradients
        loss.backward()

        # step
        optimizer.step()

        # log

        # signal received, relaunch experiment
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            if not args.rank:
                torch.save(
                    {
                        'epoch': args.epoch,
                        'start_iter': iter_epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(args.dump_path, 'checkpoint.pth.tar'))
                trigger_job_requeue(
                    os.path.join(args.dump_path, 'checkpoint.pth.tar'))

        # update stats
        log_loss.update(loss.item(), output.size(0))
        prec1 = accuracy(args, output, target)
        log_top1.update(prec1.item(), output.size(0))

        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()

        # verbose
        if iter_epoch % 100 == 0:
            logger.info(
                'Epoch[{0}] - Iter: [{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Prec {log_top1.val:.3f} ({log_top1.avg:.3f})\t'.format(
                    args.epoch,
                    iter_epoch,
                    len(loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=log_loss,
                    log_top1=log_top1))

    # end of epoch
    args.start_iter = 0
    args.epoch += 1

    # dump checkpoint
    if not args.rank:
        torch.save(
            {
                'epoch': args.epoch,
                'start_iter': 0,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.dump_path, 'checkpoint.pth.tar'))
        if not (args.epoch - 1) % args.checkpoint_freq:
            shutil.copyfile(
                os.path.join(args.dump_path, 'checkpoint.pth.tar'),
                os.path.join(args.dump_checkpoints,
                             'checkpoint' + str(args.epoch - 1) + '.pth.tar'),
            )

    return (args.epoch - 1, args.epoch * len(loader), log_top1.avg,
            log_loss.avg)