Ejemplo n.º 1
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    #loader = import_module('data.' + args.dataset).Data(args)

    # while(1):
    #     a=1

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_train')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_val')
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=batch_sizes, shuffle=True,
    #     num_workers=8, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_val_rec')


    train_queue = getTrainValDataset(traindir, valdir, batch_size=batch_size, val_batch_size=batch_size,
                                     num_shards=num_gpu, workers=num_workers)
    valid_queue = getTestDataset(valdir, test_batch_size=batch_size, num_shards=num_gpu,
                                 workers=num_workers)

    #loader = cifar100(args)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))


    if args.pruned:
        mask = checkpoint['mask']
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    # model = torchvision.models.resnet18()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
        print('Flops:  ' + flops)
        print('Params: ' + params)
    pruned_dir = args.pruned_dir
    checkpoint_pruned = torch.load(pruned_dir, map_location=torch.device(f"cuda:{args.gpus[0]}"))
    model = torch.nn.DataParallel(model)
    #
    # new_state_dict_pruned = OrderedDict()
    # for k, v in checkpoint_pruned.items():
    #     name = k[7:]
    #     new_state_dict_pruned[name] = v
    # model.load_state_dict(new_state_dict_pruned)

    model.load_state_dict(checkpoint_pruned['state_dict_s'])

    test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test)
    logging.info('Simply test after prune: %e ', test_prec1)
    logging.info('Model size: %e ', get_parameters_size(model)/1e6)

    exit()

    if args.test_only:
        return
    param_s = [param for name, param in model.named_parameters() if 'mask' not in name]
    #optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.00001, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer = optim.SGD(param_s, lr=1e-5, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.num_epochs))


    model_kd = None
    if kd_flag:
        model_kd = ResNet101()
        ckpt_kd = torch.load('resnet101.t7', map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_kd = ckpt_kd['net']
        new_state_dict_kd = OrderedDict()
        for k, v in state_dict_kd.items():
            name = k[7:]
            new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
        model_kd.load_state_dict(new_state_dict_kd)
        model_kd = model_kd.to(args.gpus[1])

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict_s'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))
    #print(model.named_parameters())
    #for name, param in model.named_parameters():
        #print(name)
    for epoch in range(start_epoch, 60):
        scheduler.step()#scheduler.step(epoch)
        t1 = time.time()
        train(args, train_queue, model, criterion, optimizer, writer_train, epoch, model_kd)
        test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test, epoch)
        t2 = time.time()
        print(epoch, t2 - t1)
        logging.info('TEST Top1: %e Top5: %e ', test_prec1, test_prec5)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
        logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)
        train_queue.reset()
        valid_queue.reset()

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
Ejemplo n.º 2
0
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR

from utils.options import args
from model.cifar10.shiftresnet import *
import torch.backends.cudnn as cudnn


def _make_dir(path):
    if not os.path.exists(path): os.makedirs(path)


ckpt = utils.checkpoint(args)
print_logger = utils.get_logger(os.path.join(args.job_dir, "logger.log"))
utils.print_params(vars(args), print_logger.info)
writer_train = SummaryWriter(args.job_dir + '/run/train')
writer_test = SummaryWriter(args.job_dir + '/run/test')


def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cudnn.benchmark = True

    start_epoch = args.start_epoch

    lr_decay_step = list(map(int, args.lr_decay_step.split(',')))

    # Data loading
Ejemplo n.º 3
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    # while(1):
    #     a=2
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    traindir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                            'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                          'ILSVRC2012_img_val_rec')
    train_loader, val_loader = getTrainValDataset(traindir, valdir,
                                                  batch_sizes, 100, num_gpu,
                                                  num_workers)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')

    model_t = ResNet50()

    # model_kd = resnet101(pretrained=False)

    #print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t

    model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = ResNet50_sprase().to(args.gpus[0])
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    #ckpt_kd = torch.load('resnet101-5d3b4d8f.pth', map_location=torch.device(f"cuda:{args.gpus[0]}"))
    #state_dict_kd = ckpt_kd
    #new_state_dict_kd = state_dict_kd
    #model_kd.load_state_dict(new_state_dict_kd)
    #model_kd = model_kd.to(args.gpus[0])

    #for para in list(model_kd.parameters())[:-2]:
    #para.requires_grad = False

    model_d = Discriminator().to(args.gpus[0])

    model_s = nn.DataParallel(model_s).cuda()
    model_t = nn.DataParallel(model_t).cuda()
    model_d = nn.DataParallel(model_d).cuda()

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr * 100, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_s = ckpt['state_dict_s']
        state_dict_d = ckpt['state_dict_d']

        new_state_dict_s = OrderedDict()
        for k, v in state_dict_s.items():
            new_state_dict_s['module.' + k] = v

        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(new_state_dict_s)
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        start_epoch = ckpt['epoch']
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    models = [model_t, model_s, model_d]  #, model_kd]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        #global g_e
        #g_e = epoch
        #gl.set_value('epoch',g_e)

        train(args, train_loader, models, optimizers, epoch, writer_train)
        test_prec1, test_prec5 = test(args, val_loader, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        train_loader.reset()
        val_loader.reset()
        #if is_best:
        checkpoint.save_model(state, epoch + 1, is_best)
        #checkpoint.save_model(state, 1, False)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
Ejemplo n.º 4
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        
    if args.pruned:
        mask = checkpoint['mask']
        pruned = sum([1 for m in mask if mask == 0])
        print(f"Pruned / Total: {pruned} / {len(mask)}")
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test)
    print(f"Simply test after prune {test_prec1:.3f}")
    
    if args.test_only:
        return 

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer, writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, epoch)

        is_best_finetune = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, False, is_best_finetune)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
Ejemplo n.º 5
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = resnet_56().to(args.gpus[0])

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['state_dict']
    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = resnet_56_sparse().to(args.gpus[0])

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch,
              writer_train)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt',
                            map_location=torch.device(f"cuda:{args.gpus[0]}"))

    model = prune_resnet(args, best_model['state_dict_s'])
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    # model_t = resnet_56().to(args.gpus[0])
    #model_t = MobileNetV2()
    model_t = ResNet18()
    model_kd = ResNet101()

    print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['net']
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t
    #for k, v in state_dict_t.items():
    #print(k[0:6])
    #if k[0:6] == 'linear':
    #temp = v[0:10]
    #print(v[0:10].shape)
    #new_state_dict_t[k] = temp

    #model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    #model_s = SpraseMobileNetV2().to(args.gpus[0])
    model_s = ResNet18_sprase().to(args.gpus[0])
    print(model_s)
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    ckpt_kd = torch.load('resnet101.t7',
                         map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_kd = ckpt_kd['net']
    new_state_dict_kd = OrderedDict()
    for k, v in state_dict_kd.items():
        name = k[7:]
        new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
    model_kd.load_state_dict(new_state_dict_kd)
    model_kd = model_kd.to(args.gpus[0])

    for para in list(model_kd.parameters())[:-2]:
        para.requires_grad = False

    if len(args.gpus) != 1:
        print('@@@@@@')
        model_s = nn.DataParallel(model_s, device_ids=args.gpus[0, 1])

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d, model_kd]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        global g_e
        g_e = epoch
        gl.set_value('epoch', g_e)

        #train(args, loader.loader_train, models, optimizers, epoch, writer_train)
        #print('###########################')
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        if is_best:
            checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")