def __init__(self): super(ExpertNet, self).__init__() self.conv = nn.Conv2d(3, 16, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.bn1 = nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) self.block1 = resnet56().layer1 self.block2 = resnet56().layer2 #self.block1 = models.resnet50().layer1 #self.block2 = models.resnet50().layer2 # avg pooling to global pooling self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( nn.Linear(in_features=32, out_features=512, bias=True), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(512, 10), )
def __init__(self, class_num, droprate=0.5, stride=2): super(ft_net56_fc128, self).__init__() self.add_module("module", resnet.resnet56()) weights_ = torch.load("weights_cifar10/resnet56-4bfd9763.th") self.load_state_dict(weights_['state_dict']) self.module.linear = nn.Sequential() self.classifier = ClassBlock(64, class_num, droprate, num_bottleneck=128)
def __init__(self, class_num, droprate=0.5, stride=2): super(ft_net56_spp, self).__init__() self.add_module("module", resnet.resnet56()) weights_ = torch.load("weights_cifar10/resnet56-4bfd9763.th") self.load_state_dict(weights_['state_dict']) self.module.linear = nn.Sequential() #### self.spp = pyrpool.SpatialPyramidPooling((1,2)) self.classifier = ClassBlock(320, class_num, droprate, num_bottleneck=128)
args.model, str(args.depth), args.dataset, 'BS%d' % args.batch_size ] if args.origin: save_fold_name.insert(0, 'Origin') if args.model == 'resnet': if args.depth == 20: network = resnet.resnet20() if args.depth == 32: network = resnet.resnet32() if args.depth == 44: network = resnet.resnet44() if args.depth == 56: network = resnet.resnet56() if args.depth == 110: network = resnet.resnet110() if not args.origin: print('Pruning the model in %s' % args.pruned_model_dir) check_point = torch.load(args.pruned_model_dir + "model_best.pth.tar") network.load_state_dict(check_point['state_dict']) codebook_index_list = np.load(args.pruned_model_dir + "codebook.npy", allow_pickle=True).tolist() m_l = [] b_l = [] for i in network.modules(): if isinstance(i, nn.Conv2d): m_l.append(i)
def test(): router = resnet56() #rweights = torch.load('./weights/router_resnet20_all_class.pth.tar') #rweights = torch.load('./weights/suSan.pth.tar') start_time = time.time() #rweights = torch.load('./weights/best_so_far_res56.pth.tar') rweights = torch.load('./weights/resnet56_fmnist.pth.tar') router.load_state_dict(rweights) if torch.cuda.is_available(): router.cuda() router.eval() test_loss = 0 correct = 0 tt = 0 c = 0 delta = [] for data, target in (test_loader): # if c == 50: # break # c = c + 1 if (c % 20 == 0): print( "----- expert accuracy so far : {}/{}-----\n----- router accuracy so far : {}/{}-----" .format(correct, c, tt, c)) print( "The DELTA/improvement between router and expert: {}\n".format( abs(correct - tt))) if (c > 0): print( "Forcasting {:.2f}% (approx) accuracy at the end\n".format( ((10000.00 / c) * abs(tt - correct)) / 100.0 + 93.88)) c = c + 1 delta.append(abs(correct - tt)) if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = router(data) output = F.softmax(output) #print (output) #################### Remove this while running rsoftmax = torch.sort(output, dim=1, descending=True)[0][0:, 0:args.topn] pred = output.data.max(1, keepdim=True)[1] pred2 = torch.argsort(output, dim=1, descending=True)[0][1] tt += pred.eq(target.data.view_as(pred)).cpu().sum() #tt += pred2.eq(target.data.view_as(pred)).cpu().sum() sortedsoftmax = torch.argsort(output, dim=1, descending=True)[0:1, 0:args.topn] sortedsoftmax = np.array(sortedsoftmax.cpu()) ## reSet/call the predicitons predictions = [] for i in SUBSET: subset_flag[str(i)] = True for i, pred in enumerate(sortedsoftmax): for j in range(args.topn): predictions.append(pred[j]) #print ("The top {} predictions of router: {}".format(args.topn, predictions)) rsm = [] for i, pred in enumerate(rsoftmax): for j in range(args.topn): rsm.append(pred[j]) #sm = {} fout = torch.zeros([1, 10], device='cuda') + (output * 0.7) for i, pred in enumerate(predictions): #sm[pred] = 0 tot = 0.0 expert = resnet20() for sub in SUBSET: if pred in sub and subset_flag[str(sub)] == True: ###### Load the saved weights for the experts ##### wt = "./weights/rr/random_injection_erasing/res20_fmnist/rr_subset_" + str( sub) + ".pth.tar" #wt = "./weights/latent_space_hardtraining/lp_subset_" + str(sub) + ".pth.tar" wts = torch.load(wt) expert.cuda() expert.eval() expert.load_state_dict(wts) ############################ ### Inference part starts here ########## output = F.softmax(expert(data)) #print (output) #output = torch.sort(output, dim=1, descending=True)[0][0][0] fout += output #print (pred, target, output) #sm[pred] += output.item() #* trust_factor(len(sub), 2) tot += 1 subset_flag[str(sub)] = False #fout = fout/tot #print ("Fout:",fout) prd = fout.data.max(1, keepdim=True)[1] # if (prd == target.item()): correct = correct + 1 # correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) print("\nThe routers performance: {:4f}".format( 100.0 * (tt.data.item() / len(test_loader.dataset)))) print('EMNN (ours) accuracy: {:.4f}%)\n'.format(100. * correct / len(test_loader.dataset))) print("Total time taken {:.2f}.".format(time.time() - start_time)) delta = np.array(delta) fl = "./inference_result/fmnist_delta_resnet56_[4_3].txt" with open(fl, 'w') as f: for item in delta: f.write("%s\n" % item)
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data() train_dataset = tf.data.Dataset.from_tensor_slices((x, y)) test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) tf.random.set_seed(22) train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).map(augmentation).map( normalize).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True) test_dataset = test_dataset.map(normalize).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True) input_shape = (HEIGHT, WIDTH, NUM_CHANNELS) img_input = tf.keras.layers.Input(shape=input_shape) model = resnet.resnet56(img_input=img_input, classes=NUM_CLASSES) # define optimizer sgd = tf.keras.optimizers.SGD(lr=0.1) model.compile(optimizer=sgd, loss='sparse_categorical_crossentropy', metrics=['accuracy']) earlystop_callback = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=1, verbose=1, mode='auto') model.fit(train_dataset, epochs=NUM_EPOCHS,
def test(): router = resnet56() #rweights = torch.load('./weights/router_resnet20_all_class.pth.tar') #rweights = torch.load('./weights/suSan.pth.tar') #rweights = torch.load('teacher_MLP_test_eresnet56_best_archi.pth.tar') rweights = torch.load('./weights/resnet56_fmnist.pth.tar') router.load_state_dict(rweights) if torch.cuda.is_available(): router.cuda() router.eval() test_loss = 0 correct = 0 tt = 0 c = 0 for data, target in (test_loader): if c == 50: break # c = c + 1 if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = router(data) output = F.softmax(output) rsoftmax = torch.sort(output, dim=1, descending=True)[0][0:, 0:args.topn] pred = output.data.max(1, keepdim=True)[1] tt += pred.eq(target.data.view_as(pred)).cpu().sum() sortedsoftmax = torch.argsort(output, dim=1, descending=True)[0:1, 0:args.topn] sortedsoftmax = np.array(sortedsoftmax.cpu()) ## reSet/call the predicitons predictions = [] for i in SUBSET: subset_flag[str(i)] = True for i, pred in enumerate(sortedsoftmax): for j in range(args.topn): predictions.append(pred[j]) #print ("The top {} predictions of router: {}".format(args.topn, predictions)) rsm = [] for i, pred in enumerate(rsoftmax): for j in range(args.topn): rsm.append(pred[j]) sm = {} #fout = torch.zeros([1,10]) for i, pred in enumerate(predictions): sm[pred] = 0 tot = 0.0 expert = resnet20() for sub in SUBSET: if pred in sub and subset_flag[str(sub)] == True: ###### Load the saved weights for the experts ##### wt = "./weights/rr/random_injection_erasing/res20_fmnist/rr_subset_" + str( sub) + ".pth.tar" #wt = "./weights/latent_space_hardtraining/lp_subset_" + str(sub) + ".pth.tar" wts = torch.load(wt) expert.cuda() expert.eval() expert.load_state_dict(wts) ############################ ### Inference part starts here ########## output = F.softmax(expert(data)) #print (output) output = torch.sort(output, dim=1, descending=True)[0][0][0] #print (pred, target, output) sm[pred] += output.item() #* trust_factor(len(sub), 2) tot += 1 subset_flag[str(sub)] = False #sm[pred] += (rsm[i].item()) #if (tot != 0): sm[pred] += (rsm[i].item() * 0.9) #print (rsm[i].item()) # for pred in predictions: # sm[pred] /= (tot) ans = -0.99 prd = 0 # for p in predictions: # print ("soft max for {} is {}".format(p, sm[p])) # print ("the target value:", target) for p in predictions: if sm[p] >= ans: ans = sm[p] prd = p if (prd == target.item()): correct = correct + 1 # if (predictions[0] != prd): # print ("The list of prediction: {} and the Target: {}".format(predictions, target)) # print ("The softmax score of expert prediction {} for {}".format(sm[2], prd)) # print ("The softmax score for acutally correct answer {}.".format(sm[target.item()])) # print ("the softmax score by the router for correct answer {}.".format(rsm[2])) # correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) print("Routers performance:", tt) print( '\nTest set: Average loss: {:.4f}, TOP 1 Accuracy: {}/{} ({:.4f}%)\n'. format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data() train_loader = tf.data.Dataset.from_tensor_slices((x, y)) test_loader = tf.data.Dataset.from_tensor_slices((x_test, y_test)) tf.random.set_seed(22) train_loader = train_loader.map(augmentation).map(preprocess).shuffle( NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True) test_loader = test_loader.map(preprocess).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True) opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) if NUM_GPUS == 1: model = resnet.resnet56(classes=NUM_CLASSES) model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy']) else: mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = resnet.resnet56(classes=NUM_CLASSES) model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy']) log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") file_writer = tf.summary.create_file_writer(log_dir + "/metrics") file_writer.set_as_default() tensorboard_callback = TensorBoard(log_dir=log_dir,
correct += (predicted == labels).sum() accuracy = correct.double() * 1.0 / total print("Total: %d, Correct: %d, Accuracy: %f" % (total, correct.double(), accuracy)) # for dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform) testloader = t.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') net = resnet56() if t.cuda.is_available(): net = net.cuda() evaluate(net, testloader)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batchSz', type=int, default=64) parser.add_argument('--nEpochs', type=int, default=300) parser.add_argument('--no-cuda', action='store_true') parser.add_argument('--net') parser.add_argument('--seed', type=int, default=1) parser.add_argument('--opt', type=str, default='sgd', choices=('sgd', 'adam', 'rmsprop')) parser.add_argument('--gpu_id', type=str, default='0') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() args.save = 'work/' + args.net setproctitle.setproctitle(args.save) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save) normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] normTransform = transforms.Normalize(normMean, normStd) trainTransform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normTransform ]) testTransform = transforms.Compose([transforms.ToTensor(), normTransform]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} trainLoader = DataLoader(dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform), batch_size=args.batchSz, shuffle=True, **kwargs) testLoader = DataLoader(dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform), batch_size=args.batchSz, shuffle=False, **kwargs) n_classes = 10 if args.net == 'resnet20': net = resnet.resnet20(num_classes=n_classes) elif args.net == 'resnet32': net = resnet.resnet32(num_classes=n_classes) elif args.net == 'resnet44': net = resnet.resnet44(num_classes=n_classes) elif args.net == 'resnet56': net = resnet.resnet56(num_classes=n_classes) elif args.net == 'resnet110': net = resnet.resnet110(num_classes=n_classes) elif args.net == 'resnetxt29': net = resnetxt.resnetxt29(num_classes=n_classes) elif args.net == 'deform_resnet32': net = deformconvnet.deform_resnet32(num_classes=n_classes) else: net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=n_classes) print(' + Number of params: {}'.format( sum([p.data.nelement() for p in net.parameters()]))) if args.cuda: net = net.cuda() gpu_id = args.gpu_id gpu_list = gpu_id.split(',') gpus = [int(i) for i in gpu_list] net = nn.DataParallel(net, device_ids=gpus) if args.opt == 'sgd': optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4) elif args.opt == 'adam': optimizer = optim.Adam(net.parameters(), weight_decay=1e-4) elif args.opt == 'rmsprop': optimizer = optim.RMSprop(net.parameters(), weight_decay=1e-4) trainF = open(os.path.join(args.save, 'train.csv'), 'w') testF = open(os.path.join(args.save, 'test.csv'), 'w') for epoch in range(1, args.nEpochs + 1): adjust_opt(args.opt, optimizer, epoch) train(args, epoch, net, trainLoader, optimizer, trainF) test(args, epoch, net, testLoader, optimizer, testF) torch.save(net, os.path.join(args.save, 'latest.pth')) os.system('python plot.py {} &'.format(args.save)) trainF.close() testF.close()
loader = CIFAR10Loader (batch_size, p.getSpeeds(), p.getBatches()) model = SimpleCIFAR10Model () num_epochs = 10 elif dset == "RS_SimpleModel_CIFAR10": loader = CIFAR10ResnetLoader(batch_size, p.getSpeeds(), p.getBatches()) import resnet if int(sys.argv[7]) == 20: model = resnet.resnet20() if int(sys.argv[7]) == 32: model = resnet.resnet32() if int(sys.argv[7]) == 44: model = resnet.resnet44() if int(sys.argv[7]) == 56: model = resnet.resnet56() if int(sys.argv[7]) == 110: model = resnet.resnet110() num_epochs = 7 elif dset == "MNIST": loader = MNISTLoader (batch_size, p.getSpeeds(), p.getBatches()) model = SimpleMNISTModel () num_epochs = 10 else: print("DATASET NOT FOUND") p.setData (loader) p.setModel (model)
import os import sys BASE_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(BASE_DIR, '..')) from torchstat import stat from resnet import resnet56 from vgg import vgg16 from ghostrestnet import gresnet56 from ghostvgg import gvgg16 if __name__ == '__main__': img_shape = (3, 32, 32) resnet56 = resnet56() stat(resnet56, img_shape) # https://github.com/Swall0w/torchstat print("↑↑↑↑ is resnet56") print("\n" * 10) ''' ghost_resnet56 = gresnet56() stat(ghost_resnet56, img_shape) print("↑↑↑↑ is ghost_resnet56") vgg = 0 if vgg: vgg16 = vgg16() stat(vgg16, img_shape) print("↑↑↑↑ is vgg16") print("\n"*10)
def kd(): global args args = parser.parse_args() # Make dataset and loader normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10( args.data_path, train=True, transform=torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomCrop(32, 4), torchvision.transforms.ToTensor(), normalize ])), batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10( args.data_path, train=False, transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), normalize])), batch_size=args.test_batch_size) # Load teacher (pretrained) teacher = resnet56() modify_properly(teacher, args.pretrained) teacher.cuda() # Make student student = resnet20() student.cuda() t_embedding = TEmbedding().cuda() s_embedding = SEmbedding().cuda() criterion = { 'ce_loss': nn.CrossEntropyLoss().cuda(), 'ct_loss': ContrastiveLoss() } params = list(student.parameters()) + list( t_embedding.parameters()) + list(s_embedding.parameters()) optimizer = torch.optim.Adam(params, lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150]) # Evaluate teacher #if args.evaluate: #validate(teacher, val_loader, criterion) min_val_prec = 0.0 logger = { 'train/loss': [], 'train/accuracy': [], 'val/loss': [], 'val/accuracy': [] } for epoch in range(args.num_epochs): # training tr_logger = train(train_loader, student, teacher, s_embedding, t_embedding, criterion, optimizer, epoch) # validating val_logger = validate(val_loader, student, criterion) logger['train/loss'].append(tr_logger['loss'].mean) logger['train/accuracy'].append(tr_logger['prec'].mean) logger['val/loss'].append(val_logger['loss'].mean) logger['val/accuracy'].append(val_logger['prec'].mean) lr_scheduler.step() if min_val_prec < val_logger['prec'].mean: min_val_prec = val_logger['prec'].mean torch.save(student.state_dict(), 'ckpt/cwfd-resnet20-epochs' + str(epoch) + '.pt') # TODO: add path variable print("maximum of avg. val accuracy: {}".format(min_val_prec)) save_log(logger, 'logs/cwfd-resnet20.log')
## Import the model ### model = resnet20() if torch.cuda.is_available(): model = model.cuda() ck = torch.load(wt) model.load_state_dict(ck) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4, nesterov=True) scheduler = StepLR(optimizer, step_size=40, gamma=0.1) teacher = resnet56() if torch.cuda.is_available(): teacher = teacher.cuda() ck = torch.load('./weights/suSan.pth.tar') teacher.load_state_dict(ck) print("Weight load success") def distillation(y, labels, teacher_scores, T, alpha): return nn.KLDivLoss()(F.log_softmax( y / T, dim=1), F.softmax(teacher_scores / T, dim=1)) * ( T * T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha) def train(epoch, model, teacher, loss_fn):
def main(): global args, best_prec1 args = parser.parse_args() # Check the save_dir exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = torch.nn.DataParallel(resnet.resnet56()) model.cuda() if args.adv_train: state_dict = torch.load('./resnet_weight/resnet56/model.th') state_dict = state_dict['state_dict'] model = resnet.resnet56() model.load_state_dict(state_dict) else: model = resnet.resnet56() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True #CIFAR-10 MEAN AND STD normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): # train for one epoch print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) train(train_loader, model, criterion, optimizer, epoch) lr_scheduler.step() # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % args.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict( ), # model.module.state_dcit() to avoid error in the future 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.th'))
#test_dataset = load_seti_dataset("test_data.csv") #valid_dataset = load_seti_dataset("valid_data.csv") tf.random.set_seed(2727) #train_dataset = train_dataset.map(augmentation).map(preprocess).shuffle(NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * TOTAL_GPU, drop_remainder=True) #test_dataset = test_dataset.map(preprocess).batch(BS_PER_GPU * TOTAL_GPU, drop_remainder=True) train_generator, valid_generator, test_generator, train_num, valid_num, test_num = pd.get_datasets( ) input_shape = (config.HEIGHT, config.WIDTH, config.NUM_CHANNELS) image_input = tf.keras.layers.Input(shape=input_shape) opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) if TOTAL_GPU == 1: model = resnet.resnet56(img_input=image_input, classes=config.NUM_CLASSES) model.compile(optimizers=opt, loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"]) else: mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = resnet.resnet56(img_input=image_input, classes=config.NUM_CLASSES) model.compile( optimizers=opt, loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"], ) log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")