예제 #1
0
def test(epoch, test_loader, save=True):
    global best_acc
    net.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            # timing
            batch_time.update(time.time() - end)
            end = time.time()

            progress_bar(
                batch_idx, len(test_loader),
                'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'.format(
                    losses.avg, top1.avg, top5.avg))

    if save:
        writer.add_scalar('loss/test', losses.avg, epoch)
        writer.add_scalar('acc/test_top1', top1.avg, epoch)
        writer.add_scalar('acc/test_top5', top5.avg, epoch)

        is_best = False
        if top1.avg > best_acc:
            best_acc = top1.avg
            is_best = True

        print('Current best acc: {}'.format(best_acc))
        save_checkpoint(
            {
                'epoch':
                epoch,
                'model':
                args.model,
                'dataset':
                args.dataset,
                'state_dict':
                net.module.state_dict()
                if isinstance(net, nn.DataParallel) else net.state_dict(),
                'acc':
                top1.avg,
                'optimizer':
                optimizer.state_dict(),
            },
            is_best,
            checkpoint_dir=log_dir)
예제 #2
0
    def infer(self, model, epoch=0):
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.eval()

        start = time.time()
        prefetcher = data_prefetcher(self.val_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            step += 1
            data_t = time.time() - start
            n = input.size(0)

            logits, logits_aux = model(input)
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step % self.report_freq == 0:
                logging.info('Val epoch %03d step %03d | top1_acc %.2f  top5_acc %.2f | batch_time %.3f  data_time %.3f', epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg)
            start = time.time()
            input, target = prefetcher.next()

        logging.info('EPOCH%d Valid_acc  top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
        return top1.avg, top5.avg, batch_time.avg, data_time.avg
예제 #3
0
def validate(val_queue, model, criterion):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # model.eval()
    # disable moving average
    model.train()

    for step, (x, target) in enumerate(val_queue):
        x = x.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        with torch.no_grad():
            logits, _ = model(x, sampling=True, mode='gumbel')
            loss = criterion(logits, target)
        # reset switches of log_alphas
        model.module.reset_switches()

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        n = x.size(0)
        objs.update(loss.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        if step % args.print_freq == 0:
            logging.info('VALIDATE Step: %04d Objs: %f R1: %f R5: %f', step,
                         objs.avg, top1.avg, top5.avg)

    return top1.avg
예제 #4
0
def validate(logger, writer, device, config, valid_loader, model, epoch, cur_step):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    model.eval()

    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            N = X.size(0)

            logits = model(X)
            loss = model.module.criterion(logits, y)

            prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.print_freq == 0 or step == len(valid_loader)-1:
                logger.info(
                    "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch+1, config.epochs, step, len(valid_loader)-1, losses=losses,
                        top1=top1, top5=top5))

    writer.add_scalar('val/loss', losses.avg, cur_step)
    writer.add_scalar('val/top1', top1.avg, cur_step)
    writer.add_scalar('val/top5', top5.avg, cur_step)

    logger.info("Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, config.epochs, top1.avg))

    return top1.avg
예제 #5
0
def validate(val_queue, model, criterion):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()

    for step, data in enumerate(val_queue):
        x = data[0].cuda(non_blocking=True)
        target = data[1].cuda(non_blocking=True)

        with torch.no_grad():
            logits = model(x)
            loss = criterion(logits, target)

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        n = x.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.print_freq == 0:
            duration = 0 if step == 0 else time.time() - duration_start
            duration_start = time.time()
            logging.info(
                'VALID Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', step,
                objs.avg, top1.avg, top5.avg, duration)

    return top1.avg, top5.avg, objs.avg
예제 #6
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        lr_scheduler.step()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
예제 #7
0
def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    data_time = utils.AvgrageMeter()
    batch_time = utils.AvgrageMeter()
    model.eval()

    start = time.time()
    for step, (input, target) in enumerate(valid_queue):
        data_t = time.time() - start

        input = input.cuda()
        target = target.cuda()
        n = input.size(0)

        logits = model(input)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

        batch_t = time.time() - start

        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        data_time.update(data_t)
        batch_time.update(batch_t)

        if step != 0 and step % args.report_freq == 0:
            logging.info(
                'Val step %03d | top1_acc %.2f  top5_acc %.2f | batch_time %.3f  data_time %.3f',
                step, top1.avg, top5.avg, batch_time.avg, data_time.avg)

        start = time.time()

    return top1.avg, top5.avg, objs.avg, batch_time.avg
예제 #8
0
    def train(self, model, epoch, optim_obj='Weights', search_stage=0):
        assert optim_obj in ['Weights', 'Arch']
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        sub_obj_avg = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.train()

        start = time.time()
        if optim_obj == 'Weights':
            prefetcher = data_prefetcher(self.train_data)
        elif optim_obj == 'Arch':
            prefetcher = data_prefetcher(self.val_data)

        input, target = prefetcher.next()
        step = 0
        while input is not None:
            input, target = input.cuda(), target.cuda()
            data_t = time.time() - start
            n = input.size(0)
            if optim_obj == 'Weights':
                self.scheduler.step()
                if step == 0:
                    logging.info(
                        'epoch %d weight_lr %e', epoch,
                        self.search_optim.weight_optimizer.param_groups[0]
                        ['lr'])
                logits, loss, sub_obj = self.search_optim.weight_step(
                    input, target, model, search_stage)
            elif optim_obj == 'Arch':
                if step == 0:
                    logging.info(
                        'epoch %d arch_lr %e', epoch,
                        self.search_optim.arch_optimizer.param_groups[0]['lr'])
                logits, loss, sub_obj = self.search_optim.arch_step(
                    input, target, model, search_stage)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            del logits, input, target

            batch_t = time.time() - start
            objs.update(loss, n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            sub_obj_avg.update(sub_obj)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step != 0 and step % self.args.report_freq == 0:
                logging.info(
                    'Train%s epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
                    optim_obj, epoch, step, objs.avg, self.sub_obj_type,
                    sub_obj_avg.avg, top1.avg, top5.avg, batch_time.avg,
                    data_time.avg)
            start = time.time()
            step += 1
            input, target = prefetcher.next()
        return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
예제 #9
0
def validate(val_loader, model, criterion):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()

    for step, data in enumerate(val_loader):
        x = data[0].cuda(non_blocking=True)
        target = data[1].cuda(non_blocking=True)

        with torch.no_grad():
            logits = model(x)
            loss = criterion(logits, target)

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        if args.distributed:
            reduced_loss = reduce_tensor(loss.data)
            prec1 = reduce_tensor(prec1)
            prec5 = reduce_tensor(prec5)
        else:
            reduced_loss = loss.data
        objs.update(reduced_loss.item(), x.size(0))
        top1.update(prec1.item(), x.size(0))
        top5.update(prec5.item(), x.size(0))

        if args.local_rank == 0 and step % args.print_freq == 0:
            duration = 0 if step == 0 else time.time() - duration_start
            duration_start = time.time()
            logging.info(
                'VALIDATE Step: %03d Objs: %e R1: %f R5: %f Duration: %ds',
                step, objs.avg, top1.avg, top5.avg, duration)

    return top1.avg, top5.avg, objs.avg
예제 #10
0
def train(train_loader, model, criterion, optimizer):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()

    end = time.time()
    for step, data in enumerate(train_loader):
        data_time.update(time.time() - end)
        x = data[0].cuda(non_blocking=True)
        target = data[1].cuda(non_blocking=True)

        # forward
        batch_start = time.time()
        logits = model(x)
        loss = criterion(logits, target)

        # backward
        optimizer.zero_grad()
        if args.opt_level is not None:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.grad_clip)
        optimizer.step()
        batch_time.update(time.time() - batch_start)

        if step % args.print_freq == 0:
            # For better performance, don't accumulate these metrics every iteration,
            # since they may incur an allreduce and some host<->device syncs.
            prec1, prec5 = accuracy(logits, target, topk=(1, 5))
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data)
                prec1 = reduce_tensor(prec1)
                prec5 = reduce_tensor(prec5)
            else:
                reduced_loss = loss.data
            objs.update(reduced_loss.item(), x.size(0))
            top1.update(prec1.item(), x.size(0))
            top5.update(prec5.item(), x.size(0))
            torch.cuda.synchronize()

            duration = 0 if step == 0 else time.time() - duration_start
            duration_start = time.time()
            if args.local_rank == 0:
                logging.info(
                    'TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs DTime: %.4fs',
                    step, objs.avg, top1.avg, top5.avg, duration,
                    batch_time.avg, data_time.avg)
        end = time.time()

    return top1.avg, objs.avg
예제 #11
0
def train(logger, writer, device, config, train_loader, model, optimizer,
          criterion, epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)
    cur_lr = optimizer.param_groups[0]['lr']
    logger.info("Epoch {} LR {}".format(epoch, cur_lr))
    writer.add_scalar('train/lr', cur_lr, cur_step)

    model.train()

    for step, (X, y) in enumerate(train_loader):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
        N = X.size(0)

        optimizer.zero_grad()
        logits, aux_logits = model(X)
        loss = criterion(logits, y)
        if config.aux_weight > 0.:
            loss += config.aux_weight * criterion(aux_logits, y)
        loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        if config.dist_privacy:
            # privice gradient clipping
            clipping_dispatcher(model.module.named_weights(),
                                config.max_weights_grad_norm, config.var_gamma,
                                device, logger)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step % config.print_freq == 0 or step == len(train_loader) - 1:
            logger.info(
                "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    len(train_loader) - 1,
                    losses=losses,
                    top1=top1,
                    top5=top5))

        writer.add_scalar('train/loss', loss.item(), cur_step)
        writer.add_scalar('train/top1', prec1.item(), cur_step)
        writer.add_scalar('train/top5', prec5.item(), cur_step)
        cur_step += 1

    logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, config.epochs, top1.avg))
예제 #12
0
    def train(self, model, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.train()
        start = time.time()

        prefetcher = data_prefetcher(self.train_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            data_t = time.time() - start
            self.scheduler.step()
            n = input.size(0)
            if step == 0:
                logging.info('epoch %d lr %e', epoch,
                             self.optimizer.param_groups[0]['lr'])
            self.optimizer.zero_grad()

            logits = model(input)
            if self.config.optim.label_smooth:
                loss = self.criterion(logits, target,
                                      self.config.optim.smooth_alpha)
            else:
                loss = self.criterion(logits, target)

            loss.backward()
            if self.config.optim.use_grad_clip:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         self.config.optim.grad_clip)
            self.optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            start = time.time()

            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            data_time.update(data_t)
            batch_time.update(batch_t)
            if step != 0 and step % self.report_freq == 0:
                logging.info(
                    'Train epoch %03d step %03d | loss %.4f  top1_acc %.2f  top5_acc %.2f | batch_time %.3f  data_time %.3f',
                    epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg,
                    data_time.avg)
            input, target = prefetcher.next()
            step += 1
        logging.info(
            'EPOCH%d Train_acc  top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
            epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)

        return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg
예제 #13
0
def train(logger, writer, device, config, train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch*len(train_loader)
    writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        trn_X, trn_y = trn_X.to(device, non_blocking=True), trn_y.to(device, non_blocking=True)
        val_X, val_y = val_X.to(device, non_blocking=True), val_y.to(device, non_blocking=True)
        N = trn_X.size(0)

        # phase 2. architect step (alpha)
        alpha_optim.zero_grad()
        architect.unrolled_backward(config, trn_X, trn_y, val_X, val_y, lr, w_optim)
        alpha_optim.step()

        # phase 1. child network step (w)
        w_optim.zero_grad()
        logits = model(trn_X)
        loss = model.module.criterion(logits, trn_y)
        loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.module.weights(), config.w_grad_clip)
        if config.dist_privacy:
            # privice gradient clipping
            clipping_dispatcher(model.module.named_weights(),
                                config.max_weights_grad_norm,
                                config.var_gamma,
                                device,
                                logger
                                )
        w_optim.step()

        prec1, prec5 = utils.accuracy(logits, trn_y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step % config.print_freq == 0 or step == len(train_loader)-1:
            logger.info(
                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch+1, config.epochs, step, len(train_loader)-1, losses=losses,
                    top1=top1, top5=top5))

        writer.add_scalar('train/loss', loss.item(), cur_step)
        writer.add_scalar('train/top1', prec1.item(), cur_step)
        writer.add_scalar('train/top5', prec5.item(), cur_step)
        cur_step += 1

    logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, config.epochs, top1.avg))
def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead,
             F2: ImageClassifierHead,
             args: argparse.Namespace) -> Tuple[float, float]:
    batch_time = AverageMeter('Time', ':6.3f')
    top1_1 = AverageMeter('Acc_1', ':6.2f')
    top1_2 = AverageMeter('Acc_2', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, top1_1, top1_2],
                             prefix='Test: ')

    # switch to evaluate mode
    G.eval()
    F1.eval()
    F2.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            g = G(images)
            y1, y2 = F1(g), F2(g)

            # measure accuracy and record loss
            acc1, = accuracy(y1, target)
            acc2, = accuracy(y2, target)
            top1_1.update(acc1[0], images.size(0))
            top1_2.update(acc2[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'.format(
            top1_1=top1_1, top1_2=top1_2))

    return top1_1.avg, top1_2.avg
예제 #15
0
    def forward_loss(self, out, label, **kwargs):
        loss = self.criterion(out, label)
        acc = accuracy(
            out,
            label)  # notice: out should be reshaped to (n*c) and label to (n)

        # log your train status on your screen
        self.logger.update(
            'Loss', loss.data[0], out.size(0)
        )  # out.size(0) is the number of this batch data for calculate average loss
        self.logger.update('acc', acc[0].data[0], out.size(0))
        return loss
def train(train_source_iter: ForeverDataIterator, model: Classifier,
          optimizer: SGD, lr_sheduler: StepwiseLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses, cls_accs],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        if lr_sheduler is not None:
            lr_sheduler.step()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device)

        # compute output
        y_s, f_s = model(x_s)

        cls_loss = F.cross_entropy(y_s, labels_s)
        loss = cls_loss

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
예제 #17
0
def validate(val_queue, model):
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()

    for data in tqdm.tqdm(val_queue):
        x = data[0].cuda(non_blocking=True)
        target = data[1].cuda(non_blocking=True)

        with torch.no_grad():
            logits = model(x)

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        n = x.size(0)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

    return top1.avg, top5.avg
예제 #18
0
def train(train_queue, model, criterion, optimizer):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()

    end = time.time()
    for step, data in enumerate(train_queue):
        data_time.update(time.time() - end)
        x = data[0].cuda(non_blocking=True)
        target = data[1].cuda(non_blocking=True)

        # forward
        batch_start = time.time()
        logits = model(x)
        loss = criterion(logits, target)

        # backward
        optimizer.zero_grad()
        loss.backward()
        if args.grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        batch_time.update(time.time() - batch_start)

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        n = x.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.print_freq == 0:
            duration = 0 if step == 0 else time.time() - duration_start
            duration_start = time.time()
            logging.info(
                'TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs DTime: %.4fs',
                step, objs.avg, top1.avg, top5.avg, duration, batch_time.avg,
                data_time.avg)
        end = time.time()

    return top1.avg, objs.avg
예제 #19
0
def train_ssl(inferred_dataloader: DataLoader, model: ImageClassifier,
              optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int,
              args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(inferred_dataloader),
                             [batch_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (x, labels) in enumerate(inferred_dataloader):
        lr_scheduler.step()

        x = x.to(device)
        labels = labels.to(device)

        # compute output
        output, _ = model(x)
        loss = F.cross_entropy(output, labels)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        losses.update(loss.item(), x.size(0))
        top1.update(acc1[0], x.size(0))
        top5.update(acc5[0], x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
예제 #20
0
def validate(val_loader: DataLoader, model: ImageClassifier, args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg
예제 #21
0
def train(data_loader, model, criterion, optimizer, epoch, stage, logger,
          args):
    loss_avg = utils.AverageMeter()
    top1_res = utils.AverageMeter()
    top5_res = utils.AverageMeter()
    global_step = epoch * len(data_loader)
    model.train()

    logger.log("stage: {}".format(stage))

    for step, (images, labels) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        num_samples = images.size(0)
        optimizer.zero_grad()

        logits = model(images)
        loss = criterion(logits, labels)
        prec1_res, prec5_res = utils.accuracy(logits.detach(),
                                              labels,
                                              topk=(1, 5))
        top1_res.update(prec1_res.item(), num_samples)
        top5_res.update(prec5_res.item(), num_samples)
        loss_avg.update(loss.detach().data.item(), num_samples)

        loss.backward()
        optimizer.step()

        epochs = args.baseline_epochs
        if step % 100 == 0 or step == len(data_loader) - 1:
            logger.log("Train, Epoch: [{:3d}/{}], Step: [{:3d}/{}], " \
                        "Loss: {:.4f}, Prec@(res1, res5):  {:.4%}, {:.4%}".format(
                            epoch, epochs, step, len(data_loader),
                            loss_avg.avg,   top1_res.avg,  top5_res.avg))

        global_step += 1
    logger.log("Train, Epoch: [{:3d}/{}], Step: [{:3d}/{}], " \
                        "Loss: {:.4f}, Prec@(res1, res5):  {:.4%}, {:.4%}".format(
                            epoch, epochs, step, len(data_loader),
                            loss_avg.avg, top1_res.avg,  top5_res.avg))
예제 #22
0
    def infer(self, model, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        sub_obj_avg = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()

        model.train()  # don't use running_mean and running_var during search
        start = time.time()
        prefetcher = data_prefetcher(self.val_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            step += 1
            data_t = time.time() - start
            n = input.size(0)

            logits, loss, sub_obj = self.search_optim.valid_step(
                input, target, model)
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            objs.update(loss, n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            sub_obj_avg.update(sub_obj)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step % self.args.report_freq == 0:
                logging.info(
                    'Val epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
                    epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg,
                    top1.avg, top5.avg, batch_time.avg, data_time.avg)
            start = time.time()
            input, target = prefetcher.next()

        return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
예제 #23
0
def train_wo_arch(train_queue, model, criterion, optimizer_w):
    objs_w = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.train()

    for param in model.module.weight_parameters():
        param.requires_grad = True
    for param in model.module.arch_parameters():
        param.requires_grad = False

    for step, (x_w, target_w) in enumerate(train_queue):
        x_w = x_w.cuda(non_blocking=True)
        target_w = target_w.cuda(non_blocking=True)

        logits_w_gumbel, _ = model(x_w, sampling=True, mode='gumbel')
        loss_w_gumbel = criterion(logits_w_gumbel, target_w)
        # reset switches of log_alphas
        model.module.reset_switches()

        optimizer_w.zero_grad()
        loss_w_gumbel.backward()
        if args.grad_clip > 0:
            nn.utils.clip_grad_norm_(model.module.weight_parameters(),
                                     args.grad_clip)
        optimizer_w.step()

        prec1, prec5 = accuracy(logits_w_gumbel, target_w, topk=(1, 5))
        n = x_w.size(0)
        objs_w.update(loss_w_gumbel.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        if step % args.print_freq == 0:
            logging.info('TRAIN wo_Arch Step: %04d Objs: %f R1: %f R5: %f',
                         step, objs_w.avg, top1.avg, top5.avg)

    return top1.avg
예제 #24
0
def train(epoch, train_loader):
    print('\nEpoch: %d' % epoch)
    net.train()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        # timing
        batch_time.update(time.time() - end)
        end = time.time()

        progress_bar(
            batch_idx, len(train_loader),
            'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'.format(
                losses.avg, top1.avg, top5.avg))
    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('acc/train_top1', top1.avg, epoch)
    writer.add_scalar('acc/train_top5', top5.avg, epoch)
예제 #25
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
          lr_sheduler: StepwiseLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    jmmd_loss.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        lr_sheduler.step()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)),
                                  (f_t, F.softmax(y_t, dim=1)))
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
예제 #26
0
파일: run.py 프로젝트: kc-ml2/darts
def train(train_loader, valid_loader, model, arch, w_optim, alpha_optim, lr,
          epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)
    tb_writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    for step, ((train_X, train_y),
               (valid_X, valid_y)) in enumerate(zip(train_loader,
                                                    valid_loader)):
        train_X, train_y = train_X.to(device, non_blocking=True), train_y.to(
            device, non_blocking=True)
        valid_X, valid_y = valid_X.to(device, non_blocking=True), valid_y.to(
            device, non_blocking=True)
        N = train_X.size(0)

        # arch step (alpha training)
        alpha_optim.zero_grad()
        arch.unrolled_backward(train_X, train_y, valid_X, valid_y, lr, w_optim)
        alpha_optim.step()

        # child network step (w)
        w_optim.zero_grad()
        logits = model(train_X)
        loss = model.criterion(logits, train_y)
        loss.backward()

        # gradient clipping
        nn.utils.clip_grad_norm_(model.weights(), config.w_grad_clip)
        w_optim.step()

        prec1, prec5 = utils.accuracy(logits, train_y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step % config.print_freq == 0 or step == len(train_loader) - 1:
            print("\r", end="", flush=True)
            logger.info(
                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    len(train_loader) - 1,
                    losses=losses,
                    top1=top1,
                    top5=top5))
        else:
            print(
                "\rTrain: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    len(train_loader) - 1,
                    losses=losses,
                    top1=top1,
                    top5=top5),
                end="",
                flush=True)

        tb_writer.add_scalar('train/loss', loss.item(), cur_step)
        tb_writer.add_scalar('train/top1', prec1.item(), cur_step)
        tb_writer.add_scalar('train/top5', prec5.item(), cur_step)

        if step % (config.print_freq // 5) == 0 or step == len(
                train_loader) - 1:  # not too much logging
            for i, tensor in enumerate(model.alpha_normal):
                for j, lsn in enumerate(F.softmax(tensor, dim=-1)):
                    tb_writer.add_scalars(
                        'alpha_normal/%d ~~ %d' % ((j - 2), i), {
                            'max_pl3': lsn[0],
                            'avg_pl3': lsn[1],
                            'skip_cn': lsn[2],
                            'sep_conv3': lsn[3],
                            'sep_conv5': lsn[4],
                            'dil_conv3': lsn[5],
                            'dil_conv5': lsn[6],
                            'none': lsn[7]
                        }, cur_step)
            for i, tensor in enumerate(model.alpha_reduce):
                for j, lsr in enumerate(F.softmax(tensor, dim=-1)):
                    tb_writer.add_scalars(
                        'alpha_reduce/%d ~~ %d' % ((j - 2), i), {
                            'max_pl3': lsr[0],
                            'avg_pl3': lsr[1],
                            'skip_cn': lsr[2],
                            'sep_conv3': lsr[3],
                            'sep_conv5': lsr[4],
                            'dil_conv3': lsr[5],
                            'dil_conv5': lsr[6],
                            'none': lsr[7]
                        }, cur_step)

        cur_step += 1

    logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, config.epochs, top1.avg))
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, classifier: ImageClassifier,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    classifier.train()
    mdd.train()

    criterion = nn.CrossEntropyLoss().to(device)

    end = time.time()
    for i in range(args.iters_per_epoch):
        lr_scheduler.step()
        optimizer.zero_grad()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        outputs, outputs_adv = classifier(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute cross entropy loss on source domain
        cls_loss = criterion(y_s, labels_s)
        # compute margin disparity discrepancy between domains
        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
        loss = cls_loss + transfer_loss * args.trade_off
        classifier.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
예제 #28
0
def train(data_loader, model, criterion, optimizer_t, optimizer_s, epoch,
          stage, logger, args):

    [loss_avg, mse_avg, top1_cnn, top5_cnn, top1_res,
     top5_res] = [utils.AverageMeter() for _ in range(6)]
    global_step = epoch * len(data_loader)
    model.train()
    logger.log("stage: {}".format(stage))
    m = Cosine(min_v=args.dc, max_v=1.0, epoch=epoch, epoch_max=60)
    #m = 1.0
    model.module.reset_margin()
    for step, (images, labels) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        num_samples = images.size(0)
        if optimizer_t is not None:
            optimizer_t.zero_grad()
        if optimizer_s is not None:
            optimizer_s.zero_grad()

        if "TA" in stage:

            ###train teacher#####################
            model.module.teacher.eval()
            logits_teacher, teacher_feas = model(images,
                                                 stage='RES_TA',
                                                 epoch=epoch)
            #logits_teacher, _ = model(images, stage='RES_TA', epoch=epoch)
            model.module.teacher.eval()
            #####################################

            logits_student, _, loss_dis = model(images,
                                                stage=stage,
                                                epoch=epoch,
                                                teacher_feas=teacher_feas[-1])
            loss = 0

            loss_last = criterion(logits_student[-1], labels)
            loss_avg.update(loss_last.detach().item(), num_samples)
            loss += loss_last
            loss += loss_dis[-1].mean() * args.dis_weight
            #10^-3 for 32x32 image
            #10^-4 for 224x224 scale classification task
            #10^-5 for detection and segmentation task
            if isinstance(logits_student, list):
                prec1_cnn, prec5_cnn = utils.accuracy(
                    logits_student[-1].detach(), labels, topk=(1, 5))
            else:
                prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                      labels,
                                                      topk=(1, 5))

            prec1_cnn, prec5_cnn = utils.accuracy(logits_student[-1].detach(),
                                                  labels,
                                                  topk=(1, 5))
            prec1_res, prec5_res = utils.accuracy(logits_teacher.detach(),
                                                  labels,
                                                  topk=(1, 5))
            ### teacher is only updated by its own loss

            loss.backward()
            optimizer_s.step()
            top1_cnn.update(prec1_cnn.item(), num_samples)
            top5_cnn.update(prec5_cnn.item(), num_samples)
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)

        elif "KD" in stage:

            ###train teacher#####################
            model.module.teacher.eval()
            logits_teacher, teacher_feas = model(images,
                                                 stage='RES_TA',
                                                 epoch=epoch)
            #logits_teacher, _ = model(images, stage='RES_TA', epoch=epoch)
            model.module.teacher.eval()
            #####################################

            logits_student, _, loss_dis = model(images,
                                                stage=stage,
                                                epoch=epoch,
                                                teacher_feas=teacher_feas[-1])
            loss = 0
            loss += criterion(logits_student[-1], labels)
            loss_avg.update(loss.detach().item(), num_samples)
            if loss_dis is not None:
                for loss_d in loss_dis[:-1]:
                    loss += loss_d.mean() * m * args.dis_weight
                mse_avg.update(loss_dis[-1].detach().mean().item(),
                               num_samples)
                loss += loss_dis[-1].mean() * args.dis_weight

            #10^-3 for 32x32 image
            #10^-4 for 224x224 scale classification task
            #10^-5 for detection and segmentation task
            if isinstance(logits_student, list):
                prec1_cnn, prec5_cnn = utils.accuracy(
                    logits_student[-1].detach(), labels, topk=(1, 5))
            else:
                prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                      labels,
                                                      topk=(1, 5))

            prec1_cnn, prec5_cnn = utils.accuracy(logits_student[-1].detach(),
                                                  labels,
                                                  topk=(1, 5))
            prec1_res, prec5_res = utils.accuracy(logits_teacher.detach(),
                                                  labels,
                                                  topk=(1, 5))
            ### teacher is only updated by its own loss

            loss.backward()
            optimizer_s.step()
            top1_cnn.update(prec1_cnn.item(), num_samples)
            top5_cnn.update(prec5_cnn.item(), num_samples)
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)

        elif "KL" in stage:

            ###train teacher#####################
            model.module.teacher.eval()
            logits_teacher = model(images, stage='RES_NMT', epoch=epoch)
            #logits_teacher, _ = model(images, stage='RES_TA', epoch=epoch)
            model.module.teacher.eval()
            #####################################

            logits_student = model(images, stage="CNN_NMT", epoch=epoch)
            loss = loss_KD_fn(criterion,
                              logits_student,
                              logits_teacher,
                              targets=labels,
                              alpha=args.alpha,
                              temperature=args.temperature)
            #10^-3 for 32x32 image
            #10^-4 for 224x224 scale classification task
            #10^-5 for detection and segmentation task
            if isinstance(logits_student, list):
                prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                      labels,
                                                      topk=(1, 5))
            else:
                prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                      labels,
                                                      topk=(1, 5))

            prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                  labels,
                                                  topk=(1, 5))
            prec1_res, prec5_res = utils.accuracy(logits_teacher.detach(),
                                                  labels,
                                                  topk=(1, 5))
            ### teacher is only updated by its own loss

            loss.backward()
            optimizer_s.step()
            top1_cnn.update(prec1_cnn.item(), num_samples)
            top5_cnn.update(prec5_cnn.item(), num_samples)
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)

        elif "JOINT" in stage:
            ## teacher and student are jointly trained from scratch

            ###train teacher#####################

            model.module.teacher.train()
            optimizer_t.zero_grad()
            logits_teacher, teacher_feas = model(images,
                                                 stage='RES_TA',
                                                 epoch=epoch)
            #logits_teacher, _ = model(images, stage='RES_TA', epoch=epoch)
            loss_teacher = criterion(logits_teacher, labels)
            loss_teacher.backward()
            optimizer_t.step()
            model.module.teacher.eval()
            #####################################

            logits_student, _, loss_dis = model(images,
                                                stage=stage,
                                                epoch=epoch,
                                                teacher_feas=teacher_feas[-1])
            loss = 0
            xishu = 1.0 / 4.
            for logit_student in logits_student[:-1]:
                KD_TRAIN = False
                if KD_TRAIN:
                    loss += loss_KD_fn(
                        criterion,
                        logit_student,
                        logits_teacher,
                        targets=labels,
                        alpha=args.alpha,
                        temperature=args.temperature) * m * xishu
                else:
                    loss += criterion(logit_student, labels) * m * xishu
            loss_last = criterion(logits_student[-1], labels) * xishu
            loss_avg.update(loss_last.detach().item(), num_samples)
            loss += loss_last

            if loss_dis is not None:
                for loss_d in loss_dis[:-1]:
                    loss += loss_d.mean() * m * xishu * args.dis_weight
                mse_avg.update(loss_dis[-1].detach().mean().item(),
                               num_samples)
                loss += loss_dis[-1].mean() * args.dis_weight * xishu

            #10^-3 for 32x32 image
            #10^-4 for 224x224 scale classification task
            #10^-5 for detection and segmentation task
            if isinstance(logits_student, list):
                prec1_cnn, prec5_cnn = utils.accuracy(
                    logits_student[-1].detach(), labels, topk=(1, 5))
            else:
                prec1_cnn, prec5_cnn = utils.accuracy(logits_student.detach(),
                                                      labels,
                                                      topk=(1, 5))

            prec1_cnn, prec5_cnn = utils.accuracy(logits_student[-1].detach(),
                                                  labels,
                                                  topk=(1, 5))
            prec1_res, prec5_res = utils.accuracy(logits_teacher.detach(),
                                                  labels,
                                                  topk=(1, 5))
            ### teacher is only updated by its own loss
            loss.backward()
            #for n, v in model.named_parameters():
            #    print(n)
            #    print(v.grad.mean())
            #pdb.set_trace()
            optimizer_s.step()
            top1_cnn.update(prec1_cnn.item(), num_samples)
            top5_cnn.update(prec5_cnn.item(), num_samples)
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)

        elif "RES_NMT" in stage:
            logits = model(images, stage='RES_NMT')
            loss = criterion(logits, labels)
            prec1_res, prec5_res = utils.accuracy(logits.detach(),
                                                  labels,
                                                  topk=(1, 5))
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)
            loss_avg.update(loss.detach().data.item(), num_samples)
            loss.backward()
            optimizer_t.step()

        elif "CNN_NMT" in stage:
            logits = model(images, stage=stage)
            loss = criterion(logits, labels)
            prec1_cnn, prec5_cnn = utils.accuracy(logits.detach(),
                                                  labels,
                                                  topk=(1, 5))
            top1_cnn.update(prec1_cnn.item(), num_samples)
            top5_cnn.update(prec5_cnn.item(), num_samples)
            loss_avg.update(loss.detach().data.item(), num_samples)
            loss.backward()
            optimizer_s.step()

        elif "RES_KD" in stage:
            logit_student, logits_teacher = model(images, stage=stage)
            loss = loss_KD_fn(criterion,
                              logit_student,
                              logits_teacher,
                              targets=labels,
                              alpha=args.alpha,
                              temperature=args.temperature)
            prec1_res, prec5_res = utils.accuracy(logit_student.detach(),
                                                  labels,
                                                  topk=(1, 5))
            top1_res.update(prec1_res.item(), num_samples)
            top5_res.update(prec5_res.item(), num_samples)
            loss_avg.update(loss.detach().data.item(), num_samples)

        else:
            raise NameError("invalide stage nanme")

        epochs = args.baseline_epochs
        if step % 100 == 0 or step == len(data_loader) - 1:
            logger.log("Train, Epoch: [{:3d}/{}], Step: [{:3d}/{}], " \
                        "Loss: {:.4f}, Loss_dis: {:.4f}, Prec@(cnn1, res1, cnn5, res5): {:.4%},{:.4%}, {:.4%}, {:.4%}".format(
                            epoch, epochs, step, len(data_loader),
                            loss_avg.avg, mse_avg.avg, top1_cnn.avg, top1_res.avg, top5_cnn.avg, top5_res.avg))

        global_step += 1
    logger.log("m is {}".format(m))
    logger.log(
        "Train, Epoch: [{:3d}/{}], Final Prec: cnn, res@1: {:.4%}, {:.4%},  Final Prec: cnn, res@5: {:.4%}, {:.4%} Loss: {:.4f}"
        .format(epoch, epochs, top1_cnn.avg, top1_res.avg, top5_cnn.avg,
                top5_res.avg, loss_avg.avg))
예제 #29
0
def train_w_arch(train_queue, val_queue, model, criterion, optimizer_w,
                 optimizer_a):
    objs_a = AverageMeter()
    objs_l = AverageMeter()
    objs_w = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.train()

    for step, (x_w, target_w) in enumerate(train_queue):
        x_w = x_w.cuda(non_blocking=True)
        target_w = target_w.cuda(non_blocking=True)

        for param in model.module.weight_parameters():
            param.requires_grad = True
        for param in model.module.arch_parameters():
            param.requires_grad = False

        logits_w_gumbel, _ = model(x_w, sampling=True, mode='gumbel')
        loss_w_gumbel = criterion(logits_w_gumbel, target_w)
        logits_w_random, _ = model(x_w, sampling=True, mode='random')
        loss_w_random = criterion(logits_w_random, target_w)
        loss_w = loss_w_gumbel + loss_w_random

        optimizer_w.zero_grad()
        loss_w.backward()
        if args.grad_clip > 0:
            nn.utils.clip_grad_norm_(model.module.weight_parameters(),
                                     args.grad_clip)
        optimizer_w.step()

        prec1, prec5 = accuracy(logits_w_gumbel, target_w, topk=(1, 5))
        n = x_w.size(0)
        objs_w.update(loss_w.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        if step % 2 == 0:
            # optimize a
            try:
                x_a, target_a = next(val_queue_iter)
            except:
                val_queue_iter = iter(val_queue)
                x_a, target_a = next(val_queue_iter)

            x_a = x_a.cuda(non_blocking=True)
            target_a = target_a.cuda(non_blocking=True)

            for param in model.module.weight_parameters():
                param.requires_grad = False
            for param in model.module.arch_parameters():
                param.requires_grad = True

            logits_a, lat = model(x_a, sampling=False)
            loss_a = criterion(logits_a, target_a)
            loss_l = torch.abs(lat / args.target_lat - 1.) * args.lambda_lat
            loss = loss_a + loss_l

            optimizer_a.zero_grad()
            loss.backward()
            if args.grad_clip > 0:
                nn.utils.clip_grad_norm_(model.module.arch_parameters(),
                                         args.grad_clip)
            optimizer_a.step()

            # ensure log_alphas to be a log probability distribution
            for log_alphas in model.module.arch_parameters():
                log_alphas.data = F.log_softmax(log_alphas.detach().data,
                                                dim=-1)

            n = x_a.size(0)
            objs_a.update(loss_a.item(), n)
            objs_l.update(loss_l.item(), n)

        if step % args.print_freq == 0:
            logging.info(
                'TRAIN w_Arch Step: %04d Objs_W: %f R1: %f R5: %f Objs_A: %f Objs_L: %f',
                step, objs_w.avg, top1.avg, top5.avg, objs_a.avg, objs_l.avg)

    return top1.avg
예제 #30
0
def valid(data_loader, model, criterion, epoch, global_step, stage, logger,
          args):

    loss_avg = utils.AverageMeter()
    top1_cnn = utils.AverageMeter()
    top5_cnn = utils.AverageMeter()
    top1_res = utils.AverageMeter()
    top5_res = utils.AverageMeter()
    global_step = epoch * len(data_loader)

    model.eval()
    logger.log("stage: {}".format(stage))
    with torch.no_grad():
        for step, (images, labels) in enumerate(data_loader):

            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            num_samples = images.size(0)

            if "TA" in stage or "JOINT" in stage or "KD" in stage or "KL" in stage:
                with torch.no_grad():
                    logits = model(images, stage='CNN_NMT')
                    logits_teacher = model(images, stage='RES_NMT')
                    prec1_cnn, prec5_cnn = utils.accuracy(logits.detach(),
                                                          labels,
                                                          topk=(1, 5))
                    prec1_res, prec5_res = utils.accuracy(
                        logits_teacher.detach(), labels, topk=(1, 5))
                    loss = criterion(logits, labels)
                loss_avg.update(loss.detach().item(), num_samples)
                top1_cnn.update(prec1_cnn.item(), num_samples)
                top5_cnn.update(prec5_cnn.item(), num_samples)
                top1_res.update(prec1_res.item(), num_samples)
                top5_res.update(prec5_res.item(), num_samples)

            elif "RES_NMT" in stage:
                logits = model(images, stage=stage)
                loss = criterion(logits, labels)
                prec1_res, prec5_res = utils.accuracy(logits,
                                                      labels,
                                                      topk=(1, 5))
                top1_res.update(prec1_res.item(), num_samples)
                top5_res.update(prec5_res.item(), num_samples)
                loss_avg.update(loss.data.item(), num_samples)

            elif "CNN_NMT" in stage:
                logits = model(images, stage=stage)
                loss = criterion(logits, labels)
                prec1_cnn, prec5_cnn = utils.accuracy(logits,
                                                      labels,
                                                      topk=(1, 5))
                top1_cnn.update(prec1_cnn.item(), num_samples)
                top5_cnn.update(prec5_cnn.item(), num_samples)
                loss_avg.update(loss.data.item(), num_samples)

            elif "RES_KD" in stage:
                logit_student, logits_teacher = model(images, stage=stage)
                loss = loss_KD_fn(criterion,
                                  logit_student,
                                  logits_teacher,
                                  targets=labels,
                                  alpha=args.alpha,
                                  temperature=args.temperature)
                prec1_res, prec5_res = utils.accuracy(logit_student.detach(),
                                                      labels,
                                                      topk=(1, 5))
                top1_res.update(prec1_res.item(), num_samples)
                top5_res.update(prec5_res.item(), num_samples)
                loss_avg.update(loss.detach().data.item(), num_samples)
            else:
                raise NameError("invalide stage nanme")

            epochs = args.baseline_epochs
            if step % 100 == 0 or step == len(data_loader) - 1:
                logger.log("Valid, Epoch: [{:3d}/{}], Step: [{:3d}/{}], " \
                            "Loss: {:.4f}, Prec@(cnn1, res1, cnn5, res5): {:.4%},{:.4%}, {:.4%}, {:.4%}".format(
                                epoch, epochs, step, len(data_loader),
                                loss_avg.avg, top1_cnn.avg, top1_res.avg, top5_cnn.avg, top5_res.avg))

            global_step += 1

        logger.log(
            "Valid, Epoch: [{:3d}/{}], Final Prec: cnn, res@1: {:.4%}, {:.4%},  Final Prec: cnn, res@5: {:.4%}, {:.4%} Loss: {:.4f}"
            .format(epoch, epochs, top1_cnn.avg, top1_res.avg, top5_cnn.avg,
                    top5_res.avg, loss_avg.avg))

        if "RES" in stage:
            return top1_res.avg
        else:
            return top1_cnn.avg