コード例 #1
0
ファイル: example.py プロジェクト: nordmtr/DomainAdaptation
    scheduler = LRSchedulerSGD()
    tr = Trainer(model, loss_DANN)
    tr.fit(train_gen_s,
           train_gen_t,
           n_epochs=dann_config.N_EPOCHS,
           validation_data=[val_gen_s, val_gen_t],
           metrics=[acc],
           steps_per_epoch=dann_config.STEPS_PER_EPOCH,
           val_freq=dann_config.VAL_FREQ,
           opt='sgd',
           opt_kwargs={
               'lr': 0.01,
               'momentum': 0.9
           },
           lr_scheduler=scheduler,
           callbacks=[
               print_callback(watch=[
                   "loss", "domain_loss", "val_loss", "val_domain_loss",
                   'trg_metrics', 'src_metrics'
               ]),
               ModelSaver('DANN', dann_config.SAVE_MODEL_FREQ),
               HistorySaver('log_with_sgd',
                            dann_config.VAL_FREQ,
                            extra_losses={
                                'domain_loss':
                                ['domain_loss', 'val_domain_loss'],
                                'train_domain_loss':
                                ['domain_loss_on_src', 'domain_loss_on_trg']
                            })
           ])
コード例 #2
0
 val_freq=dann_config.VAL_FREQ,
 opt='sgd',
 opt_kwargs={
     'lr': dann_config.LR,
     'momentum': 0.9
 },
 lr_scheduler=scheduler,
 callbacks=[
     print_callback(watch=[
         "loss", "domain_loss", "val_loss", "val_domain_loss",
         'trg_metrics', 'src_metrics'
     ]),
     ModelSaver(
         str(experiment_name + '_' + dann_config.SOURCE_DOMAIN +
             '_' + dann_config.TARGET_DOMAIN + '_' + details_name),
         dann_config.SAVE_MODEL_FREQ,
         save_by_schedule=True,
         save_best=True,
         eval_metric='accuracy'),
     WandbCallback(
         config=dann_config,
         name=str(dann_config.SOURCE_DOMAIN + "_" +
                  dann_config.TARGET_DOMAIN + "_" + details_name),
         group=experiment_name),
     HistorySaver(
         str(experiment_name + '_' + dann_config.SOURCE_DOMAIN +
             '_' + dann_config.TARGET_DOMAIN + "_" + details_name),
         dann_config.VAL_FREQ,
         path=str('_log/' + experiment_name + "_" + details_name),
         extra_losses={
             'domain_loss': ['domain_loss', 'val_domain_loss'],