Beispiel #1
0
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay,
          seed, max_norm, text_vocab, text_dim, text_depth, text_heads,
          audio_dim, audio_depth, audio_heads):
    # rng

    rng_key = random.PRNGKey(seed)

    # data

    dataset = PairTextSpectrogramDataset(data_folder)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    collate_fn=pair_text_spectrogram_dataset_collate_fn,
                    drop_last=True,
                    shuffle=True)

    # model

    model = CLAP(text_vocab=text_vocab,
                 text_dim=text_dim,
                 text_depth=text_depth,
                 text_heads=text_heads,
                 audio_dim=audio_dim,
                 audio_depth=audio_depth,
                 audio_heads=audio_heads)

    # optimizer

    exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1,
                                                     params)

    optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4),
                  add_decayed_weights(weight_decay, exclude_bias),
                  scale(-learning_rate))

    # init

    audio, audio_mask, text, text_mask = next(iter(dl))

    params = model.init(rng_key, text, audio, text_mask, audio_mask)
    optim_state = optim.init(params)

    # loss function, for use with value_and_grad

    @jit
    @value_and_grad
    def loss_fn(params, text, audio, text_mask, audio_mask):
        return model.apply(params, text, audio, text_mask, audio_mask)

    # train loop

    for _ in range(epochs):
        for audio, audio_mask, text, text_mask in dl:
            loss, grads = loss_fn(params, text, audio, text_mask, audio_mask)
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)
            print(f'loss: {loss}')
Beispiel #2
0
    def _create_jax_optimizer(self):
        import optax
        process = []
        if isinstance(self.learning_rate, LearningRateSchedule):
            scheduler = self.learning_rate._create_jax_schedule()
            process.append(optax.scale_by_schedule(scheduler))
            last_process = optax.scale(-1.0)
        else:
            lr = self.learning_rate
            last_process = optax.scale(-1.0 * lr)

        process.append(
            optax.scale_by_adam(b1=self.beta1,
                                b2=self.beta2,
                                eps=self.epsilon,
                                eps_root=0.0))
        process.append(optax.add_decayed_weights(self.weight_decay, None))
        process.append(last_process)
        return optax.chain(*process)
Beispiel #3
0
 def update(
         self, gradient: Weights, state: GenericGradientState,
         parameters: Optional[Weights]
 ) -> Tuple[Weights, GenericGradientState]:
     return GenericGradientState.wrap(*add_decayed_weights(
         **asdict(self)).update(gradient, state.data, parameters))
Beispiel #4
0
def init(key, X, lr):
    params, state = forward.init(key, X, True)
    optimizer = optax.chain(optax.scale_by_adam(),
                            optax.add_decayed_weights(0.03), optax.scale(-lr))
    opt_state = optimizer.init(params)
    return params, state, opt_state, optimizer
Beispiel #5
0
 def init(self, parameters: Weights) -> GenericGradientState:
     return GenericGradientState(
         add_decayed_weights(**asdict(self)).init(parameters))