def fit(self, subpolicies, X, y): which = np.random.randint(len(subpolicies), size=len(X)) for i, subpolicy in enumerate(subpolicies): X[which == i] = subpolicy(X[which == i]) callback = TQDMCallback(leave_inner=False, leave_outer=False) callback.on_train_batch_begin = callback.on_batch_begin callback.on_train_batch_end = callback.on_batch_end self.model.fit(X, y, CHILD_BATCH_SIZE, CHILD_EPOCHS, verbose=0, callbacks=[callback]) return self
def train_model(model, space='K', n=1): print(model.summary(line_length=150)) run_id = f'kikinet_sep_{space}{n}_af{AF}_{int(time.time())}' chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5' print(run_id) chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs // 2) log_dir = op.join('logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False, ) lrate_cback = LearningRateScheduler(learning_rate_from_epoch) tqdm_cb = TQDMCallback(metric_format="{name}: {value:e}") tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end if space == 'K': train_gen = train_gen_k val_gen = val_gen_k elif space == 'I': if n == 2: train_gen = train_gen_last val_gen = val_gen_last elif n == 1: train_gen = train_gen_i val_gen = val_gen_i model.fit_generator( train_gen, steps_per_epoch=n_volumes_train, epochs=n_epochs, validation_data=val_gen, validation_steps=1, verbose=0, callbacks=[ tqdm_cb, tboard_cback, chkpt_cback, lrate_cback, ], # max_queue_size=35, use_multiprocessing=True, workers=35, shuffle=True, ) return model
run_id = f'cascadenet_af{AF}_{int(time.time())}' chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5' print(run_id) chkpt_cback = ModelCheckpoint(chkpt_path, period=100, save_weights_only=True) log_dir = op.join('logs', run_id) tboard_cback = TensorBoard( profile_batch=0, log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False, ) tqdm_cb = TQDMCallback(metric_format="{name}: {value:e}") tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end model = cascade_net(lr=1e-3, **run_params) print(model.summary(line_length=150)) model.fit_generator( train_gen, steps_per_epoch=n_volumes_train, epochs=n_epochs, validation_data=val_gen, validation_steps=1, verbose=0, callbacks=[ tqdm_cb, tboard_cback, chkpt_cback,