Exemple #1
0
def train(train_queue, valid_queue, model, architect, criterion, optimizer,
          lr):
    objs = utils.AvgrageMeter()  # 用于保存loss的值
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()

    # device = torch.device('cuda' if torch.cuda.is_avaitargetsle() else 'cpu')
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(
            train_queue):  #每个step取出一个batch,batchsize是64(256个数据对)
        model.train()
        n = input.size(0)

        input = input.to(device)
        target = target.to(device)

        # get a random minibatch from the search queue with replacement
        input_search, target_search = next(iter(valid_queue))
        input_search = input_search.to(device)
        target_search = target_search.to(device)

        architect.step(input,
                       target,
                       input_search,
                       target_search,
                       lr,
                       optimizer,
                       unrolled=args.unrolled)

        optimizer.zero_grad()
        logits = model(input)
        logits = logits.to(device)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        #prec = utils.Accuracy(logits, target)
        #prec1 = utils.MIoU(logits, target, dataset_classes)
        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Exemple #2
0
    def train_one_epoch(self, epoch):

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()

        epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                  self.run_config.epochs)
        iter_per_epoch = len(self.run_config.train_loader)
        end = time.time()
        for i, (datas, targets) in enumerate(self.run_config.train_loader):
            if torch.cuda.is_available():
                datas = datas.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
            else:
                raise ValueError('do not support cpu version')
            data_time.update(time.time() - end)
            logits = self.model(datas)
            loss = self.criterion(logits, targets)
            evaluator = Evaluator(self.run_config.nb_classes)
            evaluator.add_batch(targets, logits)
            acc = evaluator.Pixel_Accuracy()
            miou = evaluator.Mean_Intersection_over_Union()
            fscore = evaluator.Fx_Score()
            # metrics update
            losses.update(loss.data.item(), datas.size(0))
            accs.update(acc.item(), datas.size(0))
            mious.update(miou.item(), datas.size(0))
            fscores.update(fscore.item(), datas.size(0))
            # update parameters
            self.model.zero_grad()
            loss.backward()
            self.optimizer.step()  # only update network_weight_parameters
            # elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # within one epoch, per iteration print train_log
            if (i + 1) % self.run_config.train_print_freq == 0 or (
                    i + 1) == len(self.run_config.train_loader):
                #print(i+1, self.run_config.train_print_freq)
                Wstr = '|*TRAIN*|' + time_string(
                ) + '[{:}][iter{:03d}/{:03d}]'.format(epoch_str, i + 1,
                                                      iter_per_epoch)
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]' \
                    .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr,
                                mode='retrain')
        return losses.avg, accs.avg, mious.avg, fscores.avg
Exemple #3
0
def train(train_queue, model, criterion, optimizer):
    objs = utils.AvgrageMeter()
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()
    model.train()

    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(train_queue):
        input = input.to(device)
        target = target.to(device)
        n = input.size(0)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Exemple #4
0
def infer(test_queue, model, criterion):
    objs = util.AvgrageMeter()
    accs = util.AvgrageMeter()
    MIoUs = util.AvgrageMeter()
    fscores = util.AvgrageMeter()
    model.eval()

    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    save_path = args.model_path[:-10] + 'predict'
    print(save_path)
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    for step, (input, target, data_list) in enumerate(test_queue):
        input = input.to(device)
        target = target.to(device)
        n = input.size(0)

        logits = model(input)
        util.save_pred_WHU(logits, save_path, data_list)

        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('test %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Exemple #5
0
def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    accs = utils.AvgrageMeter()
    MIoUs = utils.AvgrageMeter()
    fscores = utils.AvgrageMeter()
    model.eval()

    # device = torch.device(torch.cuda if torch.cuda.is_avaitargetsle() else cpu)
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    for step, (input, target) in enumerate(valid_queue):

        input = input.to(device)
        target = target.to(device)

        logits = model(input)
        loss = criterion(logits, target)
        evaluater = Evaluator(dataset_classes)

        #prec = utils.Accuracy(logits, target)
        #prec1 = utils.MIoU(logits, target, dataset_classes)
        evaluater.add_batch(target, logits)
        miou = evaluater.Mean_Intersection_over_Union()
        fscore = evaluater.Fx_Score()
        acc = evaluater.Pixel_Accuracy()

        n = input.size(0)

        objs.update(loss.item(), n)
        MIoUs.update(miou.item(), n)
        fscores.update(fscore.item(), n)
        accs.update(acc.item(), n)

        if step % args.report_freq == 0:
            logging.info('valid %03d %e %f %f %f', step, objs.avg, accs.avg,
                         fscores.avg, MIoUs.avg)

    return accs.avg, objs.avg, fscores.avg, MIoUs.avg
Exemple #6
0
    def train(self, fix_net_weights=False):

        # have config valid_batch_size, and ignored drop_last.
        data_loader = self.run_manager.run_config.train_loader
        iter_per_epoch = len(data_loader)
        total_iteration = iter_per_epoch * self.run_manager.run_config.epochs

        if fix_net_weights:  # used to debug
            data_loader = [(0, 0)] * iter_per_epoch
            print('Train Phase close for debug')

        # arch_parameter update frequency and times in each iteration.
        #update_schedule = self.arch_search_config.get_update_schedule(iter_per_epoch)

        # pay attention here, total_epochs include warmup epochs
        epoch_time = AverageMeter()
        end_epoch = time.time()
        # TODO : use start_epochs
        for epoch in range(self.start_epoch,
                           self.run_manager.run_config.epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Train Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='search')

            self.run_manager.scheduler.step(epoch)
            train_lr = self.run_manager.scheduler.get_lr()
            arch_lr = self.arch_optimizer.param_groups[0]['lr']
            self.net.set_tau(self.arch_search_config.tau_max -
                             (self.arch_search_config.tau_max -
                              self.arch_search_config.tau_min) * (epoch) /
                             (self.run_manager.run_config.epochs))
            tau = self.net.get_tau()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            #valid_data_time = AverageMeter()
            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            self.net.train()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            time_left = epoch_time.average * (
                self.run_manager.run_config.epochs - epoch)
            common_log = '[*Train-Search* the {:}] Left={:} WLR={:} ALR={:} tau={:}'\
                .format(epoch_str, str(timedelta(seconds=time_left)) if epoch != 0 else None, train_lr, arch_lr, tau)
            self.logger.log(common_log, 'search')

            end = time.time()
            for i, (datas, targets) in enumerate(data_loader):
                #if i == 29: break
                if not fix_net_weights:
                    if torch.cuda.is_available():
                        datas = datas.to(self.run_manager.device,
                                         non_blocking=True)
                        targets = targets.to(self.run_manager.device,
                                             non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end)

                    # get single_path in each iteration
                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(datas, single_path)

                    # loss
                    loss = self.run_manager.criterion(logits, targets)
                    # metrics and update
                    evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    acc = evaluator.Pixel_Accuracy()
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    losses.update(loss.data.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()

                    self.run_manager.optimizer.step()

                    #end_valid = time.time()
                    valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                    if torch.cuda.is_available():
                        valid_datas = valid_datas.to(self.run_manager.device,
                                                     non_blocking=True)
                        valid_targets = valid_targets.to(
                            self.run_manager.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')

                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(
                        valid_datas, single_path)

                    loss = self.run_manager.criterion(logits, valid_targets)

                    # metrics and update
                    valid_evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    valid_evaluator.add_batch(valid_targets, logits)
                    acc = valid_evaluator.Pixel_Accuracy()
                    miou = valid_evaluator.Mean_Intersection_over_Union()
                    fscore = valid_evaluator.Fx_Score()
                    valid_losses.update(loss.data.item(), datas.size(0))
                    valid_accs.update(acc.item(), datas.size(0))
                    valid_mious.update(miou.item(), datas.size(0))
                    valid_fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()
                    #if (i+1) % 5 == 0:
                    #    print('network_arch_parameters.grad in search phase:\n',
                    #                         self.net.network_arch_parameters.grad)
                    self.arch_optimizer.step()

                    # batch_time of one iter of train and valid.
                    batch_time.update(time.time() - end)
                    end = time.time()
                    if (
                            i + 1
                    ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                        Wstr = '|*Search*|' + time_string(
                        ) + '[{:}][iter{:03d}/{:03d}]'.format(
                            epoch_str, i + 1, iter_per_epoch)
                        Tstr = '|Time    | {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                            batch_time=batch_time, data_time=data_time)
                        Bstr = '|Base    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=losses, acc=accs, miou=mious, fscore=fscores)
                        Astr = '|Arch    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=valid_losses,
                            acc=valid_accs,
                            miou=valid_mious,
                            fscore=valid_fscores)
                        self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr +
                                        '\n' + Astr,
                                        mode='search')

            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )  # set self.hardwts again
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)
            cell_arch_entropy, network_arch_entropy, total_entropy = self.net.calculate_entropy(
                single_path)

            # update visdom
            if self.vis is not None:
                self.vis.visdom_update(epoch, 'loss',
                                       [losses.average, valid_losses.average])
                self.vis.visdom_update(epoch, 'accuracy',
                                       [accs.average, valid_accs.average])
                self.vis.visdom_update(epoch, 'miou',
                                       [mious.average, valid_mious.average])
                self.vis.visdom_update(
                    epoch, 'f1score', [fscores.average, valid_fscores.average])

                self.vis.visdom_update(epoch, 'cell_entropy',
                                       [cell_arch_entropy])
                self.vis.visdom_update(epoch, 'network_entropy',
                                       [network_arch_entropy])
                self.vis.visdom_update(epoch, 'entropy', [total_entropy])

            #torch.cuda.empty_cache()
            # update epoch_time
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()

            epoch_str = '{:03d}/{:03d}'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            log = '[{:}] train :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n' \
                  '[{:}] valid :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n'.format(
                epoch_str, losses.average, accs.average, mious.average, fscores.average,
                epoch_str, valid_losses.average, valid_accs.average, valid_mious.average, valid_fscores.average
            )
            self.logger.log(log, mode='search')

            self.logger.log(
                '<<<---------->>> Super Network decoding <<<---------->>> ',
                mode='search')
            actual_path, cell_genotypes = self.net.network_cell_arch_decode()
            #print(cell_genotypes)
            new_genotypes = []
            for _index, genotype in cell_genotypes:
                xlist = []
                print(_index, genotype)
                for edge_genotype in genotype:
                    for (node_str, select_index) in edge_genotype:
                        xlist.append((node_str, self.run_manager.run_config.
                                      conv_candidates[select_index]))
                new_genotypes.append((_index, xlist))
            log_str = 'The {:} decode network:\n' \
                      'actual_path = {:}\n' \
                      'genotype:'.format(epoch_str, actual_path)
            for _index, genotype in new_genotypes:
                log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
            self.logger.log(log_str, mode='network_space', display=False)

            # TODO: perform save the best network ckpt
            # 1. save network_arch_parameters and cell_arch_parameters
            # 2. save weight_parameters
            # 3. weight_optimizer.state_dict
            # 4. arch_optimizer.state_dict
            # 5. training process
            # 6. monitor_metric and the best_value
            # get best_monitor in valid phase.
            val_monitor_metric = get_monitor_metric(
                self.run_manager.monitor_metric, valid_losses.average,
                valid_accs.average, valid_mious.average, valid_fscores.average)
            is_best = self.run_manager.best_monitor < val_monitor_metric
            self.run_manager.best_monitor = max(self.run_manager.best_monitor,
                                                val_monitor_metric)
            # 1. if is_best : save_current_ckpt
            # 2. if can be divided : save_current_ckpt

            #self.run_manager.save_model(epoch, {
            #    'arch_optimizer': self.arch_optimizer.state_dict(),
            #}, is_best=True, checkpoint_file_name=None)
            # TODO: have modification on checkpoint_save semantics
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch +
                        1) == self.run_manager.run_config.epochs or is_best:
                checkpoint = {
                    'state_dict':
                    self.net.state_dict(),
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.scheduler.state_dict(),
                    'arch_optimizer':
                    self.arch_optimizer.state_dict(),
                    'best_monitor': (self.run_manager.monitor_metric,
                                     self.run_manager.best_monitor),
                    'warmup':
                    False,
                    'start_epochs':
                    epoch + 1,
                }
                checkpoint_arch = {
                    'actual_path': actual_path,
                    'cell_genotypes': cell_genotypes,
                }
                filename = self.logger.path(mode='search', is_best=is_best)
                filename_arch = self.logger.path(mode='arch', is_best=is_best)
                save_checkpoint(checkpoint,
                                filename,
                                self.logger,
                                mode='search')
                save_checkpoint(checkpoint_arch,
                                filename_arch,
                                self.logger,
                                mode='arch')
Exemple #7
0
    def warm_up(self, warmup_epochs):
        if warmup_epochs <= 0:
            self.logger.log('=> warmup close', mode='warm')
            #print('\twarmup close')
            return
        # set optimizer and scheduler in warm_up phase
        lr_max = self.arch_search_config.warmup_lr
        data_loader = self.run_manager.run_config.train_loader
        scheduler_params = self.run_manager.run_config.optimizer_config[
            'scheduler_params']
        optimizer_params = self.run_manager.run_config.optimizer_config[
            'optimizer_params']
        momentum, nesterov, weight_decay = optimizer_params[
            'momentum'], optimizer_params['nesterov'], optimizer_params[
                'weight_decay']
        eta_min = scheduler_params['eta_min']
        optimizer_warmup = torch.optim.SGD(self.net.weight_parameters(),
                                           lr_max,
                                           momentum,
                                           weight_decay=weight_decay,
                                           nesterov=nesterov)
        # set initial_learning_rate in weight_optimizer
        #for param_groups in self.run_manager.optimizer.param_groups:
        #    param_groups['lr'] = lr_max
        lr_scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_warmup, warmup_epochs, eta_min)
        iter_per_epoch = len(data_loader)
        total_iteration = warmup_epochs * iter_per_epoch

        self.logger.log('=> warmup begin', mode='warm')

        epoch_time = AverageMeter()
        end_epoch = time.time()
        for epoch in range(self.warmup_epoch, warmup_epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Warmup Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='warm')

            lr_scheduler_warmup.step(epoch)
            warmup_lr = lr_scheduler_warmup.get_lr()
            self.net.train()

            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1, warmup_epochs)
            time_left = epoch_time.average * (warmup_epochs - epoch)
            common_log = '[Warmup the {:}] Left={:} LR={:}'.format(
                epoch_str,
                str(timedelta(seconds=time_left)) if epoch != 0 else None,
                warmup_lr)
            self.logger.log(common_log, mode='warm')
            end = time.time()

            for i, (datas, targets) in enumerate(data_loader):

                #if i == 29: break

                if torch.cuda.is_available():
                    datas = datas.to(self.run_manager.device,
                                     non_blocking=True)
                    targets = targets.to(self.run_manager.device,
                                         non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end)

                # get single_path in each iteration
                _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                )
                _, aspp_index = self.net.get_aspp_hardwts_index()
                single_path = self.net.sample_single_path(
                    self.run_manager.run_config.nb_layers, aspp_index,
                    network_index)
                logits = self.net.single_path_forward(datas, single_path)

                # update without entropy reg in warm_up phase
                loss = self.run_manager.criterion(logits, targets)

                # measure metrics and update
                evaluator = Evaluator(self.run_manager.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                acc = evaluator.Pixel_Accuracy()
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                losses.update(loss.data.item(), datas.size(0))
                accs.update(acc, datas.size(0))
                mious.update(miou, datas.size(0))
                fscores.update(fscore, datas.size(0))

                self.net.zero_grad()
                loss.backward()
                self.run_manager.optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()
                if (
                        i + 1
                ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                    Wstr = '|*WARM-UP*|' + time_string(
                    ) + '[{:}][iter{:03d}/{:03d}]'.format(
                        epoch_str, i + 1, iter_per_epoch)
                    Tstr = '|Time     | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                        batch_time=batch_time, data_time=data_time)
                    Bstr = '|Base     | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'\
                        .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                    self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'warm')

            #torch.cuda.empty_cache()
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()
            '''
            # TODO: wheter perform validation after each epoch in warmup phase ?
            valid_loss, valid_acc, valid_miou, valid_fscore = self.validate()
            valid_log = 'Warmup Valid\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tF\t{5:6.4f}'\
                .format(epoch+1, warmup_epochs, valid_loss, valid_acc, valid_miou, valid_fscore)
                        #'\tflops\t{6:}M\tparams{7:}M'\
            valid_log += 'Train\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tFscore\t{5:6.4f}'
            self.run_manager.write_log(valid_log, 'valid')
            '''

            # continue warmup phrase
            self.warmup = epoch + 1 < warmup_epochs
            self.warmup_epoch = self.warmup_epoch + 1
            #self.start_epoch = self.warmup_epoch
            # To save checkpoint in warmup phase at specific frequency.
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch + 1) == warmup_epochs:
                state_dict = self.net.state_dict()
                # rm architecture parameters because, in warm_up phase, arch_parameters are not updated.
                #for key in list(state_dict.keys()):
                #    if 'cell_arch_parameters' in key or 'network_arch_parameters' in key or 'aspp_arch_parameters' in key:
                #        state_dict.pop(key)
                checkpoint = {
                    'state_dict': state_dict,
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.optimizer.state_dict(),
                    'warmup': self.warmup,
                    'warmup_epoch': epoch + 1,
                }
                filename = self.logger.path(mode='warm', is_best=False)
                save_path = save_checkpoint(checkpoint,
                                            filename,
                                            self.logger,
                                            mode='warm')
Exemple #8
0
    def validate(self, epoch=None, is_test=False, use_train_mode=False):
        # 1. super network viterbi_decodde, get actual_path
        # 2. cells genotype decode, which are on the actual_path in the super network
        # 3. according to actual_path and cells genotypes, construct the best network.
        # 4. use the best network, to perform test phrase.

        if is_test:
            data_loader = self.run_config.test_loader
            epoch_str = None
            self.logger.log('\n' + '-' * 30 + 'TESTING PHASE' + '-' * 30 +
                            '\n',
                            mode='valid')
        else:
            data_loader = self.run_config.valid_loader
            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                      self.run_config.epochs)
            self.logger.log('\n' + '-' * 30 +
                            'Valid epoch: {:}'.format(epoch_str) + '-' * 30 +
                            '\n',
                            mode='valid')

        model = self.model

        if use_train_mode: model.train()
        else: model.eval()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()
        accs = AverageMeter()
        end0 = time.time()

        with torch.no_grad():
            if is_test:
                for i, ((datas, targets), filenames) in enumerate(data_loader):
                    if torch.cuda.is_available():
                        datas = datas.to(self.device, non_blocking=True)
                        targets = targets.to(self.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end0)
                    logits = self.model(datas)
                    self._save_pred(logits, filenames)
                    loss = self.criterion(logits, targets)
                    # metrics calculate and update
                    evaluator = Evaluator(self.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    acc = evaluator.Pixel_Accuracy()
                    losses.update(loss.data.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    # duration
                    batch_time.update(time.time() - end0)
                    end0 = time.time()
                Wstr = '|*TEST*|' + time_string()
                Tstr = '|Time  | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base  | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.5f} ({miou.avg:.5f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'test')
            else:
                for i, (datas, targets) in enumerate(data_loader):
                    if torch.cuda.is_available():
                        datas = datas.to(self.device, non_blocking=True)
                        targets = targets.to(self.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end0)
                    # validation of the derived model. normal forward pass.
                    logits = self.model(datas)

                    # TODO generate predictions
                    loss = self.criterion(logits, targets)

                    # metrics calculate and update
                    evaluator = Evaluator(self.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    acc = evaluator.Pixel_Accuracy()

                    losses.update(loss.data.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))

                    # duration
                    batch_time.update(time.time() - end0)
                    end0 = time.time()
                Wstr = '|*VALID*|' + time_string() + epoch_str
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'valid')

        return losses.avg, accs.avg, mious.avg, fscores.avg
    architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)

    optimizer.zero_grad()
    logits = model(input)
    logits = logits.to(device)
    loss = criterion(logits, target)
    evaluater = Evaluator(dataset_classes)
    loss.backward()
    nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
    optimizer.step()

    #prec = utils.Accuracy(logits, target)
    #prec1 = utils.MIoU(logits, target, dataset_classes)
    evaluater.add_batch(target,logits)
    miou = evaluater.Mean_Intersection_over_Union()
    fscore = evaluater.Fx_Score()
    acc = evaluater.Pixel_Accuracy()

    objs.update(loss.item(), n)
    MIoUs.update(miou.item(), n)
    fscores.update(fscore.item(), n)
    accs.update(acc.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f %f', step, objs.avg, accs.avg, fscores.avg, MIoUs.avg)


  return accs.avg, objs.avg, fscores.avg, MIoUs.avg


def infer(valid_queue, model, criterion):