def validate(val_loader,
             model,
             criterion,
             epoch,
             log_freq=1,
             print_sum=True,
             device=None,
             stereo=True):

    losses = AverageMeter()

    # set model to evaluation
    model.eval()

    with torch.no_grad():
        epoch_time = time.time()
        end = time.time()
        for idx, (batch_images, batch_poses) in enumerate(val_loader):
            data_time = time.time() - end

            if stereo:
                batch_images = [x.to(device) for x in batch_images]
                batch_poses = [x.to(device) for x in batch_poses]
            else:
                batch_images = batch_images.to(device)
                batch_poses = batch_poses.to(device)

            # compute model output
            out = model(batch_images)
            loss = criterion(out, batch_poses)

            losses.update(
                loss.data[0],
                len(batch_images) *
                batch_images[0].size(0) if stereo else batch_images.size(0))

            batch_time = time.time() - end
            end = time.time()

            if log_freq != 0 and idx % log_freq == 0:
                print('Val Epoch: {}\t'
                      'Time: {batch_time:.3f}\t'
                      'Data Time: {data_time:.3f}\t'
                      'Loss: {losses.val:.3f}\t'
                      'Avg Loss: {losses.avg:.3f}'.format(
                          epoch,
                          batch_time=batch_time,
                          data_time=data_time,
                          losses=losses))

            # if idx == 0:
            #     break

    if print_sum:
        print(
            'Epoch: [{}]\tValidation Loss: {:.3f}\tEpoch time: {:.3f}'.format(
                epoch, losses.avg, (time.time() - epoch_time)))
Пример #2
0
 def train(self):
     epoch_time = AverageMeter()
     end = time.time()
     self.model.train()
     for epoch in range(self.start_epoch, self.run_config.epochs):
         self.logger.log('\n' + '-' * 30 +
                         'Train epoch: {}'.format(epoch + 1) + '-' * 30 +
                         '\n',
                         mode='retrain')
         epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                   self.run_config.epochs)
         self.scheduler.step(epoch)
         train_lr = self.scheduler.get_lr()
         time_left = epoch_time.average * (self.run_config.epochs - epoch)
         common_log = '[Train the {:}] Left={:} LR={:}'.format(
             epoch_str,
             str(timedelta(seconds=time_left)) if epoch != 0 else None,
             train_lr)
         self.logger.log(common_log, 'retrain')
         # train log have been wrote in train_one_epoch
         loss, acc, miou, fscore = self.train_one_epoch(epoch)
         epoch_time.update(time.time() - end)
         end = time.time()
         # perform validation at the end of each epoch.
         val_loss, val_acc, val_miou, val_fscore = self.validate(
             epoch, is_test=False, use_train_mode=False)
         val_monitor_metric = get_monitor_metric(self.monitor_metric,
                                                 val_loss, val_acc,
                                                 val_miou, val_fscore)
         is_best = val_monitor_metric > self.best_monitor
         self.best_monitor = max(self.best_monitor, val_monitor_metric)
         # update visdom
         if self.vis is not None:
             self.vis.visdom_update(epoch, 'loss', [loss, val_loss])
             self.vis.visdom_update(epoch, 'accuracy', [acc, val_acc])
             self.vis.visdom_update(epoch, 'miou', [miou, val_miou])
             self.vis.visdom_update(epoch, 'f1score', [fscore, val_fscore])
         # save checkpoint
         if (epoch + 1) % self.run_config.save_ckpt_freq == 0 or (
                 epoch + 1) == self.run_config.epochs or is_best:
             checkpoint = {
                 'state_dict': self.model.state_dict(),
                 'weight_optimizer': self.optimizer.state_dict(),
                 'weight_scheduler': self.scheduler.state_dict(),
                 'best_monitor': (self.monitor_metric, self.best_monitor),
                 'start_epoch': epoch + 1,
                 'actual_path': self.run_config.actual_path,
                 'cell_genotypes': self.run_config.cell_genotypes
             }
             filename = self.logger.path(mode='retrain', is_best=is_best)
             save_checkpoint(checkpoint,
                             filename,
                             self.logger,
                             mode='retrain')
Пример #3
0
def test(model,
         loader_test,
         data_length,
         device,
         criterion,
         batch_size,
         print_logger,
         step,
         use_top5=False,
         verbose=False):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    t1 = time.time()
    with torch.no_grad():
        # switch to evaluate mode
        model.eval()
        end = time.time()

        for i, data in enumerate(loader_test):
            inputs = data[0]["data"].to(device)
            targets = data[0]["label"].squeeze().long().to(device)

            # for i, (inputs, targets) in enumerate(loader_test, 1):
            #     inputs = inputs.to(device)
            #     targets = targets.to(device)

            # compute output
            output = model(inputs)
            loss = criterion(output, targets)

            #measure accuracy and record loss
            prec1, prec5 = accuracy(output, targets, topk=(1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(prec1[0], batch_size)
            top5.update(prec5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # plot progress

        # measure elapsed time
    t2 = time.time()

    print_logger.info('Test Step [{0}]: '
                      'Loss {loss.avg:.4f} '
                      'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '
                      'Time {time}'.format(step,
                                           loss=losses,
                                           top1=top1,
                                           top5=top5,
                                           time=t2 - t1))

    loader_test.reset()
    return top1.avg
Пример #4
0
def eval_model(val_loader, model, criterion, eval_metric, epoch, use_cuda):
    losses = AverageMeter()
    y_pred = []
    y_label = []
    model.train(False)
    torch.set_grad_enabled(False)
    for i, data in enumerate(val_loader):
        xi, xv, y = data[0], data[1], data[2]
        if use_cuda:
            xi, xv, y = xi.cuda(), xv.cuda(), y.cuda()
        outputs = model(xi, xv)
        loss = criterion(outputs, y)
        pred = torch.sigmoid(outputs).cpu()
        y_pred.extend(pred.data.numpy())
        y_label.extend(y.data.numpy())
        losses.update(loss.item(), y.shape[0])
    total_metric = eval_metric(y_label, y_pred)
    return losses.avg, total_metric
Пример #5
0
def train_epoch(train_loader, model, criterion, optimizer, epoch, use_cuda):
    losses = AverageMeter()
    model.train(True)
    torch.set_grad_enabled(True)
    for i, data in enumerate(train_loader):
        xi, xv, y = data[0], data[1], data[2]
        if use_cuda:
            xi, xv, y = xi.cuda(), xv.cuda(), y.cuda()
        optimizer.zero_grad()
        outputs = model(xi, xv)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), y.shape[0])

        progress_bar(i, len(train_loader),
                     'batch {}, train loss {:.5f}'.format(i, losses.avg))
    logging.info('Epoch: [{0}]\t Loss {loss.avg:.4f}\t'.format(epoch,
                                                               loss=losses))
Пример #6
0
    def train(self, fix_net_weights=False):

        # have config valid_batch_size, and ignored drop_last.
        data_loader = self.run_manager.run_config.train_loader
        iter_per_epoch = len(data_loader)
        total_iteration = iter_per_epoch * self.run_manager.run_config.epochs

        if fix_net_weights:  # used to debug
            data_loader = [(0, 0)] * iter_per_epoch
            print('Train Phase close for debug')

        # arch_parameter update frequency and times in each iteration.
        #update_schedule = self.arch_search_config.get_update_schedule(iter_per_epoch)

        # pay attention here, total_epochs include warmup epochs
        epoch_time = AverageMeter()
        end_epoch = time.time()
        # TODO : use start_epochs
        for epoch in range(self.start_epoch,
                           self.run_manager.run_config.epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Train Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='search')

            self.run_manager.scheduler.step(epoch)
            train_lr = self.run_manager.scheduler.get_lr()
            arch_lr = self.arch_optimizer.param_groups[0]['lr']
            self.net.set_tau(self.arch_search_config.tau_max -
                             (self.arch_search_config.tau_max -
                              self.arch_search_config.tau_min) * (epoch) /
                             (self.run_manager.run_config.epochs))
            tau = self.net.get_tau()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            #valid_data_time = AverageMeter()
            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            self.net.train()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            time_left = epoch_time.average * (
                self.run_manager.run_config.epochs - epoch)
            common_log = '[*Train-Search* the {:}] Left={:} WLR={:} ALR={:} tau={:}'\
                .format(epoch_str, str(timedelta(seconds=time_left)) if epoch != 0 else None, train_lr, arch_lr, tau)
            self.logger.log(common_log, 'search')

            end = time.time()
            for i, (datas, targets) in enumerate(data_loader):
                #if i == 29: break
                if not fix_net_weights:
                    if torch.cuda.is_available():
                        datas = datas.to(self.run_manager.device,
                                         non_blocking=True)
                        targets = targets.to(self.run_manager.device,
                                             non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end)

                    # get single_path in each iteration
                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(datas, single_path)

                    # loss
                    loss = self.run_manager.criterion(logits, targets)
                    # metrics and update
                    evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    acc = evaluator.Pixel_Accuracy()
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    losses.update(loss.data.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()

                    self.run_manager.optimizer.step()

                    #end_valid = time.time()
                    valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                    if torch.cuda.is_available():
                        valid_datas = valid_datas.to(self.run_manager.device,
                                                     non_blocking=True)
                        valid_targets = valid_targets.to(
                            self.run_manager.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')

                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(
                        valid_datas, single_path)

                    loss = self.run_manager.criterion(logits, valid_targets)

                    # metrics and update
                    valid_evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    valid_evaluator.add_batch(valid_targets, logits)
                    acc = valid_evaluator.Pixel_Accuracy()
                    miou = valid_evaluator.Mean_Intersection_over_Union()
                    fscore = valid_evaluator.Fx_Score()
                    valid_losses.update(loss.data.item(), datas.size(0))
                    valid_accs.update(acc.item(), datas.size(0))
                    valid_mious.update(miou.item(), datas.size(0))
                    valid_fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()
                    #if (i+1) % 5 == 0:
                    #    print('network_arch_parameters.grad in search phase:\n',
                    #                         self.net.network_arch_parameters.grad)
                    self.arch_optimizer.step()

                    # batch_time of one iter of train and valid.
                    batch_time.update(time.time() - end)
                    end = time.time()
                    if (
                            i + 1
                    ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                        Wstr = '|*Search*|' + time_string(
                        ) + '[{:}][iter{:03d}/{:03d}]'.format(
                            epoch_str, i + 1, iter_per_epoch)
                        Tstr = '|Time    | {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                            batch_time=batch_time, data_time=data_time)
                        Bstr = '|Base    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=losses, acc=accs, miou=mious, fscore=fscores)
                        Astr = '|Arch    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=valid_losses,
                            acc=valid_accs,
                            miou=valid_mious,
                            fscore=valid_fscores)
                        self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr +
                                        '\n' + Astr,
                                        mode='search')

            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )  # set self.hardwts again
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)
            cell_arch_entropy, network_arch_entropy, total_entropy = self.net.calculate_entropy(
                single_path)

            # update visdom
            if self.vis is not None:
                self.vis.visdom_update(epoch, 'loss',
                                       [losses.average, valid_losses.average])
                self.vis.visdom_update(epoch, 'accuracy',
                                       [accs.average, valid_accs.average])
                self.vis.visdom_update(epoch, 'miou',
                                       [mious.average, valid_mious.average])
                self.vis.visdom_update(
                    epoch, 'f1score', [fscores.average, valid_fscores.average])

                self.vis.visdom_update(epoch, 'cell_entropy',
                                       [cell_arch_entropy])
                self.vis.visdom_update(epoch, 'network_entropy',
                                       [network_arch_entropy])
                self.vis.visdom_update(epoch, 'entropy', [total_entropy])

            #torch.cuda.empty_cache()
            # update epoch_time
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()

            epoch_str = '{:03d}/{:03d}'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            log = '[{:}] train :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n' \
                  '[{:}] valid :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n'.format(
                epoch_str, losses.average, accs.average, mious.average, fscores.average,
                epoch_str, valid_losses.average, valid_accs.average, valid_mious.average, valid_fscores.average
            )
            self.logger.log(log, mode='search')

            self.logger.log(
                '<<<---------->>> Super Network decoding <<<---------->>> ',
                mode='search')
            actual_path, cell_genotypes = self.net.network_cell_arch_decode()
            #print(cell_genotypes)
            new_genotypes = []
            for _index, genotype in cell_genotypes:
                xlist = []
                print(_index, genotype)
                for edge_genotype in genotype:
                    for (node_str, select_index) in edge_genotype:
                        xlist.append((node_str, self.run_manager.run_config.
                                      conv_candidates[select_index]))
                new_genotypes.append((_index, xlist))
            log_str = 'The {:} decode network:\n' \
                      'actual_path = {:}\n' \
                      'genotype:'.format(epoch_str, actual_path)
            for _index, genotype in new_genotypes:
                log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
            self.logger.log(log_str, mode='network_space', display=False)

            # TODO: perform save the best network ckpt
            # 1. save network_arch_parameters and cell_arch_parameters
            # 2. save weight_parameters
            # 3. weight_optimizer.state_dict
            # 4. arch_optimizer.state_dict
            # 5. training process
            # 6. monitor_metric and the best_value
            # get best_monitor in valid phase.
            val_monitor_metric = get_monitor_metric(
                self.run_manager.monitor_metric, valid_losses.average,
                valid_accs.average, valid_mious.average, valid_fscores.average)
            is_best = self.run_manager.best_monitor < val_monitor_metric
            self.run_manager.best_monitor = max(self.run_manager.best_monitor,
                                                val_monitor_metric)
            # 1. if is_best : save_current_ckpt
            # 2. if can be divided : save_current_ckpt

            #self.run_manager.save_model(epoch, {
            #    'arch_optimizer': self.arch_optimizer.state_dict(),
            #}, is_best=True, checkpoint_file_name=None)
            # TODO: have modification on checkpoint_save semantics
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch +
                        1) == self.run_manager.run_config.epochs or is_best:
                checkpoint = {
                    'state_dict':
                    self.net.state_dict(),
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.scheduler.state_dict(),
                    'arch_optimizer':
                    self.arch_optimizer.state_dict(),
                    'best_monitor': (self.run_manager.monitor_metric,
                                     self.run_manager.best_monitor),
                    'warmup':
                    False,
                    'start_epochs':
                    epoch + 1,
                }
                checkpoint_arch = {
                    'actual_path': actual_path,
                    'cell_genotypes': cell_genotypes,
                }
                filename = self.logger.path(mode='search', is_best=is_best)
                filename_arch = self.logger.path(mode='arch', is_best=is_best)
                save_checkpoint(checkpoint,
                                filename,
                                self.logger,
                                mode='search')
                save_checkpoint(checkpoint_arch,
                                filename_arch,
                                self.logger,
                                mode='arch')
Пример #7
0
    def warm_up(self, warmup_epochs):
        if warmup_epochs <= 0:
            self.logger.log('=> warmup close', mode='warm')
            #print('\twarmup close')
            return
        # set optimizer and scheduler in warm_up phase
        lr_max = self.arch_search_config.warmup_lr
        data_loader = self.run_manager.run_config.train_loader
        scheduler_params = self.run_manager.run_config.optimizer_config[
            'scheduler_params']
        optimizer_params = self.run_manager.run_config.optimizer_config[
            'optimizer_params']
        momentum, nesterov, weight_decay = optimizer_params[
            'momentum'], optimizer_params['nesterov'], optimizer_params[
                'weight_decay']
        eta_min = scheduler_params['eta_min']
        optimizer_warmup = torch.optim.SGD(self.net.weight_parameters(),
                                           lr_max,
                                           momentum,
                                           weight_decay=weight_decay,
                                           nesterov=nesterov)
        # set initial_learning_rate in weight_optimizer
        #for param_groups in self.run_manager.optimizer.param_groups:
        #    param_groups['lr'] = lr_max
        lr_scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_warmup, warmup_epochs, eta_min)
        iter_per_epoch = len(data_loader)
        total_iteration = warmup_epochs * iter_per_epoch

        self.logger.log('=> warmup begin', mode='warm')

        epoch_time = AverageMeter()
        end_epoch = time.time()
        for epoch in range(self.warmup_epoch, warmup_epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Warmup Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='warm')

            lr_scheduler_warmup.step(epoch)
            warmup_lr = lr_scheduler_warmup.get_lr()
            self.net.train()

            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1, warmup_epochs)
            time_left = epoch_time.average * (warmup_epochs - epoch)
            common_log = '[Warmup the {:}] Left={:} LR={:}'.format(
                epoch_str,
                str(timedelta(seconds=time_left)) if epoch != 0 else None,
                warmup_lr)
            self.logger.log(common_log, mode='warm')
            end = time.time()

            for i, (datas, targets) in enumerate(data_loader):

                #if i == 29: break

                if torch.cuda.is_available():
                    datas = datas.to(self.run_manager.device,
                                     non_blocking=True)
                    targets = targets.to(self.run_manager.device,
                                         non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end)

                # get single_path in each iteration
                _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                )
                _, aspp_index = self.net.get_aspp_hardwts_index()
                single_path = self.net.sample_single_path(
                    self.run_manager.run_config.nb_layers, aspp_index,
                    network_index)
                logits = self.net.single_path_forward(datas, single_path)

                # update without entropy reg in warm_up phase
                loss = self.run_manager.criterion(logits, targets)

                # measure metrics and update
                evaluator = Evaluator(self.run_manager.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                acc = evaluator.Pixel_Accuracy()
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                losses.update(loss.data.item(), datas.size(0))
                accs.update(acc, datas.size(0))
                mious.update(miou, datas.size(0))
                fscores.update(fscore, datas.size(0))

                self.net.zero_grad()
                loss.backward()
                self.run_manager.optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()
                if (
                        i + 1
                ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                    Wstr = '|*WARM-UP*|' + time_string(
                    ) + '[{:}][iter{:03d}/{:03d}]'.format(
                        epoch_str, i + 1, iter_per_epoch)
                    Tstr = '|Time     | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                        batch_time=batch_time, data_time=data_time)
                    Bstr = '|Base     | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'\
                        .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                    self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'warm')

            #torch.cuda.empty_cache()
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()
            '''
            # TODO: wheter perform validation after each epoch in warmup phase ?
            valid_loss, valid_acc, valid_miou, valid_fscore = self.validate()
            valid_log = 'Warmup Valid\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tF\t{5:6.4f}'\
                .format(epoch+1, warmup_epochs, valid_loss, valid_acc, valid_miou, valid_fscore)
                        #'\tflops\t{6:}M\tparams{7:}M'\
            valid_log += 'Train\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tFscore\t{5:6.4f}'
            self.run_manager.write_log(valid_log, 'valid')
            '''

            # continue warmup phrase
            self.warmup = epoch + 1 < warmup_epochs
            self.warmup_epoch = self.warmup_epoch + 1
            #self.start_epoch = self.warmup_epoch
            # To save checkpoint in warmup phase at specific frequency.
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch + 1) == warmup_epochs:
                state_dict = self.net.state_dict()
                # rm architecture parameters because, in warm_up phase, arch_parameters are not updated.
                #for key in list(state_dict.keys()):
                #    if 'cell_arch_parameters' in key or 'network_arch_parameters' in key or 'aspp_arch_parameters' in key:
                #        state_dict.pop(key)
                checkpoint = {
                    'state_dict': state_dict,
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.optimizer.state_dict(),
                    'warmup': self.warmup,
                    'warmup_epoch': epoch + 1,
                }
                filename = self.logger.path(mode='warm', is_best=False)
                save_path = save_checkpoint(checkpoint,
                                            filename,
                                            self.logger,
                                            mode='warm')
Пример #8
0
    def train_one_epoch(self, epoch):

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()

        epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                  self.run_config.epochs)
        iter_per_epoch = len(self.run_config.train_loader)
        end = time.time()
        for i, (datas, targets) in enumerate(self.run_config.train_loader):
            if torch.cuda.is_available():
                datas = datas.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
            else:
                raise ValueError('do not support cpu version')
            data_time.update(time.time() - end)
            logits = self.model(datas)
            loss = self.criterion(logits, targets)
            evaluator = Evaluator(self.run_config.nb_classes)
            evaluator.add_batch(targets, logits)
            acc = evaluator.Pixel_Accuracy()
            miou = evaluator.Mean_Intersection_over_Union()
            fscore = evaluator.Fx_Score()
            # metrics update
            losses.update(loss.data.item(), datas.size(0))
            accs.update(acc.item(), datas.size(0))
            mious.update(miou.item(), datas.size(0))
            fscores.update(fscore.item(), datas.size(0))
            # update parameters
            self.model.zero_grad()
            loss.backward()
            self.optimizer.step()  # only update network_weight_parameters
            # elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # within one epoch, per iteration print train_log
            if (i + 1) % self.run_config.train_print_freq == 0 or (
                    i + 1) == len(self.run_config.train_loader):
                #print(i+1, self.run_config.train_print_freq)
                Wstr = '|*TRAIN*|' + time_string(
                ) + '[{:}][iter{:03d}/{:03d}]'.format(epoch_str, i + 1,
                                                      iter_per_epoch)
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]' \
                    .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr,
                                mode='retrain')
        return losses.avg, accs.avg, mious.avg, fscores.avg
Пример #9
0
    def validate(self, epoch=None, is_test=False, use_train_mode=False):
        # 1. super network viterbi_decodde, get actual_path
        # 2. cells genotype decode, which are on the actual_path in the super network
        # 3. according to actual_path and cells genotypes, construct the best network.
        # 4. use the best network, to perform test phrase.

        if is_test:
            data_loader = self.run_config.test_loader
            epoch_str = None
            self.logger.log('\n' + '-' * 30 + 'TESTING PHASE' + '-' * 30 +
                            '\n',
                            mode='valid')
        else:
            data_loader = self.run_config.valid_loader
            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                      self.run_config.epochs)
            self.logger.log('\n' + '-' * 30 +
                            'Valid epoch: {:}'.format(epoch_str) + '-' * 30 +
                            '\n',
                            mode='valid')

        model = self.model

        if use_train_mode: model.train()
        else: model.eval()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        mious = AverageMeter()
        fscores = AverageMeter()
        accs = AverageMeter()
        end0 = time.time()

        with torch.no_grad():
            if is_test:
                for i, ((datas, targets), filenames) in enumerate(data_loader):
                    if torch.cuda.is_available():
                        datas = datas.to(self.device, non_blocking=True)
                        targets = targets.to(self.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end0)
                    logits = self.model(datas)
                    self._save_pred(logits, filenames)
                    loss = self.criterion(logits, targets)
                    # metrics calculate and update
                    evaluator = Evaluator(self.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    acc = evaluator.Pixel_Accuracy()
                    losses.update(loss.data.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    # duration
                    batch_time.update(time.time() - end0)
                    end0 = time.time()
                Wstr = '|*TEST*|' + time_string()
                Tstr = '|Time  | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base  | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.5f} ({miou.avg:.5f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'test')
            else:
                for i, (datas, targets) in enumerate(data_loader):
                    if torch.cuda.is_available():
                        datas = datas.to(self.device, non_blocking=True)
                        targets = targets.to(self.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end0)
                    # validation of the derived model. normal forward pass.
                    logits = self.model(datas)

                    # TODO generate predictions
                    loss = self.criterion(logits, targets)

                    # metrics calculate and update
                    evaluator = Evaluator(self.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    acc = evaluator.Pixel_Accuracy()

                    losses.update(loss.data.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))

                    # duration
                    batch_time.update(time.time() - end0)
                    end0 = time.time()
                Wstr = '|*VALID*|' + time_string() + epoch_str
                Tstr = '|Time   | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                    batch_time=batch_time, data_time=data_time)
                Bstr = '|Base   | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                    loss=losses, acc=accs, miou=mious, fscore=fscores)
                self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'valid')

        return losses.avg, accs.avg, mious.avg, fscores.avg
Пример #10
0
class Trainer():
    def __init__(self, args, loader, t_model, s_model, ckp):
        self.args = args
        self.scale = args.scale

        self.epoch = 0
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.t_model = t_model
        self.s_model = s_model
        arch_param = [v for k, v in self.s_model.named_parameters() if 'alpha' not in k]
        alpha_param = [v for k, v in self.s_model.named_parameters() if 'alpha' in k]

        params = [{'params': arch_param}, {'params': alpha_param, 'lr': 1e-2}]

        self.optimizer = torch.optim.Adam(params, lr=args.lr, betas = args.betas, eps=args.epsilon)
        self.sheduler = StepLR(self.optimizer, step_size=int(args.decay), gamma=args.gamma)
        self.writer_train = SummaryWriter(ckp.dir + '/run/train')
        
        if args.resume is not None:
            ckpt = torch.load(args.resume)
            self.epoch = ckpt['epoch']
            print(f"Continue from {self.epoch}")
            self.s_model.load_state_dict(ckpt['state_dict'])
            self.optimizer.load_state_dict(ckpt['optimizer'])
            self.sheduler.load_state_dict(ckpt['scheduler'])

        self.losses = AverageMeter()
        self.att_losses = AverageMeter()
        self.nor_losses = AverageMeter()

    def train(self):
        self.sheduler.step(self.epoch)
        self.epoch = self.epoch + 1
        lr = self.optimizer.state_dict()['param_groups'][0]['lr']

        self.writer_train.add_scalar(f'lr', lr, self.epoch)
        self.ckp.write_log(
            '[Epoch {}]\tLearning rate: {:.2e}'.format(self.epoch, Decimal(lr))
        )

        self.t_model.eval()
        self.s_model.train()
        
        self.s_model.apply(lambda m: setattr(m, 'epoch', self.epoch))
        
        num_iterations = len(self.loader_train)
        timer_data, timer_model = utility.timer(), utility.timer()
        
        for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
            num_iters = num_iterations * (self.epoch-1) + batch

            lr, hr = self.prepare(lr, hr)
            data_size = lr.size(0) 
            
            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()

            if hasattr(self.t_model, 'set_scale'):
                self.t_model.set_scale(idx_scale)
            if hasattr(self.s_model, 'set_scale'):
                self.s_model.set_scale(idx_scale)

            with torch.no_grad():
                t_sr, t_res = self.t_model(lr)
            s_sr, s_res = self.s_model(lr)

            nor_loss = args.w_l1 * F.l1_loss(s_sr, hr)
            att_loss = args.w_at * util.at_loss(s_res, t_res)

            loss = nor_loss  + att_loss

            loss.backward()
            self.optimizer.step()

            timer_model.hold()

            self.losses.update(loss.item(),data_size)
            display_loss = f'Loss: {self.losses.avg: .3f}'

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    display_loss,
                    timer_model.release(),
                    timer_data.release()))

            timer_data.tic()
        
            for name, value in self.s_model.named_parameters():
                if 'alpha' in name:
                    if value.grad is not None:
                        self.writer_train.add_scalar(f'{name}_grad', value.grad.cpu().data.numpy(), num_iters)
                        self.writer_train.add_scalar(f'{name}_data', value.cpu().data.numpy(), num_iters)


    def test(self, is_teacher=False):
        torch.set_grad_enabled(False)
        epoch = self.epoch
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(
            torch.zeros(1, len(self.loader_test), len(self.scale))
        )
        if is_teacher:
            model = self.t_model
        else:
            model = self.s_model
        model.eval()
        timer_test = utility.timer()
        
        if self.args.save_results: self.ckp.begin_background()
        for idx_data, d in enumerate(self.loader_test):
            for idx_scale, scale in enumerate(self.scale):
                d.dataset.set_scale(idx_scale)
                i = 0
                for lr, hr, filename, _ in tqdm(d, ncols=80):
                    i += 1
                    lr, hr = self.prepare(lr, hr)
                    sr, s_res = model(lr)
                    sr = utility.quantize(sr, self.args.rgb_range)
                    save_list = [sr]
                    cur_psnr = utility.calc_psnr(
                        sr, hr, scale, self.args.rgb_range, dataset=d
                    )
                    self.ckp.log[-1, idx_data, idx_scale] += cur_psnr
                    if self.args.save_gt:
                        save_list.extend([lr, hr])

                    if self.args.save_results:
                        save_name = f'{args.k_bits}bit_{filename[0]}'
                        self.ckp.save_results(d, save_name, save_list, scale)

                self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                best = self.ckp.log.max(0)

                self.ckp.write_log(
                    '[{} x{}] PSNR: {:.3f}  (Best: {:.3f} @epoch {})'.format(
                        d.dataset.name,
                        scale,
                        self.ckp.log[-1, idx_data, idx_scale],
                        best[0][idx_data, idx_scale],
                        best[1][idx_data, idx_scale] + 1
                    )
                )
                self.writer_train.add_scalar(f'psnr', self.ckp.log[-1, idx_data, idx_scale], self.epoch)

        if self.args.save_results:
            self.ckp.end_background()
            
        if not self.args.test_only:
            is_best = (best[1][0, 0] + 1 == epoch)

            state = {
            'epoch': epoch,
            'state_dict': self.s_model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.sheduler.state_dict()
        }
            util.save_checkpoint(state, is_best, checkpoint =self.ckp.dir + '/model')
        
        self.ckp.write_log(
            'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )

        torch.set_grad_enabled(True)

    def prepare(self, *args):
        def _prepare(tensor):
            if self.args.precision == 'half': tensor = tensor.half()
            return tensor.cuda()

        return [_prepare(a) for a in args]

    def terminate(self):
        if self.args.test_only:
            self.test()
            return True
        else:
            return self.epoch >= self.args.epochs
Пример #11
0
def train(args):
    if args.batch_size % args.num_instance != 0:
        new_batch_size = (args.batch_size //
                          args.num_instance) * args.num_instance
        print(
            f"given batch size is {args.batch_size} and num_instances is {args.num_instance}."
            +
            f"Batch size must be divided into {args.num_instance}. Batch size will be replaced into {new_batch_size}"
        )
        args.batch_size = new_batch_size

    # prepare dataset
    train_loader, val_loader, num_query, train_data_len, num_classes = make_data_loader(
        args)

    model = build_model(args, num_classes)
    print("model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    loss_fn, center_criterion = make_loss(args, num_classes)
    optimizer, optimizer_center = make_optimizer(args, model, center_criterion)

    if args.cuda:
        model = model.cuda()
        if args.amp:
            if args.center_loss:
                model, [optimizer, optimizer_center] = \
                    amp.initialize(model, [optimizer, optimizer_center], opt_level="O1")
            else:
                model, optimizer = amp.initialize(model,
                                                  optimizer,
                                                  opt_level="O1")

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        if args.center_loss:
            center_criterion = center_criterion.cuda()
            for state in optimizer_center.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()

    model_state_dict = model.state_dict()
    optim_state_dict = optimizer.state_dict()
    if args.center_loss:
        optim_center_state_dict = optimizer_center.state_dict()
        center_state_dict = center_criterion.state_dict()

    reid_evaluator = ReIDEvaluator(args, model, num_query)

    start_epoch = 0
    global_step = 0
    if args.pretrain != '':  # load pre-trained model
        weights = torch.load(args.pretrain)
        model_state_dict = weights["state_dict"]

        model.load_state_dict(model_state_dict)
        if args.center_loss:
            center_criterion.load_state_dict(
                torch.load(args.pretrain.replace(
                    'model', 'center_param'))["state_dict"])

        if args.resume:
            start_epoch = weights["epoch"]
            global_step = weights["global_step"]

            optimizer.load_state_dict(
                torch.load(args.pretrain.replace('model',
                                                 'optimizer'))["state_dict"])
            if args.center_loss:
                optimizer_center.load_state_dict(
                    torch.load(
                        args.pretrain.replace(
                            'model', 'optimizer_center'))["state_dict"])
        print(f'Start epoch: {start_epoch}, Start step: {global_step}')

    scheduler = WarmupMultiStepLR(optimizer, args.steps, args.gamma,
                                  args.warmup_factor, args.warmup_step,
                                  "linear",
                                  -1 if start_epoch == 0 else start_epoch)

    current_epoch = start_epoch
    best_epoch = 0
    best_rank1 = 0
    best_mAP = 0
    if args.resume:
        rank, mAP = reid_evaluator.evaluate(val_loader)
        best_rank1 = rank[0]
        best_mAP = mAP
        best_epoch = current_epoch + 1

    batch_time = AverageMeter()
    total_losses = AverageMeter()

    model_save_dir = os.path.join(args.save_dir, 'ckpts')
    os.makedirs(model_save_dir, exist_ok=True)

    summary_writer = SummaryWriter(log_dir=os.path.join(
        args.save_dir, "tensorboard_log"),
                                   purge_step=global_step)

    def summary_loss(score, feat, labels, top_name='global'):
        loss = 0.0
        losses = loss_fn(score, feat, labels)
        for loss_name, loss_val in losses.items():
            if loss_name.lower() == "accuracy":
                summary_writer.add_scalar(f"Score/{top_name}/triplet",
                                          loss_val, global_step)
                continue
            if "dist" in loss_name.lower():
                summary_writer.add_histogram(f"Distance/{loss_name}", loss_val,
                                             global_step)
                continue
            loss += loss_val
            summary_writer.add_scalar(f"losses/{top_name}/{loss_name}",
                                      loss_val, global_step)

        ohe_labels = torch.zeros_like(score)
        ohe_labels.scatter_(1, labels.unsqueeze(1), 1.0)

        cls_score = torch.softmax(score, dim=1)
        cls_score = torch.sum(cls_score * ohe_labels, dim=1).mean()
        summary_writer.add_scalar(f"Score/{top_name}/X-entropy", cls_score,
                                  global_step)

        return loss

    def save_weights(file_name, eph, steps):
        torch.save(
            {
                "state_dict": model_state_dict,
                "epoch": eph + 1,
                "global_step": steps
            }, file_name)
        torch.save({"state_dict": optim_state_dict},
                   file_name.replace("model", "optimizer"))
        if args.center_loss:
            torch.save({"state_dict": center_state_dict},
                       file_name.replace("model", "optimizer_center"))
            torch.save({"state_dict": optim_center_state_dict},
                       file_name.replace("model", "center_param"))

    # training start
    for epoch in range(start_epoch, args.max_epoch):
        model.train()
        t0 = time.time()
        for i, (inputs, labels, _, _) in enumerate(train_loader):
            if args.cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            cls_scores, features = model(inputs, labels)

            # losses
            total_loss = summary_loss(cls_scores[0], features[0], labels,
                                      'global')
            if args.use_local_feat:
                total_loss += summary_loss(cls_scores[1], features[1], labels,
                                           'local')

            optimizer.zero_grad()
            if args.center_loss:
                optimizer_center.zero_grad()

            # backward with global loss
            if args.amp:
                optimizers = [optimizer]
                if args.center_loss:
                    optimizers.append(optimizer_center)
                with amp.scale_loss(total_loss, optimizers) as scaled_loss:
                    scaled_loss.backward()
            else:
                with torch.autograd.detect_anomaly():
                    total_loss.backward()

            # optimization
            optimizer.step()
            if args.center_loss:
                for name, param in center_criterion.named_parameters():
                    try:
                        param.grad.data *= (1. / args.center_loss_weight)
                    except AttributeError:
                        continue
                optimizer_center.step()

            batch_time.update(time.time() - t0)
            total_losses.update(total_loss.item())

            # learning_rate
            current_lr = optimizer.param_groups[0]['lr']
            summary_writer.add_scalar("lr", current_lr, global_step)

            t0 = time.time()

            if (i + 1) % args.log_period == 0:
                print(
                    f"Epoch: [{epoch}][{i+1}/{train_data_len}]  " +
                    f"Batch Time {batch_time.val:.3f} ({batch_time.mean:.3f})  "
                    +
                    f"Total_loss {total_losses.val:.3f} ({total_losses.mean:.3f})"
                )
            global_step += 1

        print(
            f"Epoch: [{epoch}]\tEpoch Time {batch_time.sum:.3f} s\tLoss {total_losses.mean:.3f}\tLr {current_lr:.2e}"
        )

        if args.eval_period > 0 and (epoch + 1) % args.eval_period == 0 or (
                epoch + 1) == args.max_epoch:
            rank, mAP = reid_evaluator.evaluate(
                val_loader,
                mode="retrieval" if args.dataset_name == "cub200" else "reid")

            rank_string = ""
            for r in (1, 2, 4, 5, 8, 10, 16, 20):
                rank_string += f"Rank-{r:<3}: {rank[r-1]:.1%}"
                if r != 20:
                    rank_string += "    "
            summary_writer.add_text("Recall@K", rank_string, global_step)
            summary_writer.add_scalar("Rank-1", rank[0], (epoch + 1))

            rank1 = rank[0]
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_mAP = mAP
                best_epoch = epoch + 1

            if (epoch + 1) % args.save_period == 0 or (epoch +
                                                       1) == args.max_epoch:
                pth_file_name = os.path.join(
                    model_save_dir,
                    f"{args.backbone}_model_{epoch + 1}.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

            if is_best:
                pth_file_name = os.path.join(
                    model_save_dir, f"{args.backbone}_model_best.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

        # end epoch
        current_epoch += 1

        batch_time.reset()
        total_losses.reset()
        torch.cuda.empty_cache()

        # update learning rate
        scheduler.step()

    print(f"Best rank-1 {best_rank1:.1%}, achived at epoch {best_epoch}")
    summary_writer.add_hparams(
        {
            "dataset_name": args.dataset_name,
            "triplet_dim": args.triplet_dim,
            "margin": args.margin,
            "base_lr": args.base_lr,
            "use_attn": args.use_attn,
            "use_mask": args.use_mask,
            "use_local_feat": args.use_local_feat
        }, {
            "mAP": best_mAP,
            "Rank1": best_rank1
        })
Пример #12
0
def finetune(model,loader_train,data_length,device,criterion,optimizer,scheduler,\
            print_freq, print_logger,step,batch_size,epochs=1,use_top5=False,verbose=True):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    top1 = AverageMeter()
    top5 = AverageMeter()
    best_acc = 0.

    # switch to train mode
    model.train()
    end = time.time()
    t1 = time.time()
    num_iterations = int(data_length / batch_size)

    for epoch in range(epochs):
        scheduler.step(epoch)

        for i, data in enumerate(loader_train):
            inputs = data[0]["data"].to(device)
            targets = data[0]["label"].squeeze().long().to(device)
            # measure data loading time
            data_time.update(time.time() - end)

            optimizer.zero_grad()
            # compute output
            output = model(inputs)
            loss = criterion(output, targets)
            # compute gradient
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), CLIP_VALUE)

            optimizer.step()

            optimizer.moment = []

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, targets, topk=(1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(prec1.item(), batch_size)
            top5.update(prec5.item(), batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print_logger.info(
                    'Finetune Step [{0}] Epoch [{1}|{2}] ({3}/{4}): '
                    'Loss {loss.avg:.4f} '
                    'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '.format(
                        step,
                        epoch,
                        epochs,
                        i,
                        num_iterations,
                        loss=losses,
                        top1=top1,
                        top5=top5))

        if use_top5:
            if top5.avg > best_acc:
                best_acc = top5.avg
        else:
            if top1.avg > best_acc:
                best_acc = top1.avg
        loader_train.reset()
    return model, top1.avg
Пример #13
0
def finetune_one_batch(model,
                       pre_params,
                       loader_train,
                       data_length,
                       device,
                       criterion,
                       optimizer,
                       scheduler,
                       print_freq,
                       print_logger,
                       step,
                       batch_size,
                       epochs=1,
                       use_top5=False,
                       verbose=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    top1 = AverageMeter()
    top5 = AverageMeter()
    best_acc = 0.
    informance = 0.0
    params = []

    model.train()
    end = time.time()
    t1 = time.time()

    for epoch in range(epochs):
        if scheduler is not None:
            scheduler.step(epoch)

        for batch_idx, data in enumerate(loader_train, 0):
            # for i,(inputs,targets) in enumerate(loader_train,0):
            # pdb.set_trace()
            inputs, targets = data
            inputs = inputs.to(device)
            targets = targets.to(device)
            # measure data loading time
            data_time.update(time.time() - end)

            optimizer.zero_grad()
            # compute output
            output = model(inputs)

            loss = criterion(output, targets)
            # compute gradient
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), CLIP_VALUE)
            params = []
            optimizer.step()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, targets, topk=(1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(prec1.item(), batch_size)
            top5.update(prec5.item(), batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print_logger.info(
                'Finetune One Batch Step [{0}]: '
                'Loss {loss.avg:.4f} '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '.format(
                    step, loss=losses, top1=top1, top5=top5))

        for _, p in model.named_parameters():
            params.append(p)

        moment = optimizer.moment
        informance = [0.0 for i in range(len(moment))]

        suminfo = 0.0
        for i in range(len(moment)):
            informance[i] = moment[i] * torch.pow(
                (pre_params[i] - params[i]), 2)

        suminfo = 0.0
        for info in informance:
            suminfo += torch.sum(info).item()

        if use_top5:
            if top5.avg > best_acc:
                best_acc = top5.avg
        else:
            if top1.avg > best_acc:
                best_acc = top1.avg
        optimizer.moment = []
    return model, suminfo, top1.avg
def train(train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          max_epoch,
          log_freq=1,
          print_sum=True,
          poses_mean=None,
          poses_std=None,
          device=None,
          stereo=True):

    # switch model to training
    model.train()

    losses = AverageMeter()

    epoch_time = time.time()

    gt_poses = np.empty((0, 7))
    pred_poses = np.empty((0, 7))

    end = time.time()
    for idx, (batch_images, batch_poses) in enumerate(train_loader):
        data_time = (time.time() - end)

        if stereo:
            batch_images = [x.to(device) for x in batch_images]
            batch_poses = [x.to(device) for x in batch_poses]
        else:
            batch_images = batch_images.to(device)
            batch_poses = batch_poses.to(device)

        out = model(batch_images)
        loss = criterion(out, batch_poses)
        #         print('loss = {}'.format(loss))

        # Make an optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(
            loss.data[0],
            len(batch_images) *
            batch_images[0].size(0) if stereo else batch_images.size(0))

        # move data to cpu & numpy
        if stereo:
            bp = [x.detach().cpu().numpy() for x in batch_poses]
            outp = [x.detach().cpu().numpy() for x in out]
            gt_poses = np.vstack((gt_poses, *bp))
            pred_poses = np.vstack((pred_poses, *outp))
        else:
            bp = batch_poses.detach().cpu().numpy()
            outp = out.detach().cpu().numpy()
            gt_poses = np.vstack((gt_poses, bp))
            pred_poses = np.vstack((pred_poses, outp))

        batch_time = (time.time() - end)
        end = time.time()

        if log_freq != 0 and idx % log_freq == 0:
            print('Epoch: [{}/{}]\tBatch: [{}/{}]\t'
                  'Time: {batch_time:.3f}\t'
                  'Data Time: {data_time:.3f}\t'
                  'Loss: {losses.val:.3f}\t'
                  'Avg Loss: {losses.avg:.3f}\t'.format(epoch,
                                                        max_epoch - 1,
                                                        idx,
                                                        len(train_loader) - 1,
                                                        batch_time=batch_time,
                                                        data_time=data_time,
                                                        losses=losses))

        # if idx == 0:
        #     break

    # un-normalize translation
    unnorm = (poses_mean is not None) and (poses_std is not None)
    if unnorm:
        gt_poses[:, :3] = gt_poses[:, :3] * poses_std + poses_mean
        pred_poses[:, :3] = pred_poses[:, :3] * poses_std + poses_mean

    t_loss = np.asarray([
        np.linalg.norm(p - t)
        for p, t in zip(pred_poses[:, :3], gt_poses[:, :3])
    ])
    q_loss = np.asarray([
        quaternion_angular_error(p, t)
        for p, t in zip(pred_poses[:, 3:], gt_poses[:, 3:])
    ])

    #     if unnorm:
    #         print('poses_std = {:.3f}'.format(np.linalg.norm(poses_std)))
    #     print('T: median = {:.3f}, mean = {:.3f}'.format(np.median(t_loss), np.mean(t_loss)))
    #     print('R: median = {:.3f}, mean = {:.3f}'.format(np.median(q_loss), np.mean(q_loss)))

    if print_sum:
        print(
            'Ep: [{}/{}]\tTrain Loss: {:.3f}\tTe: {:.3f}\tRe: {:.3f}\t Et: {:.2f}s\t\
              {criterion_sx:.5f}:{criterion_sq:.5f}'.format(
                epoch,
                max_epoch - 1,
                losses.avg,
                np.mean(t_loss),
                np.mean(q_loss), (time.time() - epoch_time),
                criterion_sx=criterion.sx.data[0],
                criterion_sq=criterion.sq.data[0]))