コード例 #1
0
def gtrobustlosstrain(parse, config: ConfigParser):
    dataset_name = config['name'].split('_')[0]
    lr_scheduler_name = config['lr_scheduler']['type']
    loss_fn_name = config['train_loss']['type']

    wandb_run_name_list = []

    if parse.distillation:
        if parse.distill_mode == 'eigen':
            wandb_run_name_list.append('distil')
        else:
            wandb_run_name_list.append('kmeans')
    else:
        wandb_run_name_list.append('baseline')
    wandb_run_name_list.append(dataset_name)
    wandb_run_name_list.append(lr_scheduler_name)
    wandb_run_name_list.append(loss_fn_name)
    wandb_run_name_list.append(str(config['trainer']['asym']))
    wandb_run_name_list.append(str(config['trainer']['percent']))
    wandb_run_name = '_'.join(wandb_run_name_list)

    if parse.no_wandb:
        wandb.init(config=config,
                   project='noisylabel',
                   entity='goguryeo',
                   name=wandb_run_name)

    # By default, pytorch utilizes multi-threaded cpu
    # Set to handle whole procedures on a single core
    torch.set_num_threads(1)

    logger = config.get_logger('train')

    # Set seed for reproducibility
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])
    torch.backends.cudnn.deterministic = True
    np.random.seed(config['seed'])

    data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=config['data_loader']['args']['batch_size'],
        shuffle=False
        if parse.distillation else config['data_loader']['args']['shuffle'],
        #         validation_split=config['data_loader']['args']['validation_split'],
        validation_split=0.0,
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory'])

    # valid_data_loader = data_loader.split_validation()

    valid_data_loader = None

    # test_data_loader = None

    test_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2).split_validation()

    # build model architecture, then print to console
    model = config.initialize('arch', module_arch)

    if parse.no_wandb:
        wandb.watch(model)

    if parse.distillation:
        teacher = config.initialize('arch', module_arch)
        teacher.load_state_dict(
            torch.load('./checkpoint/' + parse.load_name)['state_dict'])
        if not parse.reinit:
            model.load_state_dict(
                torch.load('./checkpoint/' + parse.load_name)['state_dict'])
        for params in teacher.parameters():
            params.requires_grad = False
        if parse.distill_mode == 'eigen':
            tea_label_list, tea_out_list = get_out_list(teacher, data_loader)
            singular_dict, v_ortho_dict = get_singular_value_vector(
                tea_label_list, tea_out_list)

            for key in v_ortho_dict.keys():
                v_ortho_dict[key] = v_ortho_dict[key].cuda()

            teacher_idx = singular_label(v_ortho_dict, tea_out_list,
                                         tea_label_list)
        else:
            teacher_idx = get_out_list(teacher, data_loader)

        data_loader = getattr(module_data, config['data_loader']['type'])(
            config['data_loader']['args']['data_dir'],
            batch_size=config['data_loader']['args']['batch_size'],
            shuffle=config['data_loader']['args']['shuffle'],
            #         validation_split=config['data_loader']['args']['validation_split'],
            validation_split=0.0,
            num_batches=config['data_loader']['args']['num_batches'],
            training=True,
            num_workers=config['data_loader']['args']['num_workers'],
            pin_memory=config['data_loader']['args']['pin_memory'],
            teacher_idx=teacher_idx)
    else:
        teacher = None

    # get function handles of loss and metrics
    logger.info(config.config)
    if hasattr(data_loader.dataset, 'num_raw_example'):
        num_examp = data_loader.dataset.num_raw_example
    else:
        num_examp = len(data_loader.dataset)

    if config['train_loss']['type'] == 'ELR_GTLoss':
        train_loss = getattr(module_loss, 'ELR_GTLoss')(
            num_examp=num_examp,
            num_classes=config['num_classes'],
            beta=config['train_loss']['args']['beta'])
    elif config['train_loss']['type'] == 'SCE_GTLoss':
        train_loss = getattr(module_loss, 'SCE_GTLoss')(
            alpha=config['train_loss']['args']['alpha'],
            beta=config['train_loss']['args']['beta'],
            num_classes=config['num_classes'])
    elif config['train_loss']['type'] == 'GCE_GTLoss':
        train_loss = getattr(module_loss, 'GCE_GTLoss')(
            q=config['train_loss']['args']['q'],
            k=config['train_loss']['args']['k'],
            trainset_size=num_examp,
            truncated=config['train_loss']['args']['truncated'])
    elif config['train_loss']['type'] == 'CCE_GTLoss':
        train_loss = getattr(module_loss, 'CCE_GTLoss')()

    val_loss = getattr(module_loss, config['val_loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = config.initialize('optimizer', torch.optim,
                                  [{
                                      'params': trainable_params
                                  }])

    lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler,
                                     optimizer)

    if config['train_loss']['type'] == 'ELR_GTLoss':
        trainer = GroundTruthTrainer(model,
                                     train_loss,
                                     metrics,
                                     optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode=parse.mode,
                                     entropy=parse.entropy,
                                     threshold=parse.threshold)
    elif config['train_loss']['type'] == 'SCE_GTLoss':
        trainer = GroundTruthTrainer(model,
                                     train_loss,
                                     metrics,
                                     optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode=parse.mode,
                                     entropy=parse.entropy,
                                     threshold=parse.threshold)
    elif config['train_loss']['type'] == 'GCE_GTLoss':
        trainer = GroundTruthTrainer(model,
                                     train_loss,
                                     metrics,
                                     optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode=parse.mode,
                                     entropy=parse.entropy,
                                     threshold=parse.threshold)
    elif config['train_loss']['type'] == 'CCE_GTLoss':
        trainer = GroundTruthTrainer(model,
                                     train_loss,
                                     metrics,
                                     optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode=parse.mode,
                                     entropy=parse.entropy,
                                     threshold=parse.threshold)

    trainer.train()

    logger = config.get_logger('trainer', config['trainer']['verbosity'])
    cfg_trainer = config['trainer']
コード例 #2
0
def trainClothing1m(parse, config: ConfigParser):
    # implementation for WandB
    wandb_run_name_list = wandbRunlist(config, parse)
    
    if parse.no_wandb:
        wandb.init(config=config, project='noisylabel', entity='goguryeo', name=wandb_run_name)
    
    # By default, pytorch utilizes multi-threaded cpu
    # Set to handle whole procedures on a single core
    numthread = torch.get_num_threads()
    torch.set_num_threads(numthread)
    logger = config.get_logger('train')
    
    # Set seed for reproducibility
    fix_seed(config['seed'])
    
    data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size= config['data_loader']['args']['batch_size'],
        shuffle=False if parse.distillation else config['data_loader']['args']['shuffle'] ,
        validation_split=0.0,
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory']
    )

    # valid_data_loader = data_loader.split_validation()

    valid_data_loader = None
    
    # test_data_loader = None

    test_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=0
    ).split_validation()

    print('---------')
    # build model architecture, then print to console
#     model = config.initialize('arch', module_arch)
    model = getattr(module_arch, 'resnet50')(pretrained=True,
                                             num_classes=config["num_classes"])
    
    if parse.no_wandb: wandb.watch(model)
    
    if parse.distillation:
        teacher = config.initialize('arch', module_arch)
        
        data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size= config['data_loader']['args']['batch_size'],
        shuffle=config['data_loader']['args']['shuffle'],
#         validation_split=config['data_loader']['args']['validation_split'],
        validation_split=0.0,
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory'],
        teacher_idx = extract_cleanidx(teacher, data_loader, parse))
    else:
        teacher = None

    # get function handles of loss and metrics
    logger.info(config.config)
    if hasattr(data_loader.dataset, 'num_raw_example'):
        num_examp = data_loader.dataset.num_raw_example
    else:
        num_examp = len(data_loader.dataset)
    
    if config['train_loss']['type'] == 'ELRLoss':
        train_loss = getattr(module_loss, 'ELRLoss')(num_examp=num_examp,
                                                     num_classes=config['num_classes'],
                                                     beta=config['train_loss']['args']['beta'])
    elif config['train_loss']['type'] == 'SCELoss':
        train_loss = getattr(module_loss, 'SCELoss')(alpha=config['train_loss']['args']['alpha'],
                                                     beta=config['train_loss']['args']['beta'],
                                                     num_classes=config['num_classes'])
    elif config['train_loss']['type'] == 'GCELoss':
        train_loss = getattr(module_loss, 'GCELoss')(q=config['train_loss']['args']['q'],
                                                     k=config['train_loss']['args']['k'],
                                                     trainset_size=num_examp,
                                                     truncated=config['train_loss']['args']['truncated'])
    elif config['train_loss']['type'] == 'GTLoss':
        train_loss = getattr(module_loss, 'GTLoss')()
        
    else:
        train_loss = getattr(module_loss, 'CCELoss')()
        
    print (train_loss)

        
    val_loss = getattr(module_loss, config['val_loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = config.initialize('optimizer', torch.optim, [{'params': trainable_params}])

    lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    if config['train_loss']['type'] == 'ELRLoss':
        trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     parse=parse,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode = parse.mode,
                                     entropy = parse.entropy,
                                     threshold = parse.threshold
                                )
    elif config['train_loss']['type'] == 'SCELoss':
        trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     parse=parse,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode = parse.mode,
                                     entropy = parse.entropy,
                                     threshold = parse.threshold                                  
                                )
    elif config['train_loss']['type'] == 'GCELoss':
        if config['train_loss']['args']['truncated'] == False:
            trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     parse=parse,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode = parse.mode,
                                     entropy = parse.entropy,
                                     threshold = parse.threshold
                                    )
        elif config['train_loss']['args']['truncated'] == True:
            trainer= RealDatasetTrainer(model, train_loss, metrics, optimizer,
                                      config=config,
                                      data_loader=data_loader,
                                      parse=parse,
                                      teacher=teacher,
                                      valid_data_loader=valid_data_loader,
                                      test_data_loader=test_data_loader,
                                      lr_scheduler=lr_scheduler,
                                      val_criterion=val_loss,
                                      mode = parse.mode,
                                      entropy = parse.entropy,
                                      threshold = parse.threshold
                                     )
    else:
        trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer,
                                     config=config,
                                     data_loader=data_loader,
                                     parse=parse,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=lr_scheduler,
                                     val_criterion=val_loss,
                                     mode = parse.mode,
                                     entropy = parse.entropy,
                                     threshold = parse.threshold
                                    )

    trainer.train()
    
    logger = config.get_logger('trainer', config['trainer']['verbosity'])
    cfg_trainer = config['trainer']
コード例 #3
0
def coteachingtrain(parse, config: ConfigParser):
    # implementation for WandB
    wandb_run_name_list = wandbRunlist(config, parse)

    if parse.no_wandb:
        wandb.init(config=config,
                   project='noisylabel',
                   entity='goguryeo',
                   name=wandb_run_name)

    # By default, pytorch utilizes multi-threaded cpu
    numthread = torch.get_num_threads()
    torch.set_num_threads(numthread)
    logger = config.get_logger('train')

    # Set seed for reproducibility
    fix_seed(config['seed'])

    data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=config['data_loader']['args']['batch_size'],
        shuffle=False
        if parse.distillation else config['data_loader']['args']['shuffle'],
        validation_split=0.0,
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory'],
        seed=parse.dataseed  # parse.seed
    )

    valid_data_loader = None
    test_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2).split_validation()

    # build model architecture, then print to console
    model = config.initialize('arch', module_arch)

    if parse.no_wandb: wandb.watch(model)

    if parse.distillation:
        teacher = config.initialize('teacher_arch', module_arch)

        data_loader = getattr(module_data, config['data_loader']['type'])(
            config['data_loader']['args']['data_dir'],
            batch_size=config['data_loader']['args']['batch_size'],
            shuffle=config['data_loader']['args']['shuffle'],
            validation_split=0.0,
            num_batches=config['data_loader']['args']['num_batches'],
            training=True,
            num_workers=config['data_loader']['args']['num_workers'],
            pin_memory=config['data_loader']['args']['pin_memory'],
            seed=parse.dataseed,
            teacher_idx=extract_cleanidx(teacher, data_loader, parse))
    else:
        teacher = None

    # get function handles of loss and metrics
    logger.info(config.config)
    if hasattr(data_loader.dataset, 'num_raw_example'):
        num_examp = data_loader.dataset.num_raw_example
    else:
        num_examp = len(data_loader.dataset)

    # F-coteaching
    if config['train_loss']['type'] == 'CCELoss':
        train_loss = getattr(module_loss, 'CCELoss')()

    # coteaching
    elif config['train_loss']['type'] == 'CoteachingLoss':
        train_loss = getattr(module_loss, 'CoteachingLoss')(
            forget_rate=config['trainer']['percent'],
            num_gradual=int(config['train_loss']['args']['num_gradual']),
            n_epoch=config['trainer']['epochs'])

    # coteaching_plus
    elif config['train_loss']['type'] == 'CoteachingPlusLoss':
        train_loss = getattr(module_loss, 'CoteachingPlusLoss')(
            forget_rate=config['trainer']['percent'],
            num_gradual=int(config['train_loss']['args']['num_gradual']),
            n_epoch=config['trainer']['epochs'])

    val_loss = getattr(module_loss, config['val_loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # F-coteaching
    if config['train_loss']['type'] == 'CCELoss':

        model = config.initialize('arch', module_arch)
        trainer = FCoteachingTrainer(model,
                                     train_loss,
                                     metrics,
                                     None,
                                     config=config,
                                     data_loader=data_loader,
                                     parse=parse,
                                     teacher=teacher,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader,
                                     lr_scheduler=None,
                                     val_criterion=val_loss,
                                     mode=parse.mode,
                                     entropy=parse.entropy,
                                     threshold=parse.threshold)

    # coteaching
    elif config['train_loss']['type'] == 'CoteachingLoss':

        model1, model2 = config.initialize('arch',
                                           module_arch), config.initialize(
                                               'arch', module_arch)

        trainable_params1 = filter(lambda p: p.requires_grad,
                                   model1.parameters())
        trainable_params2 = filter(lambda p: p.requires_grad,
                                   model2.parameters())

        optimizer1 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params1
                                       }])
        optimizer2 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params2
                                       }])

        if isinstance(optimizer1, torch.optim.Adam):
            lr_scheduler = None
        else:
            lr_scheduler1 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer1)
            lr_scheduler2 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer2)
            lr_scheduler = [lr_scheduler1, lr_scheduler2]

#         print ('$$$$$$$$$$$$$$$')
#         print (config['optimizer'])

        trainer = CoteachingTrainer(
            [model1, model2],
            train_loss,
            metrics, [optimizer1, optimizer2],
            config=config,
            data_loader=data_loader,
            parse=parse,
            teacher=teacher,
            valid_data_loader=valid_data_loader,
            test_data_loader=test_data_loader,
            lr_scheduler=lr_scheduler,
            val_criterion=val_loss,
            mode=parse.mode,
            entropy=parse.entropy,
            threshold=parse.threshold,
            epoch_decay_start=config['trainer']['epoch_decay_start'],
            n_epoch=config['trainer']['epochs'],
            learning_rate=config['optimizer']['args']['lr'])

    elif config['train_loss']['type'] == 'CoteachingPlusLoss':

        model1, model2 = config.initialize('arch',
                                           module_arch), config.initialize(
                                               'arch', module_arch)

        trainable_params1 = filter(lambda p: p.requires_grad,
                                   model1.parameters())
        trainable_params2 = filter(lambda p: p.requires_grad,
                                   model2.parameters())

        optimizer1 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params1
                                       }])
        optimizer2 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params2
                                       }])

        if isinstance(optimizer1, torch.optim.Adam):
            lr_scheduler = None
        else:
            lr_scheduler1 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer1)
            lr_scheduler2 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer2)
            lr_scheduler = [lr_scheduler1, lr_scheduler2]

        trainer = CoteachingTrainer(
            [model1, model2],
            train_loss,
            metrics, [optimizer1, optimizer2],
            config=config,
            data_loader=data_loader,
            parse=parse,
            teacher=teacher,
            valid_data_loader=valid_data_loader,
            test_data_loader=test_data_loader,
            lr_scheduler=lr_scheduler,
            val_criterion=val_loss,
            mode=parse.mode,
            entropy=parse.entropy,
            threshold=parse.threshold,
            epoch_decay_start=config['trainer']['epoch_decay_start'],
            n_epoch=config['trainer']['epochs'],
            learning_rate=config['optimizer']['args']['lr'])

    elif config['train_loss']['type'] == 'CoteachingDistillLoss':

        model1, model2 = config.initialize('arch',
                                           module_arch), config.initialize(
                                               'arch', module_arch)

        trainable_params1 = filter(lambda p: p.requires_grad,
                                   model1.parameters())
        trainable_params2 = filter(lambda p: p.requires_grad,
                                   model2.parameters())

        optimizer1 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params1
                                       }])
        optimizer2 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params2
                                       }])

        if isinstance(optimizer1, torch.optim.Adam):
            lr_scheduler = None
        else:
            lr_scheduler1 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer1)
            lr_scheduler2 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer2)
            lr_scheduler = [lr_scheduler1, lr_scheduler2]

        trainer = CoteachingTrainer(
            [model1, model2],
            train_loss,
            metrics, [optimizer1, optimizer2],
            config=config,
            data_loader=data_loader,
            parse=parse,
            teacher=teacher,
            valid_data_loader=valid_data_loader,
            test_data_loader=test_data_loader,
            lr_scheduler=lr_scheduler,
            val_criterion=val_loss,
            mode=parse.mode,
            entropy=parse.entropy,
            threshold=parse.threshold,
            epoch_decay_start=config['trainer']['epoch_decay_start'],
            n_epoch=config['trainer']['epochs'],
            learning_rate=config['optimizer']['args']['lr'])

    elif config['train_loss']['type'] == 'CoteachingPlusDistillLoss':

        model1, model2 = config.initialize('arch',
                                           module_arch), config.initialize(
                                               'arch', module_arch)

        trainable_params1 = filter(lambda p: p.requires_grad,
                                   model1.parameters())
        trainable_params2 = filter(lambda p: p.requires_grad,
                                   model2.parameters())

        optimizer1 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params1
                                       }])
        optimizer2 = config.initialize('optimizer', torch.optim,
                                       [{
                                           'params': trainable_params2
                                       }])

        if isinstance(optimizer1, torch.optim.Adam):
            lr_scheduler = None
        else:
            lr_scheduler1 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer1)
            lr_scheduler2 = config.initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              optimizer2)
            lr_scheduler = [lr_scheduler1, lr_scheduler2]

        trainer = CoteachingTrainer(
            [model1, model2],
            train_loss,
            metrics, [optimizer1, optimizer2],
            config=config,
            data_loader=data_loader,
            parse=parse,
            teacher=teacher,
            valid_data_loader=valid_data_loader,
            test_data_loader=test_data_loader,
            lr_scheduler=lr_scheduler,
            val_criterion=val_loss,
            mode=parse.mode,
            entropy=parse.entropy,
            threshold=parse.threshold,
            epoch_decay_start=config['trainer']['epoch_decay_start'],
            n_epoch=config['trainer']['epochs'],
            learning_rate=config['optimizer']['args']['lr'])

    trainer.train()

    logger = config.get_logger('trainer', config['trainer']['verbosity'])
    cfg_trainer = config['trainer']