Ejemplo n.º 1
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    # while(1):
    #     a=2
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    traindir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                            'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                          'ILSVRC2012_img_val_rec')
    train_loader, val_loader = getTrainValDataset(traindir, valdir,
                                                  batch_sizes, 100, num_gpu,
                                                  num_workers)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')

    model_t = ResNet50()

    # model_kd = resnet101(pretrained=False)

    #print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t

    model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = ResNet50_sprase().to(args.gpus[0])
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    #ckpt_kd = torch.load('resnet101-5d3b4d8f.pth', map_location=torch.device(f"cuda:{args.gpus[0]}"))
    #state_dict_kd = ckpt_kd
    #new_state_dict_kd = state_dict_kd
    #model_kd.load_state_dict(new_state_dict_kd)
    #model_kd = model_kd.to(args.gpus[0])

    #for para in list(model_kd.parameters())[:-2]:
    #para.requires_grad = False

    model_d = Discriminator().to(args.gpus[0])

    model_s = nn.DataParallel(model_s).cuda()
    model_t = nn.DataParallel(model_t).cuda()
    model_d = nn.DataParallel(model_d).cuda()

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr * 100, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_s = ckpt['state_dict_s']
        state_dict_d = ckpt['state_dict_d']

        new_state_dict_s = OrderedDict()
        for k, v in state_dict_s.items():
            new_state_dict_s['module.' + k] = v

        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(new_state_dict_s)
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        start_epoch = ckpt['epoch']
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    models = [model_t, model_s, model_d]  #, model_kd]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        #global g_e
        #g_e = epoch
        #gl.set_value('epoch',g_e)

        train(args, train_loader, models, optimizers, epoch, writer_train)
        test_prec1, test_prec5 = test(args, val_loader, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        train_loader.reset()
        val_loader.reset()
        #if is_best:
        checkpoint.save_model(state, epoch + 1, is_best)
        #checkpoint.save_model(state, 1, False)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
Ejemplo n.º 2
0
def main():

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = import_module(f'model.{args.arch}').__dict__[args.teacher_model]().to(device)

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir, map_location=device)
    

    if args.arch == 'densenet':
        state_dict_t = {}
        for k, v in ckpt_t['state_dict'].items():
            new_key = '.'.join(k.split('.')[1:])
            if new_key == 'linear.weight':
                new_key = 'fc.weight'
            elif new_key == 'linear.bias':
                new_key = 'fc.bias'
            state_dict_t[new_key] = v
    else:
        state_dict_t = ckpt_t['state_dict']


    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(device)

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = import_module(f'model.{args.arch}').__dict__[args.student_model]().to(device)

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(device) 

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    param_s = [param for name, param in model_s.named_parameters() if 'mask' not in name]
    param_m = [param for name, param in model_s.named_parameters() if 'mask' in name]

    optimizer_s = optim.SGD(param_s, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume, map_location=device)
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']

        model_s.load_state_dict(ckpt['            state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))


    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]
    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print_logger.info(f"Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt', map_location=device)

    model = import_module('utils.preprocess').__dict__[f'{args.arch}'](args, best_model['state_dict_s'])
Ejemplo n.º 3
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = resnet_56().to(args.gpus[0])

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['state_dict']
    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = resnet_56_sparse().to(args.gpus[0])

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch,
              writer_train)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt',
                            map_location=torch.device(f"cuda:{args.gpus[0]}"))

    model = prune_resnet(args, best_model['state_dict_s'])
Ejemplo n.º 4
0
def main():
    args = get_args()

    # get log
    args.save = 'search-{}-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)
    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)
    trainLoader = torch.utils.data.DataLoader(traindata,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              shuffle=True,
                                              num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # load pretrained model
    model_t = Network(C=args.init_channels,
                      num_classes=args.class_num,
                      layers=args.layers,
                      steps=args.nodes,
                      multiplier=args.nodes,
                      stem_multiplier=args.stem_multiplier,
                      group=args.group)
    model_t, _, _ = loadCheckpoint(args.model_path, model_t, args)
    model_t.freeze_arch_parameters()
    # 冻结教师网络
    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = Network(C=args.init_channels,
                      num_classes=args.class_num,
                      layers=args.layers,
                      steps=args.nodes,
                      multiplier=args.nodes,
                      stem_multiplier=args.stem_multiplier,
                      group=args.group)
    model_s, _, _ = loadCheckpoint(args.model_path, model_s, args)
    model_s._initialize_alphas()

    criterion = nn.CrossEntropyLoss().to(args.device)
    model_d = Discriminator().to(args.device)
    model_s = model_s.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model_s))

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.learning_rate,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_s = optim.SGD(model_s.weight_parameters(),
                            lr=args.learning_rate,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(model_s.arch_parameters(),
                        lr=args.learning_rate,
                        gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    if args.resume:
        logger.info('=> Resuming from ckpt {}'.format(args.resume_path))
        ckpt = torch.load(args.resume_path, map_location=args.device)
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        perf_scoreboard = ckpt['perf_scoreboard']
        logger.info('=> Continue from epoch {}...'.format(start_epoch))

    models = [model_t, model_s, model_d]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            logger.info('epoch %d lr %e ', epoch, s.get_lr()[0])

        _, _, _ = train(trainLoader, models, epoch, optimizers, monitors, args,
                        logger)
        v_top1, v_top5, v_loss = validate(valLoader, model_s, criterion, epoch,
                                          monitors, args, logger)

        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

        logger.info("normal: \n{}".format(
            model_s.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model_s.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev1: {}'.format(model_s.genotypev1()))
        logger.info('Genotypev2: {}'.format(model_s.genotypev2()))
        logger.info('Genotypev3: {}'.format(model_s.genotypev3()))
        mask = []
        pruned = 0
        num = 0
        for param in model_s.arch_parameters():
            weight_copy = param.clone()
            param_array = np.array(weight_copy.detach().cpu())
            pruned += sum(w == 0 for w in param_array)
            num += len(param_array)
        logger.info("Epoch:{} Pruned {} / {}".format(epoch, pruned, num))

        if epoch % args.save_freq == 0:
            model_state_dict = model_s.module.state_dict() if len(
                args.gpus) > 1 else model_s.state_dict()
            state = {
                'state_dict_s': model_state_dict,
                'state_dict_d': model_d.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'optimizer_s': optimizer_s.state_dict(),
                'optimizer_m': optimizer_m.state_dict(),
                'scheduler_d': scheduler_d.state_dict(),
                'scheduler_s': scheduler_s.state_dict(),
                'scheduler_m': scheduler_m.state_dict(),
                "perf_scoreboard": perf_scoreboard,
                'epoch': epoch + 1
            }
            tools.save_model(state,
                             epoch + 1,
                             is_best,
                             path=os.path.join(args.save, "ckpt"))
        # update learning rate
        for s in schedulers:
            s.step(epoch)
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    # model_t = resnet_56().to(args.gpus[0])
    #model_t = MobileNetV2()
    model_t = ResNet18()
    model_kd = ResNet101()

    print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['net']
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t
    #for k, v in state_dict_t.items():
    #print(k[0:6])
    #if k[0:6] == 'linear':
    #temp = v[0:10]
    #print(v[0:10].shape)
    #new_state_dict_t[k] = temp

    #model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    #model_s = SpraseMobileNetV2().to(args.gpus[0])
    model_s = ResNet18_sprase().to(args.gpus[0])
    print(model_s)
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    ckpt_kd = torch.load('resnet101.t7',
                         map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_kd = ckpt_kd['net']
    new_state_dict_kd = OrderedDict()
    for k, v in state_dict_kd.items():
        name = k[7:]
        new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
    model_kd.load_state_dict(new_state_dict_kd)
    model_kd = model_kd.to(args.gpus[0])

    for para in list(model_kd.parameters())[:-2]:
        para.requires_grad = False

    if len(args.gpus) != 1:
        print('@@@@@@')
        model_s = nn.DataParallel(model_s, device_ids=args.gpus[0, 1])

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d, model_kd]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        global g_e
        g_e = epoch
        gl.set_value('epoch', g_e)

        #train(args, loader.loader_train, models, optimizers, epoch, writer_train)
        #print('###########################')
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        if is_best:
            checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
def main():
    args = get_args()

    # get log
    args.save = 'search-{}-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    # load pretrained model
    criterion = nn.CrossEntropyLoss()
    model = Network(C=args.init_channels,
                    num_classes=args.class_num,
                    layers=args.layers,
                    steps=args.nodes,
                    multiplier=args.nodes,
                    stem_multiplier=args.stem_multiplier,
                    group=args.group)
    model, _, _ = loadCheckpoint(args.model_path, model, args)

    if args.multi_gpu:
        logger.info('use: %d gpus', args.gpus)
        model = nn.DataParallel(model)
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model))

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)
    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)

    trainLoader = torch.utils.data.DataLoader(traindata,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              shuffle=True,
                                              num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # weight optimizer and struct parameters /mask optimizer
    optimizer_w = torch.optim.SGD(model.weight_parameters(),
                                  args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    # scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(
    #   optimizer_w, float(args.epochs), eta_min=args.learning_rate_min)
    optimizer_alpha = FISTA(model.arch_parameters(),
                            lr=args.arch_learning_rate,
                            gamma=args.sparse_lambda)
    # scheduler_alpha = torch.optim.lr_scheduler.CosineAnnealingLR(
    #   optimizer_alpha, float(args.epochs))
    scheduler_w = StepLR(optimizer_w, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_alpha = StepLR(optimizer_alpha,
                             step_size=args.lr_decay_step,
                             gamma=0.1)
    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    # resume
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume_path):
            model, extras, start_epoch = loadCheckpoint(
                args.resume_path, model, args)
            scheduler_w = extras["scheduler_w"]
            scheduler_alpha = extras["scheduler_alpha"]
            optimizer_w = extras["optimizer_w"]
            optimizer_alpha = extras["optimizer_alpha"]
            perf_scoreboard = extras["perf_scoreboard"]
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))
    for epoch in range(args.epochs):
        weight_lr = scheduler_w.get_lr()[0]
        arch_lr = scheduler_alpha.get_lr()[0]
        logging.info('epoch %d weight lr %e   arch lr %e', epoch, weight_lr,
                     arch_lr)

        t_top1, t_top5, t_loss = train(trainLoader, valLoader, model,
                                       criterion, epoch, optimizer_w,
                                       optimizer_alpha, monitors, args, logger)
        v_top1, v_top5, v_loss = validate(valLoader, model, criterion, epoch,
                                          monitors, args, logger)

        tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {
            'train': t_loss,
            'val': v_loss
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {
            'train': t_top1,
            'val': v_top1
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {
            'train': t_top5,
            'val': v_top5
        }, epoch)

        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

        logger.info("normal: \n{}".format(
            model.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev1: {}'.format(model.genotypev1()))
        logger.info('Genotypev2: {}'.format(model.genotypev2()))
        logger.info('Genotypev3: {}'.format(model.genotypev3()))
        mask = []
        pruned = 0
        num = 0
        for param in model.arch_parameters():
            weight_copy = param.clone()
            param_array = np.array(weight_copy.detach().cpu())
            pruned += sum(w == 0 for w in param_array)
            num += len(param_array)
        logger.info("Epoch:{} Pruned {} / {}".format(epoch, pruned, num))

        is_best = perf_scoreboard.is_best(epoch)
        # save model
        if epoch % args.save_freq == 0:
            saveCheckpoint(
                epoch, args.model, model, {
                    'scheduler_w': scheduler_w,
                    "scheduler_alpha": scheduler_alpha,
                    "optimizer_w": optimizer_w,
                    'optimizer_alpha': optimizer_alpha,
                    'perf_scoreboard': perf_scoreboard
                }, is_best, os.path.join(args.save, "ckpts"))
        # update lr
        scheduler_w.step()
        scheduler_alpha.step()