Пример #1
0
def validate(validloader, model, criterion, args, tasks, meter, device):
    model.eval()
    with torch.no_grad():
        for step, (data, target) in enumerate(validloader):
            target = _wrap_target(target)

            data = darts.to_device(data, device)
            target = darts.to_device(target, device)

            batch_size = data.size(0)

            logits = model(data)
            loss = darts.multitask_loss(target,
                                        logits,
                                        criterion,
                                        reduce='mean')

            prec1 = darts.multitask_accuracy_topk(logits, target, topk=(1, ))
            meter.update_batch_loss(loss.item(), batch_size)
            meter.update_batch_accuracy(prec1, batch_size)

            if step % args.log_interval == 0:
                logger.info(
                    f'>> Validation: {step} loss: {meter.loss_meter.avg:.4}')

    meter.update_epoch()
    meter.save(args.save_path)
Пример #2
0
def infer(validloader, model, criterion, args, tasks, device, meter):
    model.eval()

    with torch.no_grad():
        for step, (data, target) in enumerate(validloader):

            data = data.to(device)
            for task, label in target.items():
                target[task] = target[task].to(device)

            batch_size = data.size(0)

            logits = model(data)
            loss = darts.multitask_loss(target, logits, criterion, reduce='mean')

            prec1 = darts.multitask_accuracy_topk(logits, target)
            meter.update_batch_loss(loss.item(), batch_size)
            meter.update_batch_accuracy(prec1, batch_size)

            if step % args.log_interval == 0:
                print(f'>> Validation: {step} loss: {meter.loss_meter.avg:.4}')

    meter.update_epoch()
    meter.save(args.save_path)

    return meter.loss_meter.avg
Пример #3
0
def train(trainloader, model, architecture, criterion, optimizer, scheduler,
          args, tasks, meter, genotype, genotype_store, device):

    valid_iter = iter(trainloader)
    min_accuracy = 0.0
    for step, (data, target) in enumerate(trainloader):

        batch_size = data.size(0)
        model.train()

        data = darts.to_device(data, device)
        target = darts.to_device(target, device)

        x_search, target_search = next(valid_iter)
        x_search = darts.to_device(x_search, device)
        target_search = darts.to_device(target_search, device)

        lr = scheduler.get_lr()[0]

        # 1. update alpha
        architecture.step(data,
                          target,
                          x_search,
                          target_search,
                          lr,
                          optimizer,
                          unrolled=False)

        logits = model(data)
        loss = darts.multitask_loss(target, logits, criterion, reduce='mean')

        # 2. update weight
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        scheduler.step()

        prec1 = darts.multitask_accuracy_topk(logits, target, topk=(1, ))
        meter.update_batch_loss(loss.item(), batch_size)
        meter.update_batch_accuracy(prec1, batch_size)

        accuracy_avg = meter.acc_meter.get_avg_accuracy('response')
        if accuracy_avg > min_accuracy:
            genotype_store.save_genotype(genotype)
            min_accuracy = accuracy_avg

        if step % args.log_interval == 0:
            logger.info(f'Step: {step} loss: {meter.loss_meter.avg:.4}')

    meter.update_epoch()
    meter.save(args.save_path)
Пример #4
0
def train(trainloader, validloader, model, architecture, criterion, optimizer, lr, args, tasks, device, meter):

    valid_iter = iter(trainloader)

    for step, (data, target) in enumerate(trainloader):

        batch_size = data.size(0)
        model.train()

        data = data.to(device)

        for task, label in target.items():
            target[task] = target[task].to(device)

        x_search, target_search = next(valid_iter)
        x_search = x_search.to(device)

        for task, label in target_search.items():
            target_search[task] = target_search[task].to(device)

        # 1. update alpha
        architecture.step(
            data,
            target,
            x_search,
            target_search,
            lr,
            optimizer,
            unrolled=args.unrolled
        )

        logits = model(data)
        loss = darts.multitask_loss(target, logits, criterion, reduce='mean')

        # 2. update weight
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        prec1 = darts.multitask_accuracy_topk(logits, target)
        meter.update_batch_loss(loss.item(), batch_size)
        meter.update_batch_accuracy(prec1, batch_size)

        if step % args.log_interval == 0:
            print(f'Step: {step} loss: {meter.loss_meter.avg:.4}')

    meter.update_epoch()
    meter.save(args.save_path)