Пример #1
0
 def end(self, server_ids, record_id):
     if self.localhost:
         # 删除打包的源文件
         self.localhost.local('rm -f %s' % (self.file))
         # 关闭连接
         self.localhost.close()
     # 关闭死循环读取本地日志
     gl.set_value('deploy_' + str(self.webuser), True)
     sid = ','.join(server_ids)
     defaults = {
         'record_id': record_id,
         'alias': self.alias,
         'server_ids': sid,
         'target_root': self.target_root,
         'target_releases': self.target_releases,
         'prev_record': self.prev_release_version.strip(),
         'is_rollback': True,
         'status': 'Succeed'
     }
     name = '部署_' + record_id
     if self.result.exited == 0:
         DeployRecord.objects.filter(name=name).update(**defaults)
         Project.objects.filter(id=self.project_id).update(last_task_status='Succeed')
     else:
         defaults['status'] = 'Failed'
         defaults['is_rollback'] = False
         DeployRecord.objects.filter(name=name).update(**defaults)
         Project.objects.filter(id=self.project_id).update(last_task_status='Failed')
Пример #2
0
    async def disconnect(self, code):
        # 关闭死循环读取日志
        webuser = self.scope['user'].username
        if hasattr(gl, '_global_dict'):
            deploy_key = 'deploy_' + str(webuser)
            tail_key = 'tail_' + str(webuser)
            if deploy_key in gl._global_dict.keys():
                gl.set_value(deploy_key, True)
            elif tail_key in gl._global_dict.keys():
                client = gl.get_value(tail_key)
                client.close()

        await self.channel_layer.group_discard(webuser, self.channel_name)
Пример #3
0
 def local_tail(self, logfile, webuser):
     # 创建一个可跨文件的全局变量,以便控制死循环
     gl._init()
     gl.set_value('deploy_' + str(webuser), False)
     try:
         with open(logfile, 'rt') as f:
             f.seek(0, 0)
             while True:
                 is_stop = gl.get_value('deploy_' + str(webuser))
                 line = f.readline()
                 if line:
                     self.send_message(webuser, line)
                 elif is_stop:
                     self.send_message(webuser, '[INFO]文件监视结束..')
                     break
     except Exception as e:
         self.send_message(webuser, e)
Пример #4
0
 def remote_tail(self,
                 host,
                 port,
                 user,
                 passwd,
                 logfile,
                 webuser,
                 filter_text=None):
     # 创建一个可跨文件的全局变量,控制停止
     try:
         self.client = paramiko.SSHClient()
         self.client.load_system_host_keys()
         self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
         self.client.connect(hostname=host,
                             port=port,
                             username=user,
                             password=passwd)
         interact = SSHClientInteraction(self.client,
                                         timeout=10,
                                         display=False)
         interact.expect('.*#.*')
         logfile = logfile.strip().replace('&&',
                                           '').replace('||',
                                                       '').replace('|', '')
         self.send_message(webuser, '[INFO][%s@%s]开始监控日志' % (user, host))
         gl._init()
         gl.set_value('tail_' + str(webuser), self.client)
         if filter_text:
             filter_text_re = filter_text.strip().replace('&&', '').replace(
                 '||', '').replace('|', '')
             interact.send('tail -f %s|grep --color=never %s' %
                           (logfile, filter_text_re))
         else:
             interact.send('tail -f %s' % (logfile))
         interact.tail(
             output_callback=lambda m: self.send_message(webuser, m))
     except Exception as e:
         self.send_message(webuser, e)
     finally:
         try:
             self.client.close()
         except Exception as e:
             self.send_message(webuser, e)
Пример #5
0
def train(args, loader_train, models, optimizers, epoch, writer_train):
    #losses_d = utils.AverageMeter()
    #losses_data = utils.AverageMeter()
    #losses_g = utils.AverageMeter()
    #losses_sparse = utils.AverageMeter()
    #losses_kl = utils.AverageMeter()

    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model_t = models[0]
    model_s = models[1]
    model_d = models[2]
    #model_kd = models[3]

    bce_logits = nn.BCEWithLogitsLoss()

    optimizer_d = optimizers[0]
    optimizer_s = optimizers[1]
    optimizer_m = optimizers[2]

    # switch to train mode
    model_d.train()
    model_s.train()
    num_iterations = int(loader_train._size / batch_sizes)
    #num_iterations = len(loader_train)
    print(num_iterations)
    real_label = 1
    fake_label = 0
    exact_list = ["layer3"]
    num_pruned = -1
    t0 = time.time()
    '''
    prec1 = [60]
    #prec1 = 0
    error_d = 0
    error_sparse = 0
    error_g = 0
    error_data = 0
    KD_loss = 0

    alpha_d = args.miu * ( 0.9 - epoch / args.num_epochs * 0.9 )
    sparse_lambda = args.sparse_lambda
    mask_step = args.mask_step
    lr_decay_step = args.lr_decay_step
    '''
    #for i, (inputs, targets) in enumerate(loader_train, 1):
    for i, data in enumerate(loader_train):

        global iteration
        iteration = i

        tt0 = time.time()
        if i % 60 == 1:
            t0 = time.time()

        if i % 400 == 1:
            num_mask = []
            for name, weight in model_s.named_parameters():
                if 'mask' in name:
                    for ii in range(len(weight)):
                        num_mask.append(weight[ii].item())
            num_pruned = sum(m == 0 for m in num_mask)
            if num_pruned > 1100:
                iteration = 1

        #num_iters = num_iterations * epoch + i

        if i > 100 and top1.val < 30:
            iteration = 1
        #iteration = 2
        gl.set_value('iteration', iteration)

        inputs = torch.cat([data[j]["data"] for j in range(num_gpu)], dim=0)
        targets = torch.cat([data[j]["label"] for j in range(num_gpu)],
                            dim=0).squeeze().long()

        targets = targets.cuda(non_blocking=True)
        inputs = inputs.cuda()

        #inputs = inputs.to(args.gpus[0])
        #targets = targets.to(args.gpus[0])
        features_t = model_t(inputs)
        features_s = model_s(inputs)
        #features_kd = model_kd(inputs)

        ############################
        # (1) Update
        # D network
        ###########################
        #'''
        for p in model_d.parameters():
            p.requires_grad = True

        optimizer_d.zero_grad()

        output_t = model_d(features_t.to(args.gpus[0]).detach())

        labels_real = torch.full_like(output_t,
                                      real_label,
                                      device=args.gpus[0])
        error_real = bce_logits(output_t, labels_real)

        output_s = model_d(features_s.to(args.gpus[0]).detach())

        labels_fake = torch.full_like(output_t,
                                      fake_label,
                                      device=args.gpus[0])
        error_fake = bce_logits(output_s, labels_fake)

        error_d = 0.1 * error_real + 0.1 * error_fake

        labels = torch.full_like(output_s, real_label, device=args.gpus[0])

        #error_d += bce_logits(output_s, labels)
        error_d.backward()

        #losses_d.update(error_d.item(), inputs.size(0))
        #writer_train.add_scalar(
        #'discriminator_loss', error_d.item(), num_iters)

        optimizer_d.step()
        #if i % args.print_freq == 0:#i >= 0:#
        if i < 0:
            print('=> D_Epoch[{0}]({1}/{2}):\t'
                  'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t'.format(
                      epoch, i, num_iterations, loss_d=losses_d))

        #'''
        ############################
        # (2) Update student network
        ###########################

        #'''

        for p in model_d.parameters():
            p.requires_grad = False

        optimizer_s.zero_grad()
        optimizer_m.zero_grad()

        alpha = 0.9 - epoch / args.num_epochs * 0.9
        Temperature = 10
        KD_loss = 10 * nn.KLDivLoss()(
            F.log_softmax(features_s / Temperature, dim=1),
            F.softmax(features_t / Temperature, dim=1)) * (
                alpha * Temperature * Temperature) + F.cross_entropy(
                    features_s, targets) * (1 - alpha)
        KD_loss.backward(retain_graph=True)
        #losses_kl.update(KD_loss.item(), inputs.size(0))

        # data_loss
        alpha = 0.9 - epoch / args.num_epochs * 0.9
        #one_hot = torch.zeros(targets.shape[0], 1000).cuda()
        #one_hot = one_hot.scatter_(1, targets.reshape(targets.shape[0],1), 1).cuda()
        error_data = args.miu * (
            alpha * F.mse_loss(features_t, features_s.to(args.gpus[0]))
        )  # + (1 - alpha) * F.mse_loss(one_hot, features_s.to(args.gpus[0])))
        #losses_data.update(error_data.item(), inputs.size(0))
        error_data.backward(retain_graph=True)

        # fool discriminator
        #tt3 = time.time()
        output_s = model_d(features_s.to(args.gpus[0]))
        labels = torch.full_like(output_s, real_label, device=args.gpus[0])
        error_g = 0.1 * bce_logits(output_s, labels)
        #losses_g.update(error_g.item(), inputs.size(0))
        #writer_train.add_scalar(
        #'generator_loss', error_g.item(), num_iters)
        error_g.backward(retain_graph=True)

        optimizer_s.step()

        #'''

        # train mask
        error_sparse = 0
        decay = (epoch % args.lr_decay_step == 0 and i == 1)
        if i % (args.mask_step) == 0:
            mask = []
            for name, param in model_s.named_parameters():
                if 'mask' in name:
                    mask.append(param.view(-1))
            mask = torch.cat(mask)
            error_sparse = 0.00001 * args.sparse_lambda * F.l1_loss(
                mask,
                torch.zeros(mask.size()).to(args.gpus[0]),
                reduction='sum')
            error_sparse.backward()
            optimizer_m.step(decay)
            #losses_sparse.update(error_sparse.item(), inputs.size(0))
            #writer_train.add_scalar(
            #'sparse_loss', error_sparse.item(), num_iters)
        prec1, prec5 = utils.accuracy(features_s.to(args.gpus[0]),
                                      targets.to(args.gpus[0]),
                                      topk=(1, 5))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if i % 60 == 0:
            t1 = time.time()
            print('=> G_Epoch[{0}]({1}/{2}):\n'
                  'Loss_s {loss_sparse:.4f} \t'
                  'Loss_data {loss_data:.4f}\t'
                  'Loss_d {loss_d:.4f} \n'
                  'Loss_g {loss_g:.4f} \t'
                  'Loss_kl {loss_kl:.4f} \n'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\n'
                  'time {time:.4f}\t'
                  'pruned {np}'.format(epoch,
                                       i,
                                       num_iterations,
                                       loss_sparse=error_sparse,
                                       loss_data=error_data,
                                       loss_d=error_d,
                                       loss_g=error_g,
                                       loss_kl=KD_loss,
                                       top1=top1,
                                       top5=top5,
                                       time=t1 - t0,
                                       np=num_pruned))
            logging.info(
                'TRAIN epoch: %03d step : %03d  Top1: %e Top5: %e error_g: %e error_data: %e error_d: %e Duration: %f Pruned: %d',
                epoch, i, top1.avg, top5.avg, error_g, error_data, error_d,
                t1 - t0, num_pruned)
Пример #6
0
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--channel_removed_ratio',default=0.2,type=float,help='removed ratio.')
parser.add_argument('--spatial_removed_ratio',default=0.2,type=float,help='removed ratio.')
parser.add_argument('--Is_spatial',action='store_true',help='use spatial module or not,default is channel with conv.')
parser.add_argument('--lasso',action='store_true',help='add l1 regularization to channel module.')
parser.add_argument('--l1_coe',default=1e-8,type=float,help='coe of l1 regularization.')
parser.add_argument('--show',action='store_true',help='show model architecture.')
parser.add_argument('--flops',action='store_true',help='calc flops given a pretrained model.')
parser.add_argument('--debug',action='store_true',help='debug.')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
best_acc1 = 0
gvar._init()
gvar.set_value('removed_ratio_c',args.channel_removed_ratio)
gvar.set_value('removed_ratio_s',args.spatial_removed_ratio)
gvar.set_value('is_spatial',args.Is_spatial) 
def main():
    if not os.path.isdir(args.save_dir):
    	os.makedirs(args.save_dir)
    
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
Пример #7
0
# -*- coding: utf-8 -*-
"""
@author: heyongchao
@Created on: 2019/12/7 6:48 下午
@Desc: 
"""
import os
import sys
from utils import makeImage, globalvar as gl
from utils.imgManager import ImgManager

dirname, filename = os.path.split(os.path.abspath(__file__))
gl._init()
gl.set_value('dirname', dirname)
print("当前脚本所在路径:" + dirname)
package_url = dirname + '/site-packages'
sys.path.append(package_url)
if len(sys.argv) != 4:
    print("参数缺失,参数列表:version1,version2,appiconset_path")
    sys.exit(0)
version1 = sys.argv[1]
version2 = sys.argv[2]
appiconset_path = sys.argv[3]
print("version1:%s,version2:%s" % (version1, version2))

print('begin make watermark appIcon ...')
makeImage.mark_image(version1, version2)
print('begin create appIcon ...')
ImgManager(appiconset_path).handle_icon_images()
#sys.exit(0)
Пример #8
0
    def post(self, request, format=None):
        if request.data['excu'] == 'init':
            # 项目初始化
            id = request.data['id']
            result = self.repo_init(id)
            if result.exited == 0:
                Project.objects.filter(id=id).update(status='Succeed')
                info_logger.info('初始化项目:' + str(id) + ',执行成功!')
                http_status = OK
                msg = '初始化成功!'
            else:
                error_logger.error('初始化项目:%s 执行失败! 错误信息:%s' %
                                   (str(id), result.stderr))
                http_status = BAD
                msg = '初始化项目:%s 执行失败! 错误信息:%s' % (str(id), result.stderr)

            return XopsResponse(msg, status=http_status)

        elif request.data['excu'] == 'deploy':
            # 部署操作
            id = request.data['id']
            webuser = request.user.username
            alias = request.data['alias']
            self.start_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
            record_id = str(alias) + '_' + str(self.start_time)
            name = '部署_' + record_id
            DeployRecord.objects.create(name=name,
                                        alias=alias,
                                        status='Failed',
                                        project_id=int(id))
            Project.objects.filter(id=id).update(last_task_status='Failed')
            local_log_path = self._path.rstrip('/') + '/' + str(
                id) + '_' + str(request.data['alias']) + '/logs'
            log = local_log_path + '/' + record_id + '.log'
            version = request.data['version'].strip()
            serverid = request.data['server_ids']
            deploy = DeployExcu(webuser, record_id, id)
            deploy.start(log, version, serverid, record_id, webuser,
                         self.start_time)
            return XopsResponse(record_id)

        elif request.data['excu'] == 'rollback':
            # 回滚
            id = request.data['id']
            project_id = request.data['project_id']
            alias = request.data['alias']
            self.start_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
            record_id = str(alias) + '_' + str(self.start_time)
            log = self._path.rstrip('/') + '/' + str(project_id) + '_' + str(
                alias) + '/logs/' + record_id + '.log'
            self.do_rollback(id, log, record_id)
            return XopsResponse(record_id)

        elif request.data['excu'] == 'deploymsg':
            # 部署控制台消息读取
            try:
                id = request.data['id']
                alias = request.data['alias']
                record = request.data['record']
                scenario = int(request.data['scenario'])
                logfile = self._path.rstrip('/') + '/' + str(id) + '_' + str(
                    alias) + '/logs/' + record + '.log'
                webuser = request.user.username
                print(webuser)
                msg = Tailf()
                if scenario == 0:
                    gl._init()
                    gl.set_value('deploy_' + str(webuser), False)
                    msg.local_tailf(logfile, webuser)
                http_status = OK
                request_status = '执行成功!'
            except Exception as e:
                http_status = BAD
                request_status = '执行错误:日志文件可能不存在!'
                print(e)
            return XopsResponse(request_status, status=http_status)

        elif request.data['excu'] == 'readlog' and request.data[
                'scenario'] == 1:
            # 读取部署日志
            try:
                id = request.data['id']
                alias = request.data['alias']
                record = request.data['record']
                logfile = self._path.rstrip('/') + '/' + str(id) + '_' + str(
                    alias) + '/logs/' + record + '.log'
                response = FileResponse(open(logfile, 'rb'))
                response['Content-Type'] = 'text/plain'
                return response
            except Exception:
                http_status = BAD
                request_status = '执行错误:文件不存在!'
            return XopsResponse(request_status, status=http_status)

        elif request.data['excu'] == 'app_start':
            # 项目启动
            try:
                app_start = request.data['app_start']
                host = request.data['host']
                webuser = request.user.username
                auth_info, auth_key = auth_init(host)
                connect = Shell(auth_info,
                                connect_timeout=5,
                                connect_kwargs=auth_key)
                app_start = app_start.strip().replace('&&',
                                                      '').replace('||', '')
                connect.run(app_start, ws=True, webuser=webuser)
                connect.close()
                http_status = OK
                request_status = '执行成功!'
            except Exception as e:
                http_status = BAD
                request_status = '执行错误:' + str(e)
            return XopsResponse(request_status, status=http_status)

        elif request.data['excu'] == 'app_stop':
            # 项目停止
            try:
                app_stop = request.data['app_stop']
                host = request.data['host']
                webuser = request.user.username
                auth_info, auth_key = auth_init(host)
                connect = Shell(auth_info,
                                connect_timeout=5,
                                connect_kwargs=auth_key)
                app_stop = app_stop.strip().replace('&&', '').replace('||', '')
                connect.run(app_stop, ws=True, webuser=webuser)
                connect.close()
                http_status = OK
                request_status = '执行成功!'
            except Exception as e:
                http_status = BAD
                request_status = '执行错误:' + str(e)
            return XopsResponse(request_status, status=http_status)

        elif request.data['excu'] == 'tail_start':
            # 日志监控
            try:
                filter_text = str(request.data['filter'])
                app_log_file = request.data['app_log_file']
                host = request.data['host']
                webuser = request.user.username
                device_info = DeviceInfo.objects.filter(id=int(host)).values()
                host = device_info[0]['hostname']
                auth_type = device_info[0]['auth_type']
                connect_info = ConnectionInfo.objects.filter(
                    hostname=host, auth_type=auth_type).values()
                user = connect_info[0]['username']
                passwd = connect_info[0]['password']
                port = connect_info[0]['port']
                tail = Tailf()
                tail.remote_tail(host,
                                 port,
                                 user,
                                 passwd,
                                 app_log_file,
                                 webuser,
                                 filter_text=filter_text)
                http_status = OK
                request_status = '执行成功!'
            except Exception as e:
                http_status = BAD
                request_status = str(e)
            return XopsResponse(request_status, status=http_status)

        elif request.data['excu'] == 'tail_stop':
            # 日志监控停止
            try:
                webuser = request.user.username
                if hasattr(gl, '_global_dict'):
                    tail_key = 'tail_' + str(webuser)
                    if tail_key in gl._global_dict.keys():
                        client = gl.get_value('tail_' + str(webuser))
                        client.close()
                http_status = OK
                request_status = '执行成功!'
            except Exception as e:
                http_status = BAD
                request_status = str(e)
            return XopsResponse(request_status, status=http_status)
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    # model_t = resnet_56().to(args.gpus[0])
    #model_t = MobileNetV2()
    model_t = ResNet18()
    model_kd = ResNet101()

    print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['net']
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t
    #for k, v in state_dict_t.items():
    #print(k[0:6])
    #if k[0:6] == 'linear':
    #temp = v[0:10]
    #print(v[0:10].shape)
    #new_state_dict_t[k] = temp

    #model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    #model_s = SpraseMobileNetV2().to(args.gpus[0])
    model_s = ResNet18_sprase().to(args.gpus[0])
    print(model_s)
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    ckpt_kd = torch.load('resnet101.t7',
                         map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_kd = ckpt_kd['net']
    new_state_dict_kd = OrderedDict()
    for k, v in state_dict_kd.items():
        name = k[7:]
        new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
    model_kd.load_state_dict(new_state_dict_kd)
    model_kd = model_kd.to(args.gpus[0])

    for para in list(model_kd.parameters())[:-2]:
        para.requires_grad = False

    if len(args.gpus) != 1:
        print('@@@@@@')
        model_s = nn.DataParallel(model_s, device_ids=args.gpus[0, 1])

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d, model_kd]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        global g_e
        g_e = epoch
        gl.set_value('epoch', g_e)

        #train(args, loader.loader_train, models, optimizers, epoch, writer_train)
        #print('###########################')
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        if is_best:
            checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
def train(args, loader_train, models, optimizers, epoch, writer_train):
    losses_d = utils.AverageMeter()
    losses_data = utils.AverageMeter()
    losses_g = utils.AverageMeter()
    losses_sparse = utils.AverageMeter()
    losses_kl = utils.AverageMeter()

    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model_t = models[0]
    model_s = models[1]
    model_d = models[2]
    model_kd = models[3]

    bce_logits = nn.BCEWithLogitsLoss()

    optimizer_d = optimizers[0]
    optimizer_s = optimizers[1]
    optimizer_m = optimizers[2]

    # switch to train mode
    model_d.train()
    model_s.train()

    num_iterations = len(loader_train)

    real_label = 1
    fake_label = 0
    exact_list = ["layer3"]

    for i, (inputs, targets) in enumerate(loader_train, 1):
        num_iters = num_iterations * epoch + i
        #print(i,'i')
        global iteration
        iteration = i
        gl.set_value('iteration', iteration)
        inputs = inputs.to(args.gpus[0])
        targets = targets.to(args.gpus[0])

        features_t = model_t(inputs)
        features_kd = model_kd(inputs)
        features_s = model_s(inputs)
        # FM_exactor_t = FeatureExtractor(model_t, exact_list).cuda()
        # FM_t = FM_exactor_t(inputs)
        #
        # FM_exactor_s = FeatureExtractor(model_s, exact_list).cuda()
        # FM_s = FM_exactor_s(inputs)
        # print(FM_t[0].shape,'@@@@@@@@@@@')
        # print(FM_s[0].shape, '@@@@@@@@@@@')
        ############################
        # (1) Update
        # D network
        ###########################

        for p in model_d.parameters():
            p.requires_grad = True

        optimizer_d.zero_grad()

        output_t = model_d(features_t.detach())

        # output_t = model_d(FM_t[0].detach())

        labels_real = torch.full_like(output_t,
                                      real_label,
                                      device=args.gpus[0])
        error_real = bce_logits(output_t, labels_real)

        output_s = model_d(features_s.to(args.gpus[0]).detach())
        # output_s = model_d(FM_s[0].detach())

        labels_fake = torch.full_like(output_t,
                                      fake_label,
                                      device=args.gpus[0])
        error_fake = bce_logits(output_s, labels_fake)

        error_d = error_real + error_fake

        labels = torch.full_like(output_s, real_label, device=args.gpus[0])

        error_d += bce_logits(output_s, labels)
        error_d.backward()
        losses_d.update(error_d.item(), inputs.size(0))
        writer_train.add_scalar('discriminator_loss', error_d.item(),
                                num_iters)

        optimizer_d.step()

        if i % args.print_freq == 0:
            print('=> D_Epoch[{0}]({1}/{2}):\t'
                  'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t'.format(
                      epoch, i, num_iterations, loss_d=losses_d))

        ############################
        # (2) Update student network
        ###########################

        for p in model_d.parameters():
            p.requires_grad = False

        optimizer_s.zero_grad()
        optimizer_m.zero_grad()

        # reg = 1e-6
        # l1_loss = Variable(torch.FloatTensor(1), requires_grad=True).to(args.gpus[0])
        # for name, param in model_s.named_parameters():
        #     if 'bias' not in name:
        #         l1_loss = l1_loss+(reg * torch.sum(torch.abs(param)))

        # print(l1_loss, '!!!!!!')

        # print(features_t.shape,'features_t')
        # print(features_t.shape,'features_s')
        # print(targets.shape,'targets')

        # error_data_kl = nn.functional.kl_div(features_s, features_t.float())
        # KL_loss = nn.KLDivLoss(reduction='sum')
        # log_fea_s = F.log_softmax(features_s, 1)
        # fea_t = F.softmax(features_t, 1)
        # # fea_t = nn.functional.log_softmax(features_t)
        # # fea_s = nn.functional.log_softmax(features_s)
        # error_data_kl = KL_loss(log_fea_s, fea_t.float())/inputs.size(0)
        # loss_fun = nn.CrossEntropyLoss()
        # error_data_kl = loss_fun(features_s, features_t)
        # error_data_kl.backward(retain_graph=True)
        alpha = 0.99
        Temperature = 30
        KD_loss = nn.KLDivLoss()(
            F.log_softmax(features_s / Temperature, dim=1),
            F.softmax(features_kd / Temperature, dim=1)) * (
                alpha * Temperature * Temperature) + F.cross_entropy(
                    features_s, targets) * (1 - alpha)
        KD_loss.backward(retain_graph=True)
        # print(error_data_kl,'error_data_kl')
        # print(error_data,'error_data')
        # error_data += error_data_kl
        losses_kl.update(KD_loss.item(), inputs.size(0))

        error_data = args.miu * F.mse_loss(features_t,
                                           features_s.to(args.gpus[0]))  ##
        losses_data.update(error_data.item(), inputs.size(0))
        writer_train.add_scalar('data_loss', error_data.item(), num_iters)
        error_data.backward(retain_graph=True)

        # fool discriminator
        output_s = model_d(features_s.to(args.gpus[0]))
        # output_s = model_d(FM_s[0])

        labels = torch.full_like(output_s, real_label, device=args.gpus[0])
        error_g = bce_logits(output_s, labels)
        losses_g.update(error_g.item(), inputs.size(0))
        writer_train.add_scalar('generator_loss', error_g.item(), num_iters)
        error_g.backward(retain_graph=True)

        # train mask
        mask = []
        for name, param in model_s.named_parameters():
            if 'mask' in name:
                mask.append(param.view(-1))
        mask = torch.cat(mask)
        error_sparse = 0.01 * args.sparse_lambda * F.l1_loss(
            mask, torch.zeros(mask.size()).to(args.gpus[0]), reduction='sum')
        error_sparse.backward()

        losses_sparse.update(error_sparse.item(), inputs.size(0))
        writer_train.add_scalar('sparse_loss', error_sparse.item(), num_iters)

        optimizer_s.step()

        decay = (epoch % args.lr_decay_step == 0 and i == 1)
        if i % args.mask_step == 0:
            optimizer_m.step(decay)

        prec1, prec5 = utils.accuracy(features_s.to(args.gpus[0]),
                                      targets.to(args.gpus[0]),
                                      topk=(1, 5))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if i % args.print_freq == 0:
            print('=> G_Epoch[{0}]({1}/{2}):\t'
                  'Loss_sparse {loss_sparse.val:.4f} ({loss_sparse.avg:.4f})\t'
                  'Loss_data {loss_data.val:.4f} ({loss_data.avg:.4f})\t'
                  'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t'
                  'Loss_g {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
                  'Loss_kl {loss_kl.val:.4f} ({loss_kl.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      num_iterations,
                      loss_sparse=losses_sparse,
                      loss_data=losses_data,
                      loss_g=losses_g,
                      loss_d=losses_d,
                      loss_kl=losses_kl,
                      top1=top1,
                      top5=top5))