Esempio n. 1
0
 def __init__(self,
              preprocessor,
              architecture=Architecture(),
              loss=tversky_loss,
              metrics=[dice_soft],
              learninig_rate=0.0001,
              batch_queue_size=2,
              workers=1,
              multi_gpu=False):
     # Identify data parameters
     self.three_dim = preprocessor.data_io.interface.three_dim
     self.channels = preprocessor.data_io.interface.channels
     self.classes = preprocessor.data_io.interface.classes
     # Cache parameter
     self.preprocessor = preprocessor
     self.loss = loss
     self.metrics = metrics
     self.learninig_rate = learninig_rate
     self.batch_queue_size = batch_queue_size
     self.workers = workers
     # Build model with multiple GPUs (MirroredStrategy)
     if multi_gpu:
         strategy = MirroredStrategy(
             cross_device_ops=HierarchicalCopyAllReduce())
         with strategy.scope():
             self.build_model(architecture)
     # Build model with single GPU
     else:
         self.build_model(architecture)
     # Cache starting weights
     self.initialization_weights = self.model.get_weights()
Esempio n. 2
0
def build_model(params, seq_length):
    """
    Define a Keras graph model with DNA sequence as input.
    Parameters:
        params (class): A class with a set of hyper-parameters
        seq_length (int): Length of input sequences
    Returns
        model (keras model): A keras model
    """
    # MirroredStrategy to employ all available GPUs
    mirrored_strategy = MirroredStrategy()
    with mirrored_strategy.scope():
        seq_input = Input(shape=(
            seq_length,
            4,
        ), name='seq')
        xs = Conv1D(filters=params.n_filters,
                    kernel_size=params.filter_size,
                    activation='relu')(seq_input)
        xs = BatchNormalization()(xs)
        xs = MaxPooling1D(padding="same",
                          strides=params.pooling_stride,
                          pool_size=params.pooling_size)(xs)
        xs = LSTM(32)(xs)
        # adding a specified number of dense layers
        for idx in range(params.dense_layers):
            xs = Dense(params.dense_layer_size, activation='relu')(xs)
            xs = Dropout(params.dropout)(xs)
        result = Dense(1, activation='sigmoid')(xs)
        model = Model(inputs=seq_input, outputs=result)
    return model
def dynamic_LSTM(obj):
    mirrored_strategy = MirroredStrategy()
    with mirrored_strategy.scope():
        METRICS = [
            keras.metrics.CategoricalAccuracy(name='accuracy'),
            keras.metrics.Precision(name='precision'),
            keras.metrics.Recall(name='recall'),
            keras.metrics.AUC(name='auc')
        ]
        model = keras.models.Sequential(name='dynamic_LSTM')
        model.add(
            Embedding(input_dim=obj.num_words,
                      output_dim=128,
                      input_length=None,
                      mask_zero=True))
        model.add(LSTM(units=128, dropout=.2))
        model.add(Dense(units=128, activation='relu'))
        model.add(Dropout(.4))
        model.add(Dense(units=64, activation='relu'))
        model.add(Dropout(.2))
        model.add(Dense(units=21, activation='softmax', dtype='float32'))
        model.compile(optimizer=keras.optimizers.Adam(),
                      loss='categorical_crossentropy',
                      metrics=METRICS)
        model.summary()
    return model
 def __init__(self, latent_dim=64, activation='relu', epochs=100, batch_size=256):
   self.activation = activation
   self.batch_size = batch_size
   self.cc_loss = CategoricalCrossentropy(label_smoothing=0.2, from_logits=True)
   self.epochs = epochs
   self.latent_dim = latent_dim
   self.model = None
   # tf.config.experimental_connect_to_cluster(resolver)
   # tf.tpu.experimental.initialize_tpu_system(resolver)
   # self.strategy = TPUStrategy(resolver)
   self.strategy = MirroredStrategy()
   print('Number of devices: {}'.format(self.strategy.num_replicas_in_sync))
Esempio n. 5
0
def main(args: Namespace) -> None:
    """Run the main program.

    Arguments:
        args: The object containing the commandline arguments
    """
    config = load_config(args.config)

    strategy = MirroredStrategy()
    if config.mixed_precision:
        set_global_policy("mixed_float16")

    train_dataset, test_dataset = get_dataset(args.data_path,
                                              config.gan_batch_size)

    with strategy.scope():
        generator = get_generator(config)
        critic = get_critic(config)

        classifier = Classifier(config)
        ClassifierTrainer.load_weights(classifier, args.load_dir)

    # Save each run into a directory by its timestamp
    log_dir = setup_dirs(
        dirs=[args.save_dir],
        dirs_to_tstamp=[args.log_dir],
        config=config,
        file_name=CONFIG,
    )[0]

    trainer = GANTrainer(
        generator,
        critic,
        classifier,
        strategy,
        train_dataset,
        test_dataset,
        config=config,
        log_dir=log_dir,
        save_dir=args.save_dir,
    )
    trainer.train(
        record_steps=args.record_steps,
        log_graph=args.log_graph,
        save_steps=args.save_steps,
    )
class NN_Classifier:
  def __init__(self, latent_dim=64, activation='relu', epochs=100, batch_size=256):
    self.activation = activation
    self.batch_size = batch_size
    self.cc_loss = CategoricalCrossentropy(label_smoothing=0.2, from_logits=True)
    self.epochs = epochs
    self.latent_dim = latent_dim
    self.model = None
    # tf.config.experimental_connect_to_cluster(resolver)
    # tf.tpu.experimental.initialize_tpu_system(resolver)
    # self.strategy = TPUStrategy(resolver)
    self.strategy = MirroredStrategy()
    print('Number of devices: {}'.format(self.strategy.num_replicas_in_sync))

  def _compile(self, input_dim, num_classes):
    with self.strategy.scope():
      inputs = Input(shape=(input_dim,))
      middle = Dense(self.latent_dim, activation=self.activation)(inputs)
      outputs = Dense(num_classes, activation="softmax")(middle)
      self.model = Model(inputs=inputs, outputs=outputs)

      self.model.compile(
        optimizer=RMSprop(0.01),
        loss=self.cc_loss,
        metrics=['accuracy']
      )

  def fit(self, vec, topics):
    if not self.model:
      self._compile(vec.shape[1], max(topics) + 1)

    labels = tf.keras.utils.to_categorical(topics, max(topics) + 1)
    X_train, X_test, y_train, y_test = train_test_split(vec, labels)

    callbacks = [
      EarlyStopping(monitor='loss', patience=5),
      ModelCheckpoint(filepath='./saved_models/model.{epoch:02d}-{val_loss:.2f}.h5')
    ]

    self.model.fit(
      X_train,
      y_train,
      shuffle=True,
      validation_data=(X_test, y_test),
      batch_size=self.batch_size,
      epochs=self.epochs,
      callbacks=callbacks
    )

  def predict(self, vec):
    return self.model.predict(vec)

  def save(self, filename):
    self.model.save(filename)

  def load(self, filename):
    self.model = keras.models.load_model(filename)
Esempio n. 7
0
def add_new_layers(base_model_path, seq_len, no_of_chromatin_tracks, bin_size):
    """
    Takes a pre-existing M-SEQ (Definition in README) & adds structure to \
    use it as part of a bimodal DNA sequence + prior chromatin network
    Parameters:
        base_model (keras Model): A pre-trained sequence-only (M-SEQ) model
        chrom_size (int) : The expected number of chromatin tracks
    Returns:
        model: a Keras Model
    """

    def permute(x):
        return K.permute_dimensions(x, (0, 2, 1))

    mirrored_strategy = MirroredStrategy()
    with mirrored_strategy.scope():
        # load basemodel
        base_model = load_model(base_model_path)
        # Transfer from a pre-trained M-SEQ
        curr_layer = base_model.get_layer(name='dense_2')
        curr_tensor = curr_layer.output
        xs = Dense(1, name='MSEQ-dense-new', activation='tanh')(curr_tensor)

        # Defining a M-C sub-network
        chrom_input = Input(shape=(no_of_chromatin_tracks * int(seq_len/bin_size),), name='chrom_input')
        ci = Reshape((no_of_chromatin_tracks, int(seq_len/bin_size)),
                    input_shape=(no_of_chromatin_tracks * int(seq_len/bin_size),))(chrom_input)
        # Permuting the input dimensions to match Keras input requirements:
        permute_func = Lambda(permute)
        ci = permute_func(ci)
        xc = Conv1D(15, 1, padding='valid', activation='relu', name='MC-conv1d')(ci)
        xc = LSTM(5, activation='relu', name='MC-lstm')(xc)
        xc = Dense(1, activation='tanh', name='MC-dense')(xc)

        # Concatenating sequence (MSEQ) and chromatin (MC) networks:
        merged_layer = concatenate([xs, xc])
        result = Dense(1, activation='sigmoid', name='MSC-dense')(merged_layer)
        model = Model(inputs=[base_model.input, chrom_input], outputs=result)
    return model, base_model
def seqvector_LSTM(obj):
    mirrored_strategy = MirroredStrategy()
    with mirrored_strategy.scope():
        METRICS = [
            keras.metrics.CategoricalAccuracy(name='accuracy'),
            keras.metrics.Precision(name='precision'),
            keras.metrics.Recall(name='recall'),
            keras.metrics.AUC(name='auc')
        ]
        model = keras.models.Sequential(name='seqvector_LSTM')
        model.add(
            Masking(mask_value=0., input_shape=obj.train_set[0].shape[1:]))
        model.add(Bidirectional(LSTM(units=128, dropout=.2)))
        model.add(Dense(units=128, activation='relu'))
        model.add(Dropout(.4))
        model.add(Dense(units=64, activation='relu'))
        model.add(Dropout(.2))
        model.add(Dense(units=21, activation='softmax', dtype='float32'))
        model.compile(optimizer=keras.optimizers.Adam(),
                      loss='categorical_crossentropy',
                      metrics=METRICS)
        model.summary()
        return model
Esempio n. 9
0
def main(args: Namespace) -> None:
    """Run the main program.

    Arguments:
        args: The object containing the commandline arguments
    """
    config = load_config(args.config)

    strategy = MirroredStrategy()
    if config.mixed_precision:
        set_global_policy("mixed_float16")

    train_dataset, test_dataset = get_dataset(args.data_path,
                                              config.cls_batch_size)

    with strategy.scope():
        model = Classifier(config)

    # Save each run into a directory by its timestamp.
    log_dir = setup_dirs(
        dirs=[args.save_dir],
        dirs_to_tstamp=[args.log_dir],
        config=config,
        file_name=CONFIG,
    )[0]

    trainer = ClassifierTrainer(model, strategy, config=config)
    trainer.train(
        train_dataset,
        test_dataset,
        log_dir=log_dir,
        record_eps=args.record_eps,
        save_dir=args.save_dir,
        save_steps=args.save_steps,
        log_graph=args.log_graph,
    )
Esempio n. 10
0
 def gpu_strategy(self):
     if len(self.gpu) == 1:
         return OneDeviceStrategy(device="/gpu:0")
     else:
         return MirroredStrategy()
Esempio n. 11
0
    def __init__(self, input_shape, **params):

        def convert_kwargs(kwgs):

            if kwgs is None:
                return {}

            kwgs_dict = {}

            for i in range(1, len(kwgs), 2):
                try:
                    kwgs_dict[kwgs[i-1]] = ast.literal_eval(kwgs[i])
                except ValueError:
                    kwgs_dict[kwgs[i-1]] = kwgs[i]
            
            return kwgs_dict

        def load_model_from_path(path, project_name=None, key=None):

            if path[:5] == 'gs://':
                if project_name is None:
                    fs = GCSFileSystem()
                else:
                    fs = GCSFileSystem(project_name)
                file = fs.open(path)
            else:
                file = path

            return load_model(file, custom_objects={'Swish': Swish, 'InstanceNormalization': InstanceNormalization})

        self.strategy = MirroredStrategy()
        print(f'Number of devices detected: {self.strategy.num_replicas_in_sync}')

        self.input_shape = input_shape
        self.latent_dims = params.get('latent_dims', 1000)

        self.key = params.get('key', None)
        self.project_name = params.get('project_name', None)

        if params['d_fp'] is not None:
            self.discriminator = load_model_from_path(params['d_fp'])
        else:
            with self.strategy.scope():
                self.discriminator = self.make_discriminator()
        
        reshape_dims = K.int_shape(self.discriminator.layers[-3].output)[1:]
        
        if params['g_fp'] is not None:
            self.generator = load_model_from_path(params['g_fp'])
        else:
            with self.strategy.scope():
                self.generator = self.make_generator(reshape_dims)
        
        d_lr = params.get('d_lr', 4e-4)
        g_lr = params.get('g_lr', 4e-4)

        with self.strategy.scope():
            self.d_optimizer = getattr(import_module('tensorflow.keras.optimizers'), params['d_opt'])
            self.d_optimizer = self.d_optimizer(d_lr, **convert_kwargs(params['d_opt_params']))
        
        with self.strategy.scope():
            self.g_optimizer = getattr(import_module('tensorflow.keras.optimizers'), params['g_opt'])
            self.g_optimizer = self.g_optimizer(g_lr, **convert_kwargs(params['g_opt_params']))

        if params['print_summaries']:
            print(self.discriminator.summary())
            print(self.generator.summary())
Esempio n. 12
0
class DiscoGAN():
    
    def __init__(self, input_shape, **params):

        def convert_kwargs(kwgs):

            if kwgs is None:
                return {}

            kwgs_dict = {}

            for i in range(1, len(kwgs), 2):
                try:
                    kwgs_dict[kwgs[i-1]] = ast.literal_eval(kwgs[i])
                except ValueError:
                    kwgs_dict[kwgs[i-1]] = kwgs[i]
            
            return kwgs_dict

        def load_model_from_path(path, project_name=None, key=None):

            if path[:5] == 'gs://':
                if project_name is None:
                    fs = GCSFileSystem()
                else:
                    fs = GCSFileSystem(project_name)
                file = fs.open(path)
            else:
                file = path

            return load_model(file, custom_objects={'Swish': Swish, 'InstanceNormalization': InstanceNormalization})

        self.strategy = MirroredStrategy()
        print(f'Number of devices detected: {self.strategy.num_replicas_in_sync}')

        self.input_shape = input_shape
        self.latent_dims = params.get('latent_dims', 1000)

        self.key = params.get('key', None)
        self.project_name = params.get('project_name', None)

        if params['d_fp'] is not None:
            self.discriminator = load_model_from_path(params['d_fp'])
        else:
            with self.strategy.scope():
                self.discriminator = self.make_discriminator()
        
        reshape_dims = K.int_shape(self.discriminator.layers[-3].output)[1:]
        
        if params['g_fp'] is not None:
            self.generator = load_model_from_path(params['g_fp'])
        else:
            with self.strategy.scope():
                self.generator = self.make_generator(reshape_dims)
        
        d_lr = params.get('d_lr', 4e-4)
        g_lr = params.get('g_lr', 4e-4)

        with self.strategy.scope():
            self.d_optimizer = getattr(import_module('tensorflow.keras.optimizers'), params['d_opt'])
            self.d_optimizer = self.d_optimizer(d_lr, **convert_kwargs(params['d_opt_params']))
        
        with self.strategy.scope():
            self.g_optimizer = getattr(import_module('tensorflow.keras.optimizers'), params['g_opt'])
            self.g_optimizer = self.g_optimizer(g_lr, **convert_kwargs(params['g_opt_params']))

        if params['print_summaries']:
            print(self.discriminator.summary())
            print(self.generator.summary())

    def make_generator(self, reshape_dims):

        def se_conv(x, filters, k_size, strides, padding, reg):

            x = Conv2DTranspose(filters, k_size, strides, padding, kernel_regularizer=reg, bias_regularizer=reg, use_bias=True) (x)
            
            shortcut = x
            
            x = Swish(True) (x)
            x = InstanceNormalization() (x)
            
            x = Conv2DTranspose(filters, k_size, strides, padding, kernel_regularizer=reg, bias_regularizer=reg, use_bias=True) (x)
            x = InstanceNormalization() (x)
                
            se = AveragePooling2D(K.int_shape(x)[2]) (x)
            se = Conv2D(min(filters//16, 1), 1) (se)
            se = Swish() (se)
            se = Conv2D(filters, 1, activation='sigmoid') (se)
            
            x = Multiply() ([x, se])
                
            x = Add() ([x, shortcut])
            x = Swish(True) (x)
            
            return x
        
        g_input = Input(self.latent_dims)

        self.g_reg = l2()
        
        x = Dense(np.prod(reshape_dims), kernel_regularizer=self.g_reg, bias_regularizer=self.g_reg) (g_input)
        x = Swish(True) (x)
        
        x = Reshape(reshape_dims) (x)
        
        x = Conv2DTranspose(256, 3, strides=2, padding='same', kernel_regularizer=self.g_reg, bias_regularizer=self.g_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        x = Conv2DTranspose(192, 3, strides=2, padding='same', kernel_regularizer=self.g_reg, bias_regularizer=self.g_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        x = se_conv(x, 128, 3, 1, 'same', self.g_reg)
        x = se_conv(x, 86, 3, 1, 'same', self.g_reg)
        x = se_conv(x, 64, 3, 1, 'same', self.g_reg)

        x = Conv2DTranspose(32, 3, strides=2, padding='same', kernel_regularizer=self.g_reg, bias_regularizer=self.g_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)

        x = Conv2DTranspose(16, 3, strides=2, padding='same', kernel_regularizer=self.g_reg, bias_regularizer=self.g_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        g_output = Conv2D(self.input_shape[-1], 1, padding='same', activation='tanh') (x)

        return Model(g_input, g_output, name='Generator')
    
    def make_discriminator(self):
        
        d_input = Input(self.input_shape)

        self.d_reg = l2()
        
        x = Conv2D(16, 3, padding='same', kernel_regularizer=self.d_reg, bias_regularizer=self.d_reg) (d_input)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        x = Conv2D(32, 3, strides=2, padding='same', kernel_regularizer=self.d_reg, bias_regularizer=self.d_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        x = Conv2D(64, 3, strides=2, padding='same', kernel_regularizer=self.d_reg, bias_regularizer=self.d_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        x = Conv2D(128, 3, strides=2, padding='same', kernel_regularizer=self.d_reg, bias_regularizer=self.d_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)

        x = Conv2D(256, 3, strides=2, padding='same', kernel_regularizer=self.d_reg, bias_regularizer=self.d_reg) (x)
        x = Swish(True) (x)
        x = InstanceNormalization() (x)
        
        
        x = Flatten() (x)
        d_output = Dense(1, activation='sigmoid') (x)
        
        return Model(d_input, d_output, name='Discriminator')
    
    def train(self, data_dir, **hparams):

        print(f'Loading images from {data_dir}')
        X = load_npz(data_dir, self.project_name, self.key)

        epochs = hparams.get('epochs', 1)
        batch_size = hparams.get('batch_size', 128)

        global_batch_size = batch_size * self.strategy.num_replicas_in_sync

        d_reg_C = hparams.get('d_initial_reg', 1e-2)
        g_reg_C = hparams.get('g_inital_reg', 1e-2)

        self.d_reg.l2 = -d_reg_C
        self.g_reg.l2 = g_reg_C

        d_min_reg = hparams.get('d_min_reg', 1e-4)
        g_min_reg = hparams.get('g_min_reg', 1e-4)

        q_max = hparams.get('max_q_size', 25)
        inc = hparams.get('q_update_inc', 10)

        plot_dims = hparams.get('plot_dims', (1, 5))
        plot_dir = hparams.get('plot_dir', '')
        plot_tstep = hparams.get('plot_tstep', 1)

        if plot_dir != '' and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

        X = X.transpose(0, 2, 1, 3) # Flip rows and cols
        X = X/127.5 - 1

        steps, r = divmod(X.shape[0], batch_size)
        steps += 1

        m = steps//q_max
        
        real = np.ones((batch_size, 1))
        real_remainder = np.ones((r, 1))
        fake = np.zeros((batch_size, 1))
        fake_remainder = np.zeros((r, 1))

        full_inds = np.arange(X.shape[0])
        
        epoch = 0
        t = 1
        d_queue = [self.discriminator]
        g_queue = [self.generator]

        gen_r_labels = np.zeros((len(d_queue), 1))

        mean_loss_k = brenth(lambda k: sum([np.e**(-k*x) for x in range(1, steps+1)])-1, 0, 3)

        with self.strategy.scope():

            d_loss_object = BinaryCrossentropy(from_logits=True, reduction=Reduction.None)
            g_loss_object = BinaryCrossentropy(from_logits=True, reduction=Reduction.None)

            d_loss_current = np.inf
            g_loss_current = np.inf
            
            def compute_loss(true, preds, loss_object):
                batch_loss = loss_object(true, preds)
                return compute_average_loss(batch_loss, global_batch_size=global_batch_size)

        def update_queue(queue, var, t, m, K, inc):

            if t == m:
                if len(queue) <= K:
                    del queue[-1]
                queue.append(var)
                m += inc
            else:
                queue[-1] = var
            
            return m, queue

        def disc_train_step(g_queue, noise_size, batch, r_labels, f_labels):

            nonlocal d_loss_current

            noise = np.random.normal(0, 1, (noise_size, self.latent_dims))

            ims_arr = []
            val_arr = []
            for gen in g_queue:
                ims_arr.extend(gen(noise))
                val_arr.extend(f_labels)
            
            ims_arr.extend(batch)
            val_arr.extend(r_labels)

            ims_arr = np.array(ims_arr)
            val_arr = np.array(val_arr)
            shuffle_unison(ims_arr, val_arr)

            with GradientTape() as d_tape:

                preds = self.discriminator(ims_arr)
                d_loss = compute_loss(val_arr, preds, d_loss_object)

            d_loss_current = d_loss

            d_grad = d_tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(zip(d_grad, self.discriminator.trainable_weights))

            return d_loss

        def gen_train_step(d_queue, noise_size, gen_r_labels):

            nonlocal g_loss_current

            with GradientTape() as g_tape:

                gen_ims = self.generator(np.random.normal(0, 1, (noise_size, self.latent_dims)))

                preds = []
                for disc in d_queue:
                    preds.extend(disc(gen_ims))
                preds = K.stack(preds)

                g_loss = compute_loss(gen_r_labels, preds, g_loss_object)

            g_loss_current = g_loss

            g_grad = g_tape.gradient(g_loss, self.generator.trainable_weights)
            self.g_optimizer.apply_gradients(zip(g_grad, self.generator.trainable_weights))
            
            return g_loss

        @tf_func
        def dist_train_step(step_func, args):
            per_replica_losses = self.strategy.run(step_func, args=args)
            return self.strategy.reduce(ReduceOp.SUM, per_replica_losses, axis=None)

        while epoch < epochs:
            
            np.random.shuffle(full_inds)
            
            g_loss_total = 0
            d_loss_total = 0
            
            with trange(steps) as bprogbar:
                for i in bprogbar:

                    bprogbar.set_description(f'Epoch {epoch+1}/{epochs}')
                    
                    if i < steps-1:
                        batch = X[full_inds[i*batch_size:(i+1)*batch_size]]
                        r_labels = real
                        f_labels = fake
                        noise_size = batch_size
                    else:
                        batch = X[full_inds[-r:]]
                        r_labels = real_remainder
                        f_labels = fake_remainder
                        noise_size = r
                    
                    dist_train_step(disc_train_step, (g_queue, noise_size, batch, r_labels, f_labels))
                    dist_train_step(gen_train_step, (d_queue, noise_size, gen_r_labels))

                    m, d_queue = update_queue(d_queue, clone_model(self.discriminator), t, m, q_max, inc)
                    m, g_queue = update_queue(g_queue, clone_model(self.generator), t, m, q_max, inc)

                    if t == m and len(d_queue) < K:
                        gen_r_labels = np.zeros((len(d_queue), 1))

                    bprogbar.set_postfix(d_loss=f'{d_loss_current:.4f}', g_loss=f'{g_loss_current:.4f}')
                    
                    d_loss_total += d_loss_current * np.e**(-mean_loss_k*(steps-i))
                    g_loss_total += g_loss_current * np.e**(-mean_loss_k*(steps-i))

                    t += 1

                    self.d_reg.l2 = -max(d_min_reg, d_reg_C/np.sqrt(t))
                    self.g_reg.l2 = max(g_min_reg, g_reg_C/np.sqrt(t))
                    
            epoch += 1
            
            print(f'Timestep: {t}; Average D Loss: {d_loss_total/(steps):.4f}, Average G Loss: {g_loss_total/(steps):.4f}')
            if not (epoch+1) % plot_tstep:
                self.plot_ims(plot_dims, epoch, plot_dir)
                        
    def plot_ims(self, plot_dims, epoch, save_dir='', project_name=None, key=None):
        
        r, c = plot_dims
        
        noise = np.random.normal(0, 1, (r*c, self.latent_dims))
        gen_ims = self.generator.predict(noise)
        gen_ims = np.uint8(np.transpose((gen_ims + 1) * 127.5, (0, 2, 1, 3)))
        
        fig, axs = plt.subplots(r, c)
        
        two_d = c > 1 and r > 1
        
        for i in range(r*c):
            
            if two_d:
                ax = axs[i//r][i%c]
            else:
                ax = axs[i]
                
            ax.imshow(cv2.cvtColor(gen_ims[i], cv2.COLOR_BGR2RGB))
            ax.axis('off')

        if save_dir != '':

            if save_dir[:5] == 'gs://':
            
                bucket, path = save_dir[5:].split('/', 1)
                
                client = storage.Client(credentials=key)
                bucket = client.bucket(bucket, project_name)
                blob = bucket.blob(path)

                fig.savefig(f'epoch_{epoch}.png')

                with file_io.FileIO(f'epoch_{epoch}.png') as png:
                    blob.upload_from_file(png)

            else:   
            
                fig.savefig(os.path.join(save_dir, f'epoch_{epoch}.png'))

        plt.show(block=False)
        plt.pause(3)
        plt.close()

            
    def save_models(self, save_dir, project_name=None, key=None):
        
        if save_dir[:5] == 'gs://':
        
            bucket, path = save_dir[5:].split('/', 1)
            
            client = storage.Client(credentials=key)
            bucket = client.bucket(bucket, project_name)
            blob = bucket.blob(path)

            self.discriminator.save('discriminator.h5')
            self.generator.save('generator.h5')

            with file_io.FileIO('discriminator.h5') as d_h5, file_io.FileIO('generator.h5') as g_h5:
                blob.upload_from_file(d_h5)
                blob.upload_from_file(g_h5)

        else:

            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)
        
            self.discriminator.save(os.path.join(save_dir, 'discriminator.h5'))
            self.generator.save(os.path.join(save_dir, 'generator.h5'))
Esempio n. 13
0
def main(argv):
    print_versions()
    args = argument_parser('train').parse_args(argv[1:])

    args.train_data = args.train_data.split(',')
    if args.checkpoint_steps is not None:
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    strategy = MirroredStrategy()
    num_devices = strategy.num_replicas_in_sync
    # Batch datasets with global batch size (local * GPUs)
    global_batch_size = args.batch_size * num_devices

    tokenizer = get_tokenizer(args)

    label_list = load_labels(args.labels)
    label_map = { l: i for i, l in enumerate(label_list) }
    inv_label_map = { v: k for k, v in label_map.items() }

    if args.task_name not in (["NER","RE"]):
        raise ValueError("Task not found: {}".format(args.task_name))

    if args.train_data[0].endswith('.tsv'):
        if len(args.train_data) > 1:
            raise NotImplementedError('Multiple TSV inputs')

        train_data = TsvSequence(args.train_data[0], tokenizer, label_map,
                                global_batch_size, args)
        input_format = 'tsv'
    elif args.train_data[0].endswith('.tfrecord'):
        train_data = train_tfrecord_input(args.train_data, args.max_seq_length,
                                          global_batch_size)
        input_format = 'tfrecord'
    else:
        raise ValueError('--train_data must be .tsv or .tfrecord')

    if args.dev_data is None:
        dev_x, dev_y = None, None
        validation_data = None
    else:
        dev_x, dev_y = load_dataset(args.dev_data, tokenizer,
                                    args.max_seq_length,
                                    label_map, args)
        validation_data = (dev_x, dev_y)

    print('Number of devices: {}'.format(num_devices), file=sys.stderr, 
          flush=True)
    if num_devices > 1 and input_format != 'tfrecord':
        warning('TFRecord input recommended for multi-device training')

    num_train_examples = num_examples(args.train_data)
    num_labels = len(label_list)
    print('num_train_examples: {}'.format(num_train_examples),
          file=sys.stderr, flush=True)

    with strategy.scope():
        model = restore_or_create_model(num_train_examples, num_labels, 
                                        global_batch_size, args)
    model.summary(print_fn=print)

    callbacks = []
    if args.checkpoint_steps is not None:
        callbacks.append(ModelCheckpoint(
            filepath=os.path.join(args.checkpoint_dir, CHECKPOINT_NAME),
            save_freq=args.checkpoint_steps
        ))
        callbacks.append(DeleteOldCheckpoints(
            args.checkpoint_dir, CHECKPOINT_NAME, args.max_checkpoints
        ))

    if input_format == 'tsv':
        other_args = {
            'workers': 10,    # TODO
        }
    else:
        assert input_format == 'tfrecord', 'internal error'
        steps_per_epoch = int(np.ceil(num_train_examples/global_batch_size))
        other_args = {
            'steps_per_epoch': steps_per_epoch
        }

    model.fit(
        train_data,
        epochs=args.num_train_epochs,
        callbacks=callbacks,
        validation_data=validation_data,
        validation_batch_size=global_batch_size,
        **other_args
    )

    if validation_data is not None:
        probs = model.predict(dev_x, batch_size=global_batch_size)
        preds = np.argmax(probs, axis=-1)
        correct, total = sum(g==p for g, p in zip(dev_y, preds)), len(dev_y)
        print('Final dev accuracy: {:.1%} ({}/{})'.format(
            correct/total, correct, total))

    if args.model_dir is not None:
        print('Saving model in {}'.format(args.model_dir))
        save_model_etc(model, tokenizer, label_list, args)
    
    return 0
Esempio n. 14
0
val_labels = le.transform(data['val']['labels']).tolist()
test_labels = le.transform(data['test']['labels']).tolist()

# print number of files in each split to log
print('Num Train Files:', len(train_files))
print('Num Val Files:', len(val_files))
print('Num Test Files:', len(test_files))

# # print number of files in each split to log
# print('Num Train Labels:', len(train_labels))
# print('Num Val Labels:', len(val_labels))
# print('Num Test Labels:', len(test_labels))

#
# Parallelize for multiple gpus
strategy = MirroredStrategy()

# get number of gpus (replicas) for batch_size calculation
n_gpus = strategy.num_replicas_in_sync
print('Running ', n_gpus, 'replicas in sync')

# set an azure tag for n_gpus
if args.online:
    run.tag('gpus', n_gpus)

# Batch size for generators
BATCH_SIZE = 16

#####################
## Create Datasets ##
#####################
Esempio n. 15
0

def main(multi_gpu=True):
    if args.viz:
        viz()
    elif args.mode == 'eval-generator':
        evaluate_generator()
    elif args.mode == 'eval-discriminator':
        evaluate_discriminator()
    elif args.mode == 'train-generator':
        if args.gan:
            train_gan()
        else:
            train_generator(batch_size=12 if num_gpus > 1 else 10)
    elif args.mode == 'train-discriminator':
        train_discriminator(batch_size=20)
    elif args.mode == 'train-paired-discriminator':
        train_paired_discriminator()
    elif args.mode == 'batch-eval':
        batch_eval()
    else:
        print('Command not recognized.')


if __name__ == '__main__':
    if args.multi_gpu:
        with MirroredStrategy().scope():
            main()
    else:
        main(multi_gpu=False)
Esempio n. 16
0
# print number of files in each split to log
print('Num Train Files:', len(train_files))
print('Num Val Files:', len(val_files))
print('Num Test Files:', len(test_files))

# print number of files in each split to log
print('Num Train Labels:', len(train_labels))
print('Num Val Labels:', len(val_labels))
print('Num Test Labels:', len(test_labels))

# get class names
classes = data['mapping']
n_classes = len(classes)

# Parallelize for multiple gpus
strategy = MirroredStrategy()

# get number of gpus (replicas) for batch_size calculation
n_gpus = strategy.num_replicas_in_sync
print('Running ', n_gpus, 'replicas in sync')

# set an azure tag for n_gpus
if args.online:
    run.tag('gpus', n_gpus)

# Batch size for generators
BATCH_SIZE = 32

# Choose DataGenerator or AugDataGenerator
if args.augment_position or args.augment_pitch or args.augment_stretch:
    print('Creating train AugDataGenerator wtih pitch', args.augment_pitch,