Example #1
0
    def __init__(self):
        super(ExpertNet, self).__init__()
        self.conv = nn.Conv2d(3,
                              16,
                              kernel_size=(7, 7),
                              stride=(2, 2),
                              padding=(3, 3),
                              bias=False)
        self.bn1 = nn.BatchNorm2d(16,
                                  eps=1e-05,
                                  momentum=0.1,
                                  affine=True,
                                  track_running_stats=True)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3,
                                    stride=2,
                                    padding=1,
                                    dilation=1,
                                    ceil_mode=False)
        self.block1 = resnet56().layer1
        self.block2 = resnet56().layer2

        #self.block1 = models.resnet50().layer1
        #self.block2 = models.resnet50().layer2
        # avg pooling to global pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(in_features=32, out_features=512, bias=True),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 10),
        )
 def __init__(self, class_num, droprate=0.5, stride=2):
     super(ft_net56_fc128, self).__init__()
     self.add_module("module", resnet.resnet56())
     weights_ = torch.load("weights_cifar10/resnet56-4bfd9763.th")
     self.load_state_dict(weights_['state_dict'])
     self.module.linear = nn.Sequential()
     self.classifier = ClassBlock(64, class_num, droprate, num_bottleneck=128)
 def __init__(self, class_num, droprate=0.5, stride=2):
     super(ft_net56_spp, self).__init__()
     self.add_module("module", resnet.resnet56())
     weights_ = torch.load("weights_cifar10/resnet56-4bfd9763.th")
     self.load_state_dict(weights_['state_dict'])
     self.module.linear = nn.Sequential()
     ####
     self.spp = pyrpool.SpatialPyramidPooling((1,2))
     self.classifier = ClassBlock(320, class_num, droprate, num_bottleneck=128)
        args.model,
        str(args.depth), args.dataset,
        'BS%d' % args.batch_size
    ]
    if args.origin:
        save_fold_name.insert(0, 'Origin')

    if args.model == 'resnet':
        if args.depth == 20:
            network = resnet.resnet20()
        if args.depth == 32:
            network = resnet.resnet32()
        if args.depth == 44:
            network = resnet.resnet44()
        if args.depth == 56:
            network = resnet.resnet56()
        if args.depth == 110:
            network = resnet.resnet110()

    if not args.origin:
        print('Pruning the model in %s' % args.pruned_model_dir)
        check_point = torch.load(args.pruned_model_dir + "model_best.pth.tar")
        network.load_state_dict(check_point['state_dict'])
        codebook_index_list = np.load(args.pruned_model_dir + "codebook.npy",
                                      allow_pickle=True).tolist()
        m_l = []
        b_l = []

        for i in network.modules():
            if isinstance(i, nn.Conv2d):
                m_l.append(i)
Example #5
0
def test():
    router = resnet56()
    #rweights = torch.load('./weights/router_resnet20_all_class.pth.tar')
    #rweights = torch.load('./weights/suSan.pth.tar')
    start_time = time.time()
    #rweights = torch.load('./weights/best_so_far_res56.pth.tar')
    rweights = torch.load('./weights/resnet56_fmnist.pth.tar')
    router.load_state_dict(rweights)
    if torch.cuda.is_available():
        router.cuda()
    router.eval()
    test_loss = 0
    correct = 0
    tt = 0
    c = 0
    delta = []
    for data, target in (test_loader):
        # if c == 50:
        #     break
        #        c = c + 1
        if (c % 20 == 0):
            print(
                "----- expert accuracy so far : {}/{}-----\n----- router accuracy so far : {}/{}-----"
                .format(correct, c, tt, c))
            print(
                "The DELTA/improvement between router and expert: {}\n".format(
                    abs(correct - tt)))
            if (c > 0):
                print(
                    "Forcasting {:.2f}% (approx) accuracy at the end\n".format(
                        ((10000.00 / c) * abs(tt - correct)) / 100.0 + 93.88))

        c = c + 1
        delta.append(abs(correct - tt))
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = router(data)
        output = F.softmax(output)
        #print (output) #################### Remove this while running
        rsoftmax = torch.sort(output, dim=1, descending=True)[0][0:,
                                                                 0:args.topn]
        pred = output.data.max(1, keepdim=True)[1]
        pred2 = torch.argsort(output, dim=1, descending=True)[0][1]
        tt += pred.eq(target.data.view_as(pred)).cpu().sum()
        #tt += pred2.eq(target.data.view_as(pred)).cpu().sum()

        sortedsoftmax = torch.argsort(output, dim=1,
                                      descending=True)[0:1, 0:args.topn]
        sortedsoftmax = np.array(sortedsoftmax.cpu())

        ## reSet/call the predicitons
        predictions = []
        for i in SUBSET:
            subset_flag[str(i)] = True

        for i, pred in enumerate(sortedsoftmax):
            for j in range(args.topn):
                predictions.append(pred[j])
        #print ("The top {} predictions of router: {}".format(args.topn, predictions))
        rsm = []
        for i, pred in enumerate(rsoftmax):
            for j in range(args.topn):
                rsm.append(pred[j])

        #sm = {}
        fout = torch.zeros([1, 10], device='cuda') + (output * 0.7)
        for i, pred in enumerate(predictions):
            #sm[pred] = 0
            tot = 0.0
            expert = resnet20()
            for sub in SUBSET:
                if pred in sub and subset_flag[str(sub)] == True:

                    ###### Load the saved weights for the experts #####
                    wt = "./weights/rr/random_injection_erasing/res20_fmnist/rr_subset_" + str(
                        sub) + ".pth.tar"

                    #wt = "./weights/latent_space_hardtraining/lp_subset_" + str(sub) + ".pth.tar"

                    wts = torch.load(wt)
                    expert.cuda()
                    expert.eval()
                    expert.load_state_dict(wts)
                    ############################

                    ### Inference part starts here ##########
                    output = F.softmax(expert(data))
                    #print (output)
                    #output = torch.sort(output, dim=1, descending=True)[0][0][0]
                    fout += output

                    #print (pred, target, output)
                    #sm[pred] += output.item()  #* trust_factor(len(sub), 2)
                    tot += 1
                    subset_flag[str(sub)] = False
        #fout = fout/tot
        #print ("Fout:",fout)
        prd = fout.data.max(1, keepdim=True)[1]
        #

        if (prd == target.item()):
            correct = correct + 1

    # correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print("\nThe routers performance: {:4f}".format(
        100.0 * (tt.data.item() / len(test_loader.dataset))))
    print('EMNN (ours) accuracy: {:.4f}%)\n'.format(100. * correct /
                                                    len(test_loader.dataset)))
    print("Total time taken {:.2f}.".format(time.time() - start_time))
    delta = np.array(delta)
    fl = "./inference_result/fmnist_delta_resnet56_[4_3].txt"
    with open(fl, 'w') as f:
        for item in delta:
            f.write("%s\n" % item)
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

tf.random.set_seed(22)
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).map(augmentation).map(
    normalize).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True)
test_dataset = test_dataset.map(normalize).batch(BS_PER_GPU * NUM_GPUS,
                                                 drop_remainder=True)

input_shape = (HEIGHT, WIDTH, NUM_CHANNELS)
img_input = tf.keras.layers.Input(shape=input_shape)

model = resnet.resnet56(img_input=img_input, classes=NUM_CLASSES)

# define optimizer
sgd = tf.keras.optimizers.SGD(lr=0.1)
model.compile(optimizer=sgd,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

earlystop_callback = EarlyStopping(monitor='val_accuracy',
                                   min_delta=0.0001,
                                   patience=1,
                                   verbose=1,
                                   mode='auto')

model.fit(train_dataset,
          epochs=NUM_EPOCHS,
def test():
    router = resnet56()
    #rweights = torch.load('./weights/router_resnet20_all_class.pth.tar')
    #rweights = torch.load('./weights/suSan.pth.tar')
    #rweights = torch.load('teacher_MLP_test_eresnet56_best_archi.pth.tar')
    rweights = torch.load('./weights/resnet56_fmnist.pth.tar')
    router.load_state_dict(rweights)
    if torch.cuda.is_available():
        router.cuda()
    router.eval()
    test_loss = 0
    correct = 0
    tt = 0
    c = 0
    for data, target in (test_loader):
        if c == 50:
            break
#        c = c + 1
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = router(data)
        output = F.softmax(output)

        rsoftmax = torch.sort(output, dim=1, descending=True)[0][0:,
                                                                 0:args.topn]
        pred = output.data.max(1, keepdim=True)[1]
        tt += pred.eq(target.data.view_as(pred)).cpu().sum()
        sortedsoftmax = torch.argsort(output, dim=1,
                                      descending=True)[0:1, 0:args.topn]
        sortedsoftmax = np.array(sortedsoftmax.cpu())

        ## reSet/call the predicitons
        predictions = []
        for i in SUBSET:
            subset_flag[str(i)] = True

        for i, pred in enumerate(sortedsoftmax):
            for j in range(args.topn):
                predictions.append(pred[j])
        #print ("The top {} predictions of router: {}".format(args.topn, predictions))
        rsm = []
        for i, pred in enumerate(rsoftmax):
            for j in range(args.topn):
                rsm.append(pred[j])

        sm = {}
        #fout = torch.zeros([1,10])
        for i, pred in enumerate(predictions):
            sm[pred] = 0
            tot = 0.0
            expert = resnet20()
            for sub in SUBSET:
                if pred in sub and subset_flag[str(sub)] == True:
                    ###### Load the saved weights for the experts #####
                    wt = "./weights/rr/random_injection_erasing/res20_fmnist/rr_subset_" + str(
                        sub) + ".pth.tar"
                    #wt = "./weights/latent_space_hardtraining/lp_subset_" + str(sub) + ".pth.tar"
                    wts = torch.load(wt)
                    expert.cuda()
                    expert.eval()
                    expert.load_state_dict(wts)
                    ############################
                    ### Inference part starts here ##########
                    output = F.softmax(expert(data))
                    #print (output)
                    output = torch.sort(output, dim=1,
                                        descending=True)[0][0][0]
                    #print (pred, target, output)
                    sm[pred] += output.item()  #* trust_factor(len(sub), 2)
                    tot += 1
                    subset_flag[str(sub)] = False
            #sm[pred] += (rsm[i].item())
            #if (tot != 0):
            sm[pred] += (rsm[i].item() * 0.9)
        #print (rsm[i].item())

#        for pred in predictions:
#            sm[pred] /= (tot)

        ans = -0.99
        prd = 0
        #        for p in predictions:
        #            print ("soft max for {} is {}".format(p, sm[p]))
        #        print ("the target value:", target)
        for p in predictions:
            if sm[p] >= ans:
                ans = sm[p]
                prd = p
        if (prd == target.item()):
            correct = correct + 1

#        if (predictions[0] != prd):
#            print ("The list of prediction: {} and the Target: {}".format(predictions, target))
#            print ("The softmax score of expert prediction {} for {}".format(sm[2], prd))
#            print ("The softmax score for acutally correct answer {}.".format(sm[target.item()]))
#            print ("the softmax score by the router for correct answer {}.".format(rsm[2]))

# correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print("Routers performance:", tt)
    print(
        '\nTest set: Average loss: {:.4f}, TOP 1 Accuracy: {}/{} ({:.4f}%)\n'.
        format(test_loss, correct, len(test_loader.dataset),
               100. * correct / len(test_loader.dataset)))
Example #8
0
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()

train_loader = tf.data.Dataset.from_tensor_slices((x, y))
test_loader = tf.data.Dataset.from_tensor_slices((x_test, y_test))

tf.random.set_seed(22)
train_loader = train_loader.map(augmentation).map(preprocess).shuffle(
    NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True)
test_loader = test_loader.map(preprocess).batch(BS_PER_GPU * NUM_GPUS,
                                                drop_remainder=True)

opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)

if NUM_GPUS == 1:
    model = resnet.resnet56(classes=NUM_CLASSES)
    model.compile(optimizer=opt,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
else:
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = resnet.resnet56(classes=NUM_CLASSES)
        model.compile(optimizer=opt,
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = TensorBoard(log_dir=log_dir,
Example #9
0
        correct += (predicted == labels).sum()

    accuracy = correct.double() * 1.0 / total
    print("Total: %d, Correct: %d, Accuracy: %f" %
          (total, correct.double(), accuracy))


# for dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

testset = torchvision.datasets.CIFAR10(root=args.data_dir,
                                       train=False,
                                       download=True,
                                       transform=transform)
testloader = t.utils.data.DataLoader(testset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.num_workers)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

net = resnet56()
if t.cuda.is_available():
    net = net.cuda()

evaluate(net, testloader)
Example #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=64)
    parser.add_argument('--nEpochs', type=int, default=300)
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--net')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--gpu_id', type=str, default='0')

    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = 'work/' + args.net

    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save)

    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normTransform
    ])
    testTransform = transforms.Compose([transforms.ToTensor(), normTransform])

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    trainLoader = DataLoader(dset.CIFAR10(root='cifar',
                                          train=True,
                                          download=True,
                                          transform=trainTransform),
                             batch_size=args.batchSz,
                             shuffle=True,
                             **kwargs)
    testLoader = DataLoader(dset.CIFAR10(root='cifar',
                                         train=False,
                                         download=True,
                                         transform=testTransform),
                            batch_size=args.batchSz,
                            shuffle=False,
                            **kwargs)

    n_classes = 10
    if args.net == 'resnet20':
        net = resnet.resnet20(num_classes=n_classes)
    elif args.net == 'resnet32':
        net = resnet.resnet32(num_classes=n_classes)
    elif args.net == 'resnet44':
        net = resnet.resnet44(num_classes=n_classes)
    elif args.net == 'resnet56':
        net = resnet.resnet56(num_classes=n_classes)
    elif args.net == 'resnet110':
        net = resnet.resnet110(num_classes=n_classes)
    elif args.net == 'resnetxt29':
        net = resnetxt.resnetxt29(num_classes=n_classes)
    elif args.net == 'deform_resnet32':
        net = deformconvnet.deform_resnet32(num_classes=n_classes)
    else:
        net = densenet.DenseNet(growthRate=12,
                                depth=100,
                                reduction=0.5,
                                bottleneck=True,
                                nClasses=n_classes)

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))
    if args.cuda:
        net = net.cuda()
        gpu_id = args.gpu_id
        gpu_list = gpu_id.split(',')
        gpus = [int(i) for i in gpu_list]
        net = nn.DataParallel(net, device_ids=gpus)

    if args.opt == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=1e-1,
                              momentum=0.9,
                              weight_decay=1e-4)
    elif args.opt == 'adam':
        optimizer = optim.Adam(net.parameters(), weight_decay=1e-4)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(net.parameters(), weight_decay=1e-4)

    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')

    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, net, trainLoader, optimizer, trainF)
        test(args, epoch, net, testLoader, optimizer, testF)
        torch.save(net, os.path.join(args.save, 'latest.pth'))
        os.system('python plot.py {} &'.format(args.save))

    trainF.close()
    testF.close()
		loader = CIFAR10Loader      (batch_size, p.getSpeeds(), p.getBatches())
		model  = SimpleCIFAR10Model ()
		num_epochs = 10

	elif dset == "RS_SimpleModel_CIFAR10":
		loader = CIFAR10ResnetLoader(batch_size, p.getSpeeds(), p.getBatches())
		import resnet
		
		if int(sys.argv[7]) == 20:
			model  = resnet.resnet20()
		if int(sys.argv[7]) == 32:
			model  = resnet.resnet32()
		if int(sys.argv[7]) == 44:
			model  = resnet.resnet44()
		if int(sys.argv[7]) == 56:
			model  = resnet.resnet56()
		if int(sys.argv[7]) == 110:
			model  = resnet.resnet110()

		num_epochs = 7

	elif dset == "MNIST":
		loader = MNISTLoader      (batch_size, p.getSpeeds(), p.getBatches())
		model  = SimpleMNISTModel ()
		num_epochs = 10
	else:
		print("DATASET NOT FOUND")

	p.setData  (loader)
	p.setModel (model)
Example #12
0
import os
import sys

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
from torchstat import stat
from resnet import resnet56
from vgg import vgg16
from ghostrestnet import gresnet56
from ghostvgg import gvgg16

if __name__ == '__main__':

    img_shape = (3, 32, 32)

    resnet56 = resnet56()
    stat(resnet56, img_shape)  # https://github.com/Swall0w/torchstat
    print("↑↑↑↑ is resnet56")
    print("\n" * 10)
'''
    ghost_resnet56 = gresnet56()
    stat(ghost_resnet56, img_shape)
    print("↑↑↑↑ is ghost_resnet56")

    
    vgg = 0
    if vgg:
        vgg16 = vgg16()
        stat(vgg16, img_shape)
        print("↑↑↑↑ is vgg16")
        print("\n"*10)
Example #13
0
def kd():
    global args
    args = parser.parse_args()

    # Make dataset and loader
    normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(
        args.data_path,
        train=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomCrop(32, 4),
            torchvision.transforms.ToTensor(), normalize
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)

    val_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(
        args.data_path,
        train=False,
        transform=torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor(), normalize])),
                                             batch_size=args.test_batch_size)

    # Load teacher (pretrained)
    teacher = resnet56()
    modify_properly(teacher, args.pretrained)
    teacher.cuda()

    # Make student
    student = resnet20()
    student.cuda()

    t_embedding = TEmbedding().cuda()
    s_embedding = SEmbedding().cuda()

    criterion = {
        'ce_loss': nn.CrossEntropyLoss().cuda(),
        'ct_loss': ContrastiveLoss()
    }

    params = list(student.parameters()) + list(
        t_embedding.parameters()) + list(s_embedding.parameters())
    optimizer = torch.optim.Adam(params, lr=args.lr)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[75, 150])

    # Evaluate teacher
    #if args.evaluate:
    #validate(teacher, val_loader, criterion)

    min_val_prec = 0.0
    logger = {
        'train/loss': [],
        'train/accuracy': [],
        'val/loss': [],
        'val/accuracy': []
    }

    for epoch in range(args.num_epochs):
        # training
        tr_logger = train(train_loader, student, teacher, s_embedding,
                          t_embedding, criterion, optimizer, epoch)

        # validating
        val_logger = validate(val_loader, student, criterion)

        logger['train/loss'].append(tr_logger['loss'].mean)
        logger['train/accuracy'].append(tr_logger['prec'].mean)

        logger['val/loss'].append(val_logger['loss'].mean)
        logger['val/accuracy'].append(val_logger['prec'].mean)

        lr_scheduler.step()

        if min_val_prec < val_logger['prec'].mean:
            min_val_prec = val_logger['prec'].mean
            torch.save(student.state_dict(), 'ckpt/cwfd-resnet20-epochs' +
                       str(epoch) + '.pt')  # TODO: add path variable

    print("maximum of avg. val accuracy: {}".format(min_val_prec))
    save_log(logger, 'logs/cwfd-resnet20.log')
## Import the model ###
model = resnet20()
if torch.cuda.is_available():
    model = model.cuda()

ck = torch.load(wt)
model.load_state_dict(ck)
optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=5e-4,
                      nesterov=True)

scheduler = StepLR(optimizer, step_size=40, gamma=0.1)

teacher = resnet56()
if torch.cuda.is_available():
    teacher = teacher.cuda()
ck = torch.load('./weights/suSan.pth.tar')
teacher.load_state_dict(ck)

print("Weight load success")


def distillation(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss()(F.log_softmax(
        y / T, dim=1), F.softmax(teacher_scores / T, dim=1)) * (
            T * T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)


def train(epoch, model, teacher, loss_fn):
Example #15
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = torch.nn.DataParallel(resnet.resnet56())
    model.cuda()

    if args.adv_train:
        state_dict = torch.load('./resnet_weight/resnet56/model.th')
        state_dict = state_dict['state_dict']
        model = resnet.resnet56()
        model.load_state_dict(state_dict)
    else:
        model = resnet.resnet56()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    #CIFAR-10 MEAN AND STD
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]),
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(
                    ),  # model.module.state_dcit() to avoid error in the future
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.th'))
Example #16
0
#test_dataset = load_seti_dataset("test_data.csv")
#valid_dataset = load_seti_dataset("valid_data.csv")

tf.random.set_seed(2727)
#train_dataset = train_dataset.map(augmentation).map(preprocess).shuffle(NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * TOTAL_GPU, drop_remainder=True)
#test_dataset = test_dataset.map(preprocess).batch(BS_PER_GPU * TOTAL_GPU, drop_remainder=True)

train_generator, valid_generator, test_generator, train_num, valid_num, test_num = pd.get_datasets(
)

input_shape = (config.HEIGHT, config.WIDTH, config.NUM_CHANNELS)
image_input = tf.keras.layers.Input(shape=input_shape)
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)

if TOTAL_GPU == 1:
    model = resnet.resnet56(img_input=image_input, classes=config.NUM_CLASSES)
    model.compile(optimizers=opt,
                  loss="sparse_categorical_crossentropy",
                  metrics=["sparse_categorical_accuracy"])
else:
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = resnet.resnet56(img_input=image_input,
                                classes=config.NUM_CLASSES)
        model.compile(
            optimizers=opt,
            loss="sparse_categorical_crossentropy",
            metrics=["sparse_categorical_accuracy"],
        )

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")