Example #1
0
def main():
    opt = NetOption()

    # create data loader
    data_loader = DataLoader(dataset=opt.data_set,
                             batch_size=opt.batchSize,
                             data_path=opt.dataPath,
                             n_threads=opt.nThreads,
                             ten_crop=opt.tenCrop,
                             dataset_ratio=opt.datasetRatio)
    train_loader, test_loader = data_loader.getloader()

    # define check point
    check_point = CheckPoint(opt=opt)
    # create residual network model
    if opt.retrain:
        check_point_params = check_point.retrainmodel()

    if opt.netType == "LeNet5":
        model = MD.LeNet5()
    else:
        assert False, "testing model"

    if check_point_params['model'] is not None:
        previous_model_dict = check_point_params['model']
        # model.load_state_dict(check_point_params['model'])
        model_dict = model.state_dict()
        for key, value in previous_model_dict.items():
            if key in model_dict.keys():
                model_dict[key] = value
        model.load_state_dict(model_dict)
    # model = dataparallel(model, opt.nGPU, opt.GPU)
    model.cuda()

    # testing original model
    trainer = Trainer(model=model, opt=opt)
    # trainer.test(epoch=0, test_loader=test_loader)

    # filter level prune
    # prune lenet5
    print "model structure:", model
    print "--------------------------------------"

    result_record = []
    model, prune_record, result = segment_prune(model, train_loader,
                                                test_loader, trainer, opt)
    result_record.append(result)
    model, prune_record, result = segment_prune(model, train_loader,
                                                test_loader, trainer, opt,
                                                "classifier", prune_record)
    result_record.append(result)
    print "======================================"
    print result_record
Example #2
0
def setup_and_run(args, criterion, device, train_loader, test_loader,
                  val_loader, logging, results):
    global BEST_ACC
    print("\n#### Running REF ####")

    # architecture
    if args.architecture == "MLP":
        model = models.MLP(args.input_dim, args.hidden_dim,
                           args.output_dim).to(device)
    elif args.architecture == "LENET300":
        model = models.LeNet300(args.input_dim, args.output_dim).to(device)
    elif args.architecture == "LENET5":
        model = models.LeNet5(args.input_channels, args.im_size,
                              args.output_dim).to(device)
    elif "VGG" in args.architecture:
        assert (args.architecture == "VGG11" or args.architecture == "VGG13"
                or args.architecture == "VGG16"
                or args.architecture == "VGG19")
        model = models.VGG(args.architecture, args.input_channels,
                           args.im_size, args.output_dim).to(device)
    elif args.architecture == "RESNET18":
        model = models.ResNet18(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET34":
        model = models.ResNet34(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET50":
        model = models.ResNet50(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == "RESNET101":
        model = models.ResNet101(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    elif args.architecture == "RESNET152":
        model = models.ResNet152(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    else:
        print('Architecture type "{0}" not recognized, exiting ...'.format(
            args.architecture))
        exit()

    # optimizer
    if args.optimizer == "ADAM":
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    else:
        print('Optimizer type "{0}" not recognized, exiting ...'.format(
            args.optimizer))
        exit()

    # lr-scheduler
    if args.lr_decay == "STEP":
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=args.lr_scale)
    elif args.lr_decay == "EXP":
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.lr_scale)
    elif args.lr_decay == "MSTEP":
        x = args.lr_interval.split(",")
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=lri,
                                                   gamma=args.lr_scale)
        args.lr_interval = 1  # lr_interval handled in scheduler!
    else:
        print('LR decay type "{0}" not recognized, exiting ...'.format(
            args.lr_decay))
        exit()

    init_weights(model, xavier=True)
    logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    iters = 0  # total no of iterations, used to do many things!
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(
            args.eval))
        if not os.path.isfile(args.eval):
            print(
                'Checkpoint file "{0}" for evaluation not recognized, exiting ...'
                .format(args.eval))
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint["state_dict"])

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(
            args.resume))
        if not os.path.isfile(checkpoint_file):
            print('Checkpoint file "{0}" not recognized, exiting ...'.format(
                checkpoint_file))
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint["epoch"]
        assert args.architecture == checkpoint["architecture"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        BEST_ACC = checkpoint["best_acc1"]
        iters = checkpoint["iters"]
        logging.debug("best_acc1: {0}, iters: {1}".format(BEST_ACC, iters))

    if not args.eval:
        logging.info("Training...")
        model.train()
        st = timer()

        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(model, device, data, target, optimizer,
                               criterion)
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args,
                                          model,
                                          device,
                                          val_loader,
                                          training=True)
                    logging.info(
                        "Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} "
                        "(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}".format(
                            e, i, loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group["lr"]
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group["lr"]:
                            logging.info("lr: {0}".format(
                                param_group["lr"]))  # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args,
                                  model,
                                  device,
                                  val_loader,
                                  training=True)
            results.add(
                epoch=e,
                iteration=i,
                train_loss=l,
                val_acc1=acc1,
                best_val_acc1=BEST_ACC,
            )
            util.save_checkpoint(
                {
                    "epoch": e,
                    "architecture": args.architecture,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "best_acc1": BEST_ACC,
                    "iters": iters,
                },
                is_best=False,
                path=args.save_dir,
            )
            results.save()

        et = timer()
        logging.info("Elapsed time: {0} seconds".format(et - st))

        acc1, acc5 = evaluate(args, model, device, val_loader, training=True)
        logging.info(
            "End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}"
            .format(acc1=acc1, best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model["state_dict"])
    # end of training

    # eval-set
    if args.eval_set != "TRAIN" and args.eval_set != "TEST":
        print('Evaluation set "{0}" not recognized ...'.format(args.eval_set))

    logging.info("Evaluating REF on the {0} set...".format(args.eval_set))
    st = timer()
    if args.eval_set == "TRAIN":
        acc1, acc5 = evaluate(args, model, device, train_loader)
    else:
        acc1, acc5 = evaluate(args, model, device, test_loader)
    et = timer()
    logging.info("Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%".format(
        acc1=acc1, acc5=acc5))
    logging.info("Elapsed time: {0} seconds".format(et - st))
Example #3
0
def main(net_opt=None):
    """requirements:
    apt-get install graphviz
    pip install pydot termcolor"""

    start_time = time.time()
    opt = net_opt or NetOption()

    # set torch seed
    # init random seed
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed(opt.manualSeed)
    cudnn.benchmark = True
    if opt.nGPU == 1 and torch.cuda.device_count() >= 1:
        assert opt.GPU <= torch.cuda.device_count() - 1, "Invalid GPU ID"
        torch.cuda.set_device(opt.GPU)
    else:
        torch.cuda.set_device(opt.GPU)

    # create data loader
    data_loader = DataLoader(dataset=opt.data_set,
                             train_batch_size=opt.trainBatchSize,
                             test_batch_size=opt.testBatchSize,
                             n_threads=opt.nThreads,
                             ten_crop=opt.tenCrop)
    train_loader, test_loader = data_loader.getloader()

    # define check point
    check_point = CheckPoint(opt=opt)
    # create residual network model
    if opt.retrain:
        check_point_params = check_point.retrainmodel()
    elif opt.resume:
        check_point_params = check_point.resumemodel()
    else:
        check_point_params = check_point.check_point_params

    optimizer = check_point_params['opts']
    start_epoch = check_point_params['resume_epoch'] or 0
    if check_point_params['resume_epoch'] is not None:
        start_epoch += 1
    if start_epoch >= opt.nEpochs:
        start_epoch = 0
    if opt.netType == "ResNet":
        model = check_point_params['model'] or MD.ResNet(
            depth=opt.depth,
            num_classes=opt.nClasses,
            wide_factor=opt.wideFactor)
        model = dataparallel(model, opt.nGPU, opt.GPU)
    elif opt.netType == "PreResNet":
        model = check_point_params['model'] or MD.PreResNet(
            depth=opt.depth,
            num_classes=opt.nClasses,
            wide_factor=opt.wideFactor)
        model = dataparallel(model, opt.nGPU, opt.GPU)
    elif opt.netType == "LeNet5":
        model = check_point_params['model'] or MD.LeNet5()
        model = dataparallel(model, opt.nGPU, opt.GPU)

    else:
        assert False, "invalid net type"

    # create online board
    if opt.onlineBoard:
        try:
            online_board = BoardManager("main")
        except:
            online_board = None
            print "|===> Failed to create online board! Check whether you have ran <python -m visdom.server>"
    else:
        online_board = None

    trainer = Trainer(model=model,
                      opt=opt,
                      optimizer=optimizer,
                      online_board=online_board)
    print "|===>Create trainer"

    # define visualizer
    visualize = Visualization(opt=opt)
    visualize.writeopt(opt=opt)
    # visualize model
    if opt.drawNetwork:
        if opt.data_set == "cifar10" or opt.data_set == "cifar100":
            rand_input = torch.randn(1, 3, 32, 32)
        elif opt.data_set == "mnist":
            rand_input = torch.randn(1, 1, 28, 28)
        else:
            assert False, "invalid data set"
        rand_input = Variable(rand_input.cuda())
        rand_output = trainer.forward(rand_input)
        visualize.gennetwork(rand_output)
        visualize.savenetwork()

    # test model
    if opt.testOnly:
        trainer.test(epoch=0, test_loader=test_loader)
        return

    best_top1 = 100
    best_top5 = 100
    for epoch in range(start_epoch, opt.nEpochs):
        start_epoch = 0
        # training and testing
        train_error, train_loss, train5_error = trainer.train(
            epoch=epoch, train_loader=train_loader)
        test_error, test_loss, test5_error = trainer.test(
            epoch=epoch, test_loader=test_loader)

        # show training information on online board
        if online_board is not None:
            online_board.updateplot(train_error,
                                    train5_error,
                                    train_loss,
                                    mode="Train")
            online_board.updateplot(test_error,
                                    test5_error,
                                    test_loss,
                                    mode="Test")

        # write and print result
        log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
            epoch, train_error, train_loss, test_error, test_loss,
            train5_error, test5_error)
        visualize.writelog(log_str)
        best_flag = False
        if best_top1 >= test_error:
            best_top1 = test_error
            best_top5 = test5_error
            best_flag = True
            if online_board is not None:
                online_board.updateresult([best_top1, best_top5, test_loss])
            print colored(
                "==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" %
                (best_top1, best_top5), "red")
        else:
            print colored(
                "==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" %
                (best_top1, best_top5), "blue")

        # save check_point
        # save best result and recent state
        check_point.savemodel(epoch=epoch,
                              model=trainer.model,
                              opts=trainer.optimzer,
                              best_flag=best_flag)

        if (epoch + 1) % opt.drawInterval == 0:
            visualize.drawcurves()

    end_time = time.time()
    time_interval = end_time - start_time

    t_string = "Running Time is: " + str(
        datetime.timedelta(seconds=time_interval)) + "\n"
    print(t_string)

    # save experimental results
    visualize.writereadme(
        "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" %
        (best_top1, best_top5))
    visualize.writereadme(t_string)
    visualize.drawcurves()
Example #4
0
    # net = resnet.ResNet18()
    # net = net.cuda()
    # net = torch.nn.DataParallel(net)
    # checkpoint = torch.load("H:/adversarial_attacks/pytorch-cifar/checkpoint/DataPackpt.pth")
    # net.load_state_dict(checkpoint['net'])
    # target_model = net

    # resnet32
    if args.target_model == 'resnet32':
        target_model = cifar_loader.load_pretrained_cifar_resnet(flavor=32)
    elif args.target_model == 'resnet20':
        target_model = cifar_loader.load_pretrained_cifar_resnet(flavor=20)
    elif args.target_model == 'wideresnet':
        target_model = cifar_loader.load_pretrained_cifar_wide_resnet()
    elif args.target_model == "mnist_2":
        target_model = models.LeNet5()
        target_model.load_state_dict(torch.load('./trained_lenet5.pkl'))
    # target_model = target_model.cuda()
    # target_model.eval()

    # resnet32_advtrain
    # target_model = resnet32()
    # target_model.load_state_dict(torch.load('./advtrain.resnet32.000100.path.tar'))

    target_model = target_model.cuda()
    target_model.eval()

    model_num_labels = 10
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
Example #5
0
def training(args,*k,**kw):
    # if use gpus
    device = torch.device("cuda:{}".format(args.gpuindex) if torch.cuda.is_available() and args.gpu else "cpu")
    print("user device: {}".format(device))

    # redis helper related
    redis_helper = redishelper.GoSGDHelper(host=args.host, port=args.port)
    redis_helper.signin()
    while redis_helper.cur_edge_num() < args.edgenum:
        time.sleep(1) # sleep 1 second

    model_score = 1.0 / args.edgenum # the initial model parameters score

    # log_file and summary path

    log_file = "{0}-{1}-edge-{2}.log".format(time.strftime('%Y%m%d-%H%M%S',time.localtime(time.time())),
    args.model,redis_helper.ID)
    log_dir = "tbruns/{0}-{1}-cifar10-edge-{2}".format(time.strftime('%Y%m%d%H%M%S',time.localtime(time.time())),args.model,redis_helper.ID)

    logger = open(log_file,'w')
    swriter = SummaryWriter(log_dir)

    # load traing data
    trainset = dataset.AGGData(root=args.dataset, train=True, download=False, transform=None)

    testset = dataset.AGGData(root=args.dataset, train=False, download=False, transform=None)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=0)

    # construct neural network
    net = None
    if args.model == "lenet5":
        net = models.LeNet5()
    elif args.model == "resnet18":
        net = models.ResNet18()
    elif args.model == "alexnet":
        net = models.AlexNet(args.num_classes)
    elif args.model == "alexnetimg8":
        net = models.AlexNetImg8(args.num_classes)
    elif args.model == "squeezenet":
        net = models.SqueezeNet()
    elif args.model == "mobilenetv2":
        net = models.MobileNetV2()
    elif args.model == "resnet34":
        net = models.ResNet34()
    elif args.model == "resnet50":
        net = models.ResNet50()
    elif args.model == "resnet101":
        net = models.ResNet101()
    else:
        net = models.ResNet152()
    net.to(device)

    # define optimizer
    criterion = nn.CrossEntropyLoss()
    criterion_loss = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=list(args.lrschstep), gamma=0.1)

    # start training
    wallclock = 0.0
    iteration = 0 # global iterations
    for epoch in range(0,args.epoch,1):
        starteg = time.time()
        # merge parameters of other edge
        if epoch > 0:
            mintime,maxtime,param_list = redis_helper.min2max_time_params()
            print("The min/max time cost of last epoch: {}/{}".format(mintime,maxtime))
            for item in param_list:
                w1 = model_score / (model_score + item[0])
                w2 = item[0] / (model_score + item[0])

                for local,other in zip(net.parameters(),item[1]):
                    local.data = local.data * w1 + other.data.to(device) * w2
                model_score = model_score + item[0]

            while redis_helper.finish_update() is False:
                time.sleep(1.0)

        critical_extra_start = time.time()
        # identify critical training samples
        critrainset = critical_identify(net,trainset,criterion_loss,device,args)
        critrainloader = torch.utils.data.DataLoader(critrainset, batch_size=args.batchsize, shuffle=True, num_workers=0)

        critical_extra_cost = time.time() - critical_extra_start
        training_start = time.time()

        running_loss = 0.0
        record_running_loss = 0.0
        for i, data in enumerate(critrainloader, 0):
            iteration += 1
            # get the inputs
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.squeeze().to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            record_running_loss += loss.item()
            if i % 10 == 9:
                swriter.add_scalar("Training loss",record_running_loss / 10,epoch*len(critrainloader)+i)
                record_running_loss = 0.0

            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

        training_cost = time.time() - training_start

        # push time and parameters to Redis
        model_score = model_score / 2
        sel_edge_id = redis_helper.random_edge_id(can_be_self=True)
        paramls = list(map(lambda x: x.cpu(),list(net.parameters())))
        redis_helper.ins_time_params(sel_edge_id,training_cost,model_score,paramls)
        while not redis_helper.finish_push():
            time.sleep(1.0)

        wallclock += time.time() - starteg

        total, kaccuracy = validation(net,testloader,device,topk=(1,5))

        curtime = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
        _header="[ {} Epoch {} /Iteration {} Wallclock {}]".format(curtime,epoch+1,iteration, wallclock)

        print('{} Accuracy of the network on the {} test images: {} %'.format(_header, total, kaccuracy_str(kaccuracy)))
        logger.write('{},{},{},{}\n'.format(epoch+1 ,iteration, wallclock, accuracy_str(kaccuracy)))
        logger.flush() # write to disk

        for item in kaccuracy:
            swriter.add_scalar("Top{}Accuracy".format(item[0]), item[1], epoch)

        # adopt learning rate of optimizer
        if args.lrscheduler:
            lr_scheduler.step()

    print('Finished Training')

    redis_helper.register_out()
    logger.close() # close log file writer

    return net
Example #6
0
def setup_and_run(args, criterion, device, train_loader, test_loader,
                  val_loader, logging, results):
    global BEST_ACC
    print('\n#### Running REF ####')

    # architecture
    if args.architecture == 'MLP':
        model = models.MLP(args.input_dim, args.hidden_dim,
                           args.output_dim).to(device)
    elif args.architecture == 'LENET300':
        model = models.LeNet300(args.input_dim, args.output_dim).to(device)
    elif args.architecture == 'LENET5':
        model = models.LeNet5(args.input_channels, args.im_size,
                              args.output_dim).to(device)
    elif 'VGG' in args.architecture:
        assert (args.architecture == 'VGG11' or args.architecture == 'VGG13'
                or args.architecture == 'VGG16'
                or args.architecture == 'VGG19')
        model = models.VGG(args.architecture, args.input_channels,
                           args.im_size, args.output_dim).to(device)
    elif args.architecture == 'RESNET18':
        model = models.ResNet18(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET34':
        model = models.ResNet34(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET50':
        model = models.ResNet50(args.input_channels, args.im_size,
                                args.output_dim).to(device)
    elif args.architecture == 'RESNET101':
        model = models.ResNet101(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    elif args.architecture == 'RESNET152':
        model = models.ResNet152(args.input_channels, args.im_size,
                                 args.output_dim).to(device)
    else:
        print 'Architecture type "{0}" not recognized, exiting ...'.format(
            args.architecture)
        exit()

    # optimizer
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              nesterov=args.nesterov,
                              weight_decay=args.weight_decay)
    else:
        print 'Optimizer type "{0}" not recognized, exiting ...'.format(
            args.optimizer)
        exit()

    # lr-scheduler
    if args.lr_decay == 'STEP':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=args.lr_scale)
    elif args.lr_decay == 'EXP':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.lr_scale)
    elif args.lr_decay == 'MSTEP':
        x = args.lr_interval.split(',')
        lri = [int(v) for v in x]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=lri,
                                                   gamma=args.lr_scale)
        args.lr_interval = 1  # lr_interval handled in scheduler!
    else:
        print 'LR decay type "{0}" not recognized, exiting ...'.format(
            args.lr_decay)
        exit()

    init_weights(model, xavier=True)
    logging.info(model)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("Number of parameters: %d", num_parameters)

    start_epoch = -1
    iters = 0  # total no of iterations, used to do many things!
    # optionally resume from a checkpoint
    if args.eval:
        logging.info('Loading checkpoint file "{0}" for evaluation'.format(
            args.eval))
        if not os.path.isfile(args.eval):
            print 'Checkpoint file "{0}" for evaluation not recognized, exiting ...'.format(
                args.eval)
            exit()
        checkpoint = torch.load(args.eval)
        model.load_state_dict(checkpoint['state_dict'])

    elif args.resume:
        checkpoint_file = args.resume
        logging.info('Loading checkpoint file "{0}" to resume'.format(
            args.resume))
        if not os.path.isfile(checkpoint_file):
            print 'Checkpoint file "{0}" not recognized, exiting ...'.format(
                checkpoint_file)
            exit()
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        assert (args.architecture == checkpoint['architecture'])
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        BEST_ACC = checkpoint['best_acc1']
        iters = checkpoint['iters']
        logging.debug('best_acc1: {0}, iters: {1}'.format(BEST_ACC, iters))

    if not args.eval:
        logging.info('Training...')
        model.train()
        st = timer()

        for e in range(start_epoch + 1, args.num_epochs):
            for i, (data, target) in enumerate(train_loader):
                l = train_step(model, device, data, target, optimizer,
                               criterion)
                if i % args.log_interval == 0:
                    acc1, acc5 = evaluate(args,
                                          model,
                                          device,
                                          val_loader,
                                          training=True)
                    logging.info(
                        'Epoch: {0},\t Iter: {1},\t Loss: {loss:.5f},\t Val-Acc1: {acc1:.2f} '
                        '(Best: {best:.2f}),\t Val-Acc5: {acc5:.2f}'.format(
                            e, i, loss=l, acc1=acc1, best=BEST_ACC, acc5=acc5))

                if iters % args.lr_interval == 0:
                    lr = args.learning_rate
                    for param_group in optimizer.param_groups:
                        lr = param_group['lr']
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        if lr != param_group['lr']:
                            logging.info('lr: {0}'.format(
                                param_group['lr']))  # print if changed
                iters += 1

            # save checkpoint
            acc1, acc5 = evaluate(args,
                                  model,
                                  device,
                                  val_loader,
                                  training=True)
            results.add(epoch=e,
                        iteration=i,
                        train_loss=l,
                        val_acc1=acc1,
                        best_val_acc1=BEST_ACC)
            util.save_checkpoint(
                {
                    'epoch': e,
                    'architecture': args.architecture,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'best_acc1': BEST_ACC,
                    'iters': iters
                },
                is_best=False,
                path=args.save_dir)
            results.save()

        et = timer()
        logging.info('Elapsed time: {0} seconds'.format(et - st))

        acc1, acc5 = evaluate(args, model, device, val_loader, training=True)
        logging.info(
            'End of training, Val-Acc: {acc1:.2f} (Best: {best:.2f}), Val-Acc5: {acc5:.2f}'
            .format(acc1=acc1, best=BEST_ACC, acc5=acc5))
        # load saved model
        saved_model = torch.load(args.save_name)
        model.load_state_dict(saved_model['state_dict'])
    # end of training

    # eval-set
    if args.eval_set != 'TRAIN' and args.eval_set != 'TEST':
        print 'Evaluation set "{0}" not recognized ...'.format(args.eval_set)

    logging.info('Evaluating REF on the {0} set...'.format(args.eval_set))
    st = timer()
    if args.eval_set == 'TRAIN':
        acc1, acc5 = evaluate(args, model, device, train_loader)
    else:
        acc1, acc5 = evaluate(args, model, device, test_loader)
    et = timer()
    logging.info('Accuracy: top-1: {acc1:.2f}, top-5: {acc5:.2f}%'.format(
        acc1=acc1, acc5=acc5))
    logging.info('Elapsed time: {0} seconds'.format(et - st))