Ejemplo n.º 1
0
    def train(self, model, epoch, optim_obj='Weights', search_stage=0):
        assert optim_obj in ['Weights', 'Arch']
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        sub_obj_avg = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.train()

        start = time.time()
        if optim_obj == 'Weights':
            prefetcher = data_prefetcher(self.train_data)
        elif optim_obj == 'Arch':
            prefetcher = data_prefetcher(self.val_data)

        input, target = prefetcher.next()
        step = 0
        while input is not None:
            input, target = input.cuda(), target.cuda()
            data_t = time.time() - start
            n = input.size(0)
            if optim_obj == 'Weights':
                self.scheduler.step()
                if step == 0:
                    logging.info(
                        'epoch %d weight_lr %e', epoch,
                        self.search_optim.weight_optimizer.param_groups[0]
                        ['lr'])
                logits, loss, sub_obj = self.search_optim.weight_step(
                    input, target, model, search_stage)
            elif optim_obj == 'Arch':
                if step == 0:
                    logging.info(
                        'epoch %d arch_lr %e', epoch,
                        self.search_optim.arch_optimizer.param_groups[0]['lr'])
                logits, loss, sub_obj = self.search_optim.arch_step(
                    input, target, model, search_stage)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            del logits, input, target

            batch_t = time.time() - start
            objs.update(loss, n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            sub_obj_avg.update(sub_obj)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step != 0 and step % self.args.report_freq == 0:
                logging.info(
                    'Train%s epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
                    optim_obj, epoch, step, objs.avg, self.sub_obj_type,
                    sub_obj_avg.avg, top1.avg, top5.avg, batch_time.avg,
                    data_time.avg)
            start = time.time()
            step += 1
            input, target = prefetcher.next()
        return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
Ejemplo n.º 2
0
    def infer(self, model, epoch=0):
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.eval()

        start = time.time()
        prefetcher = data_prefetcher(self.val_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            step += 1
            data_t = time.time() - start
            n = input.size(0)

            logits, logits_aux = model(input)
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step % self.report_freq == 0:
                logging.info('Val epoch %03d step %03d | top1_acc %.2f  top5_acc %.2f | batch_time %.3f  data_time %.3f', epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg)
            start = time.time()
            input, target = prefetcher.next()

        logging.info('EPOCH%d Valid_acc  top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
        return top1.avg, top5.avg, batch_time.avg, data_time.avg
Ejemplo n.º 3
0
    def train(self, model, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()
        model.train()
        start = time.time()

        prefetcher = data_prefetcher(self.train_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            data_t = time.time() - start
            self.scheduler.step()
            n = input.size(0)
            if step == 0:
                logging.info('epoch %d lr %e', epoch,
                             self.optimizer.param_groups[0]['lr'])
            self.optimizer.zero_grad()

            logits = model(input)
            if self.config.optim.label_smooth:
                loss = self.criterion(logits, target,
                                      self.config.optim.smooth_alpha)
            else:
                loss = self.criterion(logits, target)

            loss.backward()
            if self.config.optim.use_grad_clip:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         self.config.optim.grad_clip)
            self.optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            start = time.time()

            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            data_time.update(data_t)
            batch_time.update(batch_t)
            if step != 0 and step % self.report_freq == 0:
                logging.info(
                    'Train epoch %03d step %03d | loss %.4f  top1_acc %.2f  top5_acc %.2f | batch_time %.3f  data_time %.3f',
                    epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg,
                    data_time.avg)
            input, target = prefetcher.next()
            step += 1
        logging.info(
            'EPOCH%d Train_acc  top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
            epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)

        return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg
Ejemplo n.º 4
0
def train_with_distill(train_loader, d_net, optimizer, epoch):
    d_net.train()
    d_net.module.s_net.train()
    d_net.module.t_net.train()

    train_loss = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    btic = time.time()

    prefetcher = data_prefetcher(train_loader, is_sample=False)
    inputs, targets = prefetcher.next()

    step = 0
    while inputs is not None:

        batch_size = inputs.size(0)
        if step == 0:
            print('epoch %d lr %e' % (epoch, optimizer.param_groups[0]['lr']))
        optimizer.zero_grad()

        outputs, loss = d_net(inputs, targets)

        loss = torch.mean(loss)
        err1, err5 = accuracy(outputs.data, targets, topk=(1, 5))

        train_loss.update(loss.item(), batch_size)
        top1.update(err1.item(), batch_size)
        top5.update(err5.item(), batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % config.train_params.print_freq == 0:
            speed = config.train_params.print_freq * config.data.batch_size / (time.time() - btic)
            print(
                'Train with distillation: [Epoch %d/%d][Batch %d/%d]\t, speed %.3f, Loss %.3f, Top 1-error %.3f, Top 5-error %.3f' %
                (epoch, config.train_params.epochs, step, len(train_loader), speed, train_loss.avg, top1.avg, top5.avg))
            btic = time.time()

        inputs, targets = prefetcher.next()

        step += 1
Ejemplo n.º 5
0
def validate(val_loader, model, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()

    prefetcher = data_prefetcher(val_loader)
    input, target = prefetcher.next()
    step = 0
    while input is not None:
        # for PyTorch 0.4.x, volatile=True is replaced by with torch.no.grad(), so uncomment the followings:
        with torch.no_grad():
            output = model(input)

        # measure accuracy and record loss
        err1, err5 = accuracy(output.data, target, topk=(1, 5))

        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % config.train_params.print_freq == 0:
            print('Test (on val set): [Epoch {0}/{1}][Batch {2}/{3}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
                  'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format(
                epoch, config.train_params.epochs, step, len(val_loader), batch_time=batch_time, top1=top1, top5=top5))
        input, target = prefetcher.next()
        step += 1

    print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f}  Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}'
          .format(epoch, config.train_params.epochs, top1=top1, top5=top5, loss=losses))
    return top1.avg, top5.avg
Ejemplo n.º 6
0
    def infer(self, model, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        sub_obj_avg = utils.AverageMeter()
        data_time = utils.AverageMeter()
        batch_time = utils.AverageMeter()

        model.train()  # don't use running_mean and running_var during search
        start = time.time()
        prefetcher = data_prefetcher(self.val_data)
        input, target = prefetcher.next()
        step = 0
        while input is not None:
            step += 1
            data_t = time.time() - start
            n = input.size(0)

            logits, loss, sub_obj = self.search_optim.valid_step(
                input, target, model)
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            batch_t = time.time() - start
            objs.update(loss, n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            sub_obj_avg.update(sub_obj)
            data_time.update(data_t)
            batch_time.update(batch_t)

            if step % self.args.report_freq == 0:
                logging.info(
                    'Val epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
                    epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg,
                    top1.avg, top5.avg, batch_time.avg, data_time.avg)
            start = time.time()
            input, target = prefetcher.next()

        return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg