Esempio n. 1
0
def main(pretrain=True):
    config.save = 'ckpt/{}'.format(config.save)
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    assert type(pretrain) == bool or type(pretrain) == str
    update_arch = True
    if pretrain == True:
        update_arch = False
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Model #######################################
    model = Network(config.layers,
                    slimmable=config.slimmable,
                    width_mult_list=config.width_mult_list,
                    width_mult_list_sh=config.width_mult_list_sh,
                    loss_weight=config.loss_weight,
                    prun_modes=config.prun_modes,
                    quantize=config.quantize)
    model = torch.nn.DataParallel(model).cuda()

    # print(model)

    # teacher_model = Generator(3, 3)
    # teacher_model.load_state_dict(torch.load(config.generator_A2B))
    # teacher_model = torch.nn.DataParallel(teacher_model).cuda()

    # for param in teacher_model.parameters():
    #     param.require_grads = False

    if type(pretrain) == str:
        partial = torch.load(pretrain + "/weights.pt")
        state = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in partial.items()
            if k in state and state[k].size() == partial[k].size()
        }
        state.update(pretrained_dict)
        model.load_state_dict(state)
    # else:
    #     features = [model.module.stem, model.module.cells, model.module.header]
    #     init_weight(features, nn.init.kaiming_normal_, nn.InstanceNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu')

    architect = Architect(model, config)

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.module.stem.parameters())
    parameters += list(model.module.cells.parameters())
    parameters += list(model.module.header.parameters())

    if config.opt == 'Adam':
        optimizer = torch.optim.Adam(parameters,
                                     lr=base_lr,
                                     betas=config.betas)
    elif config.opt == 'Sgd':
        optimizer = torch.optim.SGD(parameters,
                                    lr=base_lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    else:
        logging.info("Wrong Optimizer Type.")
        sys.exit()

    # lr policy ##############################
    total_iteration = config.nepochs * config.niters_per_epoch

    if config.lr_schedule == 'linear':
        lr_policy = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=LambdaLR(config.nepochs, 0, config.decay_epoch).step)
    elif config.lr_schedule == 'exponential':
        lr_policy = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, config.lr_decay)
    elif config.lr_schedule == 'multistep':
        lr_policy = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=config.milestones, gamma=config.gamma)
    else:
        logging.info("Wrong Learning Rate Schedule Type.")
        sys.exit()

    # data loader ###########################

    transforms_ = [
        # transforms.Resize(int(config.image_height*1.12), Image.BICUBIC),
        # transforms.RandomCrop(config.image_height),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    # train_loader_model = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True, portion=config.train_portion),
    #                     batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    # train_loader_arch = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True, portion=config.train_portion-1),
    #                     batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

    train_loader_model = DataLoader(PairedImageDataset(
        config.dataset_path,
        config.target_path,
        transforms_=transforms_,
        portion=config.train_portion),
                                    batch_size=config.batch_size,
                                    shuffle=True,
                                    num_workers=config.num_workers)
    train_loader_arch = DataLoader(PairedImageDataset(
        config.dataset_path,
        config.target_path,
        transforms_=transforms_,
        portion=config.train_portion - 1),
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   num_workers=config.num_workers)

    transforms_ = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    test_loader = DataLoader(ImageDataset(config.dataset_path,
                                          transforms_=transforms_,
                                          mode='test'),
                             batch_size=1,
                             shuffle=False,
                             num_workers=config.num_workers)

    tbar = tqdm(range(config.nepochs), ncols=80)
    valid_fid_history = []
    flops_history = []
    flops_supernet_history = []

    best_fid = 1000
    best_epoch = 0

    for epoch in tbar:
        logging.info(pretrain)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        logging.info("update arch: " + str(update_arch))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train(pretrain,
              train_loader_model,
              train_loader_arch,
              model,
              architect,
              optimizer,
              lr_policy,
              logger,
              epoch,
              update_arch=update_arch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        if epoch and not (epoch + 1) % config.eval_epoch:
            tbar.set_description("[Epoch %d/%d][validation...]" %
                                 (epoch + 1, config.nepochs))

            save(model, os.path.join(config.save, 'weights_%d.pt' % epoch))

            with torch.no_grad():
                if pretrain == True:
                    model.module.prun_mode = "min"
                    valid_fid = infer(epoch, model, test_loader, logger)
                    logger.add_scalar('fid/val_min', valid_fid, epoch)
                    logging.info("Epoch %d: valid_fid_min %.3f" %
                                 (epoch, valid_fid))

                    if len(model.module._width_mult_list) > 1:
                        model.module.prun_mode = "max"
                        valid_fid = infer(epoch, model, test_loader, logger)
                        logger.add_scalar('fid/val_max', valid_fid, epoch)
                        logging.info("Epoch %d: valid_fid_max %.3f" %
                                     (epoch, valid_fid))

                        model.module.prun_mode = "random"
                        valid_fid = infer(epoch, model, test_loader, logger)
                        logger.add_scalar('fid/val_random', valid_fid, epoch)
                        logging.info("Epoch %d: valid_fid_random %.3f" %
                                     (epoch, valid_fid))

                else:
                    model.module.prun_mode = None

                    valid_fid, flops = infer(epoch,
                                             model,
                                             test_loader,
                                             logger,
                                             finalize=True)

                    logger.add_scalar('fid/val', valid_fid, epoch)
                    logging.info("Epoch %d: valid_fid %.3f" %
                                 (epoch, valid_fid))

                    logger.add_scalar('flops/val', flops, epoch)
                    logging.info("Epoch %d: flops %.3f" % (epoch, flops))

                    valid_fid_history.append(valid_fid)
                    flops_history.append(flops)

                    if update_arch:
                        flops_supernet_history.append(architect.flops_supernet)

                if valid_fid < best_fid:
                    best_fid = valid_fid
                    best_epoch = epoch
                logging.info("Best fid:%.3f, Best epoch:%d" %
                             (best_fid, best_epoch))

                if update_arch:
                    state = {}
                    state['alpha'] = getattr(model.module, 'alpha')
                    state['beta'] = getattr(model.module, 'beta')
                    state['ratio'] = getattr(model.module, 'ratio')
                    state['beta_sh'] = getattr(model.module, 'beta_sh')
                    state['ratio_sh'] = getattr(model.module, 'ratio_sh')
                    state["fid"] = valid_fid
                    state["flops"] = flops

                    torch.save(
                        state,
                        os.path.join(config.save,
                                     "arch_%d_%f.pt" % (epoch, flops)))

                    if config.flops_weight > 0:
                        if flops < config.flops_min:
                            architect.flops_weight /= 2
                        elif flops > config.flops_max:
                            architect.flops_weight *= 2
                        logger.add_scalar("arch/flops_weight",
                                          architect.flops_weight, epoch + 1)
                        logging.info("arch_flops_weight = " +
                                     str(architect.flops_weight))

    save(model, os.path.join(config.save, 'weights.pt'))

    if update_arch:
        torch.save(state, os.path.join(config.save, "arch.pt"))
Esempio n. 2
0
def main():
    config.save = 'ckpt/{}'.format(config.save)
    create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    state = torch.load(os.path.join(config.load_path, 'arch.pt'))
    # Model #######################################
    model = NAS_GAN_Infer(state['alpha'], state['beta'], state['ratio'], num_cell=config.num_cell, op_per_cell=config.op_per_cell, width_mult_list=config.width_mult_list, 
                          loss_weight=config.loss_weight, loss_func=config.loss_func, before_act=config.before_act, quantize=config.quantize)

    flops, params = profile(model, inputs=(torch.randn(1, 3, 510, 350),), custom_ops=custom_ops)
    flops = model.forward_flops(size=(3, 510, 350))
    logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)

    model = torch.nn.DataParallel(model).cuda()

    if type(config.pretrain) == str:
        state_dict = torch.load(config.pretrain)
        model.load_state_dict(state_dict)
    # else:
    #     features = [model.module.cells, model.module.conv_first, model.module.trunk_conv, model.module.upconv1, 
    #                 model.module.upconv2, model.module.HRconv, model.module.conv_last]
    #     init_weight(features, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu')

    teacher_model = RRDBNet(3, 3, 64, 23, gc=32)
    teacher_model.load_state_dict(torch.load(config.generator_A2B), strict=True)
    teacher_model = torch.nn.DataParallel(teacher_model).cuda()
    teacher_model.eval()

    for param in teacher_model.parameters():
        param.require_grads = False

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.module.cells.parameters())
    parameters += list(model.module.conv_first.parameters())
    parameters += list(model.module.trunk_conv.parameters())
    parameters += list(model.module.upconv1.parameters())
    parameters += list(model.module.upconv2.parameters())
    parameters += list(model.module.HRconv.parameters())
    parameters += list(model.module.conv_last.parameters())

    if config.opt == 'Adam':
        optimizer = torch.optim.Adam(
            parameters,
            lr=base_lr,
            betas=config.betas)
    elif config.opt == 'Sgd':
        optimizer = torch.optim.SGD(
            parameters,
            lr=base_lr,
            momentum=config.momentum,
            weight_decay=config.weight_decay)
    else:
        logging.info("Wrong Optimizer Type.")
        sys.exit()

    # lr policy ##############################
    total_iteration = config.nepochs * config.niters_per_epoch

    if config.lr_schedule == 'linear':
        lr_policy = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(config.nepochs, 0, config.decay_epoch).step)
    elif config.lr_schedule == 'exponential':
        lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, config.lr_decay)
    elif config.lr_schedule == 'multistep':
        lr_policy = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
    else:
        logging.info("Wrong Learning Rate Schedule Type.")
        sys.exit()


    # data loader ############################

    transforms_ = [ transforms.RandomCrop(config.image_height), 
                    transforms.RandomHorizontalFlip(), 
                    transforms.ToTensor()]
    train_loader_model = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True), 
                        batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

    transforms_ = [ transforms.ToTensor()]
    test_loader = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, mode='val'), 
                        batch_size=1, shuffle=False, num_workers=config.num_workers)


    if config.eval_only:
        logging.info('Eval: psnr = %f', infer(0, model, test_loader, logger))
        sys.exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs))
        train(train_loader_model, model, teacher_model, optimizer, lr_policy, logger, epoch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        if epoch and not (epoch+1) % config.eval_epoch:
            tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs))
            
            with torch.no_grad():
                model.prun_mode = None

                valid_psnr = infer(epoch, model, test_loader, logger)

                logger.add_scalar('psnr/val', valid_psnr, epoch)
                logging.info("Epoch %d: valid_psnr %.3f"%(epoch, valid_psnr))
                
                logger.add_scalar('flops/val', flops, epoch)
                logging.info("Epoch %d: flops %.3f"%(epoch, flops))

            save(model, os.path.join(config.save, 'weights_%d.pt'%epoch))

    save(model, os.path.join(config.save, 'weights.pt'))
Esempio n. 3
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)
    distill_criterion = nn.KLDivLoss()

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    train_loader = get_train_loader(config, Cityscapes, test=config.is_test)

    # Model #######################################
    models = []
    evaluators = []
    testers = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12: last = [2, 0]
        else: last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ))
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        if arch_idx == 0 and len(config.arch_idx) > 1:
            partial = torch.load(
                os.path.join(config.teacher_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)
        elif config.is_eval:
            partial = torch.load(
                os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=None,
                                 show_image=False)
        evaluators.append(evaluator)
        tester = SegTester(Cityscapes(data_setting, 'test', None),
                           config.num_classes,
                           config.image_mean,
                           config.image_std,
                           model,
                           config.eval_scale_array,
                           config.eval_flip,
                           0,
                           out_idx=0,
                           config=config,
                           verbose=False,
                           save_path=None,
                           show_image=False)
        testers.append(tester)

        # Optimizer ###################################
        base_lr = config.lr
        if arch_idx == 1 or len(config.arch_idx) == 1:
            # optimize teacher solo OR student (w. distill from teacher)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=base_lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weight_decay)
        models.append(model)

    # Cityscapes ###########################################
    if config.is_eval:
        logging.info(config.load_path)
        logging.info(config.eval_path)
        logging.info(config.save)
        # validation
        print("[validation...]")
        with torch.no_grad():
            valid_mIoUs = infer(models, evaluators, logger)
            for idx, arch_idx in enumerate(config.arch_idx):
                if arch_idx == 0:
                    logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0)
                    logging.info("teacher's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
                else:
                    logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0)
                    logging.info("student's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
        exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.load_path)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))
        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train_mIoUs = train(train_loader, models, ohem_criterion,
                            distill_criterion, optimizer, logger, epoch)
        torch.cuda.empty_cache()
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx],
                                  epoch)
                logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx]))
            else:
                logger.add_scalar("mIoU/train_student", train_mIoUs[idx],
                                  epoch)
                logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx]))
        adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1,
                             config.nepochs)

        # validation
        if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0):
            tbar.set_description("[Epoch %d/%d][validation...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                valid_mIoUs = infer(models, evaluators, logger)
                for idx, arch_idx in enumerate(config.arch_idx):
                    if arch_idx == 0:
                        logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx],
                                          epoch)
                        logging.info("teacher's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    else:
                        logger.add_scalar("mIoU/val_student", valid_mIoUs[idx],
                                          epoch)
                        logging.info("student's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    save(models[idx],
                         os.path.join(config.save, "weights%d.pt" % arch_idx))
        # test
        if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0:
            tbar.set_description("[Epoch %d/%d][test...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                test(epoch, models, testers, logger)

        for idx, arch_idx in enumerate(config.arch_idx):
            save(models[idx],
                 os.path.join(config.save, "weights%d.pt" % arch_idx))
Esempio n. 4
0
def main(pretrain=True):
    config.save = 'search-{}-{}'.format(config.save,
                                        time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    assert type(pretrain) == bool or type(pretrain) == str
    update_arch = True
    if pretrain == True:
        update_arch = False
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)

    # Model #######################################
    model = Network(config.num_classes,
                    config.layers,
                    ohem_criterion,
                    Fch=config.Fch,
                    width_mult_list=config.width_mult_list,
                    prun_modes=config.prun_modes,
                    stem_head_width=config.stem_head_width)
    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 1024, 2048), ),
                            verbose=False)
    logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
    model = model.cuda()
    if type(pretrain) == str:
        partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0')
        state = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in partial.items()
            if k in state and state[k].size() == partial[k].size()
        }
        state.update(pretrained_dict)
        model.load_state_dict(state)
    else:
        init_weight(model,
                    nn.init.kaiming_normal_,
                    nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    architect = Architect(model, config)

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.stem.parameters())
    parameters += list(model.cells.parameters())
    parameters += list(model.refine32.parameters())
    parameters += list(model.refine16.parameters())
    parameters += list(model.head0.parameters())
    parameters += list(model.head1.parameters())
    parameters += list(model.head2.parameters())
    parameters += list(model.head02.parameters())
    parameters += list(model.head12.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    # lr policy ##############################
    lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978)

    # data loader ###########################
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source,
        'down_sampling': config.down_sampling
    }
    train_loader_model = get_train_loader(config,
                                          EGTEA,
                                          portion=config.train_portion)
    train_loader_arch = get_train_loader(config,
                                         EGTEA,
                                         portion=config.train_portion - 1)

    evaluator = SegEvaluator(EGTEA(data_setting, 'val', None),
                             config.num_classes,
                             config.image_mean,
                             config.image_std,
                             model,
                             config.eval_scale_array,
                             config.eval_flip,
                             0,
                             config=config,
                             verbose=False,
                             save_path=None,
                             show_image=False)

    if update_arch:
        for idx in range(len(config.latency_weight)):
            logger.add_scalar("arch/latency_weight%d" % idx,
                              config.latency_weight[idx], 0)
            logging.info("arch_latency_weight%d = " % idx +
                         str(config.latency_weight[idx]))

    tbar = tqdm(range(config.nepochs), ncols=80)
    valid_mIoU_history = []
    FPSs_history = []
    latency_supernet_history = []
    latency_weight_history = []
    valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"]
    arch_names = {0: "teacher", 1: "student"}
    for epoch in tbar:
        logging.info(pretrain)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        logging.info("update arch: " + str(update_arch))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train(pretrain,
              train_loader_model,
              train_loader_arch,
              model,
              architect,
              ohem_criterion,
              optimizer,
              lr_policy,
              logger,
              epoch,
              update_arch=update_arch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        tbar.set_description("[Epoch %d/%d][validation...]" %
                             (epoch + 1, config.nepochs))
        with torch.no_grad():
            if pretrain == True:
                model.prun_mode = "min"
                valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False)
                for i in range(5):
                    logger.add_scalar('mIoU/val_min_%s' % valid_names[i],
                                      valid_mIoUs[i], epoch)
                    logging.info("Epoch %d: valid_mIoU_min_%s %.3f" %
                                 (epoch, valid_names[i], valid_mIoUs[i]))
                if len(model._width_mult_list) > 1:
                    model.prun_mode = "max"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar('mIoU/val_max_%s' % valid_names[i],
                                          valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_max_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
                    model.prun_mode = "random"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar(
                            'mIoU/val_random_%s' % valid_names[i],
                            valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_random_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
            else:
                valid_mIoUss = []
                FPSs = []
                model.prun_mode = None
                for idx in range(len(model._arch_names)):
                    # arch_idx
                    model.arch_idx = idx
                    valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator,
                                                    logger)
                    valid_mIoUss.append(valid_mIoUs)
                    FPSs.append([fps0, fps1])
                    for i in range(5):
                        # preds
                        logger.add_scalar(
                            'mIoU/val_%s_%s' %
                            (arch_names[idx], valid_names[i]), valid_mIoUs[i],
                            epoch)
                        logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" %
                                     (epoch, arch_names[idx], valid_names[i],
                                      valid_mIoUs[i]))
                    if config.latency_weight[idx] > 0:
                        logger.add_scalar(
                            'Objective/val_%s_8s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[3], 1000. / fps0),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_8s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[3], 1000. / fps0)))
                        logger.add_scalar(
                            'Objective/val_%s_16s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[4], 1000. / fps1),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_16s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[4], 1000. / fps1)))
                valid_mIoU_history.append(valid_mIoUss)
                FPSs_history.append(FPSs)
                if update_arch:
                    latency_supernet_history.append(architect.latency_supernet)
                latency_weight_history.append(architect.latency_weight)

        save(model, os.path.join(config.save, 'weights.pt'))
        if type(pretrain) == str:
            # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios}
            for idx, arch_name in enumerate(model._arch_names):
                state = {}
                for name in arch_name['alphas']:
                    state[name] = getattr(model, name)
                for name in arch_name['betas']:
                    state[name] = getattr(model, name)
                for name in arch_name['ratios']:
                    state[name] = getattr(model, name)
                state["mIoU02"] = valid_mIoUs[3]
                state["mIoU12"] = valid_mIoUs[4]
                if pretrain is not True:
                    state["latency02"] = 1000. / fps0
                    state["latency12"] = 1000. / fps1
                torch.save(
                    state,
                    os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch)))
                torch.save(state,
                           os.path.join(config.save, "arch_%d.pt" % (idx)))

        if update_arch:
            for idx in range(len(config.latency_weight)):
                if config.latency_weight[idx] > 0:
                    if (int(FPSs[idx][0] >= config.FPS_max[idx]) +
                            int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1:
                        architect.latency_weight[idx] /= 2
                    elif (int(FPSs[idx][0] <= config.FPS_min[idx]) +
                          int(FPSs[idx][1] <= config.FPS_min[idx])) > 0:
                        architect.latency_weight[idx] *= 2
                    logger.add_scalar(
                        "arch/latency_weight_%s" % arch_names[idx],
                        architect.latency_weight[idx], epoch + 1)
                    logging.info("arch_latency_weight_%s = " %
                                 arch_names[idx] +
                                 str(architect.latency_weight[idx]))