Пример #1
0
def main():
    """Main function to train a model"""

    args = parse_args()
    fpath_config = args.config
    output_dir = args.output_dir

    with open(fpath_config, 'r') as f:
        config = yaml.load(f)
    validate_config(config)
    save_config(config, output_dir)

    dataset_name = config['dataset']['name']
    dataset_iterator = DataSetIterator(name=dataset_name)

    # we know that the validate_config function ensures that there is at least
    # a loss specified in compile_args, and the KerasTrainer class specifies
    # defaults for fit_args if None
    compile_args = config['compile_args']
    fit_args = config.get('fit_args', None)

    # NOTE: This is the only network supported right now, but in the future
    # it might be configurable via the config
    network = VGG16()
    trainer = KerasTrainer(output_dir=output_dir)
    trainer.train(
        network, dataset_iterator,
        compile_args, fit_args, 
        args.gpu_ids
    )
Пример #2
0
    def test_build(self):
        """Test build method"""

        # test across the different input shapes and number of classes from
        # cifar10, cifar100, and mnist
        datasets = {
            'cifar10': {
                'input_shape': (32, 32, 3),
                'num_classes': 10
            },
            'cifar100': {
                'input_shape': (32, 32, 3),
                'num_classes': 100
            },
            'mnist': {
                'input_shape': (28, 28, 1),
                'num_classes': 10
            }
        }

        for _, dataset_specs in datasets.items():
            vgg16 = VGG16()

            model = vgg16.build(input_shape=dataset_specs['input_shape'],
                                num_classes=dataset_specs['num_classes'])
    def __init__(self):
        self.tensor_input = tf.placeholder(tf.float32, shape=(None, 128, 128, 3), name='input')
        self.tensor_lb = tf.placeholder(tf.int32, shape=(None, ), name='lb_action')
        self.tensor_is_training = tf.placeholder(tf.bool, name='is_training')
        self.learning_rate_placeholder = tf.placeholder(tf.float32, [], name='learning_rate')
        config = tf.ConfigProto() #log_device_placement=True)

        self.persistent_sess = tf.Session(config=config)

        self.model = VGG16(self.learning_rate_placeholder)
        self.model.create_network(self.tensor_input, self.tensor_lb, self.tensor_is_training)
        #self.persistent_sess.run(tf.global_variables_initializer())
        self.model.read_original_weights(self.persistent_sess)
        self.iteration = 0
        self.imgwh = None

        self.loss = []
        self.train, self.test =[], []
        self.train_lb, self.test_lb = [], []

        self.stopwatch = StopWatchManager()
Пример #4
0
def create_model(network, input_shape, img_channels, img_rows, img_cols, nb_classes):
    print("Acquring Network Model: ")
    if network == "LeNet5":
        model = LeNet5.GetNetArchitecture(input_shape)
        model_name = "LeNet5"
    elif network == "VGG16":
        model = VGG16.GetNetArchitecture(input_shape)
        model_name = "VGG16"
    elif network == "VGG19":
        model = VGG19.GetNetArchitecture(input_shape)
        model_name = "VGG19"

    elif network == "resnet18":
        model = resnet.ResnetBuilder.build_resnet18((img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet18"
    elif network == "resnet34":
        model = resnet.ResnetBuilder.build_resnet18((img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet34"

    elif network == "resnet50":
        model = resnet.ResnetBuilder.build_resnet18((img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet50"

    elif network == "resnet101":
        model = resnet.ResnetBuilder.build_resnet18((img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet101"

    elif network == "resnet152":
        model = resnet.ResnetBuilder.build_resnet18((img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet152"

    elif network =="FC":
        model = FC.GetNetArchitecture(input_shape)
        model_name = "Fully Connected"

    elif network == 'BN_LeNet5':
        model = BN_LeNet5.GetNetArchitecture(input_shape)
        model_name = 'BN_LeNet5'

    return model, model_name
Пример #5
0
    def initialize(self):
        if self.use_gpu:
            device_ids = range(len(self.config['gpu']))
            self.G = nn.DataParallel(self.G, device_ids=device_ids).cuda(0)
        if self.config['phase'] in ['train', 'finetune']:
            self.writer = SummaryWriter(log_dir=self.config['model_dir'])
            if self.config['use_vgg']:
                self.VGG = VGG16(pretrain_model=self.config['vgg_model'])
                if self.use_gpu:
                    self.VGG = nn.DataParallel(self.VGG,
                                               device_ids=device_ids).cuda(0)
                self.VGG.eval()
                self.style_criterion = StyleLoss(p=1)
                self.perceptual_criterion = PerceptualLoss(p=1)
                if self.use_gpu:
                    self.style_criterion = self.style_criterion.cuda()
                    self.perceptual_criterion = self.perceptual_criterion.cuda(
                    )
            if self.config['use_gan']:
                if self.config['progressive_growing']:
                    disc = ProgressiveGrowingDiscriminator
                else:
                    disc = PatchDiscriminator if self.config[
                        'patchgan'] else Discriminator
                self.D = disc(pix2pix_style=self.config['pix2pix_style'],
                              init_type=self.config['init_type'])
                if self.use_gpu:
                    self.D = nn.DataParallel(self.D,
                                             device_ids=device_ids).cuda(0)
                if self.config['use_local_d']:
                    self.D_local = LocalDiscriminator(pix2pix_style=self.config['pix2pix_style'], init_type=self.config['init_type'], \
                                                        patch_size=self.config['patch_size'])
                    if self.use_gpu:
                        self.D_local = nn.DataParallel(
                            self.D_local, device_ids=device_ids).cuda(0)
                if self.config['use_local_d']:
                    self.trainer_d = Adam(chain(self.D.parameters(),
                                                self.D_local.parameters()),
                                          lr=self.config[self.config['phase'] +
                                                         '_lr'])
                else:
                    self.trainer_d = Adam(self.D.parameters(),
                                          lr=self.config[self.config['phase'] +
                                                         '_lr'])
                self.gan_criterion = nn.BCELoss(
                ) if self.config['gan_loss'] == 'bce' else nn.MSELoss()
                if self.use_gpu:
                    self.gan_criterion = self.gan_criterion.cuda()

            if self.config['use_tv_loss']:
                self.tv_criterion = TVLoss(p=1)
                if self.use_gpu:
                    self.tv_criterion = self.tv_criterion.cuda()

            self.pixel_criterion = PixelLoss(p=1)
            if self.use_gpu:
                self.pixel_criterion = self.pixel_criterion.cuda()

            self.trainer_g = Adam(self.G.parameters(),
                                  lr=self.config[self.config['phase'] + '_lr'])

        self.restore()

## Constructing the training loader
train_loader = MMTDataset_baseline(args.train_set, args.labels_dict, args.img_folder, args.input_dim)
train_loader = DataLoader(train_loader, batch_size=args.batch_size, shuffle=True)

## Constructing the validation loader
val_loader = MMTDataset_baseline(args.val_set, args.labels_dict, args.img_folder, args.input_dim)
val_loader = DataLoader(val_loader, batch_size=args.batch_size, shuffle=True)

## Constructing the test loader
test_loader = MMTDataset_baseline(args.test_set, args.labels_dict, args.img_folder, args.input_dim)
test_loader = DataLoader(test_loader, batch_size=args.batch_size, shuffle=True)

## Initializing the model
model = VGG16(n_classes=18).cuda()
model.epochs = args.epochs
model.session_name = args.session_name
model.load_pretrained(args.pretrained_weights)
model.freeze_layers(args.frozen_stages)

if not os.path.exists(model.session_name):
    os.makedirs(model.session_name)

param_groups = model.get_parameter_groups()
optimizer = PolyOptimizer([
    {'params': param_groups[0], 'lr': 8*args.lr, 'weight_decay': args.weight_decay},
    {'params': param_groups[1], 'lr': 16 * args.lr, 'weight_decay': 0},
    {'params': param_groups[2], 'lr': 10 * args.lr, 'weight_decay': args.weight_decay},
    {'params': param_groups[3], 'lr': 20 * args.lr, 'weight_decay': 0}
], lr=args.lr, weight_decay=args.weight_decay, max_step=len(train_loader)*args.epochs)
Пример #7
0
def create_model(network, input_shape, img_channels, img_rows, img_cols,
                 nb_classes):
    print("Acquring Network Model: ")
    if network == "LeNet5":
        model = LeNet5.GetNetArchitecture(input_shape)
        model_name = "LeNet5"

    elif network == "VGG16":
        model = VGG16.GetNetArchitecture(input_shape)
        model_name = "VGG16"

    elif network == "VGG19":
        model = VGG19.GetNetArchitecture(input_shape)
        model_name = "VGG19"

    elif network == "resnet18":
        model = resnet.ResnetBuilder.build_resnet_18(
            (img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet18"

    elif network == "resnet34":
        model = resnet.ResnetBuilder.build_resnet_34(
            (img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet34"

    elif network == "resnet50":
        model = resnet.ResnetBuilder.build_resnet_50(
            (img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet50"

    elif network == "resnet101":
        model = resnet.ResnetBuilder.build_resnet_101(
            (img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet101"

    elif network == "resnet152":
        model = resnet.ResnetBuilder.build_resnet_152(
            (img_channels, img_rows, img_cols), nb_classes)
        model_name = "Resnet152"

    elif network == 'BN_LeNet5':
        model = BN_LeNet5.GetNetArchitecture(input_shape)
        model_name = 'BN_LeNet5'

    elif network == 'FC':
        model = FC.GetNetArchitecture(input_shape)
        model_name = 'FC'

    elif network == 'Conv5':
        model = Conv5.GetNetArchitecture(input_shape)
        model_name = 'Conv5'
    elif network == "LeNetTanh":
        model = LeNet5Tanh.GetNetArchitecture(input_shape)
        model_name = "LeNetTanh"
    elif network == "TR_LeNet5":
        model = TR_LeNet5.GetNetArchitecture(input_shape)
        model_name = "TR_LeNet5"
    elif network == "CN1":
        model = CN1.GetNetArchitecture(input_shape)
        model_name = "CN1"
    return model, model_name