예제 #1
0
def train(dataset='mnist',
          model_name='sl',
          batch_size=128,
          epochs=50,
          noise_ratio=0,
          asym=False,
          alpha=1.0,
          beta=1.0):
    """
    Train one model with data augmentation: random padding+cropping and horizontal flip
    :param dataset: 
    :param model_name:
    :param batch_size: 
    :param epochs: 
    :param noise_ratio: 
    :return: 
    """
    print(
        'Dataset: %s, model: %s, batch: %s, epochs: %s, noise ratio: %s%%, asymmetric: %s, alpha: %s, beta: %s'
        % (dataset, model_name, batch_size, epochs, noise_ratio, asym, alpha,
           beta))

    # load data
    X_train, y_train, y_train_clean, X_test, y_test = get_data(
        dataset, noise_ratio, asym=asym, random_shuffle=False)
    n_images = X_train.shape[0]
    image_shape = X_train.shape[1:]
    num_classes = y_train.shape[1]
    print("n_images", n_images, "num_classes", num_classes, "image_shape:",
          image_shape)

    # load model
    model = get_model(dataset,
                      input_tensor=None,
                      input_shape=image_shape,
                      num_classes=num_classes)
    # model.summary()

    if dataset == 'cifar-100':
        optimizer = SGD(lr=0.1, decay=5e-3, momentum=0.9)
    else:
        optimizer = SGD(lr=0.1, decay=1e-4, momentum=0.9)

    # create loss
    if model_name == 'ce':
        loss = cross_entropy
    elif model_name == 'sl':
        loss = symmetric_cross_entropy(alpha, beta)
    elif model_name == 'lsr':
        loss = lsr
    elif model_name == 'joint':
        loss = joint_optimization_loss
    elif model_name == 'gce':
        loss = generalized_cross_entropy
    elif model_name == 'boot_hard':
        loss = boot_hard
    elif model_name == 'boot_soft':
        loss = boot_soft
    elif model_name == 'forward':
        loss = forward(P)
    elif model_name == 'backward':
        loss = backward(P)
    else:
        print("Model %s is unimplemented!" % model_name)
        exit(0)

    # model
    model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])

    if asym:
        model_save_file = "model/asym_%s_%s_%s.{epoch:02d}.hdf5" % (
            model_name, dataset, noise_ratio)
    else:
        model_save_file = "model/%s_%s_%s.{epoch:02d}.hdf5" % (
            model_name, dataset, noise_ratio)

    ## do real-time updates using callbakcs
    callbacks = []

    if model_name == 'sl':
        cp_callback = ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)
    else:
        cp_callback = ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)

    # learning rate scheduler if use sgd
    lr_scheduler = get_lr_scheduler(dataset)
    callbacks.append(lr_scheduler)

    callbacks.append(SGDLearningRateTracker(model))

    # acc, loss, lid
    log_callback = LoggerCallback(model, X_train, y_train, y_train_clean,
                                  X_test, y_test, dataset, model_name,
                                  noise_ratio, asym, epochs, alpha, beta)
    callbacks.append(log_callback)

    # data augmentation
    if dataset in ['mnist', 'svhn']:
        datagen = ImageDataGenerator()
    elif dataset in ['cifar-10']:
        datagen = ImageDataGenerator(width_shift_range=0.2,
                                     height_shift_range=0.2,
                                     horizontal_flip=True)
    else:
        datagen = ImageDataGenerator(rotation_range=20,
                                     width_shift_range=0.2,
                                     height_shift_range=0.2,
                                     horizontal_flip=True)
    datagen.fit(X_train)

    # train model
    model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                        steps_per_epoch=len(X_train) / batch_size,
                        epochs=epochs,
                        validation_data=(X_test, y_test),
                        verbose=1,
                        callbacks=callbacks)
예제 #2
0
    def __init__(self, args):
        """
        Args:
            args: Configuration args passed in via the command line.
        """
        super(CycleGAN, self).__init__()
        self.device = 'cuda' if len(args.gpu_ids) > 0 else 'cpu'
        self.gpu_ids = args.gpu_ids
        self.is_training = args.is_training

        # Set up generators
        self.g_src = ResNet(args)  # Maps from src to tgt
        self.g_tgt = ResNet(args)  # Maps from tgt to src

        if self.is_training:
            # Set up discriminators
            self.d_tgt = PatchGAN(args)  # Answers Q "is this tgt image real?"
            self.d_src = PatchGAN(args)  # Answers Q "is this src image real?"

            self._data_parallel()

            # Set up loss functions
            self.lambda_src = args.lambda_src  # Weight ratio of loss CYC_SRC:GAN
            self.lambda_tgt = args.lambda_tgt  # Weight ratio of loss CYC_TGT:GAN
            self.lambda_id = args.lambda_id  # Weight ratio of loss ID_{SRC,TGT}:CYC_{SRC,TGT}
            self.l1_loss_fn = nn.L1Loss()
            self.gan_loss_fn = util.GANLoss(device=self.device,
                                            use_least_squares=True)

            # Set up optimizers
            self.opt_g = torch.optim.Adam(chain(self.g_src.parameters(),
                                                self.g_tgt.parameters()),
                                          lr=args.lr,
                                          betas=(args.beta_1, args.beta_2))
            self.opt_d = torch.optim.Adam(chain(self.d_tgt.parameters(),
                                                self.d_src.parameters()),
                                          lr=args.lr,
                                          betas=(args.beta_1, args.beta_2))
            self.optimizers = [self.opt_g, self.opt_d]
            self.schedulers = [
                util.get_lr_scheduler(opt, args) for opt in self.optimizers
            ]

            # Setup image mixers
            buffer_capacity = 50 if args.use_mixer else 0
            self.src2tgt_buffer = util.ImageBuffer(
                buffer_capacity)  # Buffer of generated tgt images
            self.tgt2src_buffer = util.ImageBuffer(
                buffer_capacity)  # Buffer of generated src images

            if args.clamp_jacobian:
                raise NotImplementedError(
                    'Jacobian Clamping not implemented for CycleGAN.')
        else:
            self._data_parallel()

        # Images in cycle src -> tgt -> src
        self.src = None
        self.src2tgt = None
        self.src2tgt2src = None

        # Images in cycle tgt -> src -> tgt
        self.tgt = None
        self.tgt2src = None
        self.tgt2src2tgt = None

        # Discriminator loss
        self.loss_d_tgt = None
        self.loss_d_src = None
        self.loss_d = None

        # Generator GAN loss
        self.loss_gan_src = None
        self.loss_gan_tgt = None
        self.loss_gan = None

        # Generator Identity loss
        self.src2src = None
        self.tgt2tgt = None
        self.loss_id_src = None
        self.loss_id_tgt = None
        self.loss_id = None

        # Generator Cycle loss
        self.loss_cyc_src = None
        self.loss_cyc_tgt = None
        self.loss_cyc = None

        # Generator total loss
        self.loss_g = None
예제 #3
0
def train(dataset='mnist', model_name='d2l', batch_size=128, epochs=50, noise_ratio=0):
    """
    Train one model with data augmentation: random padding+cropping and horizontal flip
    :param dataset: 
    :param model_name:
    :param batch_size: 
    :param epochs: 
    :param noise_ratio: 
    :return: 
    """
    print('Dataset: %s, model: %s, batch: %s, epochs: %s, noise ratio: %s%%' %
          (dataset, model_name, batch_size, epochs, noise_ratio))

    # load data
    X_train, y_train, X_test, y_test = get_data(dataset, noise_ratio, random_shuffle=True)
    # X_train, y_train, X_val, y_val = validatation_split(X_train, y_train, split=0.1)
    n_images = X_train.shape[0]
    image_shape = X_train.shape[1:]
    num_classes = y_train.shape[1]
    print("n_images", n_images, "num_classes", num_classes, "image_shape:", image_shape)

    # load model
    model = get_model(dataset, input_tensor=None, input_shape=image_shape, num_classes=num_classes)
    # model.summary()

    optimizer = SGD(lr=0.01, decay=1e-4, momentum=0.9)

    # for backward, forward loss
    # suppose the model knows noise ratio
    P = uniform_noise_model_P(num_classes, noise_ratio/100.)
    # create loss
    if model_name == 'forward':
        P = uniform_noise_model_P(num_classes, noise_ratio / 100.)
        loss = forward(P)
    elif model_name == 'backward':
        P = uniform_noise_model_P(num_classes, noise_ratio / 100.)
        loss = backward(P)
    elif model_name == 'boot_hard':
        loss = boot_hard
    elif model_name == 'boot_soft':
        loss = boot_soft
    elif model_name == 'd2l':
        loss = lid_paced_loss()
    else:
        loss = cross_entropy

    # model
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=['accuracy']
    )

    ## do real-time updates using callbakcs
    callbacks = []
    if model_name == 'd2l':
        init_epoch = D2L[dataset]['init_epoch']
        epoch_win = D2L[dataset]['epoch_win']
        d2l_learning = D2LCallback(model, X_train, y_train,
                                            dataset, noise_ratio,
                                            epochs=epochs,
                                            pace_type=model_name,
                                            init_epoch=init_epoch,
                                            epoch_win=epoch_win)

        callbacks.append(d2l_learning)

        cp_callback = ModelCheckpoint("model/%s_%s_%s.hdf5" % (model_name, dataset, noise_ratio),
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)

    else:
        cp_callback = ModelCheckpoint("model/%s_%s_%s.hdf5" % (model_name, dataset, noise_ratio),
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=epochs)
        callbacks.append(cp_callback)

    # tensorboard callback
    callbacks.append(TensorBoard(log_dir='./log/log'))

    # learning rate scheduler if use sgd
    lr_scheduler = get_lr_scheduler(dataset)
    callbacks.append(lr_scheduler)

    # acc, loss, lid
    log_callback = LoggerCallback(model, X_train, y_train, X_test, y_test, dataset,
                                  model_name, noise_ratio, epochs)
    callbacks.append(log_callback)

    # data augmentation
    if dataset in ['mnist', 'svhn']:
        datagen = ImageDataGenerator()
    elif dataset in ['cifar-10', 'cifar-100']:
        datagen = ImageDataGenerator(
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True)
    else:
        datagen = ImageDataGenerator(
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True)
    datagen.fit(X_train)

    # train model
    model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                        steps_per_epoch=len(X_train) / batch_size, epochs=epochs,
                        validation_data=(X_test, y_test),
                        verbose=1,
                        callbacks=callbacks
                        )
예제 #4
0
    def __init__(self, args):
        """
        Args:
            args: Configuration args passed in via the command line.
        """
        super(Flow2Flow, self).__init__()
        self.device = 'cuda' if len(args.gpu_ids) > 0 else 'cpu'
        self.gpu_ids = args.gpu_ids
        self.is_training = args.is_training

        self.in_channels = args.num_channels
        self.out_channels = 4 ** (args.num_scales - 1) * self.in_channels

        # Set up RealNVP generators (g_src: X <-> Z, g_tgt: Y <-> Z)
        self.g_src = RealNVP(num_scales=args.num_scales,
                             in_channels=args.num_channels,
                             mid_channels=args.num_channels_g,
                             num_blocks=args.num_blocks,
                             un_normalize_x=True,
                             no_latent=False)
        util.init_model(self.g_src, init_method=args.initializer)
        self.g_tgt = RealNVP(num_scales=args.num_scales,
                             in_channels=args.num_channels,
                             mid_channels=args.num_channels_g,
                             num_blocks=args.num_blocks,
                             un_normalize_x=True,
                             no_latent=False)
        util.init_model(self.g_tgt, init_method=args.initializer)

        if self.is_training:
            # Set up discriminators
            self.d_tgt = PatchGAN(args)  # Answers Q "is this tgt image real?"
            self.d_src = PatchGAN(args)  # Answers Q "is this src image real?"

            self._data_parallel()

            # Set up loss functions
            self.max_grad_norm = args.clip_gradient
            self.lambda_mle = args.lambda_mle
            self.mle_loss_fn = RealNVPLoss()
            self.gan_loss_fn = util.GANLoss(device=self.device, use_least_squares=True)

            self.clamp_jacobian = args.clamp_jacobian
            self.jc_loss_fn = util.JacobianClampingLoss(args.jc_lambda_min, args.jc_lambda_max)

            # Set up optimizers
            g_src_params = util.get_param_groups(self.g_src, args.weight_norm_l2, norm_suffix='weight_g')
            g_tgt_params = util.get_param_groups(self.g_tgt, args.weight_norm_l2, norm_suffix='weight_g')
            self.opt_g = torch.optim.Adam(chain(g_src_params, g_tgt_params),
                                          lr=args.rnvp_lr,
                                          betas=(args.rnvp_beta_1, args.rnvp_beta_2))
            self.opt_d = torch.optim.Adam(chain(self.d_tgt.parameters(), self.d_src.parameters()),
                                          lr=args.lr,
                                          betas=(args.beta_1, args.beta_2))
            self.optimizers = [self.opt_g, self.opt_d]
            self.schedulers = [util.get_lr_scheduler(opt, args) for opt in self.optimizers]

            # Setup image mixers
            buffer_capacity = 50 if args.use_mixer else 0
            self.src2tgt_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated tgt images
            self.tgt2src_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated src images
        else:
            self._data_parallel()

        # Images in flow src -> lat -> tgt
        self.src = None
        self.src2lat = None
        self.src2tgt = None

        # Images in flow tgt -> lat -> src
        self.tgt = None
        self.tgt2lat = None
        self.tgt2src = None

        # Jacobian clamping tensors
        self.src_jc = None
        self.tgt_jc = None
        self.src2tgt_jc = None
        self.tgt2src_jc = None

        # Discriminator loss
        self.loss_d_tgt = None
        self.loss_d_src = None
        self.loss_d = None

        # Generator GAN loss
        self.loss_gan_src = None
        self.loss_gan_tgt = None
        self.loss_gan = None

        # Generator MLE loss
        self.loss_mle_src = None
        self.loss_mle_tgt = None
        self.loss_mle = None

        # Jacobian Clamping loss
        self.loss_jc_src = None
        self.loss_jc_tgt = None
        self.loss_jc = None

        # Generator total loss
        self.loss_g = None