def get_model_config(model, dataset): """Map model name to model network configuration.""" if 'cifar10' == dataset.name: return get_cifar10_model_config(model) if model == 'vgg11': mc = vgg_model.Vgg11Model() elif model == 'vgg16': mc = vgg_model.Vgg16Model() elif model == 'vgg19': mc = vgg_model.Vgg19Model() elif model == 'lenet': mc = lenet_model.Lenet5Model() elif model == 'googlenet': mc = googlenet_model.GooglenetModel() elif model == 'overfeat': mc = overfeat_model.OverfeatModel() elif model == 'alexnet': mc = alexnet_model.AlexnetModel() elif model == 'trivial': mc = trivial_model.TrivialModel() elif model == 'inception3': mc = inception_model.Inceptionv3Model() elif model == 'inception4': mc = inception_model.Inceptionv4Model() elif model == 'resnet50' or model == 'resnet50_v2': mc = resnet_model.ResnetModel(model, (3, 4, 6, 3)) elif model == 'resnet101' or model == 'resnet101_v2': mc = resnet_model.ResnetModel(model, (3, 4, 23, 3)) elif model == 'resnet152' or model == 'resnet152_v2': mc = resnet_model.ResnetModel(model, (3, 8, 36, 3)) else: raise KeyError('Invalid model name \'%s\' for dataset \'%s\'' % (model, dataset.name)) return mc
def _get_scaled_base_learning_rate(self, num_gpus, variable_update, batch_size, base_lr=None): """Simplifies testing different learning rate calculations. Args: num_gpus: Number of GPUs to be used. variable_update: Type of variable update used. batch_size: Total batch size. base_lr: Base learning rate before scaling. Returns: Base learning rate that would be used to create lr schedule. """ params = mock.Mock() params.num_gpus = num_gpus params.variable_update = variable_update if base_lr: params.resnet_base_lr = base_lr resnet50_model = resnet_model.ResnetModel('resnet50', 50, params=params) return resnet50_model.get_scaled_base_learning_rate(batch_size)
import os import time from tqdm import tqdm from options.options import Options from models import resnet_model from datasets import create_dataset from utils.visualizer import Visualizer if __name__ == '__main__': opt = Options().parse_args() # get training options dataset = create_dataset(opt) model = resnet_model.ResnetModel(opt) visualizer = Visualizer( opt) # create a visualizer that display/save images and plots total_iters = 0 for epoch in range(opt.num_epoch): epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch for i, data in enumerate(dataset): # inner loop within one epoch iter_start_time = time.time( ) # timer for computation per iteration