def validate(run_manager, epoch=0, is_test=True, image_size_list=None,
             width_mult_list=None, ks_list=None, expand_ratio_list=None, depth_list=None, additional_setting=None):
    dynamic_net = run_manager.net
    if isinstance(dynamic_net, nn.DataParallel):
        dynamic_net = dynamic_net.module

    dynamic_net.eval()

    if image_size_list is None:
        image_size_list = int2list(run_manager.run_config.data_provider.image_size, 1)
    if width_mult_list is None:
        width_mult_list = [i for i in range(len(dynamic_net.width_mult_list))]
    if ks_list is None:
        ks_list = dynamic_net.ks_list
    if expand_ratio_list is None:
        expand_ratio_list = dynamic_net.expand_ratio_list
    if depth_list is None:
        depth_list = dynamic_net.depth_list

    subnet_settings = []
    for w in width_mult_list:
        for d in depth_list:
            for e in expand_ratio_list:
                for k in ks_list:
                    for img_size in image_size_list:
                        subnet_settings.append([{
                            'image_size': img_size,
                            'wid': w,
                            'd': d,
                            'e': e,
                            'ks': k,
                        }, 'R%s-W%s-D%s-E%s-K%s' % (img_size, w, d, e, k)])
    if additional_setting is not None:
        subnet_settings += additional_setting

    losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []

    valid_log = ''
    for setting, name in subnet_settings:
        run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False)
        run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))
        dynamic_net.set_active_subnet(**setting)
        run_manager.write_log(dynamic_net.module_str, 'train', should_print=False)

        run_manager.reset_running_statistics(dynamic_net)
        loss, top1, top5 = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net)
        losses_of_subnets.append(loss)
        top1_of_subnets.append(top1)
        top5_of_subnets.append(top5)
        valid_log += '%s (%.3f), ' % (name, top1)

    return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
    def train(self, args, warmup_epochs=5, warmup_lr=0):
        for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
            train_loss, train_top1, train_top5 = self.train_one_epoch(args, epoch, warmup_epochs, warmup_lr)
            img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(epoch, is_test=False)

            is_best = list_mean(val_top1) > self.best_acc
            self.best_acc = max(self.best_acc, list_mean(val_top1))
            if self.is_root:
                val_log = '[{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})\ttop-5 acc {5:.3f}\t' \
                          'Train top-1 {top1:.3f}\tloss {train_loss:.3f}\t'. \
                    format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
                           list_mean(val_top1), self.best_acc, list_mean(val_top5),
                           top1=train_top1, train_loss=train_loss)
                for i_s, v_a in zip(img_size, val_top1):
                    val_log += '(%d, %.3f), ' % (i_s, v_a)
                self.write_log(val_log, prefix='valid', should_print=False)

                self.save_model({
                    'epoch': epoch,
                    'best_acc': self.best_acc,
                    'optimizer': self.optimizer.state_dict(),
                    'state_dict': self.net.state_dict(),
                }, is_best=is_best)
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
    dynamic_net = run_manager.net

    # switch to train mode
    dynamic_net.train()
    run_manager.run_config.train_loader.sampler.set_epoch(epoch)
    MyRandomResizedCrop.EPOCH = epoch

    nBatch = len(run_manager.run_config.train_loader)

    data_time = AverageMeter()
    losses = DistributedMetric('train_loss')
    top1 = DistributedMetric('train_top1')
    top5 = DistributedMetric('train_top5')

    with tqdm(total=nBatch,
              desc='Train Epoch #{}'.format(epoch + 1),
              disable=not run_manager.is_root) as t:
        end = time.time()
        for i, (images,
                labels) in enumerate(run_manager.run_config.train_loader):
            data_time.update(time.time() - end)
            if epoch < warmup_epochs:
                new_lr = run_manager.run_config.warmup_adjust_learning_rate(
                    run_manager.optimizer,
                    warmup_epochs * nBatch,
                    nBatch,
                    epoch,
                    i,
                    warmup_lr,
                )
            else:
                new_lr = run_manager.run_config.adjust_learning_rate(
                    run_manager.optimizer, epoch - warmup_epochs, i, nBatch)

            images, labels = images.cuda(), labels.cuda()
            target = labels

            # soft target
            if args.kd_ratio > 0:
                args.teacher_model.train()
                with torch.no_grad():
                    soft_logits = args.teacher_model(images).detach()
                    soft_label = F.softmax(soft_logits, dim=1)

            # clear gradients
            run_manager.optimizer.zero_grad()

            loss_of_subnets, acc1_of_subnets, acc5_of_subnets = [], [], []
            # compute output
            subnet_str = ''
            for _ in range(args.dynamic_batch_size):

                # set random seed before sampling
                if args.independent_distributed_sampling:
                    subnet_seed = os.getpid() + time.time()
                else:
                    subnet_seed = int('%d%.3d%.3d' %
                                      (epoch * nBatch + i, _, 0))
                random.seed(subnet_seed)
                subnet_settings = dynamic_net.sample_active_subnet()
                subnet_str += '%d: ' % _ + ','.join([
                    '%s_%s' %
                    (key, '%.1f' %
                     subset_mean(val, 0) if isinstance(val, list) else val)
                    for key, val in subnet_settings.items()
                ]) + ' || '

                output = run_manager.net(images)
                if args.kd_ratio == 0:
                    loss = run_manager.train_criterion(output, labels)
                    loss_type = 'ce'
                else:
                    if args.kd_type == 'ce':
                        kd_loss = cross_entropy_loss_with_soft_target(
                            output, soft_label)
                    else:
                        kd_loss = F.mse_loss(output, soft_logits)
                    loss = args.kd_ratio * kd_loss + run_manager.train_criterion(
                        output, labels)
                    loss = loss * (2 / (args.kd_ratio + 1))
                    loss_type = '%.1fkd-%s & ce' % (args.kd_ratio,
                                                    args.kd_type)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                loss_of_subnets.append(loss)
                acc1_of_subnets.append(acc1[0])
                acc5_of_subnets.append(acc5[0])

                loss.backward()
            run_manager.optimizer.step()

            losses.update(list_mean(loss_of_subnets), images.size(0))
            top1.update(list_mean(acc1_of_subnets), images.size(0))
            top5.update(list_mean(acc5_of_subnets), images.size(0))

            t.set_postfix({
                'loss': losses.avg.item(),
                'top1': top1.avg.item(),
                'top5': top5.avg.item(),
                'R': images.size(2),
                'lr': new_lr,
                'loss_type': loss_type,
                'seed': str(subnet_seed),
                'str': subnet_str,
                'data_time': data_time.avg,
            })
            t.update(1)
            end = time.time()
    return losses.avg.item(), top1.avg.item(), top5.avg.item()
Пример #4
0
def validate(run_manager,
             epoch=0,
             is_test=True,
             image_size_list=None,
             width_mult_list=None,
             ks_list=None,
             expand_ratio_list=None,
             depth_list=None,
             pixelshuffle_depth_list=None,
             additional_setting=None):
    dynamic_net = run_manager.net
    if isinstance(dynamic_net, nn.DataParallel):
        dynamic_net = dynamic_net.module

    dynamic_net.eval()

    if image_size_list is None:
        image_size_list = int2list(
            run_manager.run_config.data_provider.image_size, 1)
    if width_mult_list is None:
        width_mult_list = [i for i in range(len(dynamic_net.width_mult_list))]
    if ks_list is None:
        ks_list = dynamic_net.ks_list
    if expand_ratio_list is None:
        expand_ratio_list = dynamic_net.expand_ratio_list
    if depth_list is None:
        depth_list = dynamic_net.depth_list
    if pixelshuffle_depth_list is None:
        pixelshuffle_depth_list = dynamic_net.pixelshuffle_depth_list

    subnet_settings = []
    for pixel_d in pixelshuffle_depth_list:
        for w in width_mult_list:
            for d in depth_list:
                for e in expand_ratio_list:
                    for k in ks_list:
                        # for img_size in image_size_list:
                        subnet_settings.append([
                            {
                                # 'image_size': img_size,
                                'pixel_d': pixel_d,
                                'wid': w,
                                'd': d,
                                'e': e,
                                'ks': k,
                            },
                            'PD%s-W%s-D%s-E%s-K%s' % (pixel_d, w, d, e, k)
                        ])
    if additional_setting is not None:
        subnet_settings += additional_setting

    # losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []
    losses_of_subnets, psnr_of_subnets = [], []

    valid_log = ''
    for setting, name in subnet_settings:
        #################### Validation Architecture 정하는 부분인데, Single Architecture Overfitting 혹은 뭐 빠르게 테스트 해볼일 있으면 여기서 그냥 스킵하면됨
        # if name.find('PD1-W0-D2-E3-K7') == -1:
        #     continue

        run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30,
                              'train',
                              should_print=False)
        # run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))

        #################### Random Sampling과 Structured Sampling중에 주석 바꿔가면서 고르면 됨
        # dynamic_net.sample_active_subnet()
        dynamic_net.set_active_subnet(**setting)

        run_manager.write_log(dynamic_net.module_str,
                              'train',
                              should_print=False)

        #################### Oracle Training 시에는 Batch Mean/Variance 현재 데이터로 업데이트하면 망함.
        # run_manager.reset_running_statistics(dynamic_net)
        loss, psnr = run_manager.validate(epoch=epoch,
                                          is_test=is_test,
                                          run_str=name,
                                          net=dynamic_net)
        losses_of_subnets.append(loss)
        # top1_of_subnets.append(top1)
        # top5_of_subnets.append(top5)
        psnr_of_subnets.append(psnr)
        valid_log += '%s (%.3f), ' % (name, psnr)

    return list_mean(losses_of_subnets), list_mean(psnr_of_subnets), valid_log
Пример #5
0
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
    dynamic_net = run_manager.net
    if isinstance(dynamic_net, nn.DataParallel):
        dynamic_net = dynamic_net.module

    # switch to train mode
    dynamic_net.train()
    # run_manager.run_config.train_loader.sampler.set_epoch(epoch)
    # MyRandomResizedCrop.EPOCH = epoch
    #################### Code for freezing BN. Overfitting 할 때는 주석 해제하면됨.
    # for m in dynamic_net.modules():
    #     if isinstance(m, nn.BatchNorm2d):
    #         ########## Use running mean/var
    #         m.eval()
    #         ########## BN weight/bias freeze
    #         # m.weight.requires_grad = False
    #         # m.bias.requires_grad = False

    nBatch = len(run_manager.run_config.train_loader)

    data_time = AverageMeter()
    # losses = DistributedMetric('train_loss')
    # top1 = DistributedMetric('train_top1')
    # top5 = DistributedMetric('train_top5')
    losses = AverageMeter()
    psnr_averagemeter = AverageMeter()

    with tqdm(total=nBatch, desc='Train Epoch #{}'.format(epoch + 1)) as t:
        end = time.time()
        for i, mini_batch in enumerate(run_manager.run_config.train_loader):
            images = mini_batch['image']
            #################### 2x or 4x 고르는 부분.
            x2_down_images = mini_batch['2x_down_image']
            x4_down_images = mini_batch['4x_down_image']
            data_time.update(time.time() - end)
            if epoch < warmup_epochs:
                new_lr = run_manager.run_config.warmup_adjust_learning_rate(
                    run_manager.optimizer,
                    warmup_epochs * nBatch,
                    nBatch,
                    epoch,
                    i,
                    warmup_lr,
                )
            else:
                new_lr = run_manager.run_config.adjust_learning_rate(
                    run_manager.optimizer, epoch - warmup_epochs, i, nBatch)

            images = images.cuda()
            #################### 2x or 4x 고르는 부분.
            x2_down_images = x2_down_images.cuda()
            x4_down_images = x4_down_images.cuda()
            target = images

            # soft target
            if args.kd_ratio > 0:
                args.teacher_model.train()
                with torch.no_grad():
                    soft_logits = args.teacher_model(images).detach()
                    soft_label = F.softmax(soft_logits, dim=1)

            # clear gradients
            run_manager.optimizer.zero_grad()

            loss_of_subnets, psnr_of_subnets = [], []
            # compute output
            subnet_str = ''
            for _ in range(args.dynamic_batch_size):

                # set random seed before sampling
                if args.independent_distributed_sampling:
                    subnet_seed = os.getpid() + time.time()
                else:
                    subnet_seed = int('%d%.3d%.3d' %
                                      (epoch * nBatch + i, _, 0))
                random.seed(subnet_seed)

                #################### Random Sampling과 Structured Sampling중에 주석 바꿔가면서 고르면 됨. Single Architecture Overfitting을 위해서 여기 수정해주면 가능.
                subnet_settings = dynamic_net.sample_active_subnet()
                # dynamic_net.set_active_subnet(ks=7, e=3, d=2, pixel_d=1)

                subnet_str += '%d: ' % _ + ','.join([
                    '%s_%s' %
                    (key, '%.1f' %
                     subset_mean(val, 0) if isinstance(val, list) else val)
                    for key, val in subnet_settings.items()
                ]) + ' || '

                #################### 2x or 4x 고르는 부분.
                # output = run_manager.net(images)
                if subnet_settings['pixel_d'][0] == 1:
                    output = run_manager.net(x2_down_images)
                elif subnet_settings['pixel_d'][0] == 2:
                    output = run_manager.net(x4_down_images)

                if args.kd_ratio == 0:
                    loss = run_manager.train_criterion(output, images)
                    loss_type = 'mse'
                else:
                    if args.kd_type == 'ce':
                        kd_loss = cross_entropy_loss_with_soft_target(
                            output, soft_label)
                    else:
                        kd_loss = F.mse_loss(output, soft_logits)
                    loss = args.kd_ratio * kd_loss + run_manager.train_criterion(
                        output, labels)
                    loss = loss * (2 / (args.kd_ratio + 1))
                    loss_type = '%.1fkd-%s & ce' % (args.kd_ratio,
                                                    args.kd_type)

                # measure accuracy and record loss
                # acc1, acc5 = accuracy(output, target, topk=(1, 5))
                psnr_current = psnr(rgb2y(tensor2img_np(output)),
                                    rgb2y(tensor2img_np(images)))
                loss_of_subnets.append(loss)
                # acc1_of_subnets.append(acc1[0])
                # acc5_of_subnets.append(acc5[0])
                psnr_of_subnets.append(psnr_current)

                loss.backward()
            run_manager.optimizer.step()

            losses.update(list_mean(loss_of_subnets), images.size(0))
            # top1.update(list_mean(acc1_of_subnets), images.size(0))
            # top5.update(list_mean(acc5_of_subnets), images.size(0))
            psnr_averagemeter.update(list_mean(psnr_of_subnets),
                                     images.size(0))

            t.set_postfix({
                'loss': losses.avg.item(),
                # 'top1': top1.avg.item(),
                # 'top5': top5.avg.item(),
                'psnr': psnr_averagemeter.avg,
                'R': images.size(2),
                'lr': new_lr,
                'loss_type': loss_type,
                'seed': str(subnet_seed),
                'str': subnet_str,
                'data_time': data_time.avg,
            })
            t.update(1)
            end = time.time()
    return losses.avg.item(), psnr_averagemeter.avg