Exemple #1
0
import tensorflow.compat.v1 as tf
import lenet
import utils.read_mnist as read_mnist
import time
import os
import numpy as np


train_x, train_y, test_x, test_y, index2class = read_mnist.read_fashion_mnist(one_hot=True, standard=True)
index_train = np.random.permutation(train_x.shape[0])
train_x, train_y = train_x[index_train], train_y[index_train]

lr_rate = 0.01
batch_size = 64
sess = tf.Session()
model = lenet.LeNet(lr_rate=0.001, regular=0.0005, train=True)
sess.run(tf.global_variables_initializer())
tensorboard_dir = r"fashiontensorboardlog/"
if os.path.exists(tensorboard_dir):
    for file in os.listdir(tensorboard_dir):
        path_file = os.path.join(tensorboard_dir, file)
        os.remove(path_file)
file_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
tf.summary.scalar("loss of test data", model.loss)
tf.summary.scalar("accuracy on test data", model.accuracy)
tf.summary.scalar('learning rate ', tf.reduce_mean(model.lr_rate))
merge = tf.summary.merge_all()
print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
epochs = 100
saver = tf.train.Saver()
for epoch in range(epochs):
Exemple #2
0

# In[ ]:

if args.dataset in ['MNIST', 'CIFAR10']:
    n_channels = {
        'MNIST': 1,
        'CIFAR10': 3,
    }[args.dataset]
    size = {
        'MNIST': 28,
        'CIFAR10': 32,
    }[args.dataset]
    actor = {
        'linear': nn.Linear(n_channels * size**2, n_classes),
        'lenet': lenet.LeNet(3, n_classes, size),
        'resnet': resnet.ResNet(depth=18, n_classes=n_classes),
    }[args.actor]
elif args.dataset in ['covtype']:
    n_features = train_x.size(1)
    actor = {
        'linear': nn.Linear(n_features, n_classes),
        'mlp': mlp.MLP([n_features, 60, 60, 80, n_classes], th.tanh)
    }[args.actor]

if args.w > 0:
    assert n_classes == 2
w = th.tensor([1 -
               args.w, args.w]) if args.w else th.full(n_classes, 1.0 /
                                                       n_classes)
cross_entropy = loss.CrossEntropyLoss(w)
Exemple #3
0
"""
Created on 2017.6.11

@author: tfygg
"""

import lenet
import lenetseq
import utils

# Net
net = lenet.LeNet()
print(net)

for index, param in enumerate(net.parameters()):
    print(list(param.data))
    print(type(param.data), param.size())
    print index, "-->", param

print(net.state_dict())
print(net.state_dict().keys())

for key in net.state_dict():
    print key, 'corresponds to', list(net.state_dict()[key])

#NetSeq
netSeq = lenetseq.LeNetSeq()
print(netSeq)

utils.initNetParams(netSeq)
def main(args):
    logfilename = 'convergence_mnist_{}_kfac{}_gpu{}_bs{}_{}_lr{}_sr{}_wp{}.log'.format(
        args.model, args.kfac_update_freq, hvd.size(), args.batch_size,
        args.kfac_name, args.base_lr, args.sparse_ratio, args.warmup_epochs)
    if hvd.rank() == 0:
        wandb.init(project='kfac',
                   entity='hkust-distributedml',
                   name=logfilename,
                   config=args)

    logfile = './logs/' + logfilename
    #logfile = './logs/sparse_cifar10_{}_kfac{}_gpu{}_bs{}.log'.format(args.model, args.kfac_update_freq, hvd.size(), args.batch_size)
    #logfile = './logs/cifar10_{}_kfac{}_gpu{}_bs{}.log'.format(args.model, args.kfac_update_freq, hvd.size(), args.batch_size)
    hdlr = logging.FileHandler(logfile)
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.info(args)

    torch.manual_seed(args.seed)
    verbose = True if hvd.rank() == 0 else False
    args.verbose = 1 if hvd.rank() == 0 else 0

    if args.cuda:
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(args.seed)

    torch.backends.cudnn.benchmark = True

    args.log_dir = os.path.join(
        args.log_dir, "mnist_{}_kfac{}_gpu_{}_{}".format(
            args.model, args.kfac_update_freq, hvd.size(),
            datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
    #os.makedirs(args.log_dir, exist_ok=True)
    #log_writer = SummaryWriter(args.log_dir) if verbose else None
    log_writer = None

    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(1)

    kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}

    transform_train = []
    if args.model.lower() == 'lenet':
        transform_train.append(transforms.Resize(32))

    transform_train.extend(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    transform_train = transforms.Compose(transform_train)
    transform_test = transform_train

    download = True if hvd.local_rank() == 0 else False
    if not download: hvd.allreduce(torch.tensor(1), name="barrier")

    train_dataset = torchvision.datasets.MNIST(root=self.dir,
                                               train=True,
                                               download=download,
                                               transform=transform_train)
    test_dataset = torchvision.datasets.MNIST(root=args.dir,
                                              train=False,
                                              download=download,
                                              transform=transform_test)

    if download: hvd.allreduce(torch.tensor(1), name="barrier")

    # Horovod: use DistributedSampler to partition the training data.
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    #train_loader = torch.utils.data.DataLoader(train_dataset,
    train_loader = MultiEpochsDataLoader(train_dataset,
                                         batch_size=args.batch_size *
                                         args.batches_per_allreduce,
                                         sampler=train_sampler,
                                         **kwargs)

    # Horovod: use DistributedSampler to partition the test data.
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              sampler=test_sampler,
                                              **kwargs)

    if args.model.lower() == "fcn5net":
        model = fcn.FCN5Net()
    elif args.model.lower() == "lenet":
        model = lenet.LeNet()

    if args.cuda:
        model.cuda()

    #if verbose:
    #    summary(model, (3, 32, 32))

    criterion = nn.CrossEntropyLoss()
    args.base_lr = args.base_lr * hvd.size()
    use_kfac = True if args.kfac_update_freq > 0 else False

    optimizer = optim.SGD(model.parameters(),
                          lr=args.base_lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if use_kfac:
        KFAC = kfac.get_kfac_module(args.kfac_name)
        preconditioner = KFAC(
            model,
            lr=args.base_lr,
            factor_decay=args.stat_decay,
            damping=args.damping,
            kl_clip=args.kl_clip,
            fac_update_freq=args.kfac_cov_update_freq,
            kfac_update_freq=args.kfac_update_freq,
            diag_blocks=args.diag_blocks,
            diag_warmup=args.diag_warmup,
            distribute_layer_factors=args.distribute_layer_factors,
            sparse_ratio=args.sparse_ratio)
        kfac_param_scheduler = kfac.KFACParamScheduler(
            preconditioner,
            damping_alpha=args.damping_alpha,
            damping_schedule=args.damping_schedule,
            update_freq_alpha=args.kfac_update_freq_alpha,
            update_freq_schedule=args.kfac_update_freq_schedule)

    # KFAC guarentees grads are equal across ranks before opt.step() is called
    # so if we do not use kfac we need to wrap the optimizer with horovod
    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression,
        op=hvd.Average,
        backward_passes_per_step=args.batches_per_allreduce)

    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay)
    lr_scheduler = [LambdaLR(optimizer, lrs)]
    if use_kfac:
        lr_scheduler.append(LambdaLR(preconditioner, lrs))

    def train(epoch):
        model.train()
        train_sampler.set_epoch(epoch)
        train_loss = Metric('train_loss')
        train_accuracy = Metric('train_accuracy')

        if STEP_FIRST:
            for scheduler in lr_scheduler:
                scheduler.step()
            if use_kfac:
                kfac_param_scheduler.step(epoch)

    #    with tqdm(total=len(train_loader),
    #              desc='Epoch {:3d}/{:3d}'.format(epoch + 1, args.epochs),
    #              disable=not verbose) as t:
        display = 20
        avg_time = 0.0
        io_time = 0.0
        if True:
            for batch_idx, (data, target) in enumerate(train_loader):
                stime = time.time()
                if args.cuda:
                    data, target = data.cuda(non_blocking=True), target.cuda(
                        non_blocking=True)
                io_time += time.time() - stime
                optimizer.zero_grad()

                for i in range(0, len(data), args.batch_size):
                    data_batch = data[i:i + args.batch_size]
                    target_batch = target[i:i + args.batch_size]
                    output = model(data_batch)

                    loss = criterion(output, target_batch)
                    with torch.no_grad():
                        train_loss.update(loss)
                        train_accuracy.update(accuracy(output, target_batch))
                    loss.div_(math.ceil(float(len(data)) / args.batch_size))
                    loss.backward()

                optimizer.synchronize()
                if use_kfac:
                    preconditioner.step(epoch=epoch)
                with optimizer.skip_synchronize():
                    optimizer.step()

                #t.set_postfix_str("loss: {:.4f}, acc: {:.2f}%".format(
                #train_loss.avg.item(), 100*train_accuracy.avg.item()))
                #t.update(1)
                avg_time += (time.time() - stime)
                if batch_idx > 0 and batch_idx % display == 0:
                    if args.verbose:
                        logger.info(
                            "[%d][%d] train loss: %.4f, acc: %.3f, time: %.3f [io: %.3f], speed: %.3f images/s"
                            % (epoch, batch_idx, train_loss.avg.item(), 100 *
                               train_accuracy.avg.item(), avg_time / display,
                               io_time / display, args.batch_size /
                               (avg_time / display)))
                        avg_time = 0.0
                        io_time = 0.0
                if hvd.rank() == 0:
                    wandb.log({"loss": loss, "epoch": epoch})
            if args.verbose:
                logger.info("[%d] epoch train loss: %.4f, acc: %.3f" %
                            (epoch, train_loss.avg.item(),
                             100 * train_accuracy.avg.item()))

        if not STEP_FIRST:
            for scheduler in lr_scheduler:
                scheduler.step()
            if use_kfac:
                kfac_param_scheduler.step(epoch)

        if log_writer:
            log_writer.add_scalar('train/loss', train_loss.avg, epoch)
            log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)

    def test(epoch):
        model.eval()
        test_loss = Metric('val_loss')
        test_accuracy = Metric('val_accuracy')

        #with tqdm(total=len(test_loader),
        #          bar_format='{l_bar}{bar}|{postfix}',
        #          desc='             '.format(epoch + 1, args.epochs),
        #          disable=not verbose) as t:
        if True:
            with torch.no_grad():
                for i, (data, target) in enumerate(test_loader):
                    if args.cuda:
                        data, target = data.cuda(), target.cuda()
                    output = model(data)
                    test_loss.update(criterion(output, target))
                    test_accuracy.update(accuracy(output, target))
                if args.verbose:
                    logger.info("[%d][0] evaluation loss: %.4f, acc: %.3f" %
                                (epoch, test_loss.avg.item(),
                                 100 * test_accuracy.avg.item()))
                    if hvd.rank() == 0:
                        wandb.log({
                            "val top-1 acc": test_accuracy.avg.item(),
                            "epoch": epoch
                        })

                    #t.update(1)
                    #if i + 1 == len(test_loader):
                    #    t.set_postfix_str("\b\b test_loss: {:.4f}, test_acc: {:.2f}%".format(
                    #            test_loss.avg.item(), 100*test_accuracy.avg.item()),
                    #            refresh=False)

        if log_writer:
            log_writer.add_scalar('test/loss', test_loss.avg, epoch)
            log_writer.add_scalar('test/accuracy', test_accuracy.avg, epoch)

    start = time.time()

    for epoch in range(args.epochs):
        if args.verbose:
            logger.info("[%d] epoch train starts" % (epoch))
        train(epoch)
        test(epoch)

    if verbose:
        logger.info("Training time: %s",
                    str(datetime.timedelta(seconds=time.time() - start)))

    pass
Exemple #5
0
from keras.preprocessing.image import ImageDataGenerator
import os,sys
import numpy as np

model_path = os.path.join('..','models','keras','models')
sys.path.append(model_path)

import lenet
from imagenet_utils import preprocess_input
from keras.preprocessing import image as keras_image

weights_path = os.path.join('..', 'models', 'keras', 'weights', 'weights_basic.h5')
#model.load_weights(weights_path)
#model = lenet.LeNet(weights_path=weights_path)
model = lenet.LeNet(weights_path)
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1./255)

# this is a generator that will read pictures found in
Exemple #6
0
            MODEL = 'lenet_E' + str(EPOCHS) + '_D' + str(dropout) + '_R' + str(
                rate)
            model_dir = '/home/valentin/Desktop/MLProject/models/' + MODEL + '/'
            save_dir = '/home/valentin/Desktop/MLProject/results/' + MODEL + '/'
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            #placeholder for model
            x = tf.placeholder(tf.float32, (None, 32, 32, 1))
            y = tf.placeholder(tf.int32, (None))
            one_hot_y = tf.one_hot(y, 10)
            keep_prob = tf.placeholder(
                tf.float32)  # dropout (keep probability)

            # load Alexnet
            logits = lenet.LeNet(x, keep_prob)

            cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                logits=logits, labels=one_hot_y)
            loss_operation = tf.reduce_mean(cross_entropy)
            optimizer = tf.train.AdamOptimizer(learning_rate=rate)
            training_operation = optimizer.minimize(loss_operation)

            correct_prediction = tf.equal(tf.argmax(logits, 1),
                                          tf.argmax(one_hot_y, 1))
            accuracy_operation = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            train_prediction = tf.nn.softmax(logits)
            output = tf.argmax(logits, 1)

            # evaluate model accuracy with validation set
Exemple #7
0
import lenet
import simplenet
import resnet
import resnet_with_compression
import mobilenet
import mobilenetv2
import densenet
import dpn
import preact_resnet

cifar10_networks = {
    'lenet': lenet.LeNet(),
    'simplenet9': simplenet.SimpleNet9(),
    'simplenet9_thin': simplenet.SimpleNet9_thin(),
    'simplenet9_mobile': simplenet.SimpleNet9_mobile(),
    'simplenet7': simplenet.SimpleNet7(),
    'simplenet7_thin': simplenet.SimpleNet7_thin(),
    'resnet18NNFC1': resnet_with_compression.ResNet18NNFC1(),
    'resnet18EH0': resnet_with_compression.ResNet18EH(layer=0, quantizer=20),
    'resnet18EH1': resnet_with_compression.ResNet18EH(layer=1, quantizer=6),
    'resnet18EH2': resnet_with_compression.ResNet18EH(layer=2, quantizer=5),
    'resnet18EH3': resnet_with_compression.ResNet18EH(layer=3, quantizer=3),
    'resnet18EH4': resnet_with_compression.ResNet18EH(layer=4, quantizer=10),
    'resnet18JPEG90': resnet_with_compression.ResNet18JPEG(quantizer=90),
    'resnet18JPEG87': resnet_with_compression.ResNet18JPEG(quantizer=87),
    'resnet18AVC': resnet_with_compression.ResNet18AVC(layer=2, quantizer=24),
    'resnet18': resnet.ResNet18(),
    'resnet101': resnet.ResNet101(),
    'mobilenetslimplus': mobilenet.MobileNetSlimPlus(),
    'mobilenetslim': mobilenet.MobileNetSlim(),
    'mobilenet': mobilenet.MobileNet(),
Exemple #8
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    torch.manual_seed(args.seed)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    #model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    if args.arch == 'lenet':
        model = lenet.LeNet(num_classes=int(args.data[5:]))
    else:
        model = resnet.__dict__[args.arch](num_classes=int(args.data[5:]))
    origpar = sum(param.numel() for param in model.parameters())
    print('Original weight count:', origpar)
    torch.cuda.set_device(args.device)
    
    criterion = nn.CrossEntropyLoss().cuda()
    writer = SummaryWriter(args.save_dir)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.permute or args.get_permute:
        if args.get_permute:
            permute = torch.load(args.get_permute)['permute']
        elif args.resume:
            permute = torch.load(args.resume)['permute']
        else:
            permute = RowColPermute(32, 32)
        train_transforms = [transforms.ToTensor(), permute, normalize]
        val_transforms = [transforms.ToTensor(), permute, normalize]
    else:
        permute = None
        train_transforms = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize]
        val_transforms = [transforms.ToTensor(), normalize]

    cifar = datasets.CIFAR100 if args.data == 'cifar100' else datasets.CIFAR10
    train_loader = torch.utils.data.DataLoader(
        cifar(root='./data', train=True, transform=transforms.Compose(train_transforms), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        cifar(root='./data', train=False, transform=transforms.Compose(val_transforms)),
        batch_size=128, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define optimizer

    if args.half:
        model.half()
        criterion.half()

    if args.darts:
        model, original = Supernet.metamorphosize(model), model
        X, _ = next(iter(train_loader))
        arch_kwargs = {'perturb': args.perturb,
                       'verbose': not args.resume,
                       'warm_start': not args.from_scratch}
        patchlist = (['conv'] if args.patch_conv else []) \
                  + (['pool'] if args.patch_pool else []) \
                  + (['shortcut'] if args.patch_skip else [])
        model.patch_darts(X[:1], named_modules=((n, m) for n, m in model.named_modules() if any(patch in n for patch in patchlist)), **arch_kwargs)
        print('Model weight count:', sum(p.numel() for p in model.model_weights()))
        print('Arch param count:', sum(p.numel() for p in model.arch_params()))
    else:
        model, original = Chrysalis.metamorphosize(model), model
        if args.patch_skip or args.patch_conv or args.patch_pool:
            X, _ = next(iter(train_loader))
            arch_kwargs = {key: getattr(args, key) for key in [
                                                               'kmatrix_depth', 
                                                               'max_kernel_size', 
                                                               'global_biasing', 
                                                               'channel_gating',
                                                               'base',
                                                               'perturb',
                                                               ]}
            arch_kwargs['verbose'] = not args.resume
            arch_kwargs['warm_start'] = not args.from_scratch
            if args.patch_skip:
                model.patch_skip(X[:1], named_modules=((n, m) for n, m in model.named_modules() if 'shortcut' in n), **arch_kwargs)
            if args.patch_pool:
                model.patch_pool(X[:1], named_modules=((n, m) for n, m in model.named_modules() if 'pool' in n), **arch_kwargs)
            if args.patch_conv:
                model.patch_conv(X[:1], **arch_kwargs)
            print('Model weight count:', sum(p.numel() for p in model.model_weights()))
            print('Arch param count:', sum(p.numel() for p in model.arch_params()))
        else:
            args.arch_lr = 0.0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.offline:
        model.load_arch(args.offline)
        args.arch_lr = 0.0
        if args.darts:
            model.discretize()
            for name, module in model.named_modules():
                if hasattr(module, 'discrete'):
                    print(name, '\t', module.discrete)

    cudnn.benchmark = True
    model.cuda()

    momentum = partial(torch.optim.SGD, momentum=args.momentum)
    opts = [momentum(model.model_weights(), lr=args.lr, weight_decay=args.weight_decay)]
    if args.arch_lr:
        arch_opt = torch.optim.Adam if args.arch_adam else momentum
        opts.append(arch_opt(model.arch_params(), lr=args.arch_lr, weight_decay=0.0 if args.arch_adam else args.weight_decay))
    optimizer = MixedOptimizer(opts)

    def weight_sched(epoch):

        if args.arch in ['lenet']:
            return 0.1 if epoch >= int(0.75 * args.epochs) else 0.5 if epoch >= int(0.5 * args.epochs) else 1.0

        if epoch < 1 and args.arch in ['resnet1202', 'resnet110']:
            return 0.1
        return 0.1 ** (epoch >= int(0.5 * args.epochs)) * 0.1 ** (epoch >= int(0.75 * args.epochs))
    
    def arch_sched(epoch):
        return 0.0 if epoch < args.warmup_epochs or epoch > args.epochs-args.cooldown_epochs else weight_sched(epoch)

    sched_groups = [weight_sched if g['params'][0] in set(model.model_weights()) else arch_sched for g in optimizer.param_groups]
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=sched_groups, last_epoch=args.start_epoch-1)

    def metrics(epoch):

        if args.darts:
            return

        for label, name, patched in [
                                     ('skip', 'shortcut', args.patch_skip), 
                                     ('pool', 'pool', args.patch_pool),
                                     ('conv', 'conv', args.patch_conv),
                                     ]:
            if patched:
                mods = [m for n, m in model.named_modules() if name in n and hasattr(m, 'distance_from')]
                for metric, metric_kwargs in [
                                              ('euclidean', {}),
                                              ('frobenius', {'approx': 16}),
                                              ('averaged', {'approx': 16, 'samples': 10}),
                                              ]:
                    writer.add_scalar('/'.join([label, metric+'-dist']),
                                      sum(m.distance_from(label, metric=metric, relative=True, **metric_kwargs) for m in mods) / len(mods),
                                      epoch)
                    if not metric == 'averaged':
                        writer.add_scalar('/'.join([label, metric+'-norm']),
                                          sum(getattr(m, metric)(**metric_kwargs) for m in mods) / len(mods),
                                          epoch)
                writer.add_scalar(label+'/weight-norm', sum(m.weight.data.norm() for m in mods) / len(mods), epoch)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    with open(os.path.join(args.save_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    for epoch in range(args.start_epoch, args.epochs):

        writer.add_scalar('hyper/lr', weight_sched(epoch) * args.lr, epoch)
        writer.add_scalar('hyper/arch', arch_sched(epoch) * args.arch_lr, epoch)
        metrics(epoch)
        model.set_arch_requires_grad(arch_sched(epoch) * args.arch_lr > 0.0)

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))

        acc, loss = train(train_loader, model, criterion, optimizer, epoch)
        writer.add_scalar('train/acc', acc, epoch)
        writer.add_scalar('train/loss', loss, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1, loss = validate(val_loader, model, criterion)
        writer.add_scalar('valid/acc', prec1, epoch)
        writer.add_scalar('valid/loss', loss, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        model.train()
        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'permute': permute,
            }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th'))
        if epoch > 0 and epoch+1 == args.warmup_epochs:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'permute': permute,
            }, is_best, filename=os.path.join(args.save_dir, 'warmup.th'))
        if epoch > 0 or epoch+1 == args.epochs-args.cooldown_epochs:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'permute': permute,
            }, is_best, filename=os.path.join(args.save_dir, 'cooldown.th'))
        save_checkpoint({
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'permute': permute,
        }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

    model.save_arch(os.path.join(args.save_dir, 'arch.th'))
    metrics(args.epochs)
    writer.flush()
    with open(os.path.join(args.save_dir, 'results.json'), 'w') as f:
        json.dump({'final validation accuracy': prec1,
                   'best validation accuracy': best_prec1,
                   }, f, indent=4)