示例#1
0
    def create_model(self, input_dim, autoenc_model):

        pre_trained_autoenc = keras.models.load_model(
            f'hotexamples_com/models/trained_models/{autoenc_model}')

        left_autoencoder = Autoencoder(input_dim)
        left_input = left_autoencoder.model.input
        left_autoencoder.model.set_weights(pre_trained_autoenc.get_weights())
        # Share weights for left and right autoencoder
        right_autoencoder = Autoencoder(input_dim, name='right')
        right_input = right_autoencoder.model.input
        right_autoencoder.model.set_weights(pre_trained_autoenc.get_weights())
        left_embed_layer = left_autoencoder.model.layers[-2].output
        right_embed_layer = right_autoencoder.model.layers[-2].output
        merge_layer = keras.layers.Concatenate()(
            [left_embed_layer, right_embed_layer])
        dnn_layer = keras.layers.Dense(100, activation='relu')(merge_layer)
        dnn_layer = keras.layers.Dense(100, activation='relu')(dnn_layer)
        output = keras.layers.Dense(2, activation='softmax')(dnn_layer)

        model = keras.models.Model(inputs=[left_input, right_input],
                                   outputs=output)
        print(model.summary())

        model.compile(optimizer=keras.optimizers.Adam(),
                      loss=keras.losses.BinaryCrossentropy(),
                      metrics=['accuracy'])
        return model
示例#2
0
    def __init__(self, reconstruction_loss_factor: float, cycle_loss_factor: float):
        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()
        self.reconstruction_loss_factor = reconstruction_loss_factor
        self.cycle_loss_factor = cycle_loss_factor

        self.optimizer = None
        self.scheduler = None
示例#3
0
    def __init__(self):
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()

        self.optimizer_day = None
        self.optimizer_night = None
        self.scheduler_day = None
        self.scheduler_night = None
示例#4
0
    def __init__(self, params: dict):
        self.params = params

        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())

        self.reconst_loss = nn.L1Loss()

        self.optimizer = None
        self.scheduler = None
def load_pretrained_model(args):
    """
    Load the pretrained model
    :param args: Command line arguments passed to this file, including class of pretrained model
    :return: A pretrained model object
    """
    pretrained_ckpt = torch.load(args.pretrained_model_path, map_location=config.device)
    if args.encoder_pruned:
        pretrained_state_dict = pretrained_ckpt
    else:
        pretrained_state_dict = pretrained_ckpt['state_dict']

    # Create pretrained model object
    # For 'framework' concept, later more model selection
    if args.model is Model.VGG.value:
        pretrained_model = VGGNet(hparams=hparams)
    elif args.model is Model.Autoencoder.value:
        if args.encoder_pruned:
            # Pretrained model's encoder is already pruned
            params = config.compute_pruned_autoencoder_params(config.remove_ratio, encoder=True, decoder=False)
        else:
            # Pretrained model is unpruned
            params = config.autoencoder_params

        pretrained_model = Autoencoder(hparams=hparams, model_params=params)

    pretrained_model.load_state_dict(pretrained_state_dict)
    return pretrained_model
示例#6
0
 def __init__(self, *args):
     super().__init__(*args)
     self.model = Autoencoder(encoder=self.encode,
                              shape=self.shape,
                              beta=1.)
     if len(self.prior):
         self.model.fit(*zip(*self.prior.items()),
                        epochs=initial_epochs)
示例#7
0
def train_heinsfeld_autoencoder():
    X, y = read_data('rois_cc200')
    clf = Autoencoder(num_classes=2,
                      dropout=(0.6, 0.8),
                      learning_rate=(0.0001, 0.0001, 0.0005),
                      momentum=0.9,
                      noise=(0.2, 0.3),
                      batch_size=(100, 10, 10),
                      num_epochs=(700, 2000, 100))
    clf.train(X, y)
示例#8
0
    def test_can_decode_basic_more_layers(self):
        embed_model = PretrainedTransformerGenerator(self.args)   
        gru_decoder = GRUDecoder(embed_model.config.d_model, self.tokenizer.vocab_size, embed_model.config.d_model, n_layers=4, dropout=.2)
        autoencoder = Autoencoder(embed_model, gru_decoder, "cpu:0").to("cpu:0")
        output = autoencoder(self.input, self.target)

        criterion = nn.CrossEntropyLoss(ignore_index=0)
        output = output.permute((1, 2, 0)) # swap for loss
        loss = criterion(output, self.target)
        assert type(loss.item()) == float, "could not get loss value: type {}".format(type(loss.item()))
示例#9
0
def predict_heinsfeld_autoencoder():
    trn_x, trn_y = read_data('rois_cc200')
    tst_x, tst_y = read_data('rois_cc200', training=False)
    clf = Autoencoder(num_classes=2,
                      dropout=(0.6, 0.8),
                      learning_rate=(0.0001, 0.0001, 0.0005),
                      momentum=0.9,
                      noise=(0.2, 0.3),
                      batch_size=(100, 10, 10),
                      num_epochs=(700, 2000, 100))
    clf.predict(trn_x, trn_y, tst_x, tst_y)
示例#10
0
    def test_decode_to_text(self):
        embed_model = PretrainedTransformerGenerator(self.args) 
        gru_decoder = GRUDecoder(embed_model.config.d_model, self.tokenizer.vocab_size, embed_model.config.d_model, n_layers=4, dropout=.2)
        autoencoder = Autoencoder(embed_model, gru_decoder, "cpu:0").to("cpu:0")
        output = autoencoder(self.input, self.target)

        # get text output
        _, best_guess = torch.max(output, dim=2)
        predicted = self.tokenizer.convert_ids_to_tokens(best_guess.permute(1, 0).flatten().tolist())
        string_pred = " ".join(predicted)
        print('Predicted: ', string_pred)
        print('Actual: ', self.tokenizer.convert_ids_to_tokens(self.input.flatten().tolist()))
        assert type(string_pred) == str, "predicted value was not a string, was a {}".format(type(string_pred))
示例#11
0
 def test_autoencoder_training(self):
     # create a one sample dataframe for the test
     train_ds = TensorDataset(self.input.squeeze(0), self.input.squeeze(0)) # need two dimensional (1, 7) shape for input
     train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)
     # create models
     embed_model = PretrainedTransformerGenerator(self.args) 
     decoder = GRUDecoder(embed_model.config.d_model, self.tokenizer.vocab_size, embed_model.config.d_model, n_layers=1, dropout=0)
     decoder = decoder.to(self.args.device) # warning: device is cpu for CI, slow
     autoencoder = Autoencoder(embed_model, decoder, self.args.device, tokenizer=self.tokenizer).to(self.args.device)
     # create needed params
     autoencoder_optimizer = optim.Adam(autoencoder.parameters(), lr=3e-4)
     criterion = nn.CrossEntropyLoss(ignore_index=0)
     loss_df = pd.DataFrame(columns=['batch_num', 'loss'])
     # see if it works
     autoencoder = train_autoencoder(self.args, autoencoder, train_dl, train_dl, autoencoder_optimizer, criterion, 1, loss_df, num_epochs=2)
 def __init__(self,
              encoder,
              dim,
              shape,
              beta=0.,
              alpha=5e-4,
              zeta=1e-2,
              lam=1e-6,
              mu=0.5,
              itr=200,
              M=1000,
              eps=1e-4,
              minibatch=100,
              gpbatch=2000):
     '''encoder: convert sequences to one-hot arrays.
     alpha: embedding learning rate.
     zeta: induced point ascent learning rate
     shape: sequence shape (len, channels).
     beta: embedding score weighting.
     dim: embedding dimensionality.
     lam: l2 regularization constant.
     mu: GP prior mean.
     M: max number of induced points.
     itr: gradient ascent iterations for induced pseudo-inputs.
     eps: numerical stability
     '''
     super().__init__()
     self.X, self.Y = (), ()
     self.minibatch = gpbatch
     self.embed = Autoencoder(encoder,
                              dim=dim,
                              alpha=alpha,
                              shape=shape,
                              lam=lam,
                              beta=beta,
                              minibatch=minibatch)
     self.mu = mu
     self.dim = dim
     self.alpha = alpha
     self.itr = itr
     self.eps = eps
     self.M = M
     self.zeta = zeta
示例#13
0
def main(unused_argv):
    # load test images
    test_list = list_image(FLAGS.test_folder)
    # load model
    assert (FLAGS.snapshot_dir != ""
            or FLAGS.model_fname != ""), 'No pretrained model specified'
    model = Autoencoder(cfgs.patch_size * cfgs.patch_size, cfgs, log_dir=None)
    snapshot_fname = FLAGS.model_fname if FLAGS.model_fname != "" \
        else tf.train.latest_checkpoint(FLAGS.snapshot_dir)
    model.restore(snapshot_fname)
    print('Restored from %s' % snapshot_fname)
    sum_psnr = 0.0
    stride = FLAGS.stride
    for img_fname in test_list:
        orig_img = load_image('%s/%s' % (FLAGS.test_folder, img_fname))
        # pre-process image
        gray_img = toGrayscale(orig_img)
        img = gray_img.astype(np.float32)
        img -= cfgs.mean_value
        img *= cfgs.scale
        # make measurement and reconstruct image
        recon_img = overlap_inference(model,
                                      img,
                                      bs=cfgs.batch_size,
                                      stride=stride)
        recon_img /= cfgs.scale
        recon_img += cfgs.mean_value
        # save reconstruction
        cv.imwrite(
            '%s/%sOI_%d_%s' %
            (FLAGS.reconstruction_folder, FLAGS.prefix, stride, img_fname),
            recon_img.astype(np.uint8))
        psnr_ = psnr(gray_img.astype(np.float32), recon_img)
        print('Image %s, psnr: %f' % (img_fname, psnr_))
        sum_psnr += psnr_
    mean_psnr = sum_psnr / len(test_list)

    print('---------------------------')
    print('Mean PSNR: %f' % mean_psnr)
示例#14
0
    def __init__(self, config):
        """
        Construct a new GAN trainer
        :param Config config: The parsed network configuration.
        """
        self.config = config

        LOG.info("CUDA version: {0}".format(version.cuda))
        LOG.info("Creating data loader from path {0}".format(config.FILENAME))

        self.data_loader = Data(
            config.FILENAME,
            config.BATCH_SIZE,
            polarisations=config.POLARISATIONS,  # Polarisations to use
            frequencies=config.FREQUENCIES,  # Frequencies to use
            max_inputs=config.
            MAX_SAMPLES,  # Max inputs per polarisation and frequency
            normalise=config.NORMALISE)  # Normalise inputs

        shape = self.data_loader.get_input_shape()
        width = shape[1]
        LOG.info("Creating models with input shape {0}".format(shape))
        self._autoencoder = Autoencoder(width)
        self._discriminator = Discriminator(width)
        # TODO: Get correct input and output widths for generator
        self._generator = Generator(width, width)

        if config.USE_CUDA:
            LOG.info("Using CUDA")
            self.autoencoder = self._autoencoder.cuda()
            self.discriminator = self._discriminator.cuda()
            self.generator = self._generator.cuda()
        else:
            LOG.info("Using CPU")
            self.autoencoder = self._autoencoder
            self.discriminator = self._discriminator
            self.generator = self._generator
def train_autoencoder():
    # Instantiate the model
    # Normalization function: We will scale the input image pixels within 0-1 range by dividing all input value by 255.
    autoencoder_definition = Autoencoder(input_dim = features, num_output_classes = num_output_classes, transformation = normalization)
    autoencoder_model = autoencoder_definition.create_autoencoder()
    
    reader_train = create_reader(train_file, True, input_dim, num_output_classes)

    # Train Autoencoder
    # Map the data streams to the input.
    # Instantiate the loss and error function.
    loss_function = mse(autoencoder_model, normalization(features))
    error_function = mse(autoencoder_model, normalization(features))

    input_map={
        features : reader_train.streams.features
    }
    
    train(reader=reader_train, model=autoencoder_model, loss_function=loss_function, error_function=error_function, input_map=input_map,
          num_sweeps_to_train_with = 100, num_samples_per_sweep = 2000, minibatch_size = 10, learning_rate = 0.02)

    autoencoder_model.save('autoencoder.model')

    return autoencoder_definition
示例#16
0
def main(unused_argv):
    val_losses = []
    assert FLAGS.output_dir, "--output_dir is required"
    # Create training directory.
    output_dir = FLAGS.output_dir
    if not tf.gfile.IsDirectory(output_dir):
        tf.gfile.MakeDirs(output_dir)

    dl = DataLoader(FLAGS.db_fname, mean=cfgs.mean_value, scale=cfgs.scale, n_vals=FLAGS.n_vals)
    dl.prepare()

    x_dim = dl.get_data_dim()
    model = Autoencoder(x_dim, cfgs, log_dir=FLAGS.log_dir)
    model.quantize_weights()

    txt_log_fname = FLAGS.log_dir + 'text_log.txt'
    log_fout = open(txt_log_fname, 'w')

    if FLAGS.pretrained_fname:
        try:
            log_train(log_fout, 'Resume from %s' %(FLAGS.pretrained_fname))
            model.restore(FLAGS.pretrained_fname)
        except:
            log_train(log_fout, 'Cannot restore from %s' %(FLAGS.pretrained_fname))
            pass
    
    lr = cfgs.initial_lr
    epoch_counter = 0
    ite = 0
    while True:
        start = time.time()
        x, flag = dl.next_batch(cfgs.batch_size, 'train')
        load_data_time = time.time() - start
        if flag: 
            epoch_counter += 1
        
        do_log = (ite % FLAGS.log_every_n_steps == 0) or flag
        do_snapshot = flag and epoch_counter > 0 and epoch_counter % FLAGS.save_every_n_epochs == 0
        val_loss = -1

        # train one step
        start = time.time()
        loss, _, summary, ite = model.partial_fit(x, lr, do_log)
        one_iter_time = time.time() - start
        
        # writing outs
        if do_log:
            log_train(log_fout, 'Iteration %d, (lr=%f) training loss  : %f' %(ite, lr, loss))
            if FLAGS.log_time:
                log_train(log_fout, 'Iteration %d, data loading: %f(s) ; one iteration: %f(s)' 
                    %(ite, load_data_time, one_iter_time))
            model.log(summary)
        if flag:
            val_loss = val(model, dl)
            val_losses.append(val_loss)
            log_train(log_fout, '----------------------------------------------------')
            if ite == 0:
                log_train(log_fout, 'Initial validation loss: %f' %(val_loss))
            else:
                log_train(log_fout, 'Epoch %d, validation loss: %f' %(epoch_counter, val_loss))
            log_train(log_fout, '----------------------------------------------------')
            model.log(summary)
        if do_snapshot:
            log_train(log_fout, 'Snapshotting')
            model.save(FLAGS.output_dir)
        
        if flag: 
            if cfgs.lr_update == 'val' and len(val_losses) >= 5 and val_loss >= max(val_losses[-5:-1]):
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            elif cfgs.lr_update == 'step' and epoch_counter % cfgs.num_epochs_per_decay == 0:
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            if epoch_counter == FLAGS.n_epochs:
                if not do_snapshot:
                    log_train(log_fout, 'Final snapshotting')
                    model.save(FLAGS.output_dir)
                break
    log_fout.close()
示例#17
0
 def __init__(self, *args):
     super().__init__(*args)
     self.model = Autoencoder(encoder=self.encode,
                              shape=self.shape,
                              beta=1.)
示例#18
0

if __name__ == '__main__':
    torch.manual_seed(0)

    data = read("./dataset-1")

    features = np.array([
        np.append(np.unpackbits(state), i % 2) for game in data
        for i, state in enumerate(game[0])
    ])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f'Using {device}')

    model = Autoencoder().to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

    dataset_size = features.shape[0]
    train_size = int(dataset_size * 0.95)
    test_size = dataset_size - train_size
    batch_size = 256

    train_tensor = torch.Tensor(features[:train_size])
    test_tensor = torch.Tensor(features[train_size:])

    train_dataloader = DataLoader(TensorDataset(train_tensor, train_tensor),
                                  batch_size=batch_size)
    test_dataloader = DataLoader(TensorDataset(test_tensor, test_tensor),
z = np.random.normal(size=(dim_z, 1))
samples = subset3.sample(z, noise=True)

"""
N = 3
projection = MatrixProjection(dim_z, dim_input, N, noise_variance=1e-2)
randomprojection = RandomMatrixProjection(dim_z, dim_input, N, noise_variance=1e-2)
"""



#### PyTorch
#%% PyTorch


net = Autoencoder(dim_input, 10, 3, 10, dim_input)
lr = 0.0001
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
n_iter = int(1e3)
n_epoch = 20
idx_in, idx_target = 0, 0

#launcher_predict(net, subset3, criterion, optimizer, n_epoch, n_iter, (idx_in, idx_target), plot=True)

 #%% Pytorch auto encoder multiple modalities



encoder_layer_1_size = 3
encoder_layer_2_size = 3
示例#20
0
    # Setup environment and environment state builder
    state_builder = RoomsStateBuilder(width=opt.width,
                                      height=opt.height,
                                      grayscale=opt.grayscale)
    env = RoomsEnvironment(action_space=action_space,
                           mission_name=mission_name,
                           mission_xml=mission_xml,
                           remotes=clients,
                           state_builder=state_builder,
                           role=e,
                           recording_path=None)

    vqae = None
    if args.use_vqae:
        # Initialize VQAE
        vqae = Autoencoder(plot_class=plot_class)

    # Initialize processor
    processor = MalmoProcessor(autoencoder=vqae,
                               plot_class=plot_class,
                               action_space=action_space)

    if args.agent_type == 'Random':
        # Use Random agent to train the VQAE
        agent = RandomAgent(num_actions=env.available_actions,
                            processor=processor)
    elif args.agent_type == 'DDQN':
        # Setup exploration policy
        policy = LinearAnnealedPolicy(EpsGreedyQPolicy(),
                                      attr='eps',
                                      value_max=opt.eps_value_max,
示例#21
0
        'train_original': args.train_original,
        'lr': config.lr,
        'map_location': map_location
    }

    # Load models
    if args.model is Model.VGG.value:
        smaller_checkpoint = torch.load(args.smaller_model_path, map_location=torch.device(hparams['map_location']))
        smaller_model = VGGNet(hparams=hparams, model_params=config.compute_pruned_vgg_params(config.remove_ratio))
        smaller_model.load_state_dict(smaller_checkpoint)

        original_checkpoint = torch.load(args.original_model_path, map_location=torch.device(hparams['map_location']))
        original_model = VGGNet(hparams=hparams)
        original_model.load_state_dict(original_checkpoint['state_dict'])
    elif args.model is Model.Autoencoder.value:
        hparams['type'] = 'rgb'
        hparams['prune_encoder'] = True
        hparams['prune_decoder'] = True
        smaller_checkpoint = torch.load(args.smaller_model_path, map_location=torch.device(hparams['map_location']))
        smaller_model = Autoencoder(hparams=hparams, model_params=config.compute_pruned_autoencoder_params(config.remove_ratio, True, True))
        smaller_model.load_state_dict(smaller_checkpoint)

        original_checkpoint = torch.load(args.original_model_path, map_location=torch.device(hparams['map_location']))
        original_model = Autoencoder(hparams=hparams)
        original_model.load_state_dict(original_checkpoint['state_dict'])

    csv_path = os.path.join(config.data_dir, args.model)
    if not os.path.exists(csv_path):
        os.mkdir(csv_path)
    compare(original_model, smaller_model, csv_path, args.use_gpu)
def create_model(dataset):
    final_model = Autoencoder(intermediate_dim=512,
                              original_dim=9408,
                              dataset=dataset)

    return final_model
示例#23
0
    sampled_text = gen.sample_text(2)
    print("Starting off, sampled text is ", sampled_text)

    # TODO: clean up the hardcoded amount
    if args.gen_model_type in ["gpt2", "ctrl"]:
        decoder = GRUDecoder(gen.config.n_embd, gen_tokenizer.vocab_size,
                             args.decoder_hidden, args.decoder_layers,
                             args.decoder_dropout)  # FOR GPT2
    else:
        decoder = GRUDecoder(gen.config.d_model, gen_tokenizer.vocab_size,
                             args.decoder_hidden, args.decoder_layers,
                             args.decoder_dropout)

    autoencoder = Autoencoder(gen,
                              decoder,
                              args.device,
                              tokenizer=gen_tokenizer,
                              model_type=args.gen_model_type)

    if args.record_run:
        wandb.init(project="humorgan", config=args, dir="~/wandb")
        wandb.watch((gen, dis))

    # prepare main datasets
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    print("Batch size is {}".format(args.train_batch_size))

    if args.gen_model_type in ["gpt2", "ctrl"]:
        # don't need all types of input as the other models
        train_dataset = load_and_cache_examples_generator(args,
                                                          gen_tokenizer,
示例#24
0
if __name__ == '__main__':
    hparams = {
        'in_channels': 3,
        'train_batch_size': config.train_batch_size,
        'val_batch_size': config.val_batch_size,
        'lr': config.lr,
        'type': 'rgb',
        'train_original': True,
        'map_location': "cuda:0" if torch.cuda.is_available() else "cpu"
    }

    checkpoint = os.path.join(
        config.ckpts_dir, 'Autoencoder/original_epoch=437-val_loss=1.40.ckpt')
    pretrained_ckpt = torch.load(checkpoint, map_location=config.device)
    pretrained_state_dict = pretrained_ckpt['state_dict']
    pretrained_model = Autoencoder(hparams=hparams)
    pretrained_model.load_state_dict(pretrained_state_dict)

    list_of_files = glob.glob(
        os.path.join(config.ckpts_dir,
                     'Autoencoder/Step/pruned-train-original/*.ckpt'))
    for i, path in enumerate(list_of_files):
        print(i)
        print(path)
        alpha = re.search('=(.+?)_', path)
        if alpha:
            alpha = alpha.group(1)
        print("Alpha: " + alpha)
        hparams['alpha'] = float(alpha)

        test_model = PrunedModel(hparams=hparams,
    #     threads    = 4)

    dataset = MNISTDataSet('../MNIST_data',
        batch_size = 96)

    test_dataset = MNISTDataSet('../MNIST_data',
        batch_size = 96)

    network = Autoencoder(
        sess = sess,
        n_classes = 2,
        zed_dim = 8,
        n_kernels = 16,
        bayesian = False,
        dataset = dataset,
        input_channel = 1,
        log_dir = log_dir,
        variational = True,
        save_dir = save_dir,
        input_dims = [28,28],
        load_snapshot = False,
        learning_rate = 1e-3,
        encoder_type = 'small'
        test_dataset = test_dataset,
        adversarial_training = True)

    ## Has to come after init_op ???
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

示例#26
0
    }

    tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(
        config.logs_dir, args.model),
                                             name='original')
    early_stopping = EarlyStopping(monitor='val_loss',
                                   patience=config.patience,
                                   mode='min')
    checkpoint_callback = ModelCheckpoint(filepath=os.path.join(
        config.ckpts_dir, args.model, 'original_{epoch:02d}-{val_loss:.2f}'),
                                          monitor='val_loss',
                                          mode='min')

    # Create model object to train
    if args.model is Model.VGG.value:
        model = VGGNet(hparams=hparams)
    elif args.model is Model.Autoencoder.value:
        hparams['type'] = 'rgb'
        model = Autoencoder(hparams=hparams)

    model.apply(utils.weights_init)
    model = model.to(torch.device(hparams['map_location']))
    utils.checkParams(model)

    trainer = pl.Trainer(logger=tb_logger,
                         gpus=1,
                         max_epochs=config.max_epochs,
                         callbacks=[checkpoint_callback])

    trainer.fit(model)