예제 #1
0
def get_model_config(model, dataset):
  """Map model name to model network configuration."""
  if 'cifar10' == dataset.name:
    return get_cifar10_model_config(model)
  if model == 'vgg11':
    mc = vgg_model.Vgg11Model()
  elif model == 'vgg16':
    mc = vgg_model.Vgg16Model()
  elif model == 'vgg19':
    mc = vgg_model.Vgg19Model()
  elif model == 'lenet':
    mc = lenet_model.Lenet5Model()
  elif model == 'googlenet':
    mc = googlenet_model.GooglenetModel()
  elif model == 'overfeat':
    mc = overfeat_model.OverfeatModel()
  elif model == 'alexnet':
    mc = alexnet_model.AlexnetModel()
  elif model == 'trivial':
    mc = trivial_model.TrivialModel()
  elif model == 'inception3':
    mc = inception_model.Inceptionv3Model()
  elif model == 'inception4':
    mc = inception_model.Inceptionv4Model()
  elif model == 'resnet50' or model == 'resnet50_v2':
    mc = resnet_model.ResnetModel(model, (3, 4, 6, 3))
  elif model == 'resnet101' or model == 'resnet101_v2':
    mc = resnet_model.ResnetModel(model, (3, 4, 23, 3))
  elif model == 'resnet152' or model == 'resnet152_v2':
    mc = resnet_model.ResnetModel(model, (3, 8, 36, 3))
  else:
    raise KeyError('Invalid model name \'%s\' for dataset \'%s\'' %
                   (model, dataset.name))
  return mc
예제 #2
0
    def _get_scaled_base_learning_rate(self,
                                       num_gpus,
                                       variable_update,
                                       batch_size,
                                       base_lr=None):
        """Simplifies testing different learning rate calculations.

    Args:
      num_gpus: Number of GPUs to be used.
      variable_update: Type of variable update used.
      batch_size: Total batch size.
      base_lr: Base learning rate before scaling.

    Returns:
      Base learning rate that would be used to create lr schedule.
    """
        params = mock.Mock()
        params.num_gpus = num_gpus
        params.variable_update = variable_update
        if base_lr:
            params.resnet_base_lr = base_lr
        resnet50_model = resnet_model.ResnetModel('resnet50',
                                                  50,
                                                  params=params)
        return resnet50_model.get_scaled_base_learning_rate(batch_size)
예제 #3
0
파일: train.py 프로젝트: xinwen-cs/AudioDVP
import os
import time
from tqdm import tqdm

from options.options import Options
from models import resnet_model
from datasets import create_dataset
from utils.visualizer import Visualizer

if __name__ == '__main__':
    opt = Options().parse_args()  # get training options

    dataset = create_dataset(opt)

    model = resnet_model.ResnetModel(opt)

    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots

    total_iters = 0

    for epoch in range(opt.num_epoch):

        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch

            iter_start_time = time.time(
            )  # timer for computation per iteration